]> Gitweb @ Texas Instruments - Open Source Git Repositories - git.TI.com/gitweb - processor-sdk/kaldi.git/commitdiff
Added DropoutComponent in nnet3 (#1032)
authorpegahgh <pegahgh@gmail.com>
Tue, 4 Oct 2016 03:23:57 +0000 (23:23 -0400)
committerDaniel Povey <dpovey@gmail.com>
Tue, 4 Oct 2016 03:23:57 +0000 (23:23 -0400)
21 files changed:
egs/swbd/s5c/local/chain/tuning/run_tdnn_7d.sh
src/cudamatrix/cu-kernels-ansi.h
src/cudamatrix/cu-kernels.cu
src/cudamatrix/cu-kernels.h
src/cudamatrix/cu-matrix-test.cc
src/cudamatrix/cu-matrix.cc
src/cudamatrix/cu-matrix.h
src/matrix/kaldi-matrix.cc
src/matrix/kaldi-matrix.h
src/nnet2/nnet-component.cc
src/nnet3/nnet-component-itf.cc
src/nnet3/nnet-component-itf.h
src/nnet3/nnet-component-test.cc
src/nnet3/nnet-nnet.cc
src/nnet3/nnet-nnet.h
src/nnet3/nnet-simple-component.cc
src/nnet3/nnet-simple-component.h
src/nnet3/nnet-test-utils.cc
src/nnet3/nnet-utils.cc
src/nnet3/nnet-utils.h
src/nnet3bin/nnet3-copy.cc

index 0768bd786aeafd91d9af9d413c2d4fe0792a4088..5bcfea82ec3d58ca23bf682a6aa5f4cc471ec54e 100644 (file)
@@ -60,7 +60,6 @@ If you want to use GPUs (and have them), go to src/, and configure and make on a
 where "nvcc" is installed.
 EOF
 fi
-
 # The iVector-extraction and feature-dumping parts are the same as the standard
 # nnet3 setup, and you can skip them by setting "--stage 8" if you have already
 # run those things.
@@ -76,7 +75,6 @@ ali_dir=exp/tri4_ali_nodup$suffix
 treedir=exp/chain/tri5_7d_tree$suffix
 lang=data/lang_chain_2y
 
-
 # if we are using the speed-perturbed data we need to generate
 # alignments for it.
 local/nnet3/run_ivector_common.sh --stage $stage \
index 03dd91b793fc0f8e1a747eff9bbcd9ab1440157a..4642048989ea0292545417a4ccd8d037f4323214 100644 (file)
@@ -125,7 +125,7 @@ void cudaF_add_mat(dim3 Gr, dim3 Bl, float alpha, const float *src, float *dst,
 void cudaF_add_mat_blocks(dim3 Gr, dim3 Bl, float alpha, const float *src,
                           int32_cuda num_row_blocks, int32_cuda num_col_blocks,
                           float *dst, MatrixDim d, int src_stride, int A_trans);
-void cudaF_add_mat_mat_div_mat(dim3 Gr, dim3 Bl, const float *A, const float *B,
+void cudaF_set_mat_mat_div_mat(dim3 Gr, dim3 Bl, const float *A, const float *B,
                                const float *C, float *dst, MatrixDim d,
                                int stride_a, int stride_b, int stride_c);
 void cudaF_add_vec_to_cols(dim3 Gr, dim3 Bl, float alpha, const float *col,
@@ -391,7 +391,7 @@ void cudaD_add_mat_blocks(dim3 Gr, dim3 Bl, double alpha, const double *src,
                           int32_cuda num_row_blocks, int32_cuda num_col_blocks,
                           double *dst, MatrixDim d, int src_stride,
                           int A_trans);
-void cudaD_add_mat_mat_div_mat(dim3 Gr, dim3 Bl, const double *A,
+void cudaD_set_mat_mat_div_mat(dim3 Gr, dim3 Bl, const double *A,
                                const double *B, const double *C, double *dst,
                                MatrixDim d, int stride_a, int stride_b,
                                int stride_c);
index 6f098c87fb5af70da9ac25fe63de4ea618000e8a..ba8688fe2be3ba8de18bb726246f7242ff078911 100644 (file)
@@ -584,7 +584,7 @@ static void _add_mat_blocks_trans(Real alpha, const Real* src,
 
 template<typename Real>
 __global__
-static void _add_mat_mat_div_mat(const Real* A, const Real* B, const Real* C,
+static void _set_mat_mat_div_mat(const Real* A, const Real* B, const Real* C,
                                  Real* dst, MatrixDim d, int stride_a,
                                  int stride_b, int stride_c) {
   int32_cuda i = blockIdx.x * blockDim.x + threadIdx.x;
@@ -2863,10 +2863,10 @@ void cudaF_add_mat_blocks(dim3 Gr, dim3 Bl, float alpha, const float* src,
   }
 }
 
-void cudaF_add_mat_mat_div_mat(dim3 Gr, dim3 Bl, const float *A, const float *B,
+void cudaF_set_mat_mat_div_mat(dim3 Gr, dim3 Bl, const float *A, const float *B,
                                const float *C, float *dst, MatrixDim d,
                                int stride_a, int stride_b, int stride_c) {
-  _add_mat_mat_div_mat<<<Gr,Bl>>>(A,B,C,dst,d, stride_a, stride_b, stride_c);
+  _set_mat_mat_div_mat<<<Gr,Bl>>>(A,B,C,dst,d, stride_a, stride_b, stride_c);
 }
 
 void cudaF_sy_add_tr2(dim3 Gr, dim3 Bl, float alpha, float beta, const float* T,
@@ -3505,11 +3505,11 @@ void cudaD_add_mat_blocks(dim3 Gr, dim3 Bl, double alpha, const double* src,
   }
 }
 
-void cudaD_add_mat_mat_div_mat(dim3 Gr, dim3 Bl, const double *A,
+void cudaD_set_mat_mat_div_mat(dim3 Gr, dim3 Bl, const double *A,
                                const double *B, const double *C, double *dst,
                                MatrixDim d, int stride_a, int stride_b,
                                int stride_c) {
-  _add_mat_mat_div_mat<<<Gr,Bl>>>(A,B,C,dst,d,stride_a,stride_b,stride_c);
+  _set_mat_mat_div_mat<<<Gr,Bl>>>(A,B,C,dst,d,stride_a,stride_b,stride_c);
 }
 
 void cudaD_sy_add_tr2(dim3 Gr, dim3 Bl, double alpha, double beta,
index 748418e5f2f43f2c0d104384c959944b0e71c93d..a6e81db5d6c905c26c01ac1aac319bf2a116f80f 100644 (file)
@@ -337,11 +337,11 @@ inline void cuda_add_mat_blocks(dim3 Gr, dim3 Bl, float alpha, const float *src,
   cudaF_add_mat_blocks(Gr, Bl, alpha, src, num_row_blocks, num_col_blocks, dst,
                        d, src_stride, A_trans);
 }
-inline void cuda_add_mat_mat_div_mat(dim3 Gr, dim3 Bl, const float *A,
+inline void cuda_set_mat_mat_div_mat(dim3 Gr, dim3 Bl, const float *A,
                                      const float *B, const float *C, float *dst,
                                      MatrixDim d, int stride_a, int stride_b,
                                      int stride_c) {
-  cudaF_add_mat_mat_div_mat(Gr, Bl, A, B, C, dst, d, stride_a, stride_b,
+  cudaF_set_mat_mat_div_mat(Gr, Bl, A, B, C, dst, d, stride_a, stride_b,
                             stride_c);
 }
 inline void cuda_add_vec_to_cols(dim3 Gr, dim3 Bl, float alpha,
@@ -872,11 +872,11 @@ inline void cuda_add_mat_blocks(dim3 Gr, dim3 Bl, double alpha,
   cudaD_add_mat_blocks(Gr, Bl, alpha, src, num_row_blocks, num_col_blocks, dst,
                        d, src_stride, A_trans);
 }
-inline void cuda_add_mat_mat_div_mat(dim3 Gr, dim3 Bl, const double *A,
+inline void cuda_set_mat_mat_div_mat(dim3 Gr, dim3 Bl, const double *A,
                                      const double *B, const double *C,
                                      double *dst, MatrixDim d, int stride_a,
                                      int stride_b, int stride_c) {
-  cudaD_add_mat_mat_div_mat(Gr, Bl, A, B, C, dst, d, stride_a, stride_b,
+  cudaD_set_mat_mat_div_mat(Gr, Bl, A, B, C, dst, d, stride_a, stride_b,
                             stride_c);
 }
 inline void cuda_add_vec_to_cols(dim3 Gr, dim3 Bl, double alpha,
index 739cd8eeb22b0f2a1bd29b24f5045a62eb31f1aa..da587e450e3465bf9ea2c855dfdab7e8b2fa15be 100644 (file)
@@ -1077,7 +1077,7 @@ template<typename Real> static void UnitTestCuMatrixAddMatMatElements() {
   KALDI_ASSERT(M.Sum() != 0.0);
 }
 
-template<typename Real> static void UnitTestCuMatrixAddMatMatDivMat() {
+template<typename Real> static void UnitTestCuMatrixSetMatMatDivMat() {
   // M = a * b / c (by element; when c = 0, M = a)
   MatrixIndexT dimM = 100 + Rand() % 255, dimN = 100 + Rand() % 255;
   CuMatrix<Real> M(dimM, dimN), A(dimM, dimN), B(dimM, dimN), C(dimM, dimN);
@@ -1087,13 +1087,13 @@ template<typename Real> static void UnitTestCuMatrixAddMatMatDivMat() {
   B.SetRandn();
   C.SetRandn();
 
-  M.AddMatMatDivMat(A,B,C);
+  M.SetMatMatDivMat(A,B,C);
   ref.AddMatMatElements(1.0, A, B, 0.0);
   ref.DivElements(C);
   AssertEqual(M, ref);
 
   C.SetZero();
-  M.AddMatMatDivMat(A,B,C);
+  M.SetMatMatDivMat(A,B,C);
   AssertEqual(M, A);
 }
 
@@ -2665,7 +2665,7 @@ template<typename Real> void CudaMatrixUnitTest() {
   UnitTestCuMatrixAddDiagVecMat<Real>();
   UnitTestCuMatrixAddMatDiagVec<Real>();
   UnitTestCuMatrixAddMatMatElements<Real>();
-  UnitTestCuMatrixAddMatMatDivMat<Real>();
+  UnitTestCuMatrixSetMatMatDivMat<Real>();
   UnitTestCuTanh<Real>();
   UnitTestCuCholesky<Real>();
   UnitTestCuDiffTanh<Real>();
index 8d0ad950f2fe0a1368f0bbee0b66a10a6a6b87eb..afe884b2b762d2cc1a16e228cc319779bf28aa50 100644 (file)
@@ -989,7 +989,7 @@ void CuMatrixBase<Real>::AddMatBlocks(Real alpha, const CuMatrixBase<Real> &A,
 /// dst = a * b / c (by element; when c = 0, dst = a)
 /// dst can be an alias of a, b or c safely and get expected result.
 template<typename Real>
-void CuMatrixBase<Real>::AddMatMatDivMat(const CuMatrixBase<Real> &A,
+void CuMatrixBase<Real>::SetMatMatDivMat(const CuMatrixBase<Real> &A,
                     const CuMatrixBase<Real> &B, const CuMatrixBase<Real> &C) {
 #if HAVE_CUDA == 1
   if (CuDevice::Instantiate().Enabled()) {
@@ -1002,7 +1002,7 @@ void CuMatrixBase<Real>::AddMatMatDivMat(const CuMatrixBase<Real> &A,
     dim3 dimGrid, dimBlock;
     GetBlockSizesForSimpleMatrixOperation(NumRows(), NumCols(),
                                           &dimGrid, &dimBlock);
-    cuda_add_mat_mat_div_mat(dimGrid, dimBlock, A.data_, B.data_, C.data_,
+    cuda_set_mat_mat_div_mat(dimGrid, dimBlock, A.data_, B.data_, C.data_,
                              data_, Dim(), A.Stride(), B.Stride(), C.Stride());
     CU_SAFE_CALL(cudaGetLastError());
 
@@ -1010,7 +1010,7 @@ void CuMatrixBase<Real>::AddMatMatDivMat(const CuMatrixBase<Real> &A,
   } else
 #endif
   {
-    Mat().AddMatMatDivMat(A.Mat(), B.Mat(), C.Mat());
+    Mat().SetMatMatDivMat(A.Mat(), B.Mat(), C.Mat());
   }
 }
 
index f72484f18e71a9259323ac949e463b93967897bf..38a6c25071bfcf6cb2c2f62aa5a06a0f502985d0 100644 (file)
@@ -429,7 +429,7 @@ class CuMatrixBase {
   void AddVecVec(Real alpha, const CuVectorBase<Real> &x, const CuVectorBase<Real> &y);
   /// *this = a * b / c (by element; when c = 0, *this = a)
   /// *this can be an alias of a, b or c safely and get expected result.
-  void AddMatMatDivMat(const CuMatrixBase<Real> &A, const CuMatrixBase<Real> &B, const CuMatrixBase<Real> &C);
+  void SetMatMatDivMat(const CuMatrixBase<Real> &A, const CuMatrixBase<Real> &B, const CuMatrixBase<Real> &C);
 
   /// *this = beta * *this + alpha * M M^T, for symmetric matrices.  It only
   /// updates the lower triangle of *this.  It will leave the matrix asymmetric;
index 817de100656b592336856e2a4b02a321cedd3ea9..cb7d3be0ceebcf789a79163397e0bb07125af7b5 100644 (file)
@@ -179,9 +179,9 @@ void MatrixBase<Real>::AddMatMat(const Real alpha,
 }
 
 template<typename Real>
-void MatrixBase<Real>::AddMatMatDivMat(const MatrixBase<Real>& A,
-                                     const MatrixBase<Real>& B,
-                                   const MatrixBase<Real>& C) {
+void MatrixBase<Real>::SetMatMatDivMat(const MatrixBase<Real>& A,
+                                       const MatrixBase<Real>& B,
+                                       const MatrixBase<Real>& C) {
   KALDI_ASSERT(A.NumRows() == B.NumRows() && A.NumCols() == B.NumCols());
   KALDI_ASSERT(A.NumRows() == C.NumRows() && A.NumCols() == C.NumCols());
   for (int32 r = 0; r < A.NumRows(); r++) { // each frame...
index 5b4216002fbb5847c9f6fb1cbbcd0f678f2848f6..e254fcad11846e459efb1d54e1276621cf89b1eb 100644 (file)
@@ -579,8 +579,8 @@ class MatrixBase {
                  const Real beta);
 
   /// *this = a * b / c (by element; when c = 0, *this = a)
-  void AddMatMatDivMat(const MatrixBase<Real>& A,
-                        const MatrixBase<Real>& B,
+  void SetMatMatDivMat(const MatrixBase<Real>& A,
+                       const MatrixBase<Real>& B,
                        const MatrixBase<Real>& C);
 
   /// A version of AddMatMat specialized for when the second argument
index 498cc809e5f6278b9971e018a8c717574a03f8b9..9608a5475e0854cc8a65e63bdb5686b15d3b61c3 100644 (file)
@@ -3593,7 +3593,7 @@ void DropoutComponent::Backprop(const ChunkInfo &,  //in_info,
                                 CuMatrix<BaseFloat> *in_deriv) const  {
   KALDI_ASSERT(SameDim(in_value, out_value) && SameDim(in_value, out_deriv));
   in_deriv->Resize(out_deriv.NumRows(), out_deriv.NumCols());
-  in_deriv->AddMatMatDivMat(out_deriv, out_value, in_value);
+  in_deriv->SetMatMatDivMat(out_deriv, out_value, in_value);
 }
 
 Component* DropoutComponent::Copy() const {
index cdb43473090f52087877e9c8ceb8bb0493d0cafe..168a2a5350a37ac3558253e4ee4ba60edf517534 100644 (file)
@@ -141,6 +141,8 @@ Component* Component::NewComponentOfType(const std::string &component_type) {
     ans = new StatisticsPoolingComponent();
   } else if (component_type == "ConstantFunctionComponent") {
     ans = new ConstantFunctionComponent();
+  } else if (component_type == "DropoutComponent") {
+    ans = new DropoutComponent();
   }
   if (ans != NULL) {
     KALDI_ASSERT(component_type == ans->Type());
index 90b463ec57818d38ae84fc6672c0d00efc88d048..164f9d056e7f692f3b4672dffe805aee48c3d018 100644 (file)
@@ -350,6 +350,16 @@ class Component {
 };
 
 
+class RandomComponent: public Component {
+ public:
+  // This function is required in testing code and in other places we need
+  // consistency in the random number generation (e.g. when optimizing
+  // validation-set performance), but check where else we call srand().  You'll
+  // need to call srand as well as making this call.
+  void ResetGenerator() { random_generator_.SeedGpu(); }
+ protected:
+  CuRand<BaseFloat> random_generator_;
+};
 
 /**
  * Class UpdatableComponent is a Component which has trainable parameters; it
index 51760d67557f61b29f5adc207c88b802b827c701..3cc6af1c70dbf944a47452b369ad2b965daea14a 100644 (file)
 
 namespace kaldi {
 namespace nnet3 {
-
+// Reset seeds for test time for RandomComponent
+static void ResetSeed(int32 rand_seed, const Component &c) {
+  RandomComponent *rand_component = 
+    const_cast<RandomComponent*>(dynamic_cast<const RandomComponent*>(&c));
+  
+  if (rand_component != NULL) {
+    srand(rand_seed);
+    rand_component->ResetGenerator();
+  }
+}
 // returns true if two are string are equal except for what looks like it might
 // be a difference last digit of a floating point number, e.g. accept
 // 1.234 to be the same as 1.235.  Not very rigorous.
@@ -188,6 +197,8 @@ void TestNnetComponentUpdatable(Component *c) {
 void TestSimpleComponentPropagateProperties(const Component &c) {
   int32 properties = c.Properties();
   Component *c_copy = NULL, *c_copy_scaled = NULL;
+  int32 rand_seed = Rand();
   if (RandInt(0, 1) == 0)
     c_copy = c.Copy();  // This will test backprop with an updatable component.
   if (RandInt(0, 1) == 0 &&
@@ -223,10 +234,14 @@ void TestSimpleComponentPropagateProperties(const Component &c) {
   if ((properties & kPropagateAdds) && (properties & kPropagateInPlace)) {
     KALDI_ERR << "kPropagateAdds and kPropagateInPlace flags are incompatible.";
   }
-
+  
+  ResetSeed(rand_seed, c);
   c.Propagate(NULL, input_data, &output_data1);
+
+  ResetSeed(rand_seed, c);
   c.Propagate(NULL, input_data, &output_data2);
   if (properties & kPropagateInPlace) {
+    ResetSeed(rand_seed, c);
     c.Propagate(NULL, output_data3, &output_data3);
     if (!output_data1.ApproxEqual(output_data3)) {
       KALDI_ERR << "Test of kPropagateInPlace flag for component of type "
@@ -238,12 +253,14 @@ void TestSimpleComponentPropagateProperties(const Component &c) {
   AssertEqual(output_data1, output_data2);
 
   if (c_copy_scaled) {
+    ResetSeed(rand_seed, *c_copy_scaled);
     c_copy_scaled->Propagate(NULL, input_data, &output_data4);
     output_data4.Scale(2.0);  // we scaled the parameters by 0.5 above, and the
     // output is supposed to be linear in the parameter value.
     AssertEqual(output_data1, output_data4);
   }
   if (properties & kLinearInInput) {
+    ResetSeed(rand_seed, c);
     c.Propagate(NULL, input_data_scaled, &output_data5);
     output_data5.Scale(0.5);
     AssertEqual(output_data1, output_data5);
@@ -302,14 +319,16 @@ bool TestSimpleComponentDataDerivative(const Component &c,
 
   int32 input_dim = c.InputDim(),
       output_dim = c.OutputDim(),
-      num_rows = RandInt(1, 100);
+      num_rows = RandInt(1, 100),
+      rand_seed = Rand();
   int32 properties = c.Properties();
   CuMatrix<BaseFloat> input_data(num_rows, input_dim, kSetZero, input_stride_type),
       output_data(num_rows, output_dim, kSetZero, output_stride_type),
       output_deriv(num_rows, output_dim, kSetZero, output_stride_type);
   input_data.SetRandn();
   output_deriv.SetRandn();
-
+  ResetSeed(rand_seed, c);
   c.Propagate(NULL, input_data, &output_data);
 
   CuMatrix<BaseFloat> input_deriv(num_rows, input_dim, kSetZero, input_stride_type),
@@ -334,6 +353,8 @@ bool TestSimpleComponentDataDerivative(const Component &c,
     predicted_objf_change(i) = TraceMatMat(perturbed_input_data, input_deriv,
                                            kTrans);
     perturbed_input_data.AddMat(1.0, input_data);
+
+    ResetSeed(rand_seed, c);
     c.Propagate(NULL, perturbed_input_data, &perturbed_output_data);
     measured_objf_change(i) = TraceMatMat(output_deriv, perturbed_output_data,
                                           kTrans) - original_objf;
@@ -503,7 +524,7 @@ int main() {
   TestStringsApproxEqual();
   for (kaldi::int32 loop = 0; loop < 2; loop++) {
 #if HAVE_CUDA == 1
-    CuDevice::Instantiate().SetDebugStrideMode(true);
+    //CuDevice::Instantiate().SetDebugStrideMode(true);
     if (loop == 0)
       CuDevice::Instantiate().SelectGpuId("no");
     else
index acd322eb515c7e66e026498477632ce52c9dfa1d..c84df89177dcddf80e4f5427b76780337bf63390 100644 (file)
@@ -897,7 +897,15 @@ void Nnet::RemoveOrphanNodes(bool remove_orphan_inputs) {
   RemoveSomeNodes(orphan_nodes);
 }
 
-
+void Nnet::ResetGenerators() {
+  // resets random-number generators for all random
+  // components.
+  for (int32 c = 0; c < NumComponents(); c++) {
+    RandomComponent *rc = dynamic_cast<RandomComponent*>(GetComponent(c));
+    if (rc != NULL)
+      rc->ResetGenerator();
+  }
+}
 
 } // namespace nnet3
 } // namespace kaldi
index fc10a8bb09f802335f8821a420602334bce0e767..16e8333d5b1a270f67fef40831bb45e3d9de1e11 100644 (file)
@@ -236,7 +236,6 @@ class Nnet {
   // Assignment operator
   Nnet& operator =(const Nnet &nnet);
 
-
   // Removes nodes that are never needed to compute any output.
   void RemoveOrphanNodes(bool remove_orphan_inputs = false);
 
@@ -247,6 +246,10 @@ class Nnet {
   // as it could ruin the graph structure if done carelessly.
   void RemoveSomeNodes(const std::vector<int32> &nodes_to_remove);
 
+  void ResetGenerators(); // resets random-number generators for all
+  // random components.  You must also set srand() for this to be
+  // effective.
+  
  private:
 
   void Destroy();
@@ -323,7 +326,6 @@ class Nnet {
 };
 
 
-
 } // namespace nnet3
 } // namespace kaldi
 
index ec9f226cf9f166344bc06a811e199ee5c9af8915..6940ba8302aacd12c479d21d04620ccf0cd626dd 100644 (file)
@@ -86,6 +86,85 @@ void PnormComponent::Write(std::ostream &os, bool binary) const {
 }
 
 
+void DropoutComponent::Init(int32 dim, BaseFloat dropout_proportion) {
+  dropout_proportion_ = dropout_proportion;
+  dim_ = dim;
+}
+
+void DropoutComponent::InitFromConfig(ConfigLine *cfl) {
+  int32 dim = 0;
+  BaseFloat dropout_proportion = 0.0;
+  bool ok = cfl->GetValue("dim", &dim) &&
+    cfl->GetValue("dropout-proportion", &dropout_proportion);
+  if (!ok || cfl->HasUnusedValues() || dim <= 0 || 
+      dropout_proportion < 0.0 || dropout_proportion > 1.0)
+    KALDI_ERR << "Invalid initializer for layer of type " 
+              << Type() << ": \"" << cfl->WholeLine() << "\"";   
+  Init(dim, dropout_proportion);
+}
+
+std::string DropoutComponent::Info() const {
+  std::ostringstream stream;
+  stream << Type() << ", dim = " << dim_ 
+         << ", dropout-proportion = " << dropout_proportion_;
+  return stream.str();
+}
+
+void DropoutComponent::Propagate(const ComponentPrecomputedIndexes *indexes,
+                                 const CuMatrixBase<BaseFloat> &in,
+                                 CuMatrixBase<BaseFloat> *out) const {
+  KALDI_ASSERT(out->NumRows() == in.NumRows() && out->NumCols() == in.NumCols()
+               && in.NumCols() == dim_);
+
+  BaseFloat dropout = dropout_proportion_;
+  KALDI_ASSERT(dropout >= 0.0 && dropout <= 1.0);
+
+  // This const_cast is only safe assuming you don't attempt  
+  // to use multi-threaded code with the GPU.
+  const_cast<CuRand<BaseFloat>&>(random_generator_).RandUniform(out); 
+
+  out->Add(-dropout); // now, a proportion "dropout" will be <0.0 
+  out->ApplyHeaviside(); // apply the function (x>0?1:0).  Now, a proportion "dropout" will 
+                         // be zero and (1 - dropout) will be 1.0.
+
+  out->MulElements(in);
+}
+
+
+void DropoutComponent::Backprop(const std::string &debug_info,
+                                const ComponentPrecomputedIndexes *indexes,
+                                const CuMatrixBase<BaseFloat> &in_value,
+                                const CuMatrixBase<BaseFloat> &out_value,
+                                const CuMatrixBase<BaseFloat> &out_deriv,
+                                Component *to_update,
+                                CuMatrixBase<BaseFloat> *in_deriv) const {
+  KALDI_ASSERT(in_value.NumRows() == out_value.NumRows() &&
+               in_value.NumCols() == out_value.NumCols());
+
+  KALDI_ASSERT(in_value.NumRows() == out_deriv.NumRows() &&
+               in_value.NumCols() == out_deriv.NumCols());
+  in_deriv->SetMatMatDivMat(out_deriv, out_value, in_value);
+}
+
+
+void DropoutComponent::Read(std::istream &is, bool binary) {
+  ExpectOneOrTwoTokens(is, binary, "<DropoutComponent>", "<Dim>");
+  ReadBasicType(is, binary, &dim_);
+  ExpectToken(is, binary, "<DropoutProportion>");
+  ReadBasicType(is, binary, &dropout_proportion_);
+  ExpectToken(is, binary, "</DropoutComponent>");
+}
+
+void DropoutComponent::Write(std::ostream &os, bool binary) const {
+  WriteToken(os, binary, "<DropoutComponent>");
+  WriteToken(os, binary, "<Dim>");
+  WriteBasicType(os, binary, dim_);
+  WriteToken(os, binary, "<DropoutProportion>");
+  WriteBasicType(os, binary, dropout_proportion_);
+  WriteToken(os, binary, "</DropoutComponent>");
+}
+
 void SumReduceComponent::Init(int32 input_dim, int32 output_dim)  {
   input_dim_ = input_dim;
   output_dim_ = output_dim;
index 5d4b43eb8205f03c00933c5da18ab547fbc45663..8060e4c92f0c5912640941df016477b222675a51 100644 (file)
@@ -80,6 +80,59 @@ class PnormComponent: public Component {
   int32 output_dim_;
 };
 
+// This component randomly zeros dropout_proportion of the input 
+// and the derivatives are backpropagated through the nonzero inputs.
+// Typically this component used during training but not in test time.
+// The idea is described under the name Dropout, in the paper 
+// "Dropout: A Simple Way to Prevent Neural Networks from Overfitting".
+class DropoutComponent : public RandomComponent {
+ public:
+  void Init(int32 dim, BaseFloat dropout_proportion = 0.0);
+
+  DropoutComponent(int32 dim, BaseFloat dropout = 0.0) { Init(dim, dropout); }
+
+  DropoutComponent(): dim_(0), dropout_proportion_(0.0) { }
+
+  virtual int32 Properties() const {
+    return kLinearInInput|kBackpropInPlace|kSimpleComponent|kBackpropNeedsInput|kBackpropNeedsOutput;
+  }
+  virtual std::string Type() const { return "DropoutComponent"; }
+
+  virtual void InitFromConfig(ConfigLine *cfl);
+
+  virtual int32 InputDim() const { return dim_; }
+
+  virtual int32 OutputDim() const { return dim_; }
+
+  virtual void Read(std::istream &is, bool binary);
+
+  // Write component to stream
+  virtual void Write(std::ostream &os, bool binary) const;
+
+  virtual void Propagate(const ComponentPrecomputedIndexes *indexes,
+                         const CuMatrixBase<BaseFloat> &in,
+                         CuMatrixBase<BaseFloat> *out) const;
+  virtual void Backprop(const std::string &debug_info,
+                        const ComponentPrecomputedIndexes *indexes,
+                        const CuMatrixBase<BaseFloat> &in_value,
+                        const CuMatrixBase<BaseFloat> &out_value,
+                        const CuMatrixBase<BaseFloat> &out_deriv,
+                        Component *to_update,
+                        CuMatrixBase<BaseFloat> *in_deriv) const;
+  virtual Component* Copy() const { return new DropoutComponent(dim_,
+                                                                dropout_proportion_); }
+  virtual std::string Info() const;
+  
+  void SetDropoutProportion(BaseFloat dropout_proportion) { dropout_proportion_ = dropout_proportion; }
+
+ private:
+  int32 dim_;
+  /// dropout-proportion is the proportion that is dropped out,
+  /// e.g. if 0.1, we set 10% to zero value.
+  BaseFloat dropout_proportion_;
+  
+};
+
 class ElementwiseProductComponent: public Component {
  public:
   void Init(int32 input_dim, int32 output_dim);
index dc2696e4e12851e42cdda2b2bc3f77087e63b9f3..e02ae4974c96eb4268198ee19ae0429465788cb9 100644 (file)
@@ -926,7 +926,7 @@ void ComputeExampleComputationRequestSimple(
 static void GenerateRandomComponentConfig(std::string *component_type,
                                           std::string *config) {
 
-  int32 n = RandInt(0, 28);
+  int32 n = RandInt(0, 29);
   BaseFloat learning_rate = 0.001 * RandInt(1, 3);
 
   std::ostringstream os;
@@ -1219,6 +1219,12 @@ static void GenerateRandomComponentConfig(std::string *component_type,
         os << " self-repair-target=" << RandUniform();
       break;
     }
+    case 29: {
+      *component_type = "DropoutComponent";
+      os << "dim=" << RandInt(1, 200)
+         << " dropout-proportion=" << RandUniform();
+      break;
+    }
     default:
       KALDI_ERR << "Error generating random component";
   }
index 0cb7d1fe9b35183f570e8859b4c4b7a95fe7df0e..955e200d0727fe3235c89949e4ae405a0dd334bb 100644 (file)
@@ -496,6 +496,15 @@ std::string NnetInfo(const Nnet &nnet) {
   return ostr.str();
 }
 
+void SetDropoutProportion(BaseFloat dropout_proportion,
+                          Nnet *nnet) {
+  for (int32 c = 0; c < nnet->NumComponents(); c++) {
+    Component *comp = nnet->GetComponent(c);
+    DropoutComponent *dc = dynamic_cast<DropoutComponent*>(comp);
+    if (dc != NULL)
+      dc->SetDropoutProportion(dropout_proportion);
+  }
+}
 
 void FindOrphanComponents(const Nnet &nnet, std::vector<int32> *components) {
   int32 num_components = nnet.NumComponents(), num_nodes = nnet.NumNodes();
index bd14a8c097f9f8752b41a5441c5e1656c72c56ef..9606bd5d5b78b60571765065d56462503512d49b 100644 (file)
@@ -174,6 +174,9 @@ void ConvertRepeatedToBlockAffine(Nnet *nnet);
 /// Info() function (we need this in the CTC code).
 std::string NnetInfo(const Nnet &nnet);
 
+/// This function sets the dropout proportion in all dropout component to 
+/// dropout_proportion value.
+void SetDropoutProportion(BaseFloat dropout_proportion, Nnet *nnet);
 
 /// This function finds a list of components that are never used, and outputs
 /// the integer comopnent indexes (you can use these to index
index d21967585947308155ca9c649d6ff95c8996481b..c419e0e0f9119a785732dd9b3e3f5b265a3f7197 100644 (file)
@@ -41,7 +41,8 @@ int main(int argc, char *argv[]) {
         " nnet3-copy --binary=false 0.raw text.raw\n";
 
     bool binary_write = true;
-    BaseFloat learning_rate = -1;
+    BaseFloat learning_rate = -1,
+      dropout = 0.0;
     std::string nnet_config, edits_config, edits_str;
 
     ParseOptions po(usage);
@@ -61,6 +62,8 @@ int main(int argc, char *argv[]) {
                 "Can be used as an inline alternative to edits-config; semicolons "
                 "will be converted to newlines before parsing.  E.g. "
                 "'--edits=remove-orphans'.");
+    po.Register("set-dropout-proportion", &dropout, "Set dropout proportion "
+                "in all DropoutComponent to this value.");
     po.Read(argc, argv);
 
     if (po.NumArgs() != 2) {
@@ -81,6 +84,9 @@ int main(int argc, char *argv[]) {
 
     if (learning_rate >= 0)
       SetLearningRate(learning_rate, &nnet);
+    
+    if (dropout > 0)
+      SetDropoutProportion(dropout, &nnet);
 
     if (!edits_config.empty()) {
       Input ki(edits_config);