574c707db108284cc2d213d3fd08f1a875d249cc
1 // nnet3/nnet-component-itf.cc
3 // Copyright 2015 Johns Hopkins University (author: Daniel Povey)
4 // 2015 Guoguo Chen
6 // See ../../COPYING for clarification regarding multiple authors
7 //
8 // Licensed under the Apache License, Version 2.0 (the "License");
9 // you may not use this file except in compliance with the License.
10 // You may obtain a copy of the License at
11 //
12 // http://www.apache.org/licenses/LICENSE-2.0
13 //
14 // THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 // KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
16 // WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
17 // MERCHANTABLITY OR NON-INFRINGEMENT.
18 // See the Apache 2 License for the specific language governing permissions and
19 // limitations under the License.
21 #include <iterator>
22 #include <sstream>
23 #include <iomanip>
24 #include "nnet3/nnet-component-itf.h"
25 #include "nnet3/nnet-simple-component.h"
26 #include "nnet3/nnet-general-component.h"
27 #include "nnet3/nnet-convolutional-component.h"
28 #include "nnet3/nnet-parse.h"
29 #include "nnet3/nnet-computation-graph.h"
32 // \file This file contains some more-generic component code: things in base classes.
33 // See nnet-component.cc for the code of the actual Components.
35 namespace kaldi {
36 namespace nnet3 {
38 ComponentPrecomputedIndexes* ComponentPrecomputedIndexes::ReadNew(std::istream &is,
39 bool binary) {
40 std::string token;
41 ReadToken(is, binary, &token); // e.g. "<DistributePrecomputedComponentIndexes>".
42 token.erase(0, 1); // erase "<".
43 token.erase(token.length()-1); // erase ">".
44 ComponentPrecomputedIndexes *ans = NewComponentPrecomputedIndexesOfType(token);
45 if (!ans)
46 KALDI_ERR << "Unknown ComponentPrecomputedIndexes type " << token;
47 ans->Read(is, binary);
48 return ans;
49 }
51 ComponentPrecomputedIndexes* ComponentPrecomputedIndexes::NewComponentPrecomputedIndexesOfType(
52 const std::string &cpi_type) {
53 ComponentPrecomputedIndexes *ans = NULL;
54 if (cpi_type == "DistributeComponentPrecomputedIndexes") {
55 ans = new DistributeComponentPrecomputedIndexes();
56 } else if (cpi_type == "StatisticsExtractionComponentPrecomputedIndexes") {
57 ans = new StatisticsExtractionComponentPrecomputedIndexes();
58 } else if (cpi_type == "StatisticsPoolingComponentPrecomputedIndexes") {
59 ans = new StatisticsPoolingComponentPrecomputedIndexes();
60 } else if (cpi_type == "BackpropTruncationComponentPrecomputedIndexes") {
61 ans = new BackpropTruncationComponentPrecomputedIndexes();
62 } else if (cpi_type == "TimeHeightConvolutionComponentPrecomputedIndexes") {
63 ans = new TimeHeightConvolutionComponent::PrecomputedIndexes();
64 }
65 if (ans != NULL) {
66 KALDI_ASSERT(cpi_type == ans->Type());
67 }
68 return ans;
69 }
71 // static
72 Component* Component::ReadNew(std::istream &is, bool binary) {
73 std::string token;
74 ReadToken(is, binary, &token); // e.g. "<SigmoidComponent>".
75 token.erase(0, 1); // erase "<".
76 token.erase(token.length()-1); // erase ">".
77 Component *ans = NewComponentOfType(token);
78 if (!ans)
79 KALDI_ERR << "Unknown component type " << token;
80 ans->Read(is, binary);
81 return ans;
82 }
85 // static
86 Component* Component::NewComponentOfType(const std::string &component_type) {
87 Component *ans = NULL;
88 if (component_type == "SigmoidComponent") {
89 ans = new SigmoidComponent();
90 } else if (component_type == "TanhComponent") {
91 ans = new TanhComponent();
92 } else if (component_type == "SoftmaxComponent") {
93 ans = new SoftmaxComponent();
94 } else if (component_type == "LogSoftmaxComponent") {
95 ans = new LogSoftmaxComponent();
96 } else if (component_type == "RectifiedLinearComponent") {
97 ans = new RectifiedLinearComponent();
98 } else if (component_type == "NormalizeComponent") {
99 ans = new NormalizeComponent();
100 } else if (component_type == "PnormComponent") {
101 ans = new PnormComponent();
102 } else if (component_type == "AffineComponent") {
103 ans = new AffineComponent();
104 } else if (component_type == "NaturalGradientAffineComponent") {
105 ans = new NaturalGradientAffineComponent();
106 } else if (component_type == "PerElementScaleComponent") {
107 ans = new PerElementScaleComponent();
108 } else if (component_type == "NaturalGradientPerElementScaleComponent") {
109 ans = new NaturalGradientPerElementScaleComponent();
110 } else if (component_type == "PerElementOffsetComponent") {
111 ans = new PerElementOffsetComponent();
112 } else if (component_type == "SumGroupComponent") {
113 ans = new SumGroupComponent();
114 } else if (component_type == "FixedAffineComponent") {
115 ans = new FixedAffineComponent();
116 } else if (component_type == "FixedScaleComponent") {
117 ans = new FixedScaleComponent();
118 } else if (component_type == "FixedBiasComponent") {
119 ans = new FixedBiasComponent();
120 } else if (component_type == "NoOpComponent") {
121 ans = new NoOpComponent();
122 } else if (component_type == "ClipGradientComponent") {
123 ans = new ClipGradientComponent();
124 } else if (component_type == "ElementwiseProductComponent") {
125 ans = new ElementwiseProductComponent();
126 } else if (component_type == "ConvolutionComponent") {
127 ans = new ConvolutionComponent();
128 } else if (component_type == "MaxpoolingComponent") {
129 ans = new MaxpoolingComponent();
130 } else if (component_type == "PermuteComponent") {
131 ans = new PermuteComponent();
132 } else if (component_type == "DistributeComponent") {
133 ans = new DistributeComponent();
134 } else if (component_type == "CompositeComponent") {
135 ans = new CompositeComponent();
136 } else if (component_type == "RepeatedAffineComponent") {
137 ans = new RepeatedAffineComponent();
138 } else if (component_type == "BlockAffineComponent") {
139 ans = new BlockAffineComponent();
140 } else if (component_type == "NaturalGradientRepeatedAffineComponent") {
141 ans = new NaturalGradientRepeatedAffineComponent();
142 } else if (component_type == "StatisticsExtractionComponent") {
143 ans = new StatisticsExtractionComponent();
144 } else if (component_type == "StatisticsPoolingComponent") {
145 ans = new StatisticsPoolingComponent();
146 } else if (component_type == "ConstantFunctionComponent") {
147 ans = new ConstantFunctionComponent();
148 } else if (component_type == "ConstantComponent") {
149 ans = new ConstantComponent();
150 } else if (component_type == "DropoutComponent") {
151 ans = new DropoutComponent();
152 } else if (component_type == "DropoutMaskComponent") {
153 ans = new DropoutMaskComponent();
154 } else if (component_type == "BackpropTruncationComponent") {
155 ans = new BackpropTruncationComponent();
156 } else if (component_type == "LstmNonlinearityComponent") {
157 ans = new LstmNonlinearityComponent();
158 } else if (component_type == "BatchNormComponent") {
159 ans = new BatchNormComponent();
160 } else if (component_type == "TimeHeightConvolutionComponent") {
161 ans = new TimeHeightConvolutionComponent();
162 } else if (component_type == "SumBlockComponent") {
163 ans = new SumBlockComponent();
164 }
165 if (ans != NULL) {
166 KALDI_ASSERT(component_type == ans->Type());
167 }
168 return ans;
169 }
171 std::string Component::Info() const {
172 std::stringstream stream;
173 stream << Type() << ", input-dim=" << InputDim()
174 << ", output-dim=" << OutputDim();
175 return stream.str();
176 }
178 void Component::GetInputIndexes(const MiscComputationInfo &misc_info,
179 const Index &output_index,
180 std::vector<Index> *input_indexes) const {
181 input_indexes->resize(1);
182 (*input_indexes)[0] = output_index;
183 }
185 bool Component::IsComputable(const MiscComputationInfo &misc_info,
186 const Index &output_index,
187 const IndexSet &input_index_set,
188 std::vector<Index> *used_inputs) const {
189 // the default Component dependency is for an output index to map directly to
190 // the same input index, which is required to compute the output.
191 if (!input_index_set(output_index))
192 return false;
193 if (used_inputs) {
194 used_inputs->clear();
195 used_inputs->push_back(output_index);
196 }
197 return true;
198 }
201 void UpdatableComponent::InitLearningRatesFromConfig(ConfigLine *cfl) {
202 cfl->GetValue("learning-rate", &learning_rate_);
203 cfl->GetValue("learning-rate-factor", &learning_rate_factor_);
204 max_change_ = 0.0;
205 cfl->GetValue("max-change", &max_change_);
206 if (learning_rate_ < 0.0 || learning_rate_factor_ < 0.0 || max_change_ < 0.0)
207 KALDI_ERR << "Bad initializer " << cfl->WholeLine();
208 }
211 std::string UpdatableComponent::ReadUpdatableCommon(std::istream &is,
212 bool binary) {
213 std::ostringstream opening_tag;
214 opening_tag << '<' << this->Type() << '>';
215 std::string token;
216 ReadToken(is, binary, &token);
217 if (token == opening_tag.str()) {
218 // if the first token is the opening tag, then
219 // ignore it and get the next tag.
220 ReadToken(is, binary, &token);
221 }
222 if (token == "<LearningRateFactor>") {
223 ReadBasicType(is, binary, &learning_rate_factor_);
224 ReadToken(is, binary, &token);
225 } else {
226 learning_rate_factor_ = 1.0;
227 }
228 if (token == "<IsGradient>") {
229 ReadBasicType(is, binary, &is_gradient_);
230 ReadToken(is, binary, &token);
231 } else {
232 is_gradient_ = false;
233 }
234 if (token == "<MaxChange>") {
235 ReadBasicType(is, binary, &max_change_);
236 ReadToken(is, binary, &token);
237 } else {
238 max_change_ = 0.0;
239 }
240 if (token == "<LearningRate>") {
241 ReadBasicType(is, binary, &learning_rate_);
242 return "";
243 } else {
244 return token;
245 }
246 }
248 void UpdatableComponent::WriteUpdatableCommon(std::ostream &os,
249 bool binary) const {
250 std::ostringstream opening_tag;
251 opening_tag << '<' << this->Type() << '>';
252 std::string token;
253 WriteToken(os, binary, opening_tag.str());
254 if (learning_rate_factor_ != 1.0) {
255 WriteToken(os, binary, "<LearningRateFactor>");
256 WriteBasicType(os, binary, learning_rate_factor_);
257 }
258 if (is_gradient_) {
259 WriteToken(os, binary, "<IsGradient>");
260 WriteBasicType(os, binary, is_gradient_);
261 }
262 if (max_change_ > 0.0) {
263 WriteToken(os, binary, "<MaxChange>");
264 WriteBasicType(os, binary, max_change_);
265 }
266 WriteToken(os, binary, "<LearningRate>");
267 WriteBasicType(os, binary, learning_rate_);
268 }
271 std::string UpdatableComponent::Info() const {
272 std::stringstream stream;
273 stream << Type() << ", input-dim=" << InputDim()
274 << ", output-dim=" << OutputDim() << ", learning-rate="
275 << LearningRate();
276 if (is_gradient_)
277 stream << ", is-gradient=true";
278 if (learning_rate_factor_ != 1.0)
279 stream << ", learning-rate-factor=" << learning_rate_factor_;
280 if (max_change_ > 0.0)
281 stream << ", max-change=" << max_change_;
282 return stream.str();
283 }
285 void NonlinearComponent::StoreStatsInternal(
286 const CuMatrixBase<BaseFloat> &out_value,
287 const CuMatrixBase<BaseFloat> *deriv) {
288 KALDI_ASSERT(out_value.NumCols() == InputDim());
289 // Check we have the correct dimensions.
290 if (value_sum_.Dim() != InputDim() ||
291 (deriv != NULL && deriv_sum_.Dim() != InputDim())) {
292 std::lock_guard<std::mutex> lock(mutex_);
293 if (value_sum_.Dim() != InputDim()) {
294 value_sum_.Resize(InputDim());
295 count_ = 0.0;
296 }
297 if (deriv != NULL && deriv_sum_.Dim() != InputDim()) {
298 deriv_sum_.Resize(InputDim());
299 count_ = 0.0;
300 value_sum_.SetZero();
301 }
302 }
303 count_ += out_value.NumRows();
304 CuVector<BaseFloat> temp(InputDim());
305 temp.AddRowSumMat(1.0, out_value, 0.0);
306 value_sum_.AddVec(1.0, temp);
307 if (deriv != NULL) {
308 temp.AddRowSumMat(1.0, *deriv, 0.0);
309 deriv_sum_.AddVec(1.0, temp);
310 }
311 }
313 void NonlinearComponent::ZeroStats() {
314 value_sum_.SetZero();
315 deriv_sum_.SetZero();
316 count_ = 0.0;
317 num_dims_self_repaired_ = 0.0;
318 num_dims_processed_ = 0.0;
319 }
321 std::string NonlinearComponent::Info() const {
322 std::stringstream stream;
323 if (InputDim() == OutputDim()) {
324 stream << Type() << ", dim=" << InputDim();
325 } else {
326 stream << Type() << ", input-dim=" << InputDim()
327 << ", output-dim=" << OutputDim();
328 }
330 if (self_repair_lower_threshold_ != BaseFloat(kUnsetThreshold))
331 stream << ", self-repair-lower-threshold=" << self_repair_lower_threshold_;
332 if (self_repair_upper_threshold_ != BaseFloat(kUnsetThreshold))
333 stream << ", self-repair-upper-threshold=" << self_repair_upper_threshold_;
334 if (self_repair_scale_ != 0.0)
335 stream << ", self-repair-scale=" << self_repair_scale_;
336 if (count_ > 0 && value_sum_.Dim() == dim_) {
337 stream << ", count=" << std::setprecision(3) << count_
338 << std::setprecision(6);
339 stream << ", self-repaired-proportion="
340 << (num_dims_processed_ > 0 ?
341 num_dims_self_repaired_ / num_dims_processed_ : 0);
342 Vector<double> value_avg_dbl(value_sum_);
343 Vector<BaseFloat> value_avg(value_avg_dbl);
344 value_avg.Scale(1.0 / count_);
345 stream << ", value-avg=" << SummarizeVector(value_avg);
346 if (deriv_sum_.Dim() == dim_) {
347 Vector<double> deriv_avg_dbl(deriv_sum_);
348 Vector<BaseFloat> deriv_avg(deriv_avg_dbl);
349 deriv_avg.Scale(1.0 / count_);
350 stream << ", deriv-avg=" << SummarizeVector(deriv_avg);
351 }
352 }
353 return stream.str();
354 }
356 void NonlinearComponent::Scale(BaseFloat scale) {
357 value_sum_.Scale(scale);
358 deriv_sum_.Scale(scale);
359 count_ *= scale;
360 num_dims_self_repaired_ *= scale;
361 num_dims_processed_ *= scale;
362 }
364 void NonlinearComponent::Add(BaseFloat alpha, const Component &other_in) {
365 const NonlinearComponent *other =
366 dynamic_cast<const NonlinearComponent*>(&other_in);
367 KALDI_ASSERT(other != NULL);
368 if (value_sum_.Dim() == 0 && other->value_sum_.Dim() != 0)
369 value_sum_.Resize(other->value_sum_.Dim());
370 if (deriv_sum_.Dim() == 0 && other->deriv_sum_.Dim() != 0)
371 deriv_sum_.Resize(other->deriv_sum_.Dim());
372 if (other->value_sum_.Dim() != 0)
373 value_sum_.AddVec(alpha, other->value_sum_);
374 if (other->deriv_sum_.Dim() != 0)
375 deriv_sum_.AddVec(alpha, other->deriv_sum_);
376 count_ += alpha * other->count_;
377 num_dims_self_repaired_ += alpha * other->num_dims_self_repaired_;
378 num_dims_processed_ += alpha * other->num_dims_processed_;
379 }
381 void NonlinearComponent::Read(std::istream &is, bool binary) {
382 std::ostringstream ostr_beg, ostr_end;
383 ostr_beg << "<" << Type() << ">"; // e.g. "<SigmoidComponent>"
384 ostr_end << "</" << Type() << ">"; // e.g. "</SigmoidComponent>"
385 ExpectOneOrTwoTokens(is, binary, ostr_beg.str(), "<Dim>");
386 ReadBasicType(is, binary, &dim_); // Read dimension.
387 ExpectToken(is, binary, "<ValueAvg>");
388 value_sum_.Read(is, binary);
389 ExpectToken(is, binary, "<DerivAvg>");
390 deriv_sum_.Read(is, binary);
391 ExpectToken(is, binary, "<Count>");
392 ReadBasicType(is, binary, &count_);
393 value_sum_.Scale(count_);
394 deriv_sum_.Scale(count_);
396 std::string token;
397 ReadToken(is, binary, &token);
398 if (token == "<NumDimsSelfRepaired>") {
399 ReadBasicType(is, binary, &num_dims_self_repaired_);
400 ReadToken(is, binary, &token);
401 }
402 if (token == "<NumDimsProcessed>") {
403 ReadBasicType(is, binary, &num_dims_processed_);
404 ReadToken(is, binary, &token);
405 }
406 if (token == "<SelfRepairLowerThreshold>") {
407 ReadBasicType(is, binary, &self_repair_lower_threshold_);
408 ReadToken(is, binary, &token);
409 }
410 if (token == "<SelfRepairUpperThreshold>") {
411 ReadBasicType(is, binary, &self_repair_upper_threshold_);
412 ReadToken(is, binary, &token);
413 }
414 if (token == "<SelfRepairScale>") {
415 ReadBasicType(is, binary, &self_repair_scale_);
416 ReadToken(is, binary, &token);
417 }
418 if (token != ostr_end.str()) {
419 KALDI_ERR << "Expected token " << ostr_end.str()
420 << ", got " << token;
421 }
422 }
424 void NonlinearComponent::Write(std::ostream &os, bool binary) const {
425 std::ostringstream ostr_beg, ostr_end;
426 ostr_beg << "<" << Type() << ">"; // e.g. "<SigmoidComponent>"
427 ostr_end << "</" << Type() << ">"; // e.g. "</SigmoidComponent>"
428 WriteToken(os, binary, ostr_beg.str());
429 WriteToken(os, binary, "<Dim>");
430 WriteBasicType(os, binary, dim_);
431 // Write the values and derivatives in a count-normalized way, for
432 // greater readability in text form.
433 WriteToken(os, binary, "<ValueAvg>");
434 Vector<BaseFloat> temp(value_sum_);
435 if (count_ != 0.0) temp.Scale(1.0 / count_);
436 temp.Write(os, binary);
437 WriteToken(os, binary, "<DerivAvg>");
439 temp.Resize(deriv_sum_.Dim(), kUndefined);
440 temp.CopyFromVec(deriv_sum_);
441 if (count_ != 0.0) temp.Scale(1.0 / count_);
442 temp.Write(os, binary);
443 WriteToken(os, binary, "<Count>");
444 WriteBasicType(os, binary, count_);
445 WriteToken(os, binary, "<NumDimsSelfRepaired>");
446 WriteBasicType(os, binary, num_dims_self_repaired_);
447 WriteToken(os, binary, "<NumDimsProcessed>");
448 WriteBasicType(os, binary, num_dims_processed_);
449 if (self_repair_lower_threshold_ != kUnsetThreshold) {
450 WriteToken(os, binary, "<SelfRepairLowerThreshold>");
451 WriteBasicType(os, binary, self_repair_lower_threshold_);
452 }
453 if (self_repair_upper_threshold_ != kUnsetThreshold) {
454 WriteToken(os, binary, "<SelfRepairUpperThreshold>");
455 WriteBasicType(os, binary, self_repair_upper_threshold_);
456 }
457 if (self_repair_scale_ != 0.0) {
458 WriteToken(os, binary, "<SelfRepairScale>");
459 WriteBasicType(os, binary, self_repair_scale_);
460 }
461 WriteToken(os, binary, ostr_end.str());
462 }
464 NonlinearComponent::NonlinearComponent():
465 dim_(-1), count_(0.0),
466 num_dims_self_repaired_(0.0), num_dims_processed_(0.0),
467 self_repair_lower_threshold_(kUnsetThreshold),
468 self_repair_upper_threshold_(kUnsetThreshold),
469 self_repair_scale_(0.0) { }
471 NonlinearComponent::NonlinearComponent(const NonlinearComponent &other):
472 dim_(other.dim_), value_sum_(other.value_sum_), deriv_sum_(other.deriv_sum_),
473 count_(other.count_),
474 num_dims_self_repaired_(other.num_dims_self_repaired_),
475 num_dims_processed_(other.num_dims_processed_),
476 self_repair_lower_threshold_(other.self_repair_lower_threshold_),
477 self_repair_upper_threshold_(other.self_repair_upper_threshold_),
478 self_repair_scale_(other.self_repair_scale_) { }
480 void NonlinearComponent::InitFromConfig(ConfigLine *cfl) {
481 bool ok = cfl->GetValue("dim", &dim_);
482 cfl->GetValue("self-repair-lower-threshold", &self_repair_lower_threshold_);
483 cfl->GetValue("self-repair-upper-threshold", &self_repair_upper_threshold_);
484 cfl->GetValue("self-repair-scale", &self_repair_scale_);
485 if (!ok || cfl->HasUnusedValues() || dim_ <= 0)
486 KALDI_ERR << "Invalid initializer for layer of type "
487 << Type() << ": \"" << cfl->WholeLine() << "\"";
488 }
492 } // namespace nnet3
493 } // namespace kaldi