1 // nnet3/nnet-simple-component.h
3 // Copyright 2011-2013 Karel Vesely
4 // 2012-2015 Johns Hopkins University (author: Daniel Povey)
5 // 2013 Xiaohui Zhang
6 // 2014-2015 Vijayaditya Peddinti
7 // 2014-2015 Guoguo Chen
8 // 2015 Daniel Galvez
9 // 2015 Tom Ko
11 // See ../../COPYING for clarification regarding multiple authors
12 //
13 // Licensed under the Apache License, Version 2.0 (the "License");
14 // you may not use this file except in compliance with the License.
15 // You may obtain a copy of the License at
16 //
17 // http://www.apache.org/licenses/LICENSE-2.0
18 //
19 // THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
20 // KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
21 // WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
22 // MERCHANTABLITY OR NON-INFRINGEMENT.
23 // See the Apache 2 License for the specific language governing permissions and
24 // limitations under the License.
26 #ifndef KALDI_NNET3_NNET_SIMPLE_COMPONENT_H_
27 #define KALDI_NNET3_NNET_SIMPLE_COMPONENT_H_
29 #include "nnet3/nnet-common.h"
30 #include "nnet3/nnet-component-itf.h"
31 #include "nnet3/natural-gradient-online.h"
32 #include <iostream>
34 namespace kaldi {
35 namespace nnet3 {
37 /// @file This file contains declarations of components that are "simple", meaning
38 /// they don't care about the indexes they are operating on, produce one
39 /// output for one input, and return the kSimpleComponent flag in their
40 /// Properties(): for example, tanh and affine components. In
41 /// nnet-general-component.h there are components that don't fit this pattern.
43 // This "nnet3" version of the p-norm component only supports the 2-norm.
44 class PnormComponent: public Component {
45 public:
46 void Init(int32 input_dim, int32 output_dim);
47 explicit PnormComponent(int32 input_dim, int32 output_dim) {
48 Init(input_dim, output_dim);
49 }
50 virtual int32 Properties() const {
51 return kSimpleComponent|kLinearInInput|kBackpropNeedsInput|kBackpropNeedsOutput;
52 }
53 PnormComponent(): input_dim_(0), output_dim_(0) { }
54 virtual std::string Type() const { return "PnormComponent"; }
55 virtual void InitFromConfig(ConfigLine *cfl);
56 virtual int32 InputDim() const { return input_dim_; }
57 virtual int32 OutputDim() const { return output_dim_; }
58 virtual void Propagate(const ComponentPrecomputedIndexes *indexes,
59 const CuMatrixBase<BaseFloat> &in,
60 CuMatrixBase<BaseFloat> *out) const;
61 virtual void Backprop(const std::string &debug_info,
62 const ComponentPrecomputedIndexes *indexes,
63 const CuMatrixBase<BaseFloat> &in_value,
64 const CuMatrixBase<BaseFloat> &out_value,
65 const CuMatrixBase<BaseFloat> &out_deriv,
66 Component *to_update,
67 CuMatrixBase<BaseFloat> *in_deriv) const;
68 virtual Component* Copy() const { return new PnormComponent(input_dim_,
69 output_dim_); }
71 virtual void Read(std::istream &is, bool binary); // This Read function
72 // requires that the Component has the correct type.
74 /// Write component to stream
75 virtual void Write(std::ostream &os, bool binary) const;
77 protected:
78 int32 input_dim_;
79 int32 output_dim_;
80 };
82 class ElementwiseProductComponent: public Component {
83 public:
84 void Init(int32 input_dim, int32 output_dim);
85 explicit ElementwiseProductComponent(int32 input_dim, int32 output_dim) {
86 Init(input_dim, output_dim);
87 }
88 virtual int32 Properties() const {
89 return kSimpleComponent|kBackpropNeedsInput;
90 }
91 ElementwiseProductComponent(): input_dim_(0), output_dim_(0) { }
92 virtual std::string Type() const { return "ElementwiseProductComponent"; }
93 virtual void InitFromConfig(ConfigLine *cfl);
94 virtual int32 InputDim() const { return input_dim_; }
95 virtual int32 OutputDim() const { return output_dim_; }
96 virtual void Propagate(const ComponentPrecomputedIndexes *indexes,
97 const CuMatrixBase<BaseFloat> &in,
98 CuMatrixBase<BaseFloat> *out) const;
99 virtual void Backprop(const std::string &debug_info,
100 const ComponentPrecomputedIndexes *indexes,
101 const CuMatrixBase<BaseFloat> &in_value,
102 const CuMatrixBase<BaseFloat> &out_value,
103 const CuMatrixBase<BaseFloat> &out_deriv,
104 Component *to_update,
105 CuMatrixBase<BaseFloat> *in_deriv) const;
106 virtual Component* Copy() const { return new ElementwiseProductComponent(input_dim_,
107 output_dim_); }
109 virtual void Read(std::istream &is, bool binary); // This Read function
110 // requires that the Component has the correct type.
112 /// Write component to stream
113 virtual void Write(std::ostream &os, bool binary) const;
115 protected:
116 int32 input_dim_;
117 int32 output_dim_;
118 };
120 class NormalizeComponent: public Component {
121 public:
122 void Init(int32 input_dim, BaseFloat target_rms, bool add_log_stddev);
123 explicit NormalizeComponent(int32 input_dim,
124 BaseFloat target_rms = 1.0,
125 bool add_log_stddev = false) {
126 Init(input_dim, target_rms, add_log_stddev);
127 }
128 explicit NormalizeComponent(const NormalizeComponent &other);
129 virtual int32 Properties() const {
130 return (add_log_stddev_ ?
131 kSimpleComponent|kBackpropNeedsInput|kBackpropAdds :
132 kSimpleComponent|kBackpropNeedsInput|kPropagateInPlace|
133 kBackpropAdds|kBackpropInPlace);
134 }
135 NormalizeComponent(): target_rms_(1.0), add_log_stddev_(false) { }
136 virtual std::string Type() const { return "NormalizeComponent"; }
137 virtual void InitFromConfig(ConfigLine *cfl);
138 virtual Component* Copy() const { return new NormalizeComponent(*this); }
139 virtual void Propagate(const ComponentPrecomputedIndexes *indexes,
140 const CuMatrixBase<BaseFloat> &in,
141 CuMatrixBase<BaseFloat> *out) const;
142 virtual void Backprop(const std::string &debug_info,
143 const ComponentPrecomputedIndexes *indexes,
144 const CuMatrixBase<BaseFloat> &in_value,
145 const CuMatrixBase<BaseFloat> &, // out_value
146 const CuMatrixBase<BaseFloat> &out_deriv,
147 Component *to_update,
148 CuMatrixBase<BaseFloat> *in_deriv) const;
150 virtual void Read(std::istream &is, bool binary);
151 virtual void Write(std::ostream &os, bool binary) const;
152 virtual int32 InputDim() const { return input_dim_; }
153 virtual int32 OutputDim() const {
154 return (input_dim_ + (add_log_stddev_ ? 1 : 0));
155 }
156 virtual std::string Info() const;
157 private:
158 NormalizeComponent &operator = (const NormalizeComponent &other); // Disallow.
159 enum { kExpSquaredNormFloor = -66 };
160 static const BaseFloat kSquaredNormFloor;
161 int32 input_dim_;
162 BaseFloat target_rms_; // The target rms for outputs.
163 // about 0.7e-20. We need a value that's exactly representable in
164 // float and whose inverse square root is also exactly representable
165 // in float (hence, an even power of two).
167 bool add_log_stddev_; // If true, log(max(epsi, sqrt(row_in^T row_in / D)))
168 // is an extra dimension of the output.
169 };
172 class SigmoidComponent: public NonlinearComponent {
173 public:
174 explicit SigmoidComponent(int32 dim): NonlinearComponent(dim) { }
175 explicit SigmoidComponent(const SigmoidComponent &other): NonlinearComponent(other) { }
176 SigmoidComponent() { }
177 virtual std::string Type() const { return "SigmoidComponent"; }
178 virtual int32 Properties() const {
179 return kSimpleComponent|kBackpropNeedsOutput|kPropagateInPlace|kStoresStats;
180 }
181 virtual Component* Copy() const { return new SigmoidComponent(*this); }
182 virtual void Propagate(const ComponentPrecomputedIndexes *indexes,
183 const CuMatrixBase<BaseFloat> &in,
184 CuMatrixBase<BaseFloat> *out) const;
185 virtual void Backprop(const std::string &debug_info,
186 const ComponentPrecomputedIndexes *indexes,
187 const CuMatrixBase<BaseFloat> &, //in_value
188 const CuMatrixBase<BaseFloat> &out_value,
189 const CuMatrixBase<BaseFloat> &out_deriv,
190 Component *to_update,
191 CuMatrixBase<BaseFloat> *in_deriv) const;
192 virtual void StoreStats(const CuMatrixBase<BaseFloat> &out_value);
193 private:
194 SigmoidComponent &operator = (const SigmoidComponent &other); // Disallow.
195 };
197 class TanhComponent: public NonlinearComponent {
198 public:
199 explicit TanhComponent(int32 dim): NonlinearComponent(dim) { }
200 explicit TanhComponent(const TanhComponent &other): NonlinearComponent(other) { }
201 TanhComponent() { }
202 virtual std::string Type() const { return "TanhComponent"; }
203 virtual Component* Copy() const { return new TanhComponent(*this); }
204 virtual int32 Properties() const {
205 return kSimpleComponent|kBackpropNeedsOutput|kPropagateInPlace|kStoresStats;
206 }
207 virtual void Propagate(const ComponentPrecomputedIndexes *indexes,
208 const CuMatrixBase<BaseFloat> &in,
209 CuMatrixBase<BaseFloat> *out) const;
210 virtual void Backprop(const std::string &debug_info,
211 const ComponentPrecomputedIndexes *indexes,
212 const CuMatrixBase<BaseFloat> &, //in_value
213 const CuMatrixBase<BaseFloat> &out_value,
214 const CuMatrixBase<BaseFloat> &out_deriv,
215 Component *to_update,
216 CuMatrixBase<BaseFloat> *in_deriv) const;
217 virtual void StoreStats(const CuMatrixBase<BaseFloat> &out_value);
218 private:
219 TanhComponent &operator = (const TanhComponent &other); // Disallow.
220 };
223 class RectifiedLinearComponent: public NonlinearComponent {
224 public:
225 explicit RectifiedLinearComponent(int32 dim): NonlinearComponent(dim) { }
226 explicit RectifiedLinearComponent(const RectifiedLinearComponent &other): NonlinearComponent(other) { }
227 RectifiedLinearComponent() { }
228 virtual std::string Type() const { return "RectifiedLinearComponent"; }
229 virtual Component* Copy() const { return new RectifiedLinearComponent(*this); }
230 virtual int32 Properties() const {
231 return kSimpleComponent|kLinearInInput|kBackpropNeedsOutput|kPropagateInPlace|
232 kStoresStats;
233 }
234 virtual void Propagate(const ComponentPrecomputedIndexes *indexes,
235 const CuMatrixBase<BaseFloat> &in,
236 CuMatrixBase<BaseFloat> *out) const;
237 virtual void Backprop(const std::string &debug_info,
238 const ComponentPrecomputedIndexes *indexes,
239 const CuMatrixBase<BaseFloat> &, //in_value
240 const CuMatrixBase<BaseFloat> &out_value,
241 const CuMatrixBase<BaseFloat> &out_deriv,
242 Component *to_update,
243 CuMatrixBase<BaseFloat> *in_deriv) const;
244 virtual void StoreStats(const CuMatrixBase<BaseFloat> &out_value);
245 private:
246 RectifiedLinearComponent &operator = (const RectifiedLinearComponent &other); // Disallow.
247 };
249 /**
250 This component is a fixed (non-trainable) nonlinearity that sums its inputs
251 to produce outputs. Currently the only supported configuration is that its
252 input-dim is interpreted as consisting of n blocks, and the output is just a
253 summation over the n blocks, where n = input-dim / output-dim, so for instance
254 output[n] = input[n] + input[block-size + n] + .... .
255 Later if needed we can add a configuration variable that allows you to sum
256 over 'interleaved' input.
257 */
258 class SumReduceComponent: public Component {
259 public:
260 void Init(int32 input_dim, int32 output_dim);
261 explicit SumReduceComponent(int32 input_dim, int32 output_dim) {
262 Init(input_dim, output_dim);
263 }
264 virtual int32 Properties() const {
265 return kSimpleComponent|kLinearInInput;
266 }
267 SumReduceComponent(): input_dim_(0), output_dim_(0) { }
268 virtual std::string Type() const { return "SumReduceComponent"; }
269 virtual void InitFromConfig(ConfigLine *cfl);
270 virtual int32 InputDim() const { return input_dim_; }
271 virtual int32 OutputDim() const { return output_dim_; }
272 virtual void Propagate(const ComponentPrecomputedIndexes *indexes,
273 const CuMatrixBase<BaseFloat> &in,
274 CuMatrixBase<BaseFloat> *out) const;
275 virtual void Backprop(const std::string &debug_info,
276 const ComponentPrecomputedIndexes *indexes,
277 const CuMatrixBase<BaseFloat> &, // in_value
278 const CuMatrixBase<BaseFloat> &, // out_value,
279 const CuMatrixBase<BaseFloat> &out_deriv,
280 Component *, // to_update
281 CuMatrixBase<BaseFloat> *in_deriv) const;
282 virtual Component* Copy() const { return new SumReduceComponent(input_dim_,
283 output_dim_); }
285 virtual void Read(std::istream &is, bool binary); // This Read function
286 // requires that the Component has the correct type.
288 /// Write component to stream
289 virtual void Write(std::ostream &os, bool binary) const;
291 protected:
292 int32 input_dim_;
293 int32 output_dim_;
294 };
297 class FixedAffineComponent;
298 class FixedScaleComponent;
299 class PerElementScaleComponent;
300 class PerElementOffsetComponent;
302 // Affine means a linear function plus an offset.
303 // Note: although this class can be instantiated, it also
304 // functions as a base-class for more specialized versions of
305 // AffineComponent.
306 class AffineComponent: public UpdatableComponent {
307 friend class SoftmaxComponent; // Friend declaration relates to mixing up.
308 public:
310 virtual int32 InputDim() const { return linear_params_.NumCols(); }
311 virtual int32 OutputDim() const { return linear_params_.NumRows(); }
313 virtual std::string Info() const;
314 virtual void InitFromConfig(ConfigLine *cfl);
316 AffineComponent() { } // use Init to really initialize.
317 virtual std::string Type() const { return "AffineComponent"; }
318 virtual int32 Properties() const {
319 return kSimpleComponent|kUpdatableComponent|kLinearInParameters|
320 kBackpropNeedsInput|kBackpropAdds;
321 }
324 virtual void Propagate(const ComponentPrecomputedIndexes *indexes,
325 const CuMatrixBase<BaseFloat> &in,
326 CuMatrixBase<BaseFloat> *out) const;
327 virtual void Backprop(const std::string &debug_info,
328 const ComponentPrecomputedIndexes *indexes,
329 const CuMatrixBase<BaseFloat> &in_value,
330 const CuMatrixBase<BaseFloat> &, // out_value
331 const CuMatrixBase<BaseFloat> &out_deriv,
332 Component *to_update,
333 CuMatrixBase<BaseFloat> *in_deriv) const;
335 virtual void Read(std::istream &is, bool binary);
336 virtual void Write(std::ostream &os, bool binary) const;
338 virtual Component* Copy() const;
341 // Some functions from base-class UpdatableComponent.
342 virtual void Scale(BaseFloat scale);
343 virtual void Add(BaseFloat alpha, const Component &other);
344 virtual void SetZero(bool treat_as_gradient);
345 virtual void PerturbParams(BaseFloat stddev);
346 virtual BaseFloat DotProduct(const UpdatableComponent &other) const;
347 virtual int32 NumParameters() const;
348 virtual void Vectorize(VectorBase<BaseFloat> *params) const;
349 virtual void UnVectorize(const VectorBase<BaseFloat> ¶ms);
351 // Some functions that are specific to this class.
353 // This new function is used when mixing up:
354 virtual void SetParams(const VectorBase<BaseFloat> &bias,
355 const MatrixBase<BaseFloat> &linear);
356 const CuVector<BaseFloat> &BiasParams() { return bias_params_; }
357 const CuMatrix<BaseFloat> &LinearParams() { return linear_params_; }
358 explicit AffineComponent(const AffineComponent &other);
359 // The next constructor is used in converting from nnet1.
360 AffineComponent(const CuMatrixBase<BaseFloat> &linear_params,
361 const CuVectorBase<BaseFloat> &bias_params,
362 BaseFloat learning_rate);
363 void Init(int32 input_dim, int32 output_dim,
364 BaseFloat param_stddev, BaseFloat bias_stddev);
365 void Init(std::string matrix_filename);
367 // This function resizes the dimensions of the component, setting the
368 // parameters to zero, while leaving any other configuration values the same.
369 virtual void Resize(int32 input_dim, int32 output_dim);
371 // The following functions are used for collapsing multiple layers
372 // together. They return a pointer to a new Component equivalent to
373 // the sequence of two components. We haven't implemented this for
374 // FixedLinearComponent yet.
375 Component *CollapseWithNext(const AffineComponent &next) const ;
376 Component *CollapseWithNext(const FixedAffineComponent &next) const;
377 Component *CollapseWithNext(const FixedScaleComponent &next) const;
378 Component *CollapseWithPrevious(const FixedAffineComponent &prev) const;
380 protected:
381 friend class NaturalGradientAffineComponent;
382 // This function Update() is for extensibility; child classes may override
383 // this, e.g. for natural gradient update.
384 virtual void Update(
385 const std::string &debug_info,
386 const CuMatrixBase<BaseFloat> &in_value,
387 const CuMatrixBase<BaseFloat> &out_deriv) {
388 UpdateSimple(in_value, out_deriv);
389 }
390 // UpdateSimple is used when *this is a gradient. Child classes may override
391 // this if needed, but typically won't need to.
392 virtual void UpdateSimple(
393 const CuMatrixBase<BaseFloat> &in_value,
394 const CuMatrixBase<BaseFloat> &out_deriv);
396 const AffineComponent &operator = (const AffineComponent &other); // Disallow.
397 CuMatrix<BaseFloat> linear_params_;
398 CuVector<BaseFloat> bias_params_;
399 };
401 class RepeatedAffineComponent;
403 /// This class implements an affine transform using a block diagonal matrix
404 /// e.g., one whose weight matrix is all zeros except for blocks on the
405 /// diagonal. All these blocks have the same dimensions.
406 /// input-dim: num cols of block diagonal matrix.
407 /// output-dim: num rows of block diagonal matrix.
408 /// num-blocks: number of blocks in diagonal of the matrix.
409 /// num-blocks must divide both input-dim and output-dim
410 class BlockAffineComponent : public UpdatableComponent {
411 public:
412 virtual int32 InputDim() const { return linear_params_.NumCols() * num_blocks_; }
413 virtual int32 OutputDim() const { return linear_params_.NumRows(); }
415 virtual std::string Info() const;
416 virtual void InitFromConfig(ConfigLine *cfl);
418 BlockAffineComponent() { }
419 virtual std::string Type() const { return "BlockAffineComponent"; }
420 virtual int32 Properties() const {
421 return kSimpleComponent|kUpdatableComponent|kLinearInParameters|
422 kBackpropNeedsInput|kBackpropAdds;
423 }
425 virtual void Propagate(const ComponentPrecomputedIndexes *indexes,
426 const CuMatrixBase<BaseFloat> &in,
427 CuMatrixBase<BaseFloat> *out) const;
429 virtual void Backprop(const std::string &debug_info,
430 const ComponentPrecomputedIndexes *indexes,
431 const CuMatrixBase<BaseFloat> &in_value,
432 const CuMatrixBase<BaseFloat> &, // out_value
433 const CuMatrixBase<BaseFloat> &out_deriv,
434 Component *to_update,
435 CuMatrixBase<BaseFloat> *in_deriv) const;
437 virtual void Read(std::istream &is, bool binary);
438 virtual void Write(std::ostream &os, bool binary) const;
440 virtual Component* Copy() const;
442 // Functions from base-class UpdatableComponent.
443 virtual void Scale(BaseFloat scale);
444 virtual void Add(BaseFloat alpha, const Component &other);
445 virtual void SetZero(bool treat_as_gradient);
446 virtual void PerturbParams(BaseFloat stddev);
447 virtual BaseFloat DotProduct(const UpdatableComponent &other) const;
448 virtual int32 NumParameters() const;
449 virtual void Vectorize(VectorBase<BaseFloat> *params) const;
450 virtual void UnVectorize(const VectorBase<BaseFloat> ¶ms);
452 // BlockAffine-specific functions.
453 void Init(int32 input_dim, int32 output_dim, int32 num_blocks,
454 BaseFloat param_stddev, BaseFloat bias_mean,
455 BaseFloat bias_stddev);
456 explicit BlockAffineComponent(const BlockAffineComponent &other);
457 explicit BlockAffineComponent(const RepeatedAffineComponent &rac);
458 protected:
459 // The matrix linear_params_ has a block structure, with num_blocks_ blocks of
460 // equal size. The blocks are stored in linear_params_ as
461 // [ M
462 // N
463 // O ] but we actually treat it as the matrix:
464 // [ M 0 0
465 // 0 N 0
466 // 0 0 O ]
467 CuMatrix<BaseFloat> linear_params_;
468 CuVector<BaseFloat> bias_params_;
469 int32 num_blocks_;
470 private:
471 const BlockAffineComponent &operator = (const BlockAffineComponent &other); // Disallow.
472 };
474 class RepeatedAffineComponent: public UpdatableComponent {
475 public:
477 virtual int32 InputDim() const { return linear_params_.NumCols() * num_repeats_; }
478 virtual int32 OutputDim() const { return linear_params_.NumRows() * num_repeats_; }
480 virtual std::string Info() const;
481 virtual void InitFromConfig(ConfigLine *cfl);
483 RepeatedAffineComponent() { } // use Init to really initialize.
484 virtual std::string Type() const { return "RepeatedAffineComponent"; }
485 virtual int32 Properties() const {
486 return kSimpleComponent|kUpdatableComponent|kLinearInParameters|
487 kBackpropNeedsInput|kBackpropAdds|kInputContiguous|kOutputContiguous;
488 }
489 virtual void Propagate(const ComponentPrecomputedIndexes *indexes,
490 const CuMatrixBase<BaseFloat> &in,
491 CuMatrixBase<BaseFloat> *out) const;
492 virtual void Backprop(const std::string &debug_info,
493 const ComponentPrecomputedIndexes *indexes,
494 const CuMatrixBase<BaseFloat> &in_value,
495 const CuMatrixBase<BaseFloat> &, // out_value
496 const CuMatrixBase<BaseFloat> &out_deriv,
497 Component *to_update,
498 CuMatrixBase<BaseFloat> *in_deriv) const;
500 virtual void Read(std::istream &is, bool binary);
501 virtual void Write(std::ostream &os, bool binary) const;
503 virtual Component* Copy() const;
505 // Some functions from base-class UpdatableComponent.
506 virtual void Scale(BaseFloat scale);
507 virtual void Add(BaseFloat alpha, const Component &other);
508 virtual void SetZero(bool treat_as_gradient);
509 virtual void PerturbParams(BaseFloat stddev);
510 virtual BaseFloat DotProduct(const UpdatableComponent &other) const;
511 virtual int32 NumParameters() const;
512 virtual void Vectorize(VectorBase<BaseFloat> *params) const;
513 virtual void UnVectorize(const VectorBase<BaseFloat> ¶ms);
515 // Some functions that are specific to this class.
516 const CuVector<BaseFloat> &BiasParams() { return bias_params_; }
517 const CuMatrix<BaseFloat> &LinearParams() { return linear_params_; }
518 explicit RepeatedAffineComponent(const RepeatedAffineComponent &other);
520 void Init(int32 input_dim, int32 output_dim, int32 num_repeats,
521 BaseFloat param_stddev, BaseFloat bias_mean,
522 BaseFloat bias_stddev);
523 friend BlockAffineComponent::BlockAffineComponent(const RepeatedAffineComponent &rac);
524 protected:
525 // This function Update(), called from backprop, is broken out for
526 // extensibility to natural gradient update.
527 virtual void Update(
528 const CuMatrixBase<BaseFloat> &in_value,
529 const CuMatrixBase<BaseFloat> &out_deriv);
531 // This function does nothing here but is redefined in child-class
532 // NaturalGradientRepeatedAffineComponent. This help avoid repeated code.
533 virtual void SetNaturalGradientConfigs() { }
535 const RepeatedAffineComponent &operator = (const RepeatedAffineComponent &other); // Disallow.
536 CuMatrix<BaseFloat> linear_params_;
537 CuVector<BaseFloat> bias_params_;
538 int32 num_repeats_;
539 };
541 class NaturalGradientRepeatedAffineComponent: public RepeatedAffineComponent {
542 public:
543 // Use Init() to really initialize.
544 NaturalGradientRepeatedAffineComponent() { }
546 // Most of the public functions are inherited from RepeatedAffineComponent.
547 virtual std::string Type() const {
548 return "NaturalGradientRepeatedAffineComponent";
549 }
551 virtual Component* Copy() const;
553 // Copy constructor
554 explicit NaturalGradientRepeatedAffineComponent(
555 const NaturalGradientRepeatedAffineComponent &other);
556 private:
557 virtual void Update(
558 const CuMatrixBase<BaseFloat> &in_value,
559 const CuMatrixBase<BaseFloat> &out_deriv);
561 const NaturalGradientRepeatedAffineComponent &operator=(
562 const NaturalGradientRepeatedAffineComponent &other); // Disallow.
564 // Applies the default configuration to preconditioner_in_.
565 virtual void SetNaturalGradientConfigs();
567 // For efficiency reasons we only apply the natural gradient to the input
568 // side, i.e. not to the space of output derivatives-- we believe the input
569 // side is the more important side. We don't make the natural-gradient
570 // configurable; we just give it a reasonable configuration.
571 // Instead of using the individual data-points, for efficiency reasons we use
572 // the distribution of per-minibatch summed derivatives over each dimension of
573 // the output space, as the source for the Fisher matrix.
574 OnlineNaturalGradient preconditioner_in_;
575 };
577 class SoftmaxComponent: public NonlinearComponent {
578 public:
579 explicit SoftmaxComponent(int32 dim): NonlinearComponent(dim) { }
580 explicit SoftmaxComponent(const SoftmaxComponent &other):
581 NonlinearComponent(other) { }
582 SoftmaxComponent() { }
583 virtual std::string Type() const { return "SoftmaxComponent"; }
584 virtual int32 Properties() const {
585 return kSimpleComponent|kBackpropNeedsOutput|kStoresStats;
586 }
587 virtual void Propagate(const ComponentPrecomputedIndexes *indexes,
588 const CuMatrixBase<BaseFloat> &in,
589 CuMatrixBase<BaseFloat> *out) const;
590 virtual void Backprop(const std::string &debug_info,
591 const ComponentPrecomputedIndexes *indexes,
592 const CuMatrixBase<BaseFloat> &in_value,
593 const CuMatrixBase<BaseFloat> &out_value,
594 const CuMatrixBase<BaseFloat> &out_deriv,
595 Component *to_update,
596 CuMatrixBase<BaseFloat> *in_deriv) const;
597 virtual void StoreStats(const CuMatrixBase<BaseFloat> &out_value);
599 virtual Component* Copy() const { return new SoftmaxComponent(*this); }
600 private:
601 SoftmaxComponent &operator = (const SoftmaxComponent &other); // Disallow.
602 };
604 class LogSoftmaxComponent: public NonlinearComponent {
605 public:
606 explicit LogSoftmaxComponent(int32 dim): NonlinearComponent(dim) { }
607 explicit LogSoftmaxComponent(const LogSoftmaxComponent &other):
608 NonlinearComponent(other) { }
609 LogSoftmaxComponent() { }
610 virtual std::string Type() const { return "LogSoftmaxComponent"; }
611 virtual int32 Properties() const {
612 return kSimpleComponent|kBackpropNeedsOutput|kStoresStats;
613 }
614 virtual void Propagate(const ComponentPrecomputedIndexes *indexes,
615 const CuMatrixBase<BaseFloat> &in,
616 CuMatrixBase<BaseFloat> *out) const;
617 virtual void Backprop(const std::string &debug_info,
618 const ComponentPrecomputedIndexes *indexes,
619 const CuMatrixBase<BaseFloat> &in_value,
620 const CuMatrixBase<BaseFloat> &out_value,
621 const CuMatrixBase<BaseFloat> &out_deriv,
622 Component *to_update,
623 CuMatrixBase<BaseFloat> *in_deriv) const;
625 virtual Component* Copy() const { return new LogSoftmaxComponent(*this); }
626 private:
627 LogSoftmaxComponent &operator = (const LogSoftmaxComponent &other); // Disallow.
628 };
630 /// Keywords: natural gradient descent, NG-SGD, naturalgradient. For
631 /// the top-level of the natural gradient code look here, and also in
632 /// nnet-precondition-online.h.
633 /// NaturalGradientAffineComponent is
634 /// a version of AffineComponent that has a non-(multiple of unit) learning-rate
635 /// matrix. See nnet-precondition-online.h for a description of the technique.
636 /// It is described, under the name Online NG-SGD, in the paper "Parallel
637 /// training of DNNs with Natural Gradient and Parameter Averaging" (ICLR
638 /// workshop, 2015) by Daniel Povey, Xiaohui Zhang and Sanjeev Khudanpur.
639 class NaturalGradientAffineComponent: public AffineComponent {
640 public:
641 virtual std::string Type() const { return "NaturalGradientAffineComponent"; }
642 virtual void Read(std::istream &is, bool binary);
643 virtual void Write(std::ostream &os, bool binary) const;
644 void Init(int32 input_dim, int32 output_dim,
645 BaseFloat param_stddev, BaseFloat bias_stddev, BaseFloat bias_mean,
646 int32 rank_in, int32 rank_out, int32 update_period,
647 BaseFloat num_samples_history, BaseFloat alpha,
648 BaseFloat max_change_per_sample);
649 void Init(int32 rank_in, int32 rank_out, int32 update_period,
650 BaseFloat num_samples_history,
651 BaseFloat alpha, BaseFloat max_change_per_sample,
652 std::string matrix_filename);
653 // this constructor does not really initialize, use Init() or Read().
654 NaturalGradientAffineComponent();
655 virtual void Resize(int32 input_dim, int32 output_dim);
656 virtual void InitFromConfig(ConfigLine *cfl);
657 virtual std::string Info() const;
658 virtual Component* Copy() const;
659 virtual void Scale(BaseFloat scale);
660 virtual void Add(BaseFloat alpha, const Component &other);
661 // copy constructor
662 explicit NaturalGradientAffineComponent(
663 const NaturalGradientAffineComponent &other);
664 virtual void ZeroStats();
666 private:
667 // disallow assignment operator.
668 NaturalGradientAffineComponent &operator= (
669 const NaturalGradientAffineComponent&);
671 // Configs for preconditioner. The input side tends to be better conditioned ->
672 // smaller rank needed, so make them separately configurable.
673 int32 rank_in_;
674 int32 rank_out_;
675 int32 update_period_;
676 BaseFloat num_samples_history_;
677 BaseFloat alpha_;
679 OnlineNaturalGradient preconditioner_in_;
681 OnlineNaturalGradient preconditioner_out_;
683 // If > 0, max_change_per_sample_ is the maximum amount of parameter
684 // change (in L2 norm) that we allow per sample, averaged over the minibatch.
685 // This was introduced in order to control instability.
686 // Instead of the exact L2 parameter change, for
687 // efficiency purposes we limit a bound on the exact
688 // change. The limit is applied via a constant <= 1.0
689 // for each minibatch, A suitable value might be, for
690 // example, 10 or so; larger if there are more
691 // parameters.
692 BaseFloat max_change_per_sample_;
694 // update_count_ records how many updates we have done.
695 double update_count_;
697 // active_scaling_count_ records how many updates we have done,
698 // where the scaling factor is active (not 1.0).
699 double active_scaling_count_;
701 // max_change_scale_stats_ records the sum of scaling factors
702 // in each update, so we can compute the averaged scaling factor
703 // in Info().
704 double max_change_scale_stats_;
706 // Sets the configs rank, alpha and eta in the preconditioner objects,
707 // from the class variables.
708 void SetNaturalGradientConfigs();
710 virtual void Update(
711 const std::string &debug_info,
712 const CuMatrixBase<BaseFloat> &in_value,
713 const CuMatrixBase<BaseFloat> &out_deriv);
714 };
717 /// FixedAffineComponent is an affine transform that is supplied
718 /// at network initialization time and is not trainable.
719 class FixedAffineComponent: public Component {
720 public:
721 FixedAffineComponent() { }
722 virtual std::string Type() const { return "FixedAffineComponent"; }
723 virtual std::string Info() const;
725 /// matrix should be of size input-dim+1 to output-dim, last col is offset
726 void Init(const CuMatrixBase<BaseFloat> &matrix);
728 // The ConfigLine cfl contains just the option matrix=<string>,
729 // where the string is the filename of a Kaldi-format matrix to read.
730 virtual void InitFromConfig(ConfigLine *cfl);
732 virtual int32 Properties() const { return kSimpleComponent|kBackpropAdds; }
733 virtual int32 InputDim() const { return linear_params_.NumCols(); }
734 virtual int32 OutputDim() const { return linear_params_.NumRows(); }
736 virtual void Propagate(const ComponentPrecomputedIndexes *indexes,
737 const CuMatrixBase<BaseFloat> &in,
738 CuMatrixBase<BaseFloat> *out) const;
739 virtual void Backprop(const std::string &debug_info,
740 const ComponentPrecomputedIndexes *indexes,
741 const CuMatrixBase<BaseFloat> &in_value,
742 const CuMatrixBase<BaseFloat> &, // out_value
743 const CuMatrixBase<BaseFloat> &out_deriv,
744 Component *to_update,
745 CuMatrixBase<BaseFloat> *in_deriv) const;
748 virtual Component* Copy() const;
749 virtual void Read(std::istream &is, bool binary);
750 virtual void Write(std::ostream &os, bool binary) const;
752 // Function to provide access to linear_params_.
753 const CuMatrix<BaseFloat> &LinearParams() const { return linear_params_; }
754 protected:
755 friend class AffineComponent;
756 CuMatrix<BaseFloat> linear_params_;
757 CuVector<BaseFloat> bias_params_;
759 KALDI_DISALLOW_COPY_AND_ASSIGN(FixedAffineComponent);
760 };
762 /// SumGroupComponent is used to sum up groups of posteriors.
763 /// It's used to introduce a kind of Gaussian-mixture-model-like
764 /// idea into neural nets. This is basically a degenerate case of
765 /// MixtureProbComponent; we had to implement it separately to
766 /// be efficient for CUDA (we can use this one regardless whether
767 /// we have CUDA or not; it's the normal case we want anyway).
768 ///
769 /// There are two forms of initialization in a config file: one
770 /// where the number of elements are specified for each group
771 /// individually as a vector, and one where only the total input
772 /// dimension and the output dimension (number of groups) is specified.
773 /// The second is used when all groups have the same size.
774 class SumGroupComponent: public Component {
775 public:
776 virtual int32 InputDim() const { return input_dim_; }
777 virtual int32 OutputDim() const { return output_dim_; }
778 void Init(const std::vector<int32> &sizes); // the vector is of the input dim
779 // (>= 1) for each output dim.
780 void Init(int32 input_dim, int32 output_dim);
781 void GetSizes(std::vector<int32> *sizes) const; // Get a vector saying, for
782 // each output-dim, how many
783 // inputs were summed over.
784 virtual void InitFromConfig(ConfigLine *cfl);
785 SumGroupComponent() { }
786 virtual std::string Type() const { return "SumGroupComponent"; }
787 virtual int32 Properties() const { return kSimpleComponent|kLinearInInput; }
788 virtual void Propagate(const ComponentPrecomputedIndexes *indexes,
789 const CuMatrixBase<BaseFloat> &in,
790 CuMatrixBase<BaseFloat> *out) const;
791 virtual void Backprop(const std::string &debug_info,
792 const ComponentPrecomputedIndexes *indexes,
793 const CuMatrixBase<BaseFloat> &in_value,
794 const CuMatrixBase<BaseFloat> &, // out_value
795 const CuMatrixBase<BaseFloat> &out_deriv,
796 Component *to_update,
797 CuMatrixBase<BaseFloat> *in_deriv) const;
798 virtual Component* Copy() const;
799 virtual void Read(std::istream &is, bool binary);
800 virtual void Write(std::ostream &os, bool binary) const;
802 private:
803 KALDI_DISALLOW_COPY_AND_ASSIGN(SumGroupComponent);
804 // Note: Int32Pair is just struct{ int32 first; int32 second }; it's defined
805 // in cu-matrixdim.h as extern "C" which is needed for the CUDA interface.
806 CuArray<Int32Pair> indexes_; // for each output index, the (start, end) input
807 // index.
808 CuArray<int32> reverse_indexes_; // for each input index, the output index.
809 int32 input_dim_;
810 int32 output_dim_;
811 };
814 /// FixedScaleComponent applies a fixed per-element scale; it's similar
815 /// to the Rescale component in the nnet1 setup (and only needed for nnet1
816 /// model conversion).
817 class FixedScaleComponent: public Component {
818 public:
819 FixedScaleComponent() { }
820 virtual std::string Type() const { return "FixedScaleComponent"; }
821 virtual std::string Info() const;
822 virtual int32 Properties() const {
823 return kSimpleComponent|kLinearInInput|kPropagateInPlace|kBackpropInPlace;
824 }
826 void Init(const CuVectorBase<BaseFloat> &scales);
828 // The ConfigLine cfl contains only the option scales=<string>,
829 // where the string is the filename of a Kaldi-format matrix to read.
830 virtual void InitFromConfig(ConfigLine *cfl);
832 virtual int32 InputDim() const { return scales_.Dim(); }
833 virtual int32 OutputDim() const { return scales_.Dim(); }
835 virtual void Propagate(const ComponentPrecomputedIndexes *indexes,
836 const CuMatrixBase<BaseFloat> &in,
837 CuMatrixBase<BaseFloat> *out) const;
838 virtual void Backprop(const std::string &debug_info,
839 const ComponentPrecomputedIndexes *indexes,
840 const CuMatrixBase<BaseFloat> &, // in_value
841 const CuMatrixBase<BaseFloat> &, // out_value
842 const CuMatrixBase<BaseFloat> &out_deriv,
843 Component *, // to_update
844 CuMatrixBase<BaseFloat> *in_deriv) const;
845 virtual Component* Copy() const;
846 virtual void Read(std::istream &is, bool binary);
847 virtual void Write(std::ostream &os, bool binary) const;
849 protected:
850 friend class AffineComponent; // necessary for collapse
851 CuVector<BaseFloat> scales_;
852 KALDI_DISALLOW_COPY_AND_ASSIGN(FixedScaleComponent);
853 };
856 /// FixedBiasComponent applies a fixed per-element bias; it's similar
857 /// to the AddShift component in the nnet1 setup (and only needed for nnet1
858 /// model conversion.
859 class FixedBiasComponent: public Component {
860 public:
861 FixedBiasComponent() { }
862 virtual std::string Type() const { return "FixedBiasComponent"; }
863 virtual std::string Info() const;
865 virtual int32 Properties() const {
866 return kSimpleComponent|kPropagateInPlace|kBackpropInPlace;
867 }
869 void Init(const CuVectorBase<BaseFloat> &scales);
871 // The ConfigLine cfl contains only the option bias=<string>,
872 // where the string is the filename of a Kaldi-format matrix to read.
873 virtual void InitFromConfig(ConfigLine *cfl);
874 virtual int32 InputDim() const { return bias_.Dim(); }
875 virtual int32 OutputDim() const { return bias_.Dim(); }
876 using Component::Propagate; // to avoid name hiding
877 virtual void Propagate(const ComponentPrecomputedIndexes *indexes,
878 const CuMatrixBase<BaseFloat> &in,
879 CuMatrixBase<BaseFloat> *out) const;
880 virtual void Backprop(const std::string &debug_info,
881 const ComponentPrecomputedIndexes *indexes,
882 const CuMatrixBase<BaseFloat> &, // in_value,
883 const CuMatrixBase<BaseFloat> &, // out_value
884 const CuMatrixBase<BaseFloat> &out_deriv,
885 Component *, // to_update
886 CuMatrixBase<BaseFloat> *in_deriv) const;
887 virtual Component* Copy() const;
888 virtual void Read(std::istream &is, bool binary);
889 virtual void Write(std::ostream &os, bool binary) const;
891 protected:
892 CuVector<BaseFloat> bias_;
893 KALDI_DISALLOW_COPY_AND_ASSIGN(FixedBiasComponent);
894 };
896 // NoOpComponent just duplicates its input. We don't anticipate this being used
897 // very often, but it may sometimes make your life easier
898 class NoOpComponent: public NonlinearComponent {
899 public:
900 explicit NoOpComponent(int32 dim): NonlinearComponent(dim) { }
901 explicit NoOpComponent(const NoOpComponent &other): NonlinearComponent(other) { }
902 NoOpComponent() { }
903 virtual std::string Type() const { return "NoOpComponent"; }
904 virtual int32 Properties() const {
905 return kSimpleComponent|kLinearInInput|kPropagateInPlace;
906 }
907 virtual Component* Copy() const { return new NoOpComponent(*this); }
908 virtual void Propagate(const ComponentPrecomputedIndexes *indexes,
909 const CuMatrixBase<BaseFloat> &in,
910 CuMatrixBase<BaseFloat> *out) const;
911 virtual void Backprop(const std::string &debug_info,
912 const ComponentPrecomputedIndexes *indexes,
913 const CuMatrixBase<BaseFloat> &, //in_value
914 const CuMatrixBase<BaseFloat> &, // out_value,
915 const CuMatrixBase<BaseFloat> &out_deriv,
916 Component *to_update,
917 CuMatrixBase<BaseFloat> *in_deriv) const;
918 private:
919 NoOpComponent &operator = (const NoOpComponent &other); // Disallow.
920 };
922 // ClipGradientComponent just duplicates its input, but clips gradients
923 // during backpropagation if they cross a predetermined threshold.
924 // This component will be used to prevent gradient explosion problem in
925 // recurrent neural networks
926 class ClipGradientComponent: public Component {
927 public:
928 ClipGradientComponent(int32 dim, BaseFloat clipping_threshold,
929 bool norm_based_clipping, int32 num_clipped,
930 int32 count) {
931 Init(dim, clipping_threshold, norm_based_clipping, num_clipped, count);}
933 ClipGradientComponent(): dim_(0), clipping_threshold_(-1),
934 norm_based_clipping_(false), num_clipped_(0), count_(0) { }
936 virtual int32 InputDim() const { return dim_; }
937 virtual int32 OutputDim() const { return dim_; }
938 virtual void InitFromConfig(ConfigLine *cfl);
939 void Init(int32 dim, BaseFloat clipping_threshold, bool norm_based_clipping,
940 int32 num_clipped, int32 count);
942 virtual std::string Type() const { return "ClipGradientComponent"; }
944 virtual int32 Properties() const {
945 return kSimpleComponent|kLinearInInput|kPropagateInPlace|kBackpropInPlace;
946 }
948 virtual void ZeroStats();
950 virtual Component* Copy() const {
951 return new ClipGradientComponent(dim_,
952 clipping_threshold_,
953 norm_based_clipping_,
954 num_clipped_,
955 count_);}
957 virtual void Propagate(const ComponentPrecomputedIndexes *indexes,
958 const CuMatrixBase<BaseFloat> &in,
959 CuMatrixBase<BaseFloat> *out) const;
960 virtual void Backprop(const std::string &debug_info,
961 const ComponentPrecomputedIndexes *indexes,
962 const CuMatrixBase<BaseFloat> &, //in_value
963 const CuMatrixBase<BaseFloat> &, // out_value,
964 const CuMatrixBase<BaseFloat> &out_deriv,
965 Component *to_update,
966 CuMatrixBase<BaseFloat> *in_deriv) const;
968 virtual void Scale(BaseFloat scale);
969 virtual void Add(BaseFloat alpha, const Component &other);
970 virtual void Read(std::istream &is, bool binary); // This Read function
971 // requires that the Component has the correct type.
972 /// Write component to stream
973 virtual void Write(std::ostream &os, bool binary) const;
974 virtual std::string Info() const;
975 private:
976 int32 dim_; // input/output dimension
977 BaseFloat clipping_threshold_; // threshold to be used for clipping
978 // could correspond to max-row-norm (if
979 // norm_based_clipping_ == true) or
980 // max-absolute-value (otherwise)
981 bool norm_based_clipping_; // if true the max-row-norm will be clipped
982 // else element-wise absolute value clipping is
983 // done
986 ClipGradientComponent &operator =
987 (const ClipGradientComponent &other); // Disallow.
989 protected:
990 // variables to store stats
991 // An element corresponds to rows of derivative matrix, when
992 // norm_based_clipping_ is true,
993 // else it corresponds to each element of the derivative matrix
994 // Note: no stats are stored when norm_based_clipping_ is false
995 int32 num_clipped_; // number of elements which were clipped
996 int32 count_; // number of elements which were processed
998 };
1000 /** PermuteComponent changes the order of the columns (i.e. the feature or
1001 activation dimensions). Output dimension i is mapped to input dimension
1002 column_map_[i], so it's like doing:
1003 for each row:
1004 for each feature/activation dimension i:
1005 output(row, i) = input(row, column_map_[i]).
1007 */
1008 class PermuteComponent: public Component {
1009 public:
1010 PermuteComponent() {}
1011 PermuteComponent(const std::vector<int32> &column_map) { Init(column_map); }
1013 virtual int32 InputDim() const { return column_map_.Dim(); }
1014 virtual int32 OutputDim() const { return column_map_.Dim(); }
1015 virtual void InitFromConfig(ConfigLine *cfl);
1016 void Init(const std::vector<int32> &column_map);
1018 virtual std::string Type() const { return "PermuteComponent"; }
1020 virtual int32 Properties() const {
1021 return kSimpleComponent|kLinearInInput;
1022 }
1024 virtual void ZeroStats() {}
1026 virtual Component* Copy() const;
1028 virtual void Propagate(const ComponentPrecomputedIndexes *indexes,
1029 const CuMatrixBase<BaseFloat> &in,
1030 CuMatrixBase<BaseFloat> *out) const;
1031 virtual void Backprop(const std::string &debug_info,
1032 const ComponentPrecomputedIndexes *indexes,
1033 const CuMatrixBase<BaseFloat> &, //in_value
1034 const CuMatrixBase<BaseFloat> &, // out_value,
1035 const CuMatrixBase<BaseFloat> &out_deriv,
1036 Component *to_update,
1037 CuMatrixBase<BaseFloat> *in_deriv) const;
1039 virtual void Scale(BaseFloat scale) {}
1040 virtual void Add(BaseFloat alpha, const Component &other) {}
1041 virtual void Read(std::istream &is, bool binary); // This Read function
1042 // requires that the Component has the correct type.
1043 /// Write component to stream
1044 virtual void Write(std::ostream &os, bool binary) const;
1045 virtual std::string Info() const;
1046 private:
1047 // computes the reverse column map. Must not be called if column_map_.Dim()
1048 // == 0
1049 void ComputeReverseColumnMap();
1050 CuArray<int32> column_map_;
1051 // the following is a derived variable, not written to disk.
1052 // It is used in backprop.
1053 CuArray<int32> reverse_column_map_;
1054 PermuteComponent &operator =
1055 (const PermuteComponent &other); // Disallow.
1056 };
1061 // PerElementScaleComponent scales each dimension of its input with a separate
1062 // trainable scale; it's like a linear component with a diagonal matrix.
1063 class PerElementScaleComponent: public UpdatableComponent {
1064 public:
1065 virtual int32 InputDim() const { return scales_.Dim(); }
1066 virtual int32 OutputDim() const { return scales_.Dim(); }
1068 virtual std::string Info() const;
1069 virtual void InitFromConfig(ConfigLine *cfl);
1071 PerElementScaleComponent() { } // use Init to really initialize.
1072 virtual std::string Type() const { return "PerElementScaleComponent"; }
1073 virtual int32 Properties() const {
1074 return kSimpleComponent|kUpdatableComponent|kLinearInInput|
1075 kLinearInParameters|kBackpropNeedsInput|kPropagateInPlace;
1076 }
1078 virtual void Propagate(const ComponentPrecomputedIndexes *indexes,
1079 const CuMatrixBase<BaseFloat> &in,
1080 CuMatrixBase<BaseFloat> *out) const;
1081 virtual void Backprop(const std::string &debug_info,
1082 const ComponentPrecomputedIndexes *indexes,
1083 const CuMatrixBase<BaseFloat> &in_value,
1084 const CuMatrixBase<BaseFloat> &, // out_value
1085 const CuMatrixBase<BaseFloat> &out_deriv,
1086 Component *to_update,
1087 CuMatrixBase<BaseFloat> *in_deriv) const;
1089 virtual void Read(std::istream &is, bool binary);
1090 virtual void Write(std::ostream &os, bool binary) const;
1092 virtual Component* Copy() const;
1095 // Some functions from base-class UpdatableComponent.
1096 virtual void Scale(BaseFloat scale);
1097 virtual void Add(BaseFloat alpha, const Component &other);
1098 virtual void SetZero(bool treat_as_gradient);
1099 virtual void PerturbParams(BaseFloat stddev);
1100 virtual BaseFloat DotProduct(const UpdatableComponent &other) const;
1101 virtual int32 NumParameters() const;
1102 virtual void Vectorize(VectorBase<BaseFloat> *params) const;
1103 virtual void UnVectorize(const VectorBase<BaseFloat> ¶ms);
1105 // Some functions that are specific to this class.
1106 explicit PerElementScaleComponent(const PerElementScaleComponent &other);
1108 void Init(int32 dim, BaseFloat param_mean, BaseFloat param_stddev);
1109 void Init(std::string vector_filename);
1111 protected:
1112 friend class AffineComponent; // necessary for collapse
1113 // This function Update() is for extensibility; child classes may override
1114 // this, e.g. for natural gradient update.
1115 virtual void Update(
1116 const std::string &debug_info,
1117 const CuMatrixBase<BaseFloat> &in_value,
1118 const CuMatrixBase<BaseFloat> &out_deriv) {
1119 UpdateSimple(in_value, out_deriv);
1120 }
1121 // UpdateSimple is used when *this is a gradient. Child classes may override
1122 // this if needed, but typically won't need to.
1123 virtual void UpdateSimple(
1124 const CuMatrixBase<BaseFloat> &in_value,
1125 const CuMatrixBase<BaseFloat> &out_deriv);
1127 const PerElementScaleComponent &operator
1128 = (const PerElementScaleComponent &other); // Disallow.
1129 CuVector<BaseFloat> scales_;
1130 };
1133 // PerElementOffsetComponent offsets each dimension of its input with a separate
1134 // trainable bias; it's like an affine component with fixed weight matrix which is always equal to I.
1135 class PerElementOffsetComponent: public UpdatableComponent {
1136 public:
1137 virtual int32 InputDim() const { return offsets_.Dim(); }
1138 virtual int32 OutputDim() const { return offsets_.Dim(); }
1140 virtual std::string Info() const;
1141 virtual void InitFromConfig(ConfigLine *cfl);
1143 PerElementOffsetComponent() { } // use Init to really initialize.
1144 virtual std::string Type() const { return "PerElementOffsetComponent"; }
1145 virtual int32 Properties() const {
1146 return kSimpleComponent|kUpdatableComponent|
1147 kBackpropInPlace|kPropagateInPlace;
1148 }
1150 virtual void Propagate(const ComponentPrecomputedIndexes *indexes,
1151 const CuMatrixBase<BaseFloat> &in,
1152 CuMatrixBase<BaseFloat> *out) const;
1153 virtual void Backprop(const std::string &debug_info,
1154 const ComponentPrecomputedIndexes *indexes,
1155 const CuMatrixBase<BaseFloat> &, // in_value
1156 const CuMatrixBase<BaseFloat> &, // out_value
1157 const CuMatrixBase<BaseFloat> &out_deriv,
1158 Component *to_update,
1159 CuMatrixBase<BaseFloat> *in_deriv) const;
1161 virtual void Read(std::istream &is, bool binary);
1162 virtual void Write(std::ostream &os, bool binary) const;
1164 virtual Component* Copy() const;
1167 // Some functions from base-class UpdatableComponent.
1168 virtual void Scale(BaseFloat scale);
1169 virtual void Add(BaseFloat alpha, const Component &other);
1170 virtual void SetZero(bool treat_as_gradient);
1171 virtual void PerturbParams(BaseFloat stddev);
1172 virtual BaseFloat DotProduct(const UpdatableComponent &other) const;
1173 virtual int32 NumParameters() const;
1174 virtual void Vectorize(VectorBase<BaseFloat> *params) const;
1175 virtual void UnVectorize(const VectorBase<BaseFloat> ¶ms);
1177 // Some functions that are specific to this class.
1178 explicit PerElementOffsetComponent(const PerElementOffsetComponent &other);
1180 void Init(int32 dim, BaseFloat param_mean,
1181 BaseFloat param_stddev);
1182 void Init(std::string vector_filename);
1184 protected:
1185 const PerElementOffsetComponent &operator
1186 = (const PerElementOffsetComponent &other); // Disallow.
1187 CuVector<BaseFloat> offsets_;
1188 };
1191 // ConstantFunctionComponent returns constant function of its input,
1192 // i.e. its output does not depend on its input. It is the same as
1193 // an affine component with the linear term fixed at zero.
1194 // It is optionally trainable, and optionally you can use natural
1195 // gradient. The input is required only because the framework
1196 // requires components to have an input.
1197 class ConstantFunctionComponent: public UpdatableComponent {
1198 public:
1199 virtual int32 InputDim() const { return input_dim_; }
1200 virtual int32 OutputDim() const { return output_.Dim(); }
1202 virtual std::string Info() const;
1203 // possible parameter values with their defaults:
1204 // input-dim=-1 is-updatable=true use-natural-gradient=true output-dim=-1
1205 // output-mean=0 output-stddev=0
1206 virtual void InitFromConfig(ConfigLine *cfl);
1208 ConstantFunctionComponent();
1210 ConstantFunctionComponent(const ConstantFunctionComponent &other);
1212 virtual std::string Type() const { return "ConstantFunctionComponent"; }
1213 virtual int32 Properties() const {
1214 return kSimpleComponent|
1215 (is_updatable_ ? kUpdatableComponent|kLinearInParameters : 0) |
1216 (InputDim() == OutputDim() ? kPropagateInPlace|kBackpropInPlace: 0) |
1217 kBackpropAdds;
1218 }
1219 virtual void Propagate(const ComponentPrecomputedIndexes *indexes,
1220 const CuMatrixBase<BaseFloat> &in,
1221 CuMatrixBase<BaseFloat> *out) const;
1222 virtual void Backprop(const std::string &debug_info,
1223 const ComponentPrecomputedIndexes *indexes,
1224 const CuMatrixBase<BaseFloat> &, // in_value
1225 const CuMatrixBase<BaseFloat> &, // out_value
1226 const CuMatrixBase<BaseFloat> &out_deriv,
1227 Component *to_update,
1228 CuMatrixBase<BaseFloat> *in_deriv) const;
1230 virtual void Read(std::istream &is, bool binary);
1231 virtual void Write(std::ostream &os, bool binary) const;
1233 virtual Component* Copy() const;
1235 // Some functions from base-class UpdatableComponent.
1236 virtual void Scale(BaseFloat scale);
1237 virtual void Add(BaseFloat alpha, const Component &other);
1238 virtual void SetZero(bool treat_as_gradient);
1239 virtual void PerturbParams(BaseFloat stddev);
1240 virtual BaseFloat DotProduct(const UpdatableComponent &other) const;
1241 virtual int32 NumParameters() const;
1242 virtual void Vectorize(VectorBase<BaseFloat> *params) const;
1243 virtual void UnVectorize(const VectorBase<BaseFloat> ¶ms);
1244 private:
1245 int32 input_dim_;
1246 // the output value-- a vector.
1247 CuVector<BaseFloat> output_;
1249 bool is_updatable_;
1250 // if true, and if updatable, do natural-gradient update.
1251 bool use_natural_gradient_;
1252 OnlineNaturalGradient preconditioner_;
1254 const ConstantFunctionComponent &operator
1255 = (const ConstantFunctionComponent &other); // Disallow.
1256 };
1260 // NaturalGradientPerElementScaleComponent is like PerElementScaleComponent but
1261 // it uses a natural gradient update for the per-element scales, and enforces a
1262 // maximum amount of change per minibatch, for stability.
1263 class NaturalGradientPerElementScaleComponent: public PerElementScaleComponent {
1264 public:
1266 virtual std::string Info() const;
1268 virtual void InitFromConfig(ConfigLine *cfl);
1270 NaturalGradientPerElementScaleComponent() { } // use Init to really initialize.
1271 virtual std::string Type() const {
1272 return "NaturalGradientPerElementScaleComponent";
1273 }
1275 virtual void Read(std::istream &is, bool binary);
1276 virtual void Write(std::ostream &os, bool binary) const;
1278 virtual Component* Copy() const;
1280 // Some functions that are specific to this class:
1281 explicit NaturalGradientPerElementScaleComponent(
1282 const NaturalGradientPerElementScaleComponent &other);
1284 void Init(int32 dim, BaseFloat param_mean,
1285 BaseFloat param_stddev, int32 rank, int32 update_period,
1286 BaseFloat num_samples_history, BaseFloat alpha,
1287 BaseFloat max_change_per_minibatch);
1288 void Init(std::string vector_filename,
1289 int32 rank, int32 update_period, BaseFloat num_samples_history,
1290 BaseFloat alpha, BaseFloat max_change_per_minibatch);
1292 private:
1293 // configuration value for imposing max-change...
1294 BaseFloat max_change_per_minibatch_;
1296 // unlike the NaturalGradientAffineComponent, there is only one dimension to
1297 // consider as the parameters are a vector not a matrix, so we only need one
1298 // preconditioner.
1299 // The preconditioner stores its own configuration values; we write and read
1300 // these, but not the preconditioner object itself.
1301 OnlineNaturalGradient preconditioner_;
1303 // Override of the parent-class Update() function, called only
1304 // if this->is_gradient_ = false; this implements the natural
1305 // gradient update.
1306 virtual void Update(
1307 const std::string &debug_info,
1308 const CuMatrixBase<BaseFloat> &in_value,
1309 const CuMatrixBase<BaseFloat> &out_deriv);
1311 const NaturalGradientPerElementScaleComponent &operator
1312 = (const NaturalGradientPerElementScaleComponent &other); // Disallow.
1313 };
1315 /**
1316 * ConvolutionalComponent implements 2d-convolution.
1317 * It uses 3D filters on 3D inputs, but the 3D filters hop only over
1318 * 2 dimensions as it has same size as the input along the 3rd dimension.
1319 * Input : A matrix where each row is a vectorized 3D-tensor.
1320 * The 3D tensor has dimensions
1321 * x: (e.g. time)
1322 * y: (e.g. frequency)
1323 * z: (e.g. channels like features/delta/delta-delta)
1324 *
1325 * The component supports input vectorizations of type zyx and yzx.
1326 * The default vectorization type is zyx.
1327 * e.g. for input vectorization of type zyx the input is vectorized by
1328 * spanning axes z, y and x of the tensor in that order.
1329 * Given 3d tensor A with sizes (2, 2, 2) along the three dimensions
1330 * the zyx vectorized input looks like
1331 * A(0,0,0) A(0,0,1) A(0,1,0) A(0,1,1) A(1,0,0) A(1,0,1) A(1,1,0) A(1,1,1)
1332 *
1333 *
1334 * Output : The output is also a 3D tensor vectorized in the zyx format.
1335 * The channel axis (z) in the output corresponds to the output of
1336 * different filters. The first channel corresponds to the first filter
1337 * i.e., first row of the filter_params_ matrix.
1338 *
1339 * Note: The component has to support yzx input vectorization as the binaries
1340 * like add-deltas generate yz vectorized output. These input vectors are
1341 * concatenated using the Append descriptor across time steps to form a yzx
1342 * vectorized 3D tensor input.
1343 * e.g. Append(Offset(input, -1), input, Offset(input, 1))
1344 *
1345 *
1346 * For information on the hyperparameters and parameters of this component see
1347 * the variable declarations.
1348 *
1349 * Propagation:
1350 * ------------
1351 * Convolution operation consists of a dot-products between the filter tensor
1352 * and input tensor patch, for various shifts of filter tensor along the x and y
1353 * axes input tensor. (Note: there is no shift along z-axis as the filter and
1354 * input tensor have same size along this axis).
1355 *
1356 * For a particular shift (i,j) of the filter tensor
1357 * along input tensor dimensions x and y, the elements of the input tensor which
1358 * overlap with the filter form the input tensor patch. This patch is vectorized
1359 * in zyx format. All the patches corresponding to various samples in the
1360 * mini-batch are stacked into a matrix, where each row corresponds to one
1361 * patch. Let this matrix be represented by X_{i,j}. The dot products with
1362 * various filters are computed simultaneously by computing the matrix product
1363 * with the filter_params_ matrix (W)
1364 * Y_{i,j} = X_{i,j}*W^T.
1365 * Each row of W corresponds to one filter 3D tensor vectorized in zyx format.
1366 *
1367 * All the matrix products corresponding to various shifts (i,j) of the
1368 * filter tensor are computed simultaneously using the AddMatMatBatched
1369 * call of CuMatrixBase class.
1370 *
1371 * BackPropagation:
1372 * ----------------
1373 * Backpropagation to compute the input derivative (\nabla X_{i,j})
1374 * consists of the a series of matrix products.
1375 * \nablaX_{i,j} = \nablaY_{i,j}*W where \nablaY_{i,j} corresponds to the
1376 * output derivative for a particular shift of the filter.
1377 *
1378 * Once again these matrix products are computed simultaneously.
1379 *
1380 * Update:
1381 * -------
1382 * The weight gradient is computed as
1383 * \nablaW = \Sum_{i,j} (X_{i,j}^T *\nablaY_{i,j})
1384 *
1385 */
1386 class ConvolutionComponent: public UpdatableComponent {
1387 public:
1388 enum TensorVectorizationType {
1389 kYzx = 0,
1390 kZyx = 1
1391 };
1393 ConvolutionComponent();
1394 // constructor using another component
1395 ConvolutionComponent(const ConvolutionComponent &component);
1396 // constructor using parameters
1397 ConvolutionComponent(
1398 const CuMatrixBase<BaseFloat> &filter_params,
1399 const CuVectorBase<BaseFloat> &bias_params,
1400 int32 input_x_dim, int32 input_y_dim, int32 input_z_dim,
1401 int32 filt_x_dim, int32 filt_y_dim,
1402 int32 filt_x_step, int32 filt_y_step,
1403 TensorVectorizationType input_vectorization,
1404 BaseFloat learning_rate);
1406 virtual int32 InputDim() const;
1407 virtual int32 OutputDim() const;
1409 virtual std::string Info() const;
1410 virtual void InitFromConfig(ConfigLine *cfl);
1411 virtual std::string Type() const { return "ConvolutionComponent"; }
1412 virtual int32 Properties() const {
1413 return kSimpleComponent|kUpdatableComponent|kBackpropNeedsInput|
1414 kBackpropAdds|kPropagateAdds;
1415 }
1417 virtual void Propagate(const ComponentPrecomputedIndexes *indexes,
1418 const CuMatrixBase<BaseFloat> &in,
1419 CuMatrixBase<BaseFloat> *out) const;
1420 virtual void Backprop(const std::string &debug_info,
1421 const ComponentPrecomputedIndexes *indexes,
1422 const CuMatrixBase<BaseFloat> &in_value,
1423 const CuMatrixBase<BaseFloat> &, // out_value,
1424 const CuMatrixBase<BaseFloat> &out_deriv,
1425 Component *to_update_in,
1426 CuMatrixBase<BaseFloat> *in_deriv) const;
1427 void Update(const std::string &debug_info,
1428 const CuMatrixBase<BaseFloat> &in_value,
1429 const CuMatrixBase<BaseFloat> &out_deriv,
1430 const std::vector<CuSubMatrix<BaseFloat> *>& out_deriv_batch);
1434 virtual void Read(std::istream &is, bool binary);
1435 virtual void Write(std::ostream &os, bool binary) const;
1437 virtual Component* Copy() const;
1439 // Some functions from base-class UpdatableComponent.
1440 virtual void Scale(BaseFloat scale);
1441 virtual void Add(BaseFloat alpha, const Component &other);
1442 virtual void SetZero(bool treat_as_gradient);
1443 virtual void PerturbParams(BaseFloat stddev);
1444 virtual BaseFloat DotProduct(const UpdatableComponent &other) const;
1445 virtual int32 NumParameters() const;
1446 virtual void Vectorize(VectorBase<BaseFloat> *params) const;
1447 virtual void UnVectorize(const VectorBase<BaseFloat> ¶ms);
1449 // Some functions that are specific to this class.
1450 void SetParams(const VectorBase<BaseFloat> &bias,
1451 const MatrixBase<BaseFloat> &filter);
1452 const CuVector<BaseFloat> &BiasParams() { return bias_params_; }
1453 const CuMatrix<BaseFloat> &LinearParams() { return filter_params_; }
1454 void Init(int32 input_x_dim, int32 input_y_dim, int32 input_z_dim,
1455 int32 filt_x_dim, int32 filt_y_dim,
1456 int32 filt_x_step, int32 filt_y_step, int32 num_filters,
1457 TensorVectorizationType input_vectorization,
1458 BaseFloat param_stddev, BaseFloat bias_stddev);
1459 // there is no filt_z_dim parameter as the length of the filter along
1460 // z-dimension is same as the input
1461 void Init(int32 input_x_dim, int32 input_y_dim, int32 input_z_dim,
1462 int32 filt_x_dim, int32 filt_y_dim,
1463 int32 filt_x_step, int32 filt_y_step,
1464 TensorVectorizationType input_vectorization,
1465 std::string matrix_filename);
1467 // resize the component, setting the parameters to zero, while
1468 // leaving any other configuration values the same
1469 void Resize(int32 input_dim, int32 output_dim);
1471 void Update(const std::string &debug_info,
1472 const CuMatrixBase<BaseFloat> &in_value,
1473 const CuMatrixBase<BaseFloat> &out_deriv);
1476 private:
1477 int32 input_x_dim_; // size of the input along x-axis
1478 // (e.g. number of time steps)
1480 int32 input_y_dim_; // size of input along y-axis
1481 // (e.g. number of mel-frequency bins)
1483 int32 input_z_dim_; // size of input along z-axis
1484 // (e.g. number of channels is 3 if the input has
1485 // features + delta + delta-delta features
1487 int32 filt_x_dim_; // size of the filter along x-axis
1489 int32 filt_y_dim_; // size of the filter along y-axis
1491 // there is no filt_z_dim_ as it is always assumed to be
1492 // the same as input_z_dim_
1494 int32 filt_x_step_; // the number of steps taken along x-axis of input
1495 // before computing the next dot-product
1496 // of filter and input
1498 int32 filt_y_step_; // the number of steps taken along y-axis of input
1499 // before computing the next dot-product of the filter
1500 // and input
1502 // there is no filt_z_step_ as only dot product is possible along this axis
1504 TensorVectorizationType input_vectorization_; // type of vectorization of the
1505 // input 3D tensor. Accepts zyx and yzx formats
1507 CuMatrix<BaseFloat> filter_params_;
1508 // the filter (or kernel) matrix is a matrix of vectorized 3D filters
1509 // where each row in the matrix corresponds to one filter.
1510 // The 3D filter tensor is vectorizedin zyx format.
1511 // The first row of the matrix corresponds to the first filter and so on.
1512 // Keep in mind the vectorization type and order of filters when using file
1513 // based initialization.
1515 CuVector<BaseFloat> bias_params_;
1516 // the filter-specific bias vector (i.e., there is a seperate bias added
1517 // to the output of each filter).
1518 bool is_gradient_;
1520 void InputToInputPatches(const CuMatrixBase<BaseFloat>& in,
1521 CuMatrix<BaseFloat> *patches) const;
1522 void InderivPatchesToInderiv(const CuMatrix<BaseFloat>& in_deriv_patches,
1523 CuMatrixBase<BaseFloat> *in_deriv) const;
1524 const ConvolutionComponent &operator = (const ConvolutionComponent &other); // Disallow.
1525 };
1528 /*
1529 * MaxPoolingComponent :
1530 * Maxpooling component was firstly used in ConvNet for selecting an
1531 * representative activation in an area. It inspired Maxout nonlinearity.
1532 * Each output element of this component is the maximum of a block of
1533 * input elements where the block has a 3D dimension (pool_x_size_,
1534 * pool_y_size_, pool_z_size_).
1535 * Blocks could overlap if the shift value on any axis is smaller
1536 * than its corresponding pool size (e.g. pool_x_step_ < pool_x_size_).
1537 * If the shift values are euqal to their pool size, there is no
1538 * overlap; while if they all equal 1, the blocks overlap to
1539 * the greatest possible extent.
1540 *
1541 * This component is designed to be used after a ConvolutionComponent
1542 * so that the input matrix is propagated from a 2d-convolutional layer.
1543 * This component implements 3d-maxpooling which performs
1544 * max pooling along the three axes.
1545 * Input : A matrix where each row is a vectorized 3D-tensor.
1546 * The 3D tensor has dimensions
1547 * x: (e.g. time)
1548 * y: (e.g. frequency)
1549 * z: (e.g. channels like number of filters in the ConvolutionComponent)
1550 *
1551 * The component assumes input vectorizations of type zyx
1552 * which is the default output vectorization type of a ConvolutionComponent.
1553 * e.g. for input vectorization of type zyx the input is vectorized by
1554 * spanning axes z, y and x of the tensor in that order.
1555 * Given 3d tensor A with sizes (2, 2, 2) along the three dimensions
1556 * the zyx vectorized input looks like
1557 * A(0,0,0) A(0,0,1) A(0,1,0) A(0,1,1) A(1,0,0) A(1,0,1) A(1,1,0) A(1,1,1)
1558 *
1559 * Output : The output is also a 3D tensor vectorized in the zyx format.
1560 *
1561 * For information on the hyperparameters and parameters of this component see
1562 * the variable declarations.
1563 *
1564 *
1565 */
1567 class MaxpoolingComponent: public Component {
1568 public:
1570 MaxpoolingComponent(): input_x_dim_(0), input_y_dim_(0), input_z_dim_(0),
1571 pool_x_size_(0), pool_y_size_(0), pool_z_size_(0),
1572 pool_x_step_(0), pool_y_step_(0), pool_z_step_(0) { }
1573 // constructor using another component
1574 MaxpoolingComponent(const MaxpoolingComponent &component);
1576 virtual int32 InputDim() const;
1577 virtual int32 OutputDim() const;
1578 virtual void Check() const;
1580 virtual std::string Info() const;
1581 virtual void InitFromConfig(ConfigLine *cfl);
1582 virtual std::string Type() const { return "MaxpoolingComponent"; }
1583 virtual int32 Properties() const {
1584 return kSimpleComponent|kBackpropNeedsInput|kBackpropNeedsOutput|
1585 kBackpropAdds;
1586 }
1588 virtual void Propagate(const ComponentPrecomputedIndexes *indexes,
1589 const CuMatrixBase<BaseFloat> &in,
1590 CuMatrixBase<BaseFloat> *out) const;
1591 virtual void Backprop(const std::string &debug_info,
1592 const ComponentPrecomputedIndexes *indexes,
1593 const CuMatrixBase<BaseFloat> &in_value,
1594 const CuMatrixBase<BaseFloat> &out_value,
1595 const CuMatrixBase<BaseFloat> &out_deriv,
1596 Component *, // to_update,
1597 CuMatrixBase<BaseFloat> *in_deriv) const;
1599 virtual void Read(std::istream &is, bool binary); // This Read function
1600 // requires that the Component has the correct type.
1602 /// Write component to stream
1603 virtual void Write(std::ostream &os, bool binary) const;
1604 virtual Component* Copy() const { return new MaxpoolingComponent(*this); }
1606 void InputToInputPatches(const CuMatrixBase<BaseFloat>& in,
1607 CuMatrix<BaseFloat> *patches) const;
1608 void InderivPatchesToInderiv(const CuMatrix<BaseFloat>& in_deriv_patches,
1609 CuMatrixBase<BaseFloat> *in_deriv) const;
1611 protected:
1612 int32 input_x_dim_; // size of the input along x-axis
1613 // (e.g. number of time steps)
1614 int32 input_y_dim_; // size of input along y-axis
1615 // (e.g. number of mel-frequency bins)
1616 int32 input_z_dim_; // size of input along z-axis
1617 // (e.g. number of filters in the ConvolutionComponent)
1619 int32 pool_x_size_; // size of the pooling window along x-axis
1620 int32 pool_y_size_; // size of the pooling window along y-axis
1621 int32 pool_z_size_; // size of the pooling window along z-axis
1623 int32 pool_x_step_; // the number of steps taken along x-axis of input
1624 // before computing the next pool
1625 int32 pool_y_step_; // the number of steps taken along y-axis of input
1626 // before computing the next pool
1627 int32 pool_z_step_; // the number of steps taken along z-axis of input
1628 // before computing the next pool
1630 };
1633 /**
1634 CompositeComponent is a component representing a sequence of
1635 [simple] components. The config line would be something like the following
1636 (imagine this is all on one line):
1638 component name=composite1 type=CompositeComponent max-rows-process=2048 num-components=3 \
1639 component1='type=BlockAffineComponent input-dim=1000 output-dim=10000 num-blocks=100' \
1640 component2='type=RectifiedLinearComponent dim=10000' \
1641 component3='type=BlockAffineComponent input-dim=10000 output-dim=1000 num-blocks=100'
1643 The reason you might want to use this component, instead of directly using
1644 the same sequence of components in the config file, is to save GPU memory (at
1645 the expense of more compute)-- because doing it like this means we have to
1646 re-do parts of the forward pass in the backprop phase, but we avoid using
1647 much memory for very long (and you can make the memory usage very small by
1648 making max-rows-process small). We inherit from UpdatableComponent just in
1649 case one or more of the components in the sequence are updatable.
1651 It is an error to nest a CompositeComponent inside a CompositeComponent.
1652 The same effect can be accomplished by specifying a smaller max-rows-process
1653 in a single CompositeComponent.
1654 */
1655 class CompositeComponent: public UpdatableComponent {
1656 public:
1657 virtual int32 InputDim() const;
1658 virtual int32 OutputDim() const;
1660 virtual std::string Info() const;
1662 virtual void InitFromConfig(ConfigLine *cfl);
1664 virtual Component* Copy() const;
1666 CompositeComponent() { } // use Init() or InitFromConfig() to really initialize.
1668 // Initialize from this list of components; takes ownership of the pointers.
1669 void Init(const std::vector<Component*> &components,
1670 int32 max_rows_process);
1672 virtual std::string Type() const { return "CompositeComponent"; }
1674 // The properties depend on the properties of the constituent components. As
1675 // a special case, we never return kStoresStats in the properties: by default
1676 // we store things like activation stats (e.g. for nonlinear components like
1677 // ReLU) as part of the backprop. This means we may wastefully store stats
1678 // even when not requested, but it does save time as a separate StoreStats()
1679 // call would involve propagating the internals.
1680 virtual int32 Properties() const;
1682 virtual void Propagate(const ComponentPrecomputedIndexes *indexes,
1683 const CuMatrixBase<BaseFloat> &in,
1684 CuMatrixBase<BaseFloat> *out) const;
1685 virtual void Backprop(const std::string &debug_info,
1686 const ComponentPrecomputedIndexes *indexes,
1687 const CuMatrixBase<BaseFloat> &in_value,
1688 const CuMatrixBase<BaseFloat> &, // out_value
1689 const CuMatrixBase<BaseFloat> &out_deriv,
1690 Component *to_update,
1691 CuMatrixBase<BaseFloat> *in_deriv) const;
1693 // note, we don't implement StoreStats() as it would be inefficient. Instead,
1694 // by default we call StoreStats() on all members that have the flag set,
1695 // inside the Backprop.
1696 virtual void ZeroStats();
1698 virtual void Read(std::istream &is, bool binary);
1699 virtual void Write(std::ostream &os, bool binary) const;
1701 // Don't implement Copy() at this level: implement it in the child class.
1703 // Some functions from base-class UpdatableComponent.
1704 virtual void SetLearningRate(BaseFloat lrate);
1705 virtual void Scale(BaseFloat scale);
1706 virtual void Add(BaseFloat alpha, const Component &other);
1707 virtual void SetZero(bool treat_as_gradient);
1708 virtual void PerturbParams(BaseFloat stddev);
1709 virtual BaseFloat DotProduct(const UpdatableComponent &other) const;
1710 virtual int32 NumParameters() const;
1711 virtual void Vectorize(VectorBase<BaseFloat> *params) const;
1712 virtual void UnVectorize(const VectorBase<BaseFloat> ¶ms);
1714 // note: we dont implement the StoreStats function as it would be quite
1715 // expensive; instead, by default we call StoreStats() for any components that
1716 // want to store stats, as part of the backprop pass. This is not 100% ideal
1717 // but it will usually do what you want. We can revisit this later if needed.
1719 // Functions to iterate over the internal components
1721 int32 NumComponents() const { return components_.size();}
1722 /// Gets the ith component in this component.
1723 /// The ordering is the same as in the config line. The caller
1724 /// does not own the received component.
1725 const Component* GetComponent(int32 i) const;
1726 /// Sets the ith component. After this call, CompositeComponent owns
1727 /// the reference to the argument component. Frees the previous
1728 /// ith component.
1729 void SetComponent(int32 i, Component *component);
1731 virtual ~CompositeComponent() { DeletePointers(&components_); }
1732 private:
1733 // returns the stride type, kDefaultStride or kStrideEqualNumCols,
1734 // at the output of the i'th component.
1735 inline MatrixStrideType GetStrideType(int32 i) const;
1737 // returns true if at least one of 'components_' returns the kUpdatable flag
1738 // in its flags.
1739 bool IsUpdatable() const;
1741 // the maximum number of
1742 int32 max_rows_process_;
1743 std::vector<Component*> components_;
1745 };
1748 } // namespace nnet3
1749 } // namespace kaldi
1752 #endif