f4554600ae23f92b14b54db4129f17da31df8af1
1 // Copyright 2013 Yangqing Jia
3 // Fillers are random number generators that fills a blob using the specified
4 // algorithm. The expectation is that they are only going to be used during
5 // initialization time and will not involve any GPUs.
7 #ifndef CAFFE_FILLER_HPP
8 #define CAFFE_FILLER_HPP
10 #include <mkl.h>
11 #include <string>
13 #include "caffe/common.hpp"
14 #include "caffe/blob.hpp"
15 #include "caffe/syncedmem.hpp"
16 #include "caffe/util/math_functions.hpp"
17 #include "caffe/proto/caffe.pb.h"
19 namespace caffe {
21 template <typename Dtype>
22 class Filler {
23 public:
24 explicit Filler(const FillerParameter& param) : filler_param_(param) {}
25 virtual ~Filler() {}
26 virtual void Fill(Blob<Dtype>* blob) = 0;
27 protected:
28 FillerParameter filler_param_;
29 }; // class Filler
32 template <typename Dtype>
33 class ConstantFiller : public Filler<Dtype> {
34 public:
35 explicit ConstantFiller(const FillerParameter& param)
36 : Filler<Dtype>(param) {}
37 virtual void Fill(Blob<Dtype>* blob) {
38 Dtype* data = blob->mutable_cpu_data();
39 const int count = blob->count();
40 const Dtype value = this->filler_param_.value();
41 CHECK(count);
42 for (int i = 0; i < count; ++i) {
43 data[i] = value;
44 }
45 };
46 };
48 template <typename Dtype>
49 class UniformFiller : public Filler<Dtype> {
50 public:
51 explicit UniformFiller(const FillerParameter& param)
52 : Filler<Dtype>(param) {}
53 virtual void Fill(Blob<Dtype>* blob) {
54 DCHECK(blob->count());
55 caffe_vRngUniform<Dtype>(blob->count(), blob->mutable_cpu_data(),
56 Dtype(this->filler_param_.min()),
57 Dtype(this->filler_param_.max()));
58 }
59 };
61 template <typename Dtype>
62 class GaussianFiller : public Filler<Dtype> {
63 public:
64 explicit GaussianFiller(const FillerParameter& param)
65 : Filler<Dtype>(param) {}
66 virtual void Fill(Blob<Dtype>* blob) {
67 Dtype* data = blob->mutable_cpu_data();
68 DCHECK(blob->count());
69 caffe_vRngGaussian<Dtype>(blob->count(), blob->mutable_cpu_data(),
70 Dtype(this->filler_param_.mean()),
71 Dtype(this->filler_param_.std()));
72 }
73 };
75 template <typename Dtype>
76 class PositiveUnitballFiller : public Filler<Dtype> {
77 public:
78 explicit PositiveUnitballFiller(const FillerParameter& param)
79 : Filler<Dtype>(param) {}
80 virtual void Fill(Blob<Dtype>* blob) {
81 Dtype* data = blob->mutable_cpu_data();
82 DCHECK(blob->count());
83 caffe_vRngUniform<Dtype>(blob->count(), blob->mutable_cpu_data(), 0, 1);
84 // We expect the filler to not be called very frequently, so we will
85 // just use a simple implementation
86 int dim = blob->count() / blob->num();
87 DCHECK(dim);
88 for (int i = 0; i < blob->num(); ++i) {
89 Dtype sum = 0;
90 for (int j = 0; j < dim; ++j) {
91 sum += data[i * dim + j];
92 }
93 for (int j = 0; j < dim; ++j) {
94 data[i * dim + j] /= sum;
95 }
96 }
97 }
98 };
100 template <typename Dtype>
101 class XavierFiller : public Filler<Dtype> {
102 public:
103 explicit XavierFiller(const FillerParameter& param)
104 : Filler<Dtype>(param) {}
105 virtual void Fill(Blob<Dtype>* blob) {
107 }
108 };
111 // A function to get a specific filler from the specification given in
112 // FillerParameter. Ideally this would be replaced by a factory pattern,
113 // but we will leave it this way for now.
114 template <typename Dtype>
115 Filler<Dtype>* GetFiller(const FillerParameter& param) {
116 const std::string& type = param.type();
117 if (type == "constant") {
118 return new ConstantFiller<Dtype>(param);
119 } else if (type == "uniform") {
120 return new UniformFiller<Dtype>(param);
121 } else if (type == "gaussian") {
122 return new GaussianFiller<Dtype>(param);
123 } else if (type == "positive_unitball") {
124 return new PositiveUnitballFiller<Dtype>(param);
125 } else {
126 CHECK(false) << "Unknown filler name: " << param.type();
127 }
128 return (Filler<Dtype>*)(NULL);
129 }
131 } // namespace caffe
133 #endif // CAFFE_FILLER_HPP_