]> Gitweb @ Texas Instruments - Open Source Git Repositories - git.TI.com/gitweb - jacinto-ai/caffe-jacinto.git/blob - src/caffe/vision_layers.hpp
29d2eb363837f0c5bab4bc80dad159fee58bdb96
[jacinto-ai/caffe-jacinto.git] / src / caffe / vision_layers.hpp
1 // Copyright 2013 Yangqing Jia
3 #ifndef CAFFE_VISION_LAYERS_HPP_
4 #define CAFFE_VISION_LAYERS_HPP_
6 #include <leveldb/db.h>
8 #include <vector>
10 #include "caffe/layer.hpp"
12 namespace caffe {
14 // The neuron layer is a specific type of layers that just works on single
15 // celements.
16 template <typename Dtype>
17 class NeuronLayer : public Layer<Dtype> {
18  public:
19   explicit NeuronLayer(const LayerParameter& param)
20      : Layer<Dtype>(param) {}
21   virtual void SetUp(const vector<Blob<Dtype>*>& bottom,
22       vector<Blob<Dtype>*>* top);
23 };
26 template <typename Dtype>
27 class ReLULayer : public NeuronLayer<Dtype> {
28  public:
29   explicit ReLULayer(const LayerParameter& param)
30       : NeuronLayer<Dtype>(param) {}
32  protected:
33   virtual void Forward_cpu(const vector<Blob<Dtype>*>& bottom,
34       vector<Blob<Dtype>*>* top);
35   virtual void Forward_gpu(const vector<Blob<Dtype>*>& bottom,
36       vector<Blob<Dtype>*>* top);
38   virtual Dtype Backward_cpu(const vector<Blob<Dtype>*>& top,
39       const bool propagate_down, vector<Blob<Dtype>*>* bottom);
40   virtual Dtype Backward_gpu(const vector<Blob<Dtype>*>& top,
41       const bool propagate_down, vector<Blob<Dtype>*>* bottom);
42 };
45 template <typename Dtype>
46 class DropoutLayer : public NeuronLayer<Dtype> {
47  public:
48   explicit DropoutLayer(const LayerParameter& param)
49       : NeuronLayer<Dtype>(param) {}
50   virtual void SetUp(const vector<Blob<Dtype>*>& bottom,
51       vector<Blob<Dtype>*>* top);
53  protected:
54   virtual void Forward_cpu(const vector<Blob<Dtype>*>& bottom,
55       vector<Blob<Dtype>*>* top);
56   virtual void Forward_gpu(const vector<Blob<Dtype>*>& bottom,
57       vector<Blob<Dtype>*>* top);
59   virtual Dtype Backward_cpu(const vector<Blob<Dtype>*>& top,
60       const bool propagate_down, vector<Blob<Dtype>*>* bottom);
61   virtual Dtype Backward_gpu(const vector<Blob<Dtype>*>& top,
62       const bool propagate_down, vector<Blob<Dtype>*>* bottom);
63   shared_ptr<SyncedMemory> rand_vec_;
64   float threshold_;
65   float scale_;
66   unsigned int uint_thres_;
67 };
70 template <typename Dtype>
71 class InnerProductLayer : public Layer<Dtype> {
72  public:
73   explicit InnerProductLayer(const LayerParameter& param)
74       : Layer<Dtype>(param) {}
75   virtual void SetUp(const vector<Blob<Dtype>*>& bottom,
76       vector<Blob<Dtype>*>* top);
78  protected:
79   virtual void Forward_cpu(const vector<Blob<Dtype>*>& bottom,
80       vector<Blob<Dtype>*>* top);
81   virtual void Forward_gpu(const vector<Blob<Dtype>*>& bottom,
82       vector<Blob<Dtype>*>* top);
84   virtual Dtype Backward_cpu(const vector<Blob<Dtype>*>& top,
85       const bool propagate_down, vector<Blob<Dtype>*>* bottom);
86   virtual Dtype Backward_gpu(const vector<Blob<Dtype>*>& top,
87       const bool propagate_down, vector<Blob<Dtype>*>* bottom);
88   int M_;
89   int K_;
90   int N_;
91   bool biasterm_;
92   shared_ptr<SyncedMemory> bias_multiplier_;
93 };
95 template <typename Dtype>
96 class PaddingLayer : public Layer<Dtype> {
97  public:
98   explicit PaddingLayer(const LayerParameter& param)
99       : Layer<Dtype>(param) {}
100   virtual void SetUp(const vector<Blob<Dtype>*>& bottom,
101       vector<Blob<Dtype>*>* top);
103  protected:
104   virtual void Forward_cpu(const vector<Blob<Dtype>*>& bottom,
105       vector<Blob<Dtype>*>* top);
106   virtual void Forward_gpu(const vector<Blob<Dtype>*>& bottom,
107       vector<Blob<Dtype>*>* top);
108   virtual Dtype Backward_cpu(const vector<Blob<Dtype>*>& top,
109       const bool propagate_down, vector<Blob<Dtype>*>* bottom);
110   virtual Dtype Backward_gpu(const vector<Blob<Dtype>*>& top,
111       const bool propagate_down, vector<Blob<Dtype>*>* bottom);
112   unsigned int PAD_;
113   int NUM_;
114   int CHANNEL_;
115   int HEIGHT_IN_;
116   int WIDTH_IN_;
117   int HEIGHT_OUT_;
118   int WIDTH_OUT_;
119 };
121 template <typename Dtype>
122 class LRNLayer : public Layer<Dtype> {
123  public:
124   explicit LRNLayer(const LayerParameter& param)
125       : Layer<Dtype>(param) {}
126   virtual void SetUp(const vector<Blob<Dtype>*>& bottom,
127       vector<Blob<Dtype>*>* top);
129  protected:
130   virtual void Forward_cpu(const vector<Blob<Dtype>*>& bottom,
131       vector<Blob<Dtype>*>* top);
132   virtual void Forward_gpu(const vector<Blob<Dtype>*>& bottom,
133       vector<Blob<Dtype>*>* top);
134   virtual Dtype Backward_cpu(const vector<Blob<Dtype>*>& top,
135       const bool propagate_down, vector<Blob<Dtype>*>* bottom);
136   virtual Dtype Backward_gpu(const vector<Blob<Dtype>*>& top,
137       const bool propagate_down, vector<Blob<Dtype>*>* bottom);
138   // scale_ stores the intermediate summing results
139   Blob<Dtype> scale_;
140   int size_;
141   int pre_pad_;
142   Dtype alpha_;
143   Dtype beta_;
144   int num_;
145   int channels_;
146   int height_;
147   int width_;
148 };
150 template <typename Dtype>
151 class Im2colLayer : public Layer<Dtype> {
152  public:
153   explicit Im2colLayer(const LayerParameter& param)
154       : Layer<Dtype>(param) {}
155   virtual void SetUp(const vector<Blob<Dtype>*>& bottom,
156       vector<Blob<Dtype>*>* top);
158  protected:
159   virtual void Forward_cpu(const vector<Blob<Dtype>*>& bottom,
160       vector<Blob<Dtype>*>* top);
161   virtual void Forward_gpu(const vector<Blob<Dtype>*>& bottom,
162       vector<Blob<Dtype>*>* top);
163   virtual Dtype Backward_cpu(const vector<Blob<Dtype>*>& top,
164       const bool propagate_down, vector<Blob<Dtype>*>* bottom);
165   virtual Dtype Backward_gpu(const vector<Blob<Dtype>*>& top,
166       const bool propagate_down, vector<Blob<Dtype>*>* bottom);
167   int KSIZE_;
168   int STRIDE_;
169   int CHANNELS_;
170   int HEIGHT_;
171   int WIDTH_;
172 };
174 template <typename Dtype>
175 class PoolingLayer : public Layer<Dtype> {
176  public:
177   explicit PoolingLayer(const LayerParameter& param)
178       : Layer<Dtype>(param) {}
179   virtual void SetUp(const vector<Blob<Dtype>*>& bottom,
180       vector<Blob<Dtype>*>* top);
182  protected:
183   virtual void Forward_cpu(const vector<Blob<Dtype>*>& bottom,
184       vector<Blob<Dtype>*>* top);
185   virtual void Forward_gpu(const vector<Blob<Dtype>*>& bottom,
186       vector<Blob<Dtype>*>* top);
187   virtual Dtype Backward_cpu(const vector<Blob<Dtype>*>& top,
188       const bool propagate_down, vector<Blob<Dtype>*>* bottom);
189   virtual Dtype Backward_gpu(const vector<Blob<Dtype>*>& top,
190       const bool propagate_down, vector<Blob<Dtype>*>* bottom);
191   int KSIZE_;
192   int STRIDE_;
193   int CHANNELS_;
194   int HEIGHT_;
195   int WIDTH_;
196   int POOLED_HEIGHT_;
197   int POOLED_WIDTH_;
198 };
200 template <typename Dtype>
201 class ConvolutionLayer : public Layer<Dtype> {
202  public:
203   explicit ConvolutionLayer(const LayerParameter& param)
204       : Layer<Dtype>(param) {}
205   virtual void SetUp(const vector<Blob<Dtype>*>& bottom,
206       vector<Blob<Dtype>*>* top);
208  protected:
209   virtual void Forward_cpu(const vector<Blob<Dtype>*>& bottom,
210       vector<Blob<Dtype>*>* top);
211   virtual void Forward_gpu(const vector<Blob<Dtype>*>& bottom,
212       vector<Blob<Dtype>*>* top);
213   virtual Dtype Backward_cpu(const vector<Blob<Dtype>*>& top,
214       const bool propagate_down, vector<Blob<Dtype>*>* bottom);
215   virtual Dtype Backward_gpu(const vector<Blob<Dtype>*>& top,
216       const bool propagate_down, vector<Blob<Dtype>*>* bottom);
217   Blob<Dtype> col_bob_;
219   int KSIZE_;
220   int STRIDE_;
221   int NUM_;
222   int CHANNELS_;
223   int HEIGHT_;
224   int WIDTH_;
225   int NUM_OUTPUT_;
226   int GROUP_;
227   Blob<Dtype> col_buffer_;
228   shared_ptr<SyncedMemory> bias_multiplier_;
229   bool biasterm_;
230   int M_;
231   int K_;
232   int N_;
233 };
235 template <typename Dtype>
236 class DataLayer : public Layer<Dtype> {
237  public:
238   explicit DataLayer(const LayerParameter& param)
239       : Layer<Dtype>(param) {}
240   virtual void SetUp(const vector<Blob<Dtype>*>& bottom,
241       vector<Blob<Dtype>*>* top);
243  protected:
244   virtual void Forward_cpu(const vector<Blob<Dtype>*>& bottom,
245       vector<Blob<Dtype>*>* top);
246   virtual void Forward_gpu(const vector<Blob<Dtype>*>& bottom,
247       vector<Blob<Dtype>*>* top);
248   virtual Dtype Backward_cpu(const vector<Blob<Dtype>*>& top,
249       const bool propagate_down, vector<Blob<Dtype>*>* bottom);
250   virtual Dtype Backward_gpu(const vector<Blob<Dtype>*>& top,
251       const bool propagate_down, vector<Blob<Dtype>*>* bottom);
253   shared_ptr<leveldb::DB> db_;
254   shared_ptr<leveldb::Iterator> iter_;
255   int datum_size_;
256 };
259 template <typename Dtype>
260 class SoftmaxLayer : public Layer<Dtype> {
261  public:
262   explicit SoftmaxLayer(const LayerParameter& param)
263       : Layer<Dtype>(param) {}
264   virtual void SetUp(const vector<Blob<Dtype>*>& bottom,
265       vector<Blob<Dtype>*>* top);
267  protected:
268   virtual void Forward_cpu(const vector<Blob<Dtype>*>& bottom,
269       vector<Blob<Dtype>*>* top);
270   // virtual void Forward_gpu(const vector<Blob<Dtype>*>& bottom,
271   //     vector<Blob<Dtype>*>* top);
272   virtual Dtype Backward_cpu(const vector<Blob<Dtype>*>& top,
273       const bool propagate_down, vector<Blob<Dtype>*>* bottom);
274   // virtual Dtype Backward_gpu(const vector<Blob<Dtype>*>& top,
275   //     const bool propagate_down, vector<Blob<Dtype>*>* bottom);
277   // sum_multiplier is just used to carry out sum using blas
278   Blob<Dtype> sum_multiplier_;
279   // scale is an intermediate blob to hold temporary results.
280   Blob<Dtype> scale_;
281 };
283 template <typename Dtype>
284 class MultinomialLogisticLossLayer : public Layer<Dtype> {
285  public:
286   explicit MultinomialLogisticLossLayer(const LayerParameter& param)
287       : Layer<Dtype>(param) {}
288   virtual void SetUp(const vector<Blob<Dtype>*>& bottom,
289       vector<Blob<Dtype>*>* top);
291  protected:
292   // The loss layer will do nothing during forward - all computation are
293   // carried out in the backward pass.
294   virtual void Forward_cpu(const vector<Blob<Dtype>*>& bottom,
295       vector<Blob<Dtype>*>* top) { return; }
296   virtual void Forward_gpu(const vector<Blob<Dtype>*>& bottom,
297       vector<Blob<Dtype>*>* top) { return; }
298   virtual Dtype Backward_cpu(const vector<Blob<Dtype>*>& top,
299       const bool propagate_down, vector<Blob<Dtype>*>* bottom);
300   // virtual Dtype Backward_gpu(const vector<Blob<Dtype>*>& top,
301   //     const bool propagate_down, vector<Blob<Dtype>*>* bottom);
302 };
304 template <typename Dtype>
305 class EuclideanLossLayer : public Layer<Dtype> {
306  public:
307   explicit EuclideanLossLayer(const LayerParameter& param)
308       : Layer<Dtype>(param), difference_() {}
309   virtual void SetUp(const vector<Blob<Dtype>*>& bottom,
310       vector<Blob<Dtype>*>* top);
312  protected:
313   // The loss layer will do nothing during forward - all computation are
314   // carried out in the backward pass.
315   virtual void Forward_cpu(const vector<Blob<Dtype>*>& bottom,
316       vector<Blob<Dtype>*>* top) { return; }
317   virtual void Forward_gpu(const vector<Blob<Dtype>*>& bottom,
318       vector<Blob<Dtype>*>* top) { return; }
319   virtual Dtype Backward_cpu(const vector<Blob<Dtype>*>& top,
320       const bool propagate_down, vector<Blob<Dtype>*>* bottom);
321   // virtual Dtype Backward_gpu(const vector<Blob<Dtype>*>& top,
322   //     const bool propagate_down, vector<Blob<Dtype>*>* bottom);
323   Blob<Dtype> difference_;
324 };
327 }  // namespace caffe
329 #endif  // CAFFE_VISION_LAYERS_HPP_