1 // Copyright 2013 Yangqing Jia
3 #ifndef CAFFE_LAYER_H_
4 #define CAFFE_LAYER_H_
6 #include <vector>
7 #include "caffe/blob.hpp"
8 #include "caffe/common.hpp"
9 #include "caffe/proto/caffe.pb.h"
11 using std::vector;
13 namespace caffe {
15 template <typename Dtype>
16 class Layer {
17 public:
18 // You should not implement your own constructor. Any set up code should go
19 // to SetUp(), where the dimensions of the bottom blobs are provided to the
20 // layer.
21 explicit Layer(const LayerParameter& param)
22 : layer_param_(param) {}
23 virtual ~Layer() {}
24 // SetUp: your function should implement this.
25 virtual void SetUp(const vector<Blob<Dtype>*>& bottom,
26 vector<Blob<Dtype>*>* top) = 0;
28 // Forward and backward wrappers. You should implement the cpu and
29 // gpu specific implementations instead, and should not change these
30 // functions.
31 inline void Forward(const vector<Blob<Dtype>*>& bottom,
32 vector<Blob<Dtype>*>* top);
33 inline Dtype Backward(const vector<Blob<Dtype>*>& top,
34 const bool propagate_down,
35 vector<Blob<Dtype>*>* bottom);
37 // Returns the vector of parameters.
38 vector<Blob<Dtype> >& params() {
39 return blobs_;
40 }
42 // Writes the layer parameter to a protocol buffer
43 void ToProto(LayerParameter* param, bool write_diff = false);
45 protected:
46 // The protobuf that stores the layer parameters
47 LayerParameter layer_param_;
48 // The vector that stores the parameters as a set of blobs.
49 vector<Blob<Dtype> > blobs_;
51 // Forward functions
52 virtual void Forward_cpu(const vector<Blob<Dtype>*>& bottom,
53 vector<Blob<Dtype>*>* top) = 0;
54 // If no gpu code is provided, we will simply use cpu code.
55 virtual void Forward_gpu(const vector<Blob<Dtype>*>& bottom,
56 vector<Blob<Dtype>*>* top) {
57 LOG(WARNING) << "Using CPU code as backup.";
58 Forward_cpu(bottom, top);
59 };
61 // Backward functions: the backward function will compute the gradients for
62 // any parameters and also for the bottom blobs if propagate_down is true.
63 // It will return the loss produced from this layer.
64 virtual Dtype Backward_cpu(const vector<Blob<Dtype>*>& top,
65 const bool propagate_down,
66 vector<Blob<Dtype>*>* bottom) = 0;
67 virtual Dtype Backward_gpu(const vector<Blob<Dtype>*>& top,
68 const bool propagate_down,
69 vector<Blob<Dtype>*>* bottom) {
70 LOG(WARNING) << "Using CPU code as backup.";
71 return Backward_cpu(top, propagate_down, bottom);
72 };
73 }; // class Layer
75 // Forward and backward wrappers. You should implement the cpu and
76 // gpu specific implementations instead, and should not change these
77 // functions.
78 template <typename Dtype>
79 inline void Layer<Dtype>::Forward(const vector<Blob<Dtype>*>& bottom,
80 vector<Blob<Dtype>*>* top) {
81 switch (Caffe::mode()) {
82 case Caffe::CPU:
83 Forward_cpu(bottom, top);
84 break;
85 case Caffe::GPU:
86 Forward_gpu(bottom, top);
87 break;
88 default:
89 LOG(FATAL) << "Unknown caffe mode.";
90 }
91 };
93 template <typename Dtype>
94 inline Dtype Layer<Dtype>::Backward(const vector<Blob<Dtype>*>& top,
95 const bool propagate_down,
96 vector<Blob<Dtype>*>* bottom) {
97 switch (Caffe::mode()) {
98 case Caffe::CPU:
99 return Backward_cpu(top, propagate_down, bottom);
100 case Caffe::GPU:
101 return Backward_gpu(top, propagate_down, bottom);
102 default:
103 LOG(FATAL) << "Unknown caffe mode.";
104 }
105 };
107 template <typename Dtype>
108 void Layer<Dtype>::ToProto(LayerParameter* param, bool write_diff) {
109 param->Clear();
110 param->CopyFrom(layer_param_);
111 param->clear_blobs();
112 for (int i = 0; i < blobs_.size(); ++i) {
113 blobs_[i].ToProto(param->add_blobs(), write_diff);
114 }
115 }
117 } // namespace caffe
119 #endif // CAFFE_LAYER_H_