solver restructuring: now all prototxt are specified in the solver protocol buffer
[jacinto-ai/caffe-jacinto.git] / src / caffe / net.cpp
1 // Copyright Yangqing Jia 2013
3 #include <map>
4 #include <set>
5 #include <string>
6 #include <vector>
8 #include "caffe/proto/caffe.pb.h"
9 #include "caffe/layer.hpp"
10 #include "caffe/net.hpp"
12 using std::pair;
13 using std::map;
14 using std::set;
16 namespace caffe {
18 template <typename Dtype>
19 Net<Dtype>::Net(const NetParameter& param,
20     const vector<Blob<Dtype>* >& bottom) {
21   // Basically, build all the layers and set up its connections.
22   name_ = param.name();
23   map<string, int> blob_name_to_idx;
24   set<string> available_blobs;
25   int num_layers = param.layers_size();
26   CHECK_EQ(bottom.size(), param.input_size())
27       << "Incorrect bottom blob size.";
28   // set the input blobs
29   for (int i = 0; i < param.input_size(); ++i) {
30     const string& blob_name = param.input(i);
31     CHECK_GT(bottom[i]->count(), 0);
32     shared_ptr<Blob<Dtype> > blob_pointer(
33         new Blob<Dtype>(bottom[i]->num(), bottom[i]->channels(),
34             bottom[i]->height(), bottom[i]->width()));
35     blobs_.push_back(blob_pointer);
36     blob_names_.push_back(blob_name);
37     blob_need_backward_.push_back(false);
38     net_input_blob_indices_.push_back(i);
39     blob_name_to_idx[blob_name] = i;
40     available_blobs.insert(blob_name);
41   }
42   // For each layer, set up their input and output
43   bottom_vecs_.resize(param.layers_size());
44   top_vecs_.resize(param.layers_size());
45   bottom_id_vecs_.resize(param.layers_size());
46   top_id_vecs_.resize(param.layers_size());
47   for (int i = 0; i < param.layers_size(); ++i) {
48     const LayerConnection& layer_connection = param.layers(i);
49     const LayerParameter& layer_param = layer_connection.layer();
50     layers_.push_back(shared_ptr<Layer<Dtype> >(GetLayer<Dtype>(layer_param)));
51     layer_names_.push_back(layer_param.name());
52     LOG(INFO) << "Creating Layer " << layer_param.name();
53     bool need_backward = false;
54     // Figure out this layer's input and output
55     for (int j = 0; j < layer_connection.bottom_size(); ++j) {
56       const string& blob_name = layer_connection.bottom(j);
57       const int blob_id = blob_name_to_idx[blob_name];
58       if (available_blobs.find(blob_name) == available_blobs.end()) {
59         LOG(FATAL) << "Unknown blob input " << blob_name <<
60             " to layer" << j;
61       }
62       LOG(INFO) << layer_param.name() << " <- " << blob_name;
63       bottom_vecs_[i].push_back(
64           blobs_[blob_id].get());
65       bottom_id_vecs_[i].push_back(blob_id);
66       // If a blob needs backward, this layer should provide it.
67       need_backward |= blob_need_backward_[blob_id];
68       available_blobs.erase(blob_name);
69     }
70     for (int j = 0; j < layer_connection.top_size(); ++j) {
71       const string& blob_name = layer_connection.top(j);
72       // Check if we are doing in-place computation
73       if (layer_connection.bottom_size() > j &&
74           blob_name == layer_connection.bottom(j)) {
75         // In-place computation
76         LOG(INFO) << layer_param.name() << " -> " << blob_name << " (in-place)";
77         available_blobs.insert(blob_name);
78         top_vecs_[i].push_back(
79             blobs_[blob_name_to_idx[blob_name]].get());
80         top_id_vecs_[i].push_back(blob_name_to_idx[blob_name]);
81       } else if (blob_name_to_idx.find(blob_name) != blob_name_to_idx.end()) {
82         // If we are not doing in-place computation but has duplicated blobs,
83         // raise an error.
84         LOG(FATAL) << "Duplicate blobs produced by multiple sources.";
85       } else {
86         // Normal output.
87         LOG(INFO) << layer_param.name() << " -> " << blob_name;
88         shared_ptr<Blob<Dtype> > blob_pointer(new Blob<Dtype>());
89         blobs_.push_back(blob_pointer);
90         blob_names_.push_back(blob_name);
91         blob_need_backward_.push_back(false);
92         blob_name_to_idx[blob_name] = blob_names_.size() - 1;
93         available_blobs.insert(blob_name);
94         top_vecs_[i].push_back(blobs_[blob_names_.size() - 1].get());
95         top_id_vecs_[i].push_back(blob_names_.size() - 1);
96       }
97     }
98     // After this layer is connected, set it up.
99     // LOG(INFO) << "Setting up " << layer_names_[i];
100     layers_[i]->SetUp(bottom_vecs_[i], &top_vecs_[i]);
101     for (int topid = 0; topid < top_vecs_[i].size(); ++topid) {
102       LOG(INFO) << "Top shape: " << top_vecs_[i][topid]->channels() << " "
103           << top_vecs_[i][topid]->height() << " "
104           << top_vecs_[i][topid]->width();
105     }
106     // Check if this layer needs backward operation itself
107     for (int j = 0; j < layers_[i]->layer_param().blobs_lr_size(); ++j) {
108       need_backward |= (layers_[i]->layer_param().blobs_lr(j) > 0);
109     }
110     // Finally, set the backward flag
111     layer_need_backward_.push_back(need_backward);
112     if (need_backward) {
113       LOG(INFO) << layer_names_[i] << " needs backward computation.";
114       for (int j = 0; j < top_id_vecs_[i].size(); ++j) {
115         blob_need_backward_[top_id_vecs_[i][j]] = true;
116       }
117     } else {
118       LOG(INFO) << layer_names_[i] << " does not need backward computation.";
119     }
120   }
121   // In the end, all remaining blobs are considered output blobs.
122   for (set<string>::iterator it = available_blobs.begin();
123       it != available_blobs.end(); ++it) {
124     LOG(INFO) << "This network produces output " << *it;
125     net_output_blob_indices_.push_back(blob_name_to_idx[*it]);
126     net_output_blobs_.push_back(blobs_[blob_name_to_idx[*it]].get());
127   }
128   GetLearningRateAndWeightDecay();
129   LOG(INFO) << "Network initialization done.";
133 template <typename Dtype>
134 void Net<Dtype>::GetLearningRateAndWeightDecay() {
135   LOG(INFO) << "Collecting Learning Rate and Weight Decay.";
136   for (int i = 0; i < layers_.size(); ++i) {
137     vector<shared_ptr<Blob<Dtype> > >& layer_blobs = layers_[i]->blobs();
138     for (int j = 0; j < layer_blobs.size(); ++j) {
139       params_.push_back(layer_blobs[j]);
140     }
141     // push the learning rate mutlipliers
142     if (layers_[i]->layer_param().blobs_lr_size()) {
143       CHECK_EQ(layers_[i]->layer_param().blobs_lr_size(), layer_blobs.size());
144       for (int j = 0; j < layer_blobs.size(); ++j) {
145         float local_lr = layers_[i]->layer_param().blobs_lr(j);
146         CHECK_GE(local_lr, 0.);
147         params_lr_.push_back(local_lr);
148       }
149     } else {
150       for (int j = 0; j < layer_blobs.size(); ++j) {
151         params_lr_.push_back(1.);
152       }
153     }
154     // push the weight decay multipliers
155     if (layers_[i]->layer_param().weight_decay_size()) {
156       CHECK_EQ(layers_[i]->layer_param().weight_decay_size(),
157           layer_blobs.size());
158       for (int j = 0; j < layer_blobs.size(); ++j) {
159         float local_decay = layers_[i]->layer_param().weight_decay(j);
160         CHECK_GE(local_decay, 0.);
161         params_weight_decay_.push_back(local_decay);
162       }
163     } else {
164       for (int j = 0; j < layer_blobs.size(); ++j) {
165         params_weight_decay_.push_back(1.);
166       }
167     }
168   }
171 template <typename Dtype>
172 const vector<Blob<Dtype>*>& Net<Dtype>::Forward(
173     const vector<Blob<Dtype>*> & bottom) {
174   // Copy bottom to internal bottom
175   for (int i = 0; i < bottom.size(); ++i) {
176     blobs_[net_input_blob_indices_[i]]->CopyFrom(*bottom[i]);
177   }
178   for (int i = 0; i < layers_.size(); ++i) {
179     // LOG(ERROR) << "Forwarding " << layer_names_[i];
180     layers_[i]->Forward(bottom_vecs_[i], &top_vecs_[i]);
181   }
182   return net_output_blobs_;
185 template <typename Dtype>
186 Dtype Net<Dtype>::Backward() {
187   Dtype loss = 0;
188   for (int i = layers_.size() - 1; i >= 0; --i) {
189     if (layer_need_backward_[i]) {
190       Dtype layer_loss = layers_[i]->Backward(
191           top_vecs_[i], true, &bottom_vecs_[i]);
192       loss += layer_loss;
193     }
194   }
195   return loss;
198 template <typename Dtype>
199 void Net<Dtype>::CopyTrainedLayersFrom(const NetParameter& param) {
200   int num_source_layers = param.layers_size();
201   for (int i = 0; i < num_source_layers; ++i) {
202     const LayerParameter& source_layer = param.layers(i).layer();
203     const string& source_layer_name = source_layer.name();
204     int target_layer_id = 0;
205     while (target_layer_id != layer_names_.size() &&
206         layer_names_[target_layer_id] != source_layer_name) {
207       ++target_layer_id;
208     }
209     if (target_layer_id == layer_names_.size()) {
210       DLOG(INFO) << "Ignoring source layer " << source_layer_name;
211       continue;
212     }
213     DLOG(INFO) << "Loading source layer " << source_layer_name;
214     vector<shared_ptr<Blob<Dtype> > >& target_blobs =
215         layers_[target_layer_id]->blobs();
216     CHECK_EQ(target_blobs.size(), source_layer.blobs_size())
217         << "Incompatible number of blobs for layer " << source_layer_name;
218     for (int j = 0; j < target_blobs.size(); ++j) {
219       CHECK_EQ(target_blobs[j]->num(), source_layer.blobs(j).num());
220       CHECK_EQ(target_blobs[j]->channels(), source_layer.blobs(j).channels());
221       CHECK_EQ(target_blobs[j]->height(), source_layer.blobs(j).height());
222       CHECK_EQ(target_blobs[j]->width(), source_layer.blobs(j).width());
223       target_blobs[j]->FromProto(source_layer.blobs(j));
224     }
225   }
228 template <typename Dtype>
229 void Net<Dtype>::ToProto(NetParameter* param, bool write_diff) {
230   param->Clear();
231   param->set_name(name_);
232   // Add bottom and top
233   for (int i = 0; i < net_input_blob_indices_.size(); ++i) {
234     param->add_input(blob_names_[net_input_blob_indices_[i]]);
235   }
236   DLOG(INFO) << "Serializing " << layers_.size() << " layers";
237   for (int i = 0; i < layers_.size(); ++i) {
238     LayerConnection* layer_connection = param->add_layers();
239     for (int j = 0; j < bottom_id_vecs_[i].size(); ++j) {
240       layer_connection->add_bottom(blob_names_[bottom_id_vecs_[i][j]]);
241     }
242     for (int j = 0; j < top_id_vecs_[i].size(); ++j) {
243       layer_connection->add_top(blob_names_[top_id_vecs_[i][j]]);
244     }
245     LayerParameter* layer_parameter = layer_connection->mutable_layer();
246     layers_[i]->ToProto(layer_parameter, write_diff);
247   }
250 template <typename Dtype>
251 void Net<Dtype>::Update() {
252   for (int i = 0; i < params_.size(); ++i) {
253     params_[i]->Update();
254   }
257 INSTANTIATE_CLASS(Net);
259 }  // namespace caffe