]> Gitweb @ Texas Instruments - Open Source Git Repositories - git.TI.com/gitweb - processor-sdk/kaldi.git/commitdiff
trunk:added PowerComponent in the p-norm nnet training recipe.
authorXiaohui Zhang <samuelzhang1104@gmail.com>
Mon, 23 Jun 2014 02:03:51 +0000 (02:03 +0000)
committerXiaohui Zhang <samuelzhang1104@gmail.com>
Mon, 23 Jun 2014 02:03:51 +0000 (02:03 +0000)
git-svn-id: https://svn.code.sf.net/p/kaldi/code/trunk@4074 5e6a8d80-dfce-4ca6-a32a-6e07a63d50c8

15 files changed:
egs/wsj/s5/steps/nnet2/train_pnorm.sh
src/cudamatrix/cu-kernels-ansi.h
src/cudamatrix/cu-kernels.cu
src/cudamatrix/cu-kernels.h
src/cudamatrix/cu-matrix-test.cc
src/cudamatrix/cu-matrix.cc
src/cudamatrix/cu-matrix.h
src/matrix/kaldi-matrix.cc
src/matrix/kaldi-matrix.h
src/matrix/kaldi-vector.cc
src/matrix/kaldi-vector.h
src/matrix/matrix-lib-test.cc
src/nnet2/nnet-component-test.cc
src/nnet2/nnet-component.cc
src/nnet2/nnet-component.h

index 5d076c3a336a74ee058a60b571a58bbb919b803f..d08c2ef55b6415b8f3698dfc771aa73e86188d96 100755 (executable)
@@ -27,6 +27,7 @@ softmax_learning_rate_factor=1.0 # In the default setting keep the same learning
 combine_regularizer=1.0e-14 # Small regularizer so that parameters won't go crazy.
 pnorm_input_dim=3000 
 pnorm_output_dim=300
+first_component_power=1.0  # could set this to 0.5, often seems to improve results.
 p=2
 minibatch_size=128 # by default use a smallish minibatch size for neural net
                    # training; this controls instability which would otherwise
@@ -213,6 +214,11 @@ SpliceComponent input-dim=$ext_feat_dim left-context=$splice_width right-context
 FixedAffineComponent matrix=$lda_mat
 AffineComponentPreconditioned input-dim=$ext_lda_dim output-dim=$pnorm_input_dim alpha=$alpha max-change=$max_change learning-rate=$initial_learning_rate param-stddev=$stddev bias-stddev=$bias_stddev
 PnormComponent input-dim=$pnorm_input_dim output-dim=$pnorm_output_dim p=$p
+EOF
+  if [ $first_component_power != 1.0 ]; then
+    echo "PowerComponent dim=$pnorm_output_dim power=$first_component_power" >> $dir/nnet.config
+  fi
+  cat >>$dir/nnet.config <<EOF
 NormalizeComponent dim=$pnorm_output_dim
 AffineComponentPreconditioned input-dim=$pnorm_output_dim output-dim=$num_leaves alpha=$alpha max-change=$max_change learning-rate=$initial_learning_rate param-stddev=0 bias-stddev=0
 SoftmaxComponent dim=$num_leaves
index 001b1d1f58e81276b8c7d3ab90dfa3b519dff9b5..4d5eb6116a58ccc4065cbb51a49a5a1652e39ec1 100644 (file)
@@ -58,6 +58,7 @@ void cudaFD_copy_from_tp(dim3 Gr, dim3 Bl, float* A, const double* B, MatrixDim
 void cudaF_copy_col_from_vec(int Gr, int Bl, float* mat, const float* v, int col, MatrixDim d);
 void cudaF_apply_exp(dim3 Gr, dim3 Bl, float* mat, MatrixDim d);
 void cudaF_apply_pow(dim3 Gr, dim3 Bl, float* mat, float power, MatrixDim d);
+void cudaF_apply_pow_abs(dim3 Gr, dim3 Bl, float* mat, float power, bool include_sign,  MatrixDim d);
 void cudaF_apply_heaviside(dim3 Gr, dim3 Bl, float* mat, MatrixDim d);  
 void cudaF_apply_floor(dim3 Gr, dim3 Bl, float* mat, float floor_val, MatrixDim d);
 void cudaF_copy_cols(dim3 Gr, dim3 Bl, float* dst, const float* src, const MatrixIndexT_cuda* reorder, MatrixDim dst_dim, int src_stride);
@@ -187,6 +188,7 @@ void cudaDF_copy_from_tp(dim3 Gr, dim3 Bl, double* A, const float* B, MatrixDim
 void cudaD_copy_col_from_vec(int Gr, int Bl, double* mat, const double* v, int col, MatrixDim d);
 void cudaD_apply_exp(dim3 Gr, dim3 Bl, double* mat, MatrixDim d);
 void cudaD_apply_pow(dim3 Gr, dim3 Bl, double* mat, double power, MatrixDim d);
+void cudaD_apply_pow_abs(dim3 Gr, dim3 Bl, double* mat, double power, bool include_sign, MatrixDim d);
 void cudaD_apply_heaviside(dim3 Gr, dim3 Bl, double* mat, MatrixDim d);  
 void cudaD_apply_floor(dim3 Gr, dim3 Bl, double* mat, double floor_val, MatrixDim d);
 void cudaD_copy_cols(dim3 Gr, dim3 Bl, double* dst, const double* src, const MatrixIndexT_cuda* reorder, MatrixDim dst_dim, int src_stride);
index 032bf888e1d6a7648f47863fefa5b3112fc6d0cd..e9359281772e3a886c59cc7c8a5472aade3cb192 100644 (file)
@@ -387,7 +387,6 @@ static void _apply_log(Real* mat, MatrixDim d) {
     mat[index] = log(mat[index]);
 }
 
-
 template<typename Real>
 __global__
 static void _mul_elements(Real* mat, const Real* A, MatrixDim dst_d, int src_stride) {
@@ -1161,6 +1160,40 @@ static void _apply_pow(Real* mat, Real power, MatrixDim d) {
   }
 }
 
+template<typename Real>
+__global__
+static void _apply_pow_abs(Real* mat, Real power, bool include_sign, MatrixDim d) {
+  int i = blockIdx.x * blockDim.x + threadIdx.x;
+  int j = blockIdx.y * blockDim.y + threadIdx.y;
+  int index = i * d.stride + j;
+
+  if (i < d.rows && j < d.cols) {
+    if (include_sign == true && mat[index] < 0) {
+      if (power == 1.0) 
+        mat[index] = -std::abs(mat[index]);
+      if (power == 2.0) {
+        mat[index] = -mat[index] * mat[index];
+      } else if (power == 0.5) {
+        mat[index] = -sqrt(std::abs(mat[index]));
+      } else {
+        mat[index] = -pow(std::abs(mat[index]), power);
+      }
+    } else {
+      if (power == 1.0) 
+        mat[index] = std::abs(mat[index]);
+      if (power == 2.0) {
+        mat[index] = mat[index] * mat[index];
+      } else if (power == 0.5) {
+        mat[index] = sqrt(std::abs(mat[index]));
+      } else if (power < 0.0 && mat[index] == 0.0) {
+        mat[index] = 0.0;
+      } else {
+        mat[index] = pow(std::abs(mat[index]), power);
+      }
+    }
+  }
+}
+
 // Caution, here i/block{idx,dim}.x is the row index and j/block{idx,dim}.y is the col index.
 // this is for no reason, really, I just happened to prefer this
 // at the time. [dan]
@@ -1953,6 +1986,10 @@ void cudaF_apply_pow(dim3 Gr, dim3 Bl, float* mat, float power, MatrixDim d) {
   _apply_pow<<<Gr,Bl>>>(mat, power, d);
 }
 
+void cudaF_apply_pow_abs(dim3 Gr, dim3 Bl, float* mat, float power, bool include_sign, MatrixDim d) {
+  _apply_pow_abs<<<Gr,Bl>>>(mat, power, include_sign, d);
+}
+
 void cudaF_apply_heaviside(dim3 Gr, dim3 Bl, float* mat, MatrixDim d) {
   _apply_heaviside<<<Gr,Bl>>>(mat, d);
 
@@ -2372,6 +2409,10 @@ void cudaD_apply_pow(dim3 Gr, dim3 Bl, double* mat, double power, MatrixDim d) {
   _apply_pow<<<Gr,Bl>>>(mat, power, d);
 }
 
+void cudaD_apply_pow_abs(dim3 Gr, dim3 Bl, double* mat, double power, bool include_sign, MatrixDim d) {
+  _apply_pow_abs<<<Gr,Bl>>>(mat, power, include_sign, d);
+}
+
 void cudaD_apply_heaviside(dim3 Gr, dim3 Bl, double* mat, MatrixDim d) {
   _apply_heaviside<<<Gr,Bl>>>(mat, d);
 }
index e202076c00a7dec959e491704d0e6df56bdf0cf1..6a911f3b8c4d26f69d1380408543b51d5ce99de9 100644 (file)
@@ -85,6 +85,7 @@ inline void cuda_copy_from_mat_trans(dim3 Gr, dim3 Bl, double* mat_out, const fl
 inline void cuda_copy_col_from_vec(int Gr, int Bl, float* mat, const float* v, int col, MatrixDim d) { cudaF_copy_col_from_vec(Gr,Bl,mat,v,col,d); }
 inline void cuda_apply_exp(dim3 Gr, dim3 Bl, float* mat, MatrixDim d) { cudaF_apply_exp(Gr,Bl,mat,d); }
 inline void cuda_apply_pow(dim3 Gr, dim3 Bl, float* mat, float power, MatrixDim dim) { cudaF_apply_pow(Gr,Bl,mat,power,dim); }
+inline void cuda_apply_pow_abs(dim3 Gr, dim3 Bl, float* mat, float power, bool include_sign, MatrixDim dim) { cudaF_apply_pow_abs(Gr,Bl,mat,power,include_sign, dim); }
 inline void cuda_apply_heaviside(dim3 Gr, dim3 Bl, float* mat, MatrixDim dim) { cudaF_apply_heaviside(Gr,Bl,mat,dim); }
 inline void cuda_apply_floor(dim3 Gr, dim3 Bl, float* mat, float floor_val, MatrixDim dim) { cudaF_apply_floor(Gr,Bl,mat,floor_val,dim); }
 inline void cuda_apply_ceiling(dim3 Gr, dim3 Bl, float* mat, float ceiling_val, MatrixDim dim) { cudaF_apply_ceiling(Gr,Bl,mat,ceiling_val,dim); }
@@ -254,6 +255,7 @@ inline void cuda_copy_from_tp(dim3 Gr, dim3 Bl, double* A, const float* B, Matri
 inline void cuda_copy_col_from_vec(int Gr, int Bl, double* mat, const double* v, int col, MatrixDim d) { cudaD_copy_col_from_vec(Gr,Bl,mat,v,col,d); }
 inline void cuda_apply_exp(dim3 Gr, dim3 Bl, double* mat, MatrixDim d) { cudaD_apply_exp(Gr,Bl,mat,d); }
 inline void cuda_apply_pow(dim3 Gr, dim3 Bl, double* mat, double power, MatrixDim dim) { cudaD_apply_pow(Gr,Bl,mat,power,dim); }
+inline void cuda_apply_pow_abs(dim3 Gr, dim3 Bl, double* mat, double power, bool include_sign, MatrixDim dim) { cudaD_apply_pow_abs(Gr,Bl,mat,power,include_sign,dim); }
 inline void cuda_apply_heaviside(dim3 Gr, dim3 Bl, double* mat, MatrixDim dim) { cudaD_apply_heaviside(Gr,Bl,mat,dim); }
 inline void cuda_apply_floor(dim3 Gr, dim3 Bl, double* mat, double floor_val, MatrixDim dim) { cudaD_apply_floor(Gr,Bl,mat,floor_val,dim); }
 inline void cuda_apply_ceiling(dim3 Gr, dim3 Bl, double* mat, double ceiling_val, MatrixDim dim) { cudaD_apply_ceiling(Gr,Bl,mat,ceiling_val,dim); }
index e3b5be0e5f517b5040da11d4ed13e7eeef75f64c..e410bc651afe49f41cd77edd6f2264ec1992c74a 100644 (file)
@@ -305,6 +305,27 @@ static void UnitTestCuMatrixApplyPow() {
   }
 }
 
+template<typename Real> 
+static void UnitTestCuMatrixApplyPowAbs() {
+
+  for (int32 i = 0; i < 2; i++) {
+    BaseFloat pow = 0.5 * (rand() % 6);
+    
+    Matrix<Real> H(10 + rand() % 60, 10 + rand() % 20);
+    H.SetRandn();
+    H.Row(0).Set(0.0);
+    if (i == 2) { Matrix<Real> tmp(H, kTrans); H = tmp; }
+    
+    CuMatrix<Real> cH(H);
+
+    cH.ApplyPowAbs(pow, true);
+
+    H.ApplyPowAbs(pow, true);
+    Matrix<Real> H2(cH);
+    AssertEqual(H, H2);
+  }
+}
+
 
 template<typename Real>
 static void UnitTestCuMatrixCopyRowsFromVec() {
@@ -509,7 +530,6 @@ static void UnitTestCuMatrixApplyHeaviside() {
 }
 
 
-
 template<typename Real> 
 static void UnitTestCuMatrixMulElements() {
   for (int32 i = 0; i < 2; i++) {
@@ -1923,6 +1943,7 @@ template<typename Real> void CudaMatrixUnitTest() {
   UnitTestCuMatrixSigmoid<Real>();
   UnitTestCuMatrixSoftHinge<Real>();
   UnitTestCuMatrixApplyPow<Real>(); 
+  UnitTestCuMatrixApplyPowAbs<Real>(); 
   UnitTestCuMatrixSet<Real>();
   UnitTestCuMatrixAdd<Real>();
   UnitTestCuMatrixApplyFloor<Real>();
index 0724d1c2e1b77d1a1c152a164fd6868d3df3ab96..a1f3278139e31fd7924be2fc261fe2dd8a00ecea 100644 (file)
@@ -602,8 +602,6 @@ void CuMatrixBase<Real>::ApplyLog() {
   }
 }
 
-
-
 template<typename Real>
 void CuMatrixBase<Real>::MulElements(const CuMatrixBase<Real>& A) {
   #if HAVE_CUDA == 1
@@ -1632,6 +1630,25 @@ void CuMatrixBase<Real>::ApplyPow(Real power) {
   }
 }
 
+template<typename Real>
+void CuMatrixBase<Real>::ApplyPowAbs(Real power, bool include_sign) {
+#if HAVE_CUDA == 1
+  if (CuDevice::Instantiate().Enabled()) {
+    Timer tim;
+    dim3 dimBlock(CU2DBLOCK, CU2DBLOCK);
+    dim3 dimGrid(n_blocks(NumRows(), CU2DBLOCK),
+                 n_blocks(NumCols(), CU2DBLOCK));
+    
+    cuda_apply_pow_abs(dimGrid, dimBlock, data_, power, include_sign, Dim());
+    CU_SAFE_CALL(cudaGetLastError());
+    CuDevice::Instantiate().AccuProfile(__func__, tim.Elapsed());
+  } else
+#endif
+  {
+    Mat().ApplyPowAbs(power, include_sign);
+  }
+}
+
 template<typename Real>
 void CuMatrixBase<Real>::ApplyHeaviside() {
 #if HAVE_CUDA == 1
index 2f2fecf430348975f4125e6c6fb189d7f90b01e0..d873ce6ba00e28938bb4b0e1d2a8f1360d170a58 100644 (file)
@@ -233,6 +233,13 @@ class CuMatrixBase {
   void SymInvertPosDef(); 
 
   void ApplyPow(Real power);
+  ///< Apply power to the absolute value of each element. 
+  ///< If inlude_sign is true, the result will be multiplied with 
+  ///< the sign of the input value.
+  ///< If the power is negative and the input to the power is zero,
+  ///< The output will be set zero. If include_sign is true, it will
+  ///< multiply the result by the sign of the input.
+  void ApplyPowAbs(Real power, bool include_sign=false);
   void ApplyHeaviside(); ///< For each element, sets x = (x > 0 ? 1.0 : 0.0)
   void ApplyFloor(Real floor_val);
   void ApplyCeiling(Real ceiling_val);
index 046fa03f4a6340c56dbce16ef780322882851e4a..ea32b73cde124ae2981c4390ca4a2491e9553aac 100644 (file)
@@ -1847,6 +1847,13 @@ void MatrixBase<Real>::ApplyPow(Real power) {
   }
 }
 
+template<typename Real>
+void MatrixBase<Real>::ApplyPowAbs(Real power, bool include_sign) {
+  for (MatrixIndexT i = 0; i < num_rows_; i++) {
+    Row(i).ApplyPowAbs(power, include_sign);
+  }
+}
+
 template<typename Real>
 void MatrixBase<Real>::ApplyHeaviside() {
   MatrixIndexT num_rows = num_rows_, num_cols = num_cols_;
index 601fdcb2ded09f7907616bfa7bd22a5d49e5711e..71db6b16f7487d704108baba51e80b6c02986c34 100644 (file)
@@ -298,6 +298,12 @@ class MatrixBase {
   /// Applies power to all matrix elements
   void ApplyPow(Real power);
 
+  /// Apply power to the absolute value of each element. 
+  /// Include the sign of the input element if include_sign == true.
+  /// If the power is negative and the input to the power is zero,
+  /// The output will be set zero.
+  void ApplyPowAbs(Real power, bool include_sign=false);
+  
   /// Applies the Heaviside step function (x > 0 ? 1 : 0) to all matrix elements
   /// Note: in general you can make different choices for x = 0, but for now
   /// please leave it as it (i.e. returning zero) because it affects the
index 053cfc61025563e712a354730d332f3209b6b840..1dff4210a93d92d9d150605975fb75cbe431a6e8 100644 (file)
@@ -451,6 +451,40 @@ void VectorBase<Real>::ApplyPow(Real power) {
 }
 #endif
 
+// takes absolute value of the elements to a power.
+// Throws exception if could not (but only for power != 1 and power != 2).
+template<typename Real>
+void VectorBase<Real>::ApplyPowAbs(Real power, bool include_sign) {
+  if (power == 1.0) 
+    for (MatrixIndexT i = 0; i < dim_; i++)
+      data_[i] = (include_sign && data_[i] < 0 ? -1 : 1) * std::abs(data_[i]);
+  if (power == 2.0) {
+    for (MatrixIndexT i = 0; i < dim_; i++)
+      data_[i] = (include_sign && data_[i] < 0 ? -1 : 1) * data_[i] * data_[i];
+  } else if (power == 0.5) {
+    for (MatrixIndexT i = 0; i < dim_; i++) {
+      data_[i] = (include_sign && data_[i] < 0 ? -1 : 1) * std::sqrt(std::abs(data_[i]));
+    }
+  } else if (power < 0.0) {
+    for (MatrixIndexT i = 0; i < dim_; i++) {
+      data_[i] = (data_[i] == 0.0 ? 0.0 : pow(std::abs(data_[i]), power));
+      data_[i] *= (include_sign && data_[i] < 0 ? -1 : 1);
+      if (data_[i] == HUGE_VAL) {  // HUGE_VAL is what errno returns on error.
+        KALDI_ERR << "Could not raise element "  << i << "to power "
+                  << power << ": returned value = " << data_[i];
+      }
+    }
+  } else {
+    for (MatrixIndexT i = 0; i < dim_; i++) {
+      data_[i] = (include_sign && data_[i] < 0 ? -1 : 1) * pow(std::abs(data_[i]), power);
+      if (data_[i] == HUGE_VAL) {  // HUGE_VAL is what errno returns on error.
+        KALDI_ERR << "Could not raise element "  << i << "to power "
+                  << power << ": returned value = " << data_[i];
+      }
+    }
+  }
+}
+
 // Computes the p-th norm. Throws exception if could not.
 template<typename Real>
 Real VectorBase<Real>::Norm(Real p) const {
index 84e35e576b7b8c84e44877e95e4900be44a6106f..a9abf15962b496fba97ea5a2bf3b47203045ac2c 100644 (file)
@@ -151,6 +151,11 @@ class VectorBase {
   /// Take all  elements of vector to a power.
   void ApplyPow(Real power);
 
+  /// Take the absolute value of all elements of a vector to a power.
+  /// Include the sign of the input element if include_sign == true.
+  /// If power is negative and the input value is zero, the output is set zero.
+  void ApplyPowAbs(Real power, bool include_sign=false);
+  
   /// Compute the p-th norm of the vector.
   Real Norm(Real p) const;
   
index 44e4f526c1a7a0a5ec59d7b9e5925ad3bdef1f17..f604d31bbef113afe97ab57a65cfda329676ce46 100644 (file)
@@ -612,6 +612,7 @@ static void UnitTestVectorMin() {
 
 template<typename Real>  
 static void UnitTestReplaceValue(){
+  // for vector
   MatrixIndexT dim = 10 + rand() % 2;
   Real orig = 0.1 * (rand() % 100), changed = 0.1 * (rand() % 50);
   Vector<Real> V(dim);
@@ -947,6 +948,23 @@ template<typename Real> static void UnitTestPower() {
   }
 }
 
+template<typename Real> static void UnitTestPowerAbs() {
+  for (MatrixIndexT iter = 0;iter < 5;iter++) {
+    MatrixIndexT dimV = 10 + rand() % 10;
+    Vector<Real> V(dimV), V1(dimV), V2(dimV);
+    InitRand(&V);
+    V1.AddVecVec(1.0, V, V, 0.0);  // V1:=V.*V.
+    V2.CopyFromVec(V1);
+    KALDI_LOG << V1;
+    V2.ApplyPowAbs(0.5);
+    KALDI_LOG << V2;
+    V2.ApplyPowAbs(2.0);
+    KALDI_LOG << V2;
+    AssertEqual(V1, V2);
+  }
+}
+
+
 template<typename Real> static void UnitTestHeaviside() {
   for (MatrixIndexT iter = 0;iter < 5;iter++) {
     MatrixIndexT dimM = 10 + rand() % 10, dimN = 10 + rand() % 10;
@@ -4149,6 +4167,7 @@ template<typename Real> static void MatrixUnitTest(bool full_test) {
   UnitTestDotprod<Real>();
   // UnitTestSvdVariants<Real>();
   UnitTestPower<Real>();
+  UnitTestPowerAbs<Real>();
   UnitTestHeaviside<Real>();
   UnitTestCopySp<Real>();
   UnitTestDeterminant<Real>();
@@ -4209,7 +4228,7 @@ template<typename Real> static void MatrixUnitTest(bool full_test) {
   UnitTestTp2<Real>();
   UnitTestAddDiagMat2<Real>();
   UnitTestAddDiagMatMat<Real>();
-  UnitTestOrthogonalizeRows<Real>();
//  UnitTestOrthogonalizeRows<Real>();
   UnitTestTopEigs<Real>();
   UnitTestRandCategorical<Real>();
   UnitTestTridiag<Real>();
index 36ac370df2baa896345be57660312a048335ae53..04c3afbd1e79a9149257c2481bc48738fcbb2890 100644 (file)
@@ -801,6 +801,7 @@ int main() {
     for (int32 i = 0; i < 3; i++) {
       UnitTestGenericComponent<SigmoidComponent>();
       UnitTestGenericComponent<TanhComponent>();
+      UnitTestGenericComponent<PowerComponent>("power=1.5");
       UnitTestGenericComponent<PermuteComponent>();
       UnitTestGenericComponent<SoftmaxComponent>();
       UnitTestGenericComponent<RectifiedLinearComponent>();
index 3fc18640a3729273952c9dd9eb728ee9511e00bd..90102490b828098246838e910f636c82bdf8c9ea 100644 (file)
@@ -50,6 +50,8 @@ Component* Component::NewComponentOfType(const std::string &component_type) {
     ans = new SigmoidComponent();
   } else if (component_type == "TanhComponent") {
     ans = new TanhComponent();
+  } else if (component_type == "PowerComponent") {
+    ans = new PowerComponent();
   } else if (component_type == "SoftmaxComponent") {
     ans = new SoftmaxComponent();
   } else if (component_type == "RectifiedLinearComponent") {
@@ -788,6 +790,78 @@ void TanhComponent::Backprop(const CuMatrixBase<BaseFloat> &, // in_value
   in_deriv->MulElements(out_deriv);
 }  
 
+void PowerComponent::Init(int32 dim, BaseFloat power) {
+  dim_ = dim;
+  power_ = power;
+  KALDI_ASSERT(dim > 0 && power >= 0);
+}
+
+void PowerComponent::InitFromString(std::string args) {
+  std::string orig_args(args);
+  int32 dim;
+  BaseFloat power = 2.0;
+  ParseFromString("power", &args, &power); // Optional.
+  // Accept either "dim" or "input-dim" to specify the input dim.
+  // "input-dim" is the canonical one; "dim" simplifies the testing code.
+  bool ok = (ParseFromString("dim", &args, &dim) ||
+             ParseFromString("input-dim", &args, &dim));
+  if (!ok || !args.empty() || dim <= 0)
+    KALDI_ERR << "Invalid initializer for layer of type "
+              << Type() << ": \"" << orig_args << "\"";
+  Init(dim, power);
+}
+
+void PowerComponent::Propagate(const CuMatrixBase<BaseFloat> &in,
+                              int32, // num_chunks
+                              CuMatrix<BaseFloat> *out) const {
+  // Apply power operation to each element of the input...
+  out->Resize(in.NumRows(), in.NumCols(), kUndefined);
+  out->CopyFromMat(in);
+  out->ApplyPowAbs(power_);
+}
+
+void PowerComponent::Backprop(const CuMatrixBase<BaseFloat> &in_value,
+                             const CuMatrixBase<BaseFloat> &out_value,
+                             const CuMatrixBase<BaseFloat> &out_deriv,
+                             int32, // num_chunks
+                             Component *to_update,
+                             CuMatrix<BaseFloat> *in_deriv) const {
+  in_deriv->Resize(in_value.NumRows(), in_value.NumCols());
+  // in scalar terms: in_deriv += p * in_value^(p-1) * out_deriv
+  in_deriv->CopyFromMat(in_value); 
+  in_deriv->ApplyPowAbs(power_ - 1.0, true);
+  in_deriv->Scale(power_);
+  in_deriv->MulElements(out_deriv);
+}
+
+void PowerComponent::Read(std::istream &is, bool binary) {
+  ExpectOneOrTwoTokens(is, binary, "<PowerComponent>", "<InputDim>");
+  ReadBasicType(is, binary, &dim_);
+  ExpectToken(is, binary, "<OutputDim>");
+  ReadBasicType(is, binary, &dim_);
+  ExpectToken(is, binary, "<Power>");
+  ReadBasicType(is, binary, &power_);
+  ExpectToken(is, binary, "</PowerComponent>");
+}
+
+void PowerComponent::Write(std::ostream &os, bool binary) const {
+  WriteToken(os, binary, "<PowerComponent>");
+  WriteToken(os, binary, "<InputDim>");
+  WriteBasicType(os, binary, dim_);
+  WriteToken(os, binary, "<OutputDim>");
+  WriteBasicType(os, binary, dim_);
+  WriteToken(os, binary, "<Power>");
+  WriteBasicType(os, binary, power_);
+  WriteToken(os, binary, "</PowerComponent>");
+}
+
+std::string PowerComponent::Info() const {
+  std::stringstream stream;
+  stream << Type() << ", dim = " << dim_
+        << ", power = " << power_;
+  return stream.str();
+}
+
 void RectifiedLinearComponent::Propagate(const CuMatrixBase<BaseFloat> &in,
                               int32, // num_chunks
                               CuMatrix<BaseFloat> *out) const {
index b6cfbb94df387d69bdafe39676f2ccc17f2da7a9..cdee2e5c9219b320a5dc51f0ef2c72491bd5f343 100644 (file)
@@ -466,6 +466,44 @@ class TanhComponent: public NonlinearComponent {
   TanhComponent &operator = (const TanhComponent &other); // Disallow.
 };
 
+/// Take the absoute values of an input vector to a power.
+/// The derivative for zero input will be treated as zero.
+class PowerComponent: public NonlinearComponent {
+ public:
+  void Init(int32 dim, BaseFloat power = 2);
+  explicit PowerComponent(int32 dim, BaseFloat power = 2) {
+    Init(dim, power);
+  }
+  PowerComponent(): dim_(0), power_(2) { }
+  virtual std::string Type() const { return "PowerComponent"; }
+  virtual void InitFromString(std::string args); 
+  virtual int32 InputDim() const { return dim_; }
+  virtual int32 OutputDim() const { return dim_; }
+  virtual void Propagate(const CuMatrixBase<BaseFloat> &in,
+                         int32 num_chunks,
+                         CuMatrix<BaseFloat> *out) const;
+  virtual void Backprop(const CuMatrixBase<BaseFloat> &in_value,
+                        const CuMatrixBase<BaseFloat> &, // out_value
+                        const CuMatrixBase<BaseFloat> &out_deriv,
+                        int32 num_chunks,
+                        Component *to_update, // may be identical to "this".
+                        CuMatrix<BaseFloat> *in_deriv) const;
+  virtual bool BackpropNeedsInput() const { return true; }
+  virtual bool BackpropNeedsOutput() const { return true; }
+  virtual Component* Copy() const { return new PowerComponent(dim_, power_); }
+  virtual void Read(std::istream &is, bool binary); // This Read function
+  // requires that the Component has the correct type.
+  
+  /// Write component to stream
+  virtual void Write(std::ostream &os, bool binary) const;
+
+  virtual std::string Info() const;
+
+ private:
+  int32 dim_;
+  BaseFloat power_;
+};
+
 class RectifiedLinearComponent: public NonlinearComponent {
  public:
   explicit RectifiedLinearComponent(int32 dim): NonlinearComponent(dim) { }