07b46a42e971de8093471cc8a42b1774182e2d23
[jacinto-ai/caffe-jacinto.git] / src / caffe / solver.cpp
1 // Copyright Yangqing Jia 2013
3 #include <cstdio>
5 #include <algorithm>
6 #include <string>
7 #include <vector>
9 #include "caffe/net.hpp"
10 #include "caffe/proto/caffe.pb.h"
11 #include "caffe/solver.hpp"
12 #include "caffe/util/io.hpp"
13 #include "caffe/util/math_functions.hpp"
15 using std::max;
16 using std::min;
18 namespace caffe {
20 template <typename Dtype>
21 Solver<Dtype>::Solver(const SolverParameter& param)
22     : param_(param), net_(NULL), test_net_(NULL) {
23   // Scaffolding code
24   NetParameter train_net_param;
25   ReadProtoFromTextFile(param_.train_net(), &train_net_param);
26   // For the training network, there should be no input - so we simply create
27   // a dummy bottom_vec instance to initialize the networks.
28   vector<Blob<Dtype>*> bottom_vec;
29   LOG(INFO) << "Creating training net.";
30   net_ = new Net<Dtype>(train_net_param, bottom_vec);
31   if (param_.has_test_net()) {
32     LOG(INFO) << "Creating testing net.";
33     NetParameter test_net_param;
34     ReadProtoFromTextFile(param_.test_net(), &test_net_param);
35     test_net_ = new Net<Dtype>(test_net_param, bottom_vec);
36     CHECK_GT(param_.test_iter(), 0);
37     CHECK_GT(param_.test_interval(), 0);
38   }
39   LOG(INFO) << "Solver scaffolding done.";
40 }
43 template <typename Dtype>
44 void Solver<Dtype>::Solve(const char* resume_file) {
45   Caffe::set_phase(Caffe::TRAIN);
46   LOG(INFO) << "Solving " << net_->name();
47   PreSolve();
49   iter_ = 0;
50   if (resume_file) {
51     LOG(INFO) << "Restoring previous solver status from " << resume_file;
52     Restore(resume_file);
53   }
55   // For a network that is trained by the solver, no bottom or top vecs
56   // should be given, and we will just provide dummy vecs.
57   vector<Blob<Dtype>*> bottom_vec;
58   while (iter_++ < param_.max_iter()) {
59     Dtype loss = net_->ForwardBackward(bottom_vec);
60     ComputeUpdateValue();
61     net_->Update();
63     // Check if we need to do snapshot
64     if (param_.snapshot() && iter_ % param_.snapshot() == 0) {
65       Snapshot();
66     }
67     if (param_.display() && iter_ % param_.display() == 0) {
68       LOG(INFO) << "Iteration " << iter_ << ", loss = " << loss;
69     }
70     if (param_.test_interval() && iter_ % param_.test_interval() == 0) {
71       // We need to set phase to test before running.
72       Caffe::set_phase(Caffe::TEST);
73       Test();
74       Caffe::set_phase(Caffe::TRAIN);
75     }
76   }
77   LOG(INFO) << "Optimization Done.";
78 }
81 template <typename Dtype>
82 void Solver<Dtype>::Test() {
83   LOG(INFO) << "Testing net";
84   NetParameter net_param;
85   net_->ToProto(&net_param);
86   CHECK_NOTNULL(test_net_)->CopyTrainedLayersFrom(net_param);
87   vector<Dtype> test_score;
88   vector<Blob<Dtype>*> bottom_vec;
89   for (int i = 0; i < param_.test_iter(); ++i) {
90     const vector<Blob<Dtype>*>& result =
91         test_net_->Forward(bottom_vec);
92     if (i == 0) {
93       for (int j = 0; j < result.size(); ++j) {
94         const Dtype* result_vec = result[j]->cpu_data();
95         for (int k = 0; k < result[j]->count(); ++k) {
96           test_score.push_back(result_vec[k]);
97         }
98       }
99     } else {
100       int idx = 0;
101       for (int j = 0; j < result.size(); ++j) {
102         const Dtype* result_vec = result[j]->cpu_data();
103         for (int k = 0; k < result[j]->count(); ++k) {
104           test_score[idx++] += result_vec[k];
105         }
106       }
107     }
108   }
109   for (int i = 0; i < test_score.size(); ++i) {
110     LOG(INFO) << "Test score #" << i << ": "
111         << test_score[i] / param_.test_iter();
112   }
116 template <typename Dtype>
117 void Solver<Dtype>::Snapshot() {
118   NetParameter net_param;
119   // For intermediate results, we will also dump the gradient values.
120   net_->ToProto(&net_param, param_.snapshot_diff());
121   string filename(param_.snapshot_prefix());
122   char iter_str_buffer[20];
123   sprintf(iter_str_buffer, "_iter_%d", iter_);
124   filename += iter_str_buffer;
125   LOG(INFO) << "Snapshotting to " << filename;
126   WriteProtoToBinaryFile(net_param, filename.c_str());
127   SolverState state;
128   SnapshotSolverState(&state);
129   state.set_iter(iter_);
130   state.set_learned_net(filename);
131   filename += ".solverstate";
132   LOG(INFO) << "Snapshotting solver state to " << filename;
133   WriteProtoToBinaryFile(state, filename.c_str());
136 template <typename Dtype>
137 void Solver<Dtype>::Restore(const char* state_file) {
138   SolverState state;
139   NetParameter net_param;
140   ReadProtoFromBinaryFile(state_file, &state);
141   ReadProtoFromBinaryFile(state.learned_net().c_str(), &net_param);
142   net_->CopyTrainedLayersFrom(net_param);
143   iter_ = state.iter();
144   RestoreSolverState(state);
148 // Return the current learning rate. The currently implemented learning rate
149 // policies are as follows:
150 //    - fixed: always return base_lr.
151 //    - step: return base_lr * gamma ^ (floor(iter / step))
152 //    - exp: return base_lr * gamma ^ iter
153 //    - inv: return base_lr * (1 + gamma * iter) ^ (- power)
154 // where base_lr, gamma, step and power are defined in the solver parameter
155 // protocol buffer, and iter is the current iteration.
156 template <typename Dtype>
157 Dtype SGDSolver<Dtype>::GetLearningRate() {
158   Dtype rate;
159   const string& lr_policy = this->param_.lr_policy();
160   if (lr_policy == "fixed") {
161     rate = this->param_.base_lr();
162   } else if (lr_policy == "step") {
163     int current_step = this->iter_ / this->param_.stepsize();
164     rate = this->param_.base_lr() *
165         pow(this->param_.gamma(), current_step);
166   } else if (lr_policy == "exp") {
167     rate = this->param_.base_lr() * pow(this->param_.gamma(), this->iter_);
168   } else if (lr_policy == "inv") {
169     rate = this->param_.base_lr() *
170         pow(Dtype(1) + this->param_.gamma() * this->iter_,
171             - this->param_.power());
172   } else {
173     LOG(FATAL) << "Unknown learning rate policy: " << lr_policy;
174   }
175   return rate;
179 template <typename Dtype>
180 void SGDSolver<Dtype>::PreSolve() {
181   // Initialize the history
182   vector<shared_ptr<Blob<Dtype> > >& net_params = this->net_->params();
183   history_.clear();
184   for (int i = 0; i < net_params.size(); ++i) {
185     const Blob<Dtype>* net_param = net_params[i].get();
186     history_.push_back(shared_ptr<Blob<Dtype> >(new Blob<Dtype>(
187         net_param->num(), net_param->channels(), net_param->height(),
188         net_param->width())));
189   }
193 template <typename Dtype>
194 void SGDSolver<Dtype>::ComputeUpdateValue() {
195   vector<shared_ptr<Blob<Dtype> > >& net_params = this->net_->params();
196   vector<float>& net_params_lr = this->net_->params_lr();
197   vector<float>& net_params_weight_decay = this->net_->params_weight_decay();
198   // get the learning rate
199   Dtype rate = GetLearningRate();
200   if (this->param_.display() && this->iter_ % this->param_.display() == 0) {
201     LOG(INFO) << "Iteration " << this->iter_ << ", lr = " << rate;
202   }
203   Dtype momentum = this->param_.momentum();
204   Dtype weight_decay = this->param_.weight_decay();
205   switch (Caffe::mode()) {
206   case Caffe::CPU:
207     for (int param_id = 0; param_id < net_params.size(); ++param_id) {
208       // Compute the value to history, and then copy them to the blob's diff.
209       Dtype local_rate = rate * net_params_lr[param_id];
210       Dtype local_decay = weight_decay * net_params_weight_decay[param_id];
211       caffe_axpby(net_params[param_id]->count(), local_rate,
212           net_params[param_id]->cpu_diff(), momentum,
213           history_[param_id]->mutable_cpu_data());
214       if (local_decay) {
215         // add weight decay
216         caffe_axpy(net_params[param_id]->count(),
217             local_decay * local_rate,
218             net_params[param_id]->cpu_data(),
219             history_[param_id]->mutable_cpu_data());
220       }
221       // copy
222       caffe_copy(net_params[param_id]->count(),
223           history_[param_id]->cpu_data(),
224           net_params[param_id]->mutable_cpu_diff());
225     }
226     break;
227   case Caffe::GPU:
228     for (int param_id = 0; param_id < net_params.size(); ++param_id) {
229       // Compute the value to history, and then copy them to the blob's diff.
230       Dtype local_rate = rate * net_params_lr[param_id];
231       Dtype local_decay = weight_decay * net_params_weight_decay[param_id];
232       caffe_gpu_axpby(net_params[param_id]->count(), local_rate,
233           net_params[param_id]->gpu_diff(), momentum,
234           history_[param_id]->mutable_gpu_data());
235       if (local_decay) {
236         // add weight decay
237         caffe_gpu_axpy(net_params[param_id]->count(),
238             local_decay * local_rate,
239             net_params[param_id]->gpu_data(),
240             history_[param_id]->mutable_gpu_data());
241       }
242       // copy
243       caffe_gpu_copy(net_params[param_id]->count(),
244           history_[param_id]->gpu_data(),
245           net_params[param_id]->mutable_gpu_diff());
246     }
247     break;
248   default:
249     LOG(FATAL) << "Unknown caffe mode: " << Caffe::mode();
250   }
253 template <typename Dtype>
254 void SGDSolver<Dtype>::SnapshotSolverState(SolverState* state) {
255   state->clear_history();
256   for (int i = 0; i < history_.size(); ++i) {
257     // Add history
258     BlobProto* history_blob = state->add_history();
259     history_[i]->ToProto(history_blob);
260   }
263 template <typename Dtype>
264 void SGDSolver<Dtype>::RestoreSolverState(const SolverState& state) {
265   CHECK_EQ(state.history_size(), history_.size())
266       << "Incorrect length of history blobs.";
267   LOG(INFO) << "SGDSolver: restoring history";
268   for (int i = 0; i < history_.size(); ++i) {
269     history_[i]->FromProto(state.history(i));
270   }
273 INSTANTIATE_CLASS(Solver);
274 INSTANTIATE_CLASS(SGDSolver);
276 }  // namespace caffe