]> Gitweb @ Texas Instruments - Open Source Git Repositories - git.TI.com/gitweb - jacinto-ai/caffe-jacinto.git/blob - src/caffeine/filler.hpp
padding layer cuda code, need debug
[jacinto-ai/caffe-jacinto.git] / src / caffeine / filler.hpp
1 // Fillers are random number generators that fills a blob using the specified
2 // algorithm. The expectation is that they are only going to be used during
3 // initialization time and will not involve any GPUs.
5 #ifndef CAFFEINE_FILLER_HPP
6 #define CAFFEINE_FILLER_HPP
8 #include <mkl.h>
10 #include "caffeine/common.hpp"
11 #include "caffeine/blob.hpp"
12 #include "caffeine/syncedmem.hpp"
13 #include "caffeine/proto/layer_param.pb.h"
15 namespace caffeine {
17 template <typename Dtype>
18 class Filler {
19  public:
20   Filler(const FillerParameter& param) : filler_param_(param) {};
21   virtual ~Filler() {};
22   virtual void Fill(Blob<Dtype>* blob) = 0;
23  protected:
24   FillerParameter filler_param_;
25 };  // class Filler
27 template <typename Dtype>
28 class FillerFactory {
30 };
32 template <typename Dtype>
33 class ConstantFiller : public Filler<Dtype> {
34  public:
35   ConstantFiller(const FillerParameter& param) : Filler<Dtype>(param) {};
36   virtual void Fill(Blob<Dtype>* blob) {
37     Dtype* data = blob->mutable_cpu_data();
38     const int count = blob->count();
39     const Dtype value = this->filler_param_.value();
40     CHECK(count);
41     for (int i = 0; i < count; ++i) {
42       data[i] = value;
43     }
44   };
45 };
47 template <typename Dtype>
48 class UniformFiller : public Filler<Dtype> {
49  public:
50   UniformFiller(const FillerParameter& param) : Filler<Dtype>(param) {};
51   virtual void Fill(Blob<Dtype>* blob) {
52     void* data = (void*)(blob->mutable_cpu_data());
53     const int count = blob->count();
54     const Dtype value = this->filler_param_.value();
55     CHECK(count);
56     switch(sizeof(Dtype)) {
57     case sizeof(float):
58       VSL_CHECK(vsRngUniform(VSL_RNG_METHOD_UNIFORM_STD, Caffeine::vsl_stream(),
59           count, (float*)data, this->filler_param_.min(),
60           this->filler_param_.max()));
61       break;
62     case sizeof(double):
63       VSL_CHECK(vdRngUniform(VSL_RNG_METHOD_UNIFORM_STD, Caffeine::vsl_stream(),
64           count, (double*)data, this->filler_param_.min(),
65           this->filler_param_.max()));
66       break;
67     default:
68       CHECK(false) << "Unknown dtype.";
69     }
70   };
71 };
73 template <typename Dtype>
74 class GaussianFiller : public Filler<Dtype> {
75  public:
76   GaussianFiller(const FillerParameter& param) : Filler<Dtype>(param) {};
77   virtual void Fill(Blob<Dtype>* blob) {
78     void* data = (void*)(blob->mutable_cpu_data());
79     const int count = blob->count();
80     const Dtype value = this->filler_param_.value();
81     CHECK(count);
82     switch(sizeof(Dtype)) {
83     case sizeof(float):
84       VSL_CHECK(vsRngGaussian(VSL_RNG_METHOD_GAUSSIAN_BOXMULLER,
85           Caffeine::vsl_stream(), count, (float*)data,
86           this->filler_param_.mean(), this->filler_param_.std()));
87       break;
88     case sizeof(double):
89       VSL_CHECK(vdRngGaussian(VSL_RNG_METHOD_GAUSSIAN_BOXMULLER,
90           Caffeine::vsl_stream(), count, (double*)data,
91           this->filler_param_.mean(), this->filler_param_.std()));
92       break;
93     default:
94       CHECK(false) << "Unknown dtype.";
95     }
96   };
97 };
99 // A function to get a specific filler from the specification given in
100 // FillerParameter. Ideally this would be replaced by a factory pattern,
101 // but we will leave it this way for now.
102 template <typename Dtype>
103 Filler<Dtype>* GetFiller(const FillerParameter& param) {
104   const std::string& type = param.type();
105   if (type == "constant") {
106     return new ConstantFiller<Dtype>(param);
107   } else if (type == "uniform") {
108     return new UniformFiller<Dtype>(param);
109   } else if (type == "gaussian") {
110     return new GaussianFiller<Dtype>(param);
111   } else {
112     CHECK(false) << "Unknown filler name: " << param.type();
113   }
114   return (Filler<Dtype>*)(NULL);
117 }  // namespace caffeine
119 #endif  // CAFFEINE_FILLER_HPP_