aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorpegahgh2017-08-25 17:27:40 -0500
committerDaniel Povey2017-08-25 17:27:40 -0500
commite5a48fc0ec205d46e1757851f9d2398506bf11d9 (patch)
treef50936271e4d6472c5a5e126186d4d116ee3e0a3
parent1d1373130820c09c5a3da5f40355cd995602cd46 (diff)
downloadkaldi-e5a48fc0ec205d46e1757851f9d2398506bf11d9.tar.gz
kaldi-e5a48fc0ec205d46e1757851f9d2398506bf11d9.tar.xz
kaldi-e5a48fc0ec205d46e1757851f9d2398506bf11d9.zip
[src] fixed copy-constructor bug RE test_mode option for DropoutComponent (#1845)
-rw-r--r--src/nnet3/nnet-component-itf.h3
-rw-r--r--src/nnet3/nnet-simple-component.cc10
-rw-r--r--src/nnet3/nnet-simple-component.h8
3 files changed, 18 insertions, 3 deletions
diff --git a/src/nnet3/nnet-component-itf.h b/src/nnet3/nnet-component-itf.h
index 060238203..f94228a1c 100644
--- a/src/nnet3/nnet-component-itf.h
+++ b/src/nnet3/nnet-component-itf.h
@@ -403,6 +403,9 @@ class RandomComponent: public Component {
403 void SetTestMode(bool test_mode) { test_mode_ = test_mode; } 403 void SetTestMode(bool test_mode) { test_mode_ = test_mode; }
404 404
405 RandomComponent(): test_mode_(false) { } 405 RandomComponent(): test_mode_(false) { }
406
407 RandomComponent(const RandomComponent &other):
408 test_mode_(other.test_mode_) {}
406 protected: 409 protected:
407 CuRand<BaseFloat> random_generator_; 410 CuRand<BaseFloat> random_generator_;
408 411
diff --git a/src/nnet3/nnet-simple-component.cc b/src/nnet3/nnet-simple-component.cc
index dff3b0497..a77da760c 100644
--- a/src/nnet3/nnet-simple-component.cc
+++ b/src/nnet3/nnet-simple-component.cc
@@ -88,6 +88,16 @@ void PnormComponent::Write(std::ostream &os, bool binary) const {
88 WriteToken(os, binary, "</PnormComponent>"); 88 WriteToken(os, binary, "</PnormComponent>");
89} 89}
90 90
91DropoutComponent::DropoutComponent(const DropoutComponent &other):
92 RandomComponent(other),
93 dim_(other.dim_),
94 dropout_proportion_(other.dropout_proportion_),
95 dropout_per_frame_(other.dropout_per_frame_) { }
96
97Component* DropoutComponent::Copy() const {
98 DropoutComponent *ans = new DropoutComponent(*this);
99 return ans;
100}
91 101
92void DropoutComponent::Init(int32 dim, BaseFloat dropout_proportion, 102void DropoutComponent::Init(int32 dim, BaseFloat dropout_proportion,
93 bool dropout_per_frame) { 103 bool dropout_per_frame) {
diff --git a/src/nnet3/nnet-simple-component.h b/src/nnet3/nnet-simple-component.h
index a10df8c88..5649d4e54 100644
--- a/src/nnet3/nnet-simple-component.h
+++ b/src/nnet3/nnet-simple-component.h
@@ -99,6 +99,8 @@ class DropoutComponent : public RandomComponent {
99 DropoutComponent(): dim_(0), dropout_proportion_(0.0), 99 DropoutComponent(): dim_(0), dropout_proportion_(0.0),
100 dropout_per_frame_(false) { } 100 dropout_per_frame_(false) { }
101 101
102 DropoutComponent(const DropoutComponent &other);
103
102 virtual int32 Properties() const { 104 virtual int32 Properties() const {
103 return kLinearInInput|kBackpropInPlace|kSimpleComponent|kBackpropNeedsInput| 105 return kLinearInInput|kBackpropInPlace|kSimpleComponent|kBackpropNeedsInput|
104 kBackpropNeedsOutput|kRandomComponent; 106 kBackpropNeedsOutput|kRandomComponent;
@@ -127,9 +129,9 @@ class DropoutComponent : public RandomComponent {
127 void *memo, 129 void *memo,
128 Component *to_update, 130 Component *to_update,
129 CuMatrixBase<BaseFloat> *in_deriv) const; 131 CuMatrixBase<BaseFloat> *in_deriv) const;
130 virtual Component* Copy() const { return new DropoutComponent(dim_, 132
131 dropout_proportion_, 133 virtual Component* Copy() const;
132 dropout_per_frame_); } 134
133 virtual std::string Info() const; 135 virtual std::string Info() const;
134 136
135 void SetDropoutProportion(BaseFloat dropout_proportion) { 137 void SetDropoutProportion(BaseFloat dropout_proportion) {