1 // Copyright 2013 Yangqing Jia
3 #ifndef CAFFE_VISION_LAYERS_HPP_
4 #define CAFFE_VISION_LAYERS_HPP_
6 #include <leveldb/db.h>
8 #include <vector>
10 #include "caffe/layer.hpp"
12 namespace caffe {
14 // The neuron layer is a specific type of layers that just works on single
15 // celements.
16 template <typename Dtype>
17 class NeuronLayer : public Layer<Dtype> {
18 public:
19 explicit NeuronLayer(const LayerParameter& param)
20 : Layer<Dtype>(param) {}
21 virtual void SetUp(const vector<Blob<Dtype>*>& bottom,
22 vector<Blob<Dtype>*>* top);
23 };
26 template <typename Dtype>
27 class ReLULayer : public NeuronLayer<Dtype> {
28 public:
29 explicit ReLULayer(const LayerParameter& param)
30 : NeuronLayer<Dtype>(param) {}
32 protected:
33 virtual void Forward_cpu(const vector<Blob<Dtype>*>& bottom,
34 vector<Blob<Dtype>*>* top);
35 virtual void Forward_gpu(const vector<Blob<Dtype>*>& bottom,
36 vector<Blob<Dtype>*>* top);
38 virtual Dtype Backward_cpu(const vector<Blob<Dtype>*>& top,
39 const bool propagate_down, vector<Blob<Dtype>*>* bottom);
40 virtual Dtype Backward_gpu(const vector<Blob<Dtype>*>& top,
41 const bool propagate_down, vector<Blob<Dtype>*>* bottom);
42 };
45 template <typename Dtype>
46 class DropoutLayer : public NeuronLayer<Dtype> {
47 public:
48 explicit DropoutLayer(const LayerParameter& param)
49 : NeuronLayer<Dtype>(param) {}
50 virtual void SetUp(const vector<Blob<Dtype>*>& bottom,
51 vector<Blob<Dtype>*>* top);
53 protected:
54 virtual void Forward_cpu(const vector<Blob<Dtype>*>& bottom,
55 vector<Blob<Dtype>*>* top);
56 virtual void Forward_gpu(const vector<Blob<Dtype>*>& bottom,
57 vector<Blob<Dtype>*>* top);
59 virtual Dtype Backward_cpu(const vector<Blob<Dtype>*>& top,
60 const bool propagate_down, vector<Blob<Dtype>*>* bottom);
61 virtual Dtype Backward_gpu(const vector<Blob<Dtype>*>& top,
62 const bool propagate_down, vector<Blob<Dtype>*>* bottom);
63 shared_ptr<SyncedMemory> rand_vec_;
64 float threshold_;
65 float scale_;
66 unsigned int uint_thres_;
67 };
70 template <typename Dtype>
71 class InnerProductLayer : public Layer<Dtype> {
72 public:
73 explicit InnerProductLayer(const LayerParameter& param)
74 : Layer<Dtype>(param) {}
75 virtual void SetUp(const vector<Blob<Dtype>*>& bottom,
76 vector<Blob<Dtype>*>* top);
78 protected:
79 virtual void Forward_cpu(const vector<Blob<Dtype>*>& bottom,
80 vector<Blob<Dtype>*>* top);
81 virtual void Forward_gpu(const vector<Blob<Dtype>*>& bottom,
82 vector<Blob<Dtype>*>* top);
84 virtual Dtype Backward_cpu(const vector<Blob<Dtype>*>& top,
85 const bool propagate_down, vector<Blob<Dtype>*>* bottom);
86 virtual Dtype Backward_gpu(const vector<Blob<Dtype>*>& top,
87 const bool propagate_down, vector<Blob<Dtype>*>* bottom);
88 int M_;
89 int K_;
90 int N_;
91 bool biasterm_;
92 shared_ptr<SyncedMemory> bias_multiplier_;
93 };
95 template <typename Dtype>
96 class PaddingLayer : public Layer<Dtype> {
97 public:
98 explicit PaddingLayer(const LayerParameter& param)
99 : Layer<Dtype>(param) {}
100 virtual void SetUp(const vector<Blob<Dtype>*>& bottom,
101 vector<Blob<Dtype>*>* top);
103 protected:
104 virtual void Forward_cpu(const vector<Blob<Dtype>*>& bottom,
105 vector<Blob<Dtype>*>* top);
106 virtual void Forward_gpu(const vector<Blob<Dtype>*>& bottom,
107 vector<Blob<Dtype>*>* top);
108 virtual Dtype Backward_cpu(const vector<Blob<Dtype>*>& top,
109 const bool propagate_down, vector<Blob<Dtype>*>* bottom);
110 virtual Dtype Backward_gpu(const vector<Blob<Dtype>*>& top,
111 const bool propagate_down, vector<Blob<Dtype>*>* bottom);
112 unsigned int PAD_;
113 int NUM_;
114 int CHANNEL_;
115 int HEIGHT_IN_;
116 int WIDTH_IN_;
117 int HEIGHT_OUT_;
118 int WIDTH_OUT_;
119 };
121 template <typename Dtype>
122 class LRNLayer : public Layer<Dtype> {
123 public:
124 explicit LRNLayer(const LayerParameter& param)
125 : Layer<Dtype>(param) {}
126 virtual void SetUp(const vector<Blob<Dtype>*>& bottom,
127 vector<Blob<Dtype>*>* top);
129 protected:
130 virtual void Forward_cpu(const vector<Blob<Dtype>*>& bottom,
131 vector<Blob<Dtype>*>* top);
132 virtual void Forward_gpu(const vector<Blob<Dtype>*>& bottom,
133 vector<Blob<Dtype>*>* top);
134 virtual Dtype Backward_cpu(const vector<Blob<Dtype>*>& top,
135 const bool propagate_down, vector<Blob<Dtype>*>* bottom);
136 virtual Dtype Backward_gpu(const vector<Blob<Dtype>*>& top,
137 const bool propagate_down, vector<Blob<Dtype>*>* bottom);
138 // scale_ stores the intermediate summing results
139 Blob<Dtype> scale_;
140 int size_;
141 int pre_pad_;
142 Dtype alpha_;
143 Dtype beta_;
144 int num_;
145 int channels_;
146 int height_;
147 int width_;
148 };
150 template <typename Dtype>
151 class Im2colLayer : public Layer<Dtype> {
152 public:
153 explicit Im2colLayer(const LayerParameter& param)
154 : Layer<Dtype>(param) {}
155 virtual void SetUp(const vector<Blob<Dtype>*>& bottom,
156 vector<Blob<Dtype>*>* top);
158 protected:
159 virtual void Forward_cpu(const vector<Blob<Dtype>*>& bottom,
160 vector<Blob<Dtype>*>* top);
161 virtual void Forward_gpu(const vector<Blob<Dtype>*>& bottom,
162 vector<Blob<Dtype>*>* top);
163 virtual Dtype Backward_cpu(const vector<Blob<Dtype>*>& top,
164 const bool propagate_down, vector<Blob<Dtype>*>* bottom);
165 virtual Dtype Backward_gpu(const vector<Blob<Dtype>*>& top,
166 const bool propagate_down, vector<Blob<Dtype>*>* bottom);
167 int KSIZE_;
168 int STRIDE_;
169 int CHANNELS_;
170 int HEIGHT_;
171 int WIDTH_;
172 };
174 template <typename Dtype>
175 class PoolingLayer : public Layer<Dtype> {
176 public:
177 explicit PoolingLayer(const LayerParameter& param)
178 : Layer<Dtype>(param) {}
179 virtual void SetUp(const vector<Blob<Dtype>*>& bottom,
180 vector<Blob<Dtype>*>* top);
182 protected:
183 virtual void Forward_cpu(const vector<Blob<Dtype>*>& bottom,
184 vector<Blob<Dtype>*>* top);
185 virtual void Forward_gpu(const vector<Blob<Dtype>*>& bottom,
186 vector<Blob<Dtype>*>* top);
187 virtual Dtype Backward_cpu(const vector<Blob<Dtype>*>& top,
188 const bool propagate_down, vector<Blob<Dtype>*>* bottom);
189 virtual Dtype Backward_gpu(const vector<Blob<Dtype>*>& top,
190 const bool propagate_down, vector<Blob<Dtype>*>* bottom);
191 int KSIZE_;
192 int STRIDE_;
193 int CHANNELS_;
194 int HEIGHT_;
195 int WIDTH_;
196 int POOLED_HEIGHT_;
197 int POOLED_WIDTH_;
198 };
200 template <typename Dtype>
201 class ConvolutionLayer : public Layer<Dtype> {
202 public:
203 explicit ConvolutionLayer(const LayerParameter& param)
204 : Layer<Dtype>(param) {}
205 virtual void SetUp(const vector<Blob<Dtype>*>& bottom,
206 vector<Blob<Dtype>*>* top);
208 protected:
209 virtual void Forward_cpu(const vector<Blob<Dtype>*>& bottom,
210 vector<Blob<Dtype>*>* top);
211 virtual void Forward_gpu(const vector<Blob<Dtype>*>& bottom,
212 vector<Blob<Dtype>*>* top);
213 virtual Dtype Backward_cpu(const vector<Blob<Dtype>*>& top,
214 const bool propagate_down, vector<Blob<Dtype>*>* bottom);
215 virtual Dtype Backward_gpu(const vector<Blob<Dtype>*>& top,
216 const bool propagate_down, vector<Blob<Dtype>*>* bottom);
217 Blob<Dtype> col_bob_;
219 int KSIZE_;
220 int STRIDE_;
221 int NUM_;
222 int CHANNELS_;
223 int HEIGHT_;
224 int WIDTH_;
225 int NUM_OUTPUT_;
226 int GROUP_;
227 Blob<Dtype> col_buffer_;
228 shared_ptr<SyncedMemory> bias_multiplier_;
229 bool biasterm_;
230 int M_;
231 int K_;
232 int N_;
233 };
235 template <typename Dtype>
236 class DataLayer : public Layer<Dtype> {
237 public:
238 explicit DataLayer(const LayerParameter& param)
239 : Layer<Dtype>(param) {}
240 virtual void SetUp(const vector<Blob<Dtype>*>& bottom,
241 vector<Blob<Dtype>*>* top);
243 protected:
244 virtual void Forward_cpu(const vector<Blob<Dtype>*>& bottom,
245 vector<Blob<Dtype>*>* top);
246 virtual void Forward_gpu(const vector<Blob<Dtype>*>& bottom,
247 vector<Blob<Dtype>*>* top);
248 virtual Dtype Backward_cpu(const vector<Blob<Dtype>*>& top,
249 const bool propagate_down, vector<Blob<Dtype>*>* bottom);
250 virtual Dtype Backward_gpu(const vector<Blob<Dtype>*>& top,
251 const bool propagate_down, vector<Blob<Dtype>*>* bottom);
253 shared_ptr<leveldb::DB> db_;
254 shared_ptr<leveldb::Iterator> iter_;
255 int datum_size_;
256 };
259 template <typename Dtype>
260 class SoftmaxLayer : public Layer<Dtype> {
261 public:
262 explicit SoftmaxLayer(const LayerParameter& param)
263 : Layer<Dtype>(param) {}
264 virtual void SetUp(const vector<Blob<Dtype>*>& bottom,
265 vector<Blob<Dtype>*>* top);
267 protected:
268 virtual void Forward_cpu(const vector<Blob<Dtype>*>& bottom,
269 vector<Blob<Dtype>*>* top);
270 virtual void Forward_gpu(const vector<Blob<Dtype>*>& bottom,
271 vector<Blob<Dtype>*>* top);
272 virtual Dtype Backward_cpu(const vector<Blob<Dtype>*>& top,
273 const bool propagate_down, vector<Blob<Dtype>*>* bottom);
274 virtual Dtype Backward_gpu(const vector<Blob<Dtype>*>& top,
275 const bool propagate_down, vector<Blob<Dtype>*>* bottom);
277 // sum_multiplier is just used to carry out sum using blas
278 Blob<Dtype> sum_multiplier_;
279 // scale is an intermediate blob to hold temporary results.
280 Blob<Dtype> scale_;
281 };
283 template <typename Dtype>
284 class MultinomialLogisticLossLayer : public Layer<Dtype> {
285 public:
286 explicit MultinomialLogisticLossLayer(const LayerParameter& param)
287 : Layer<Dtype>(param) {}
288 virtual void SetUp(const vector<Blob<Dtype>*>& bottom,
289 vector<Blob<Dtype>*>* top);
291 protected:
292 // The loss layer will do nothing during forward - all computation are
293 // carried out in the backward pass.
294 virtual void Forward_cpu(const vector<Blob<Dtype>*>& bottom,
295 vector<Blob<Dtype>*>* top) { return; }
296 virtual void Forward_gpu(const vector<Blob<Dtype>*>& bottom,
297 vector<Blob<Dtype>*>* top) { return; }
298 virtual Dtype Backward_cpu(const vector<Blob<Dtype>*>& top,
299 const bool propagate_down, vector<Blob<Dtype>*>* bottom);
300 // virtual Dtype Backward_gpu(const vector<Blob<Dtype>*>& top,
301 // const bool propagate_down, vector<Blob<Dtype>*>* bottom);
302 };
304 template <typename Dtype>
305 class EuclideanLossLayer : public Layer<Dtype> {
306 public:
307 explicit EuclideanLossLayer(const LayerParameter& param)
308 : Layer<Dtype>(param), difference_() {}
309 virtual void SetUp(const vector<Blob<Dtype>*>& bottom,
310 vector<Blob<Dtype>*>* top);
312 protected:
313 // The loss layer will do nothing during forward - all computation are
314 // carried out in the backward pass.
315 virtual void Forward_cpu(const vector<Blob<Dtype>*>& bottom,
316 vector<Blob<Dtype>*>* top) { return; }
317 virtual void Forward_gpu(const vector<Blob<Dtype>*>& bottom,
318 vector<Blob<Dtype>*>* top) { return; }
319 virtual Dtype Backward_cpu(const vector<Blob<Dtype>*>& top,
320 const bool propagate_down, vector<Blob<Dtype>*>* bottom);
321 // virtual Dtype Backward_gpu(const vector<Blob<Dtype>*>& top,
322 // const bool propagate_down, vector<Blob<Dtype>*>* bottom);
323 Blob<Dtype> difference_;
324 };
326 template <typename Dtype>
327 class AccuracyLayer : public Layer<Dtype> {
328 public:
329 explicit AccuracyLayer(const LayerParameter& param)
330 : Layer<Dtype>(param) {}
331 virtual void SetUp(const vector<Blob<Dtype>*>& bottom,
332 vector<Blob<Dtype>*>* top);
334 protected:
335 // The loss layer will do nothing during forward - all computation are
336 // carried out in the backward pass.
337 virtual void Forward_cpu(const vector<Blob<Dtype>*>& bottom,
338 vector<Blob<Dtype>*>* top);
339 // The accuracy layer should not be used to compute backward operations.
340 virtual Dtype Backward_cpu(const vector<Blob<Dtype>*>& top,
341 const bool propagate_down, vector<Blob<Dtype>*>* bottom) {
342 NOT_IMPLEMENTED;
343 return Dtype(0.);
344 }
345 };
347 } // namespace caffe
349 #endif // CAFFE_VISION_LAYERS_HPP_