1 // Copyright Yangqing Jia 2013
3 #include <cstdio>
5 #include <algorithm>
6 #include <string>
7 #include <vector>
9 #include "caffe/net.hpp"
10 #include "caffe/proto/caffe.pb.h"
11 #include "caffe/solver.hpp"
12 #include "caffe/util/io.hpp"
13 #include "caffe/util/math_functions.hpp"
15 using std::max;
16 using std::min;
18 namespace caffe {
20 template <typename Dtype>
21 Solver<Dtype>::Solver(const SolverParameter& param)
22 : param_(param), net_(), test_net_() {
23 // Scaffolding code
24 NetParameter train_net_param;
25 ReadProtoFromTextFile(param_.train_net(), &train_net_param);
26 LOG(INFO) << "Creating training net.";
27 net_.reset(new Net<Dtype>(train_net_param));
28 if (param_.has_test_net()) {
29 LOG(INFO) << "Creating testing net.";
30 NetParameter test_net_param;
31 ReadProtoFromTextFile(param_.test_net(), &test_net_param);
32 test_net_.reset(new Net<Dtype>(test_net_param));
33 CHECK_GT(param_.test_iter(), 0);
34 CHECK_GT(param_.test_interval(), 0);
35 }
36 LOG(INFO) << "Solver scaffolding done.";
37 }
40 template <typename Dtype>
41 void Solver<Dtype>::Solve(const char* resume_file) {
42 Caffe::set_phase(Caffe::TRAIN);
43 LOG(INFO) << "Solving " << net_->name();
44 PreSolve();
46 iter_ = 0;
47 if (resume_file) {
48 LOG(INFO) << "Restoring previous solver status from " << resume_file;
49 Restore(resume_file);
50 }
52 // For a network that is trained by the solver, no bottom or top vecs
53 // should be given, and we will just provide dummy vecs.
54 vector<Blob<Dtype>*> bottom_vec;
55 while (iter_++ < param_.max_iter()) {
56 Dtype loss = net_->ForwardBackward(bottom_vec);
57 ComputeUpdateValue();
58 net_->Update();
60 // Check if we need to do snapshot
61 if (param_.snapshot() && iter_ % param_.snapshot() == 0) {
62 Snapshot();
63 }
64 if (param_.display() && iter_ % param_.display() == 0) {
65 LOG(INFO) << "Iteration " << iter_ << ", loss = " << loss;
66 }
67 if (param_.test_interval() && iter_ % param_.test_interval() == 0) {
68 // We need to set phase to test before running.
69 Caffe::set_phase(Caffe::TEST);
70 Test();
71 Caffe::set_phase(Caffe::TRAIN);
72 }
73 }
74 LOG(INFO) << "Optimization Done.";
75 }
78 template <typename Dtype>
79 void Solver<Dtype>::Test() {
80 LOG(INFO) << "Testing net";
81 NetParameter net_param;
82 net_->ToProto(&net_param);
83 CHECK_NOTNULL(test_net_.get())->CopyTrainedLayersFrom(net_param);
84 vector<Dtype> test_score;
85 vector<Blob<Dtype>*> bottom_vec;
86 for (int i = 0; i < param_.test_iter(); ++i) {
87 const vector<Blob<Dtype>*>& result =
88 test_net_->Forward(bottom_vec);
89 if (i == 0) {
90 for (int j = 0; j < result.size(); ++j) {
91 const Dtype* result_vec = result[j]->cpu_data();
92 for (int k = 0; k < result[j]->count(); ++k) {
93 test_score.push_back(result_vec[k]);
94 }
95 }
96 } else {
97 int idx = 0;
98 for (int j = 0; j < result.size(); ++j) {
99 const Dtype* result_vec = result[j]->cpu_data();
100 for (int k = 0; k < result[j]->count(); ++k) {
101 test_score[idx++] += result_vec[k];
102 }
103 }
104 }
105 }
106 for (int i = 0; i < test_score.size(); ++i) {
107 LOG(INFO) << "Test score #" << i << ": "
108 << test_score[i] / param_.test_iter();
109 }
110 }
113 template <typename Dtype>
114 void Solver<Dtype>::Snapshot() {
115 NetParameter net_param;
116 // For intermediate results, we will also dump the gradient values.
117 net_->ToProto(&net_param, param_.snapshot_diff());
118 string filename(param_.snapshot_prefix());
119 char iter_str_buffer[20];
120 sprintf(iter_str_buffer, "_iter_%d", iter_);
121 filename += iter_str_buffer;
122 LOG(INFO) << "Snapshotting to " << filename;
123 WriteProtoToBinaryFile(net_param, filename.c_str());
124 SolverState state;
125 SnapshotSolverState(&state);
126 state.set_iter(iter_);
127 state.set_learned_net(filename);
128 filename += ".solverstate";
129 LOG(INFO) << "Snapshotting solver state to " << filename;
130 WriteProtoToBinaryFile(state, filename.c_str());
131 }
133 template <typename Dtype>
134 void Solver<Dtype>::Restore(const char* state_file) {
135 SolverState state;
136 NetParameter net_param;
137 ReadProtoFromBinaryFile(state_file, &state);
138 if (state.has_learned_net()) {
139 ReadProtoFromBinaryFile(state.learned_net().c_str(), &net_param);
140 net_->CopyTrainedLayersFrom(net_param);
141 }
142 iter_ = state.iter();
143 RestoreSolverState(state);
144 }
147 // Return the current learning rate. The currently implemented learning rate
148 // policies are as follows:
149 // - fixed: always return base_lr.
150 // - step: return base_lr * gamma ^ (floor(iter / step))
151 // - exp: return base_lr * gamma ^ iter
152 // - inv: return base_lr * (1 + gamma * iter) ^ (- power)
153 // where base_lr, gamma, step and power are defined in the solver parameter
154 // protocol buffer, and iter is the current iteration.
155 template <typename Dtype>
156 Dtype SGDSolver<Dtype>::GetLearningRate() {
157 Dtype rate;
158 const string& lr_policy = this->param_.lr_policy();
159 if (lr_policy == "fixed") {
160 rate = this->param_.base_lr();
161 } else if (lr_policy == "step") {
162 int current_step = this->iter_ / this->param_.stepsize();
163 rate = this->param_.base_lr() *
164 pow(this->param_.gamma(), current_step);
165 } else if (lr_policy == "exp") {
166 rate = this->param_.base_lr() * pow(this->param_.gamma(), this->iter_);
167 } else if (lr_policy == "inv") {
168 rate = this->param_.base_lr() *
169 pow(Dtype(1) + this->param_.gamma() * this->iter_,
170 - this->param_.power());
171 } else {
172 LOG(FATAL) << "Unknown learning rate policy: " << lr_policy;
173 }
174 return rate;
175 }
178 template <typename Dtype>
179 void SGDSolver<Dtype>::PreSolve() {
180 // Initialize the history
181 vector<shared_ptr<Blob<Dtype> > >& net_params = this->net_->params();
182 history_.clear();
183 for (int i = 0; i < net_params.size(); ++i) {
184 const Blob<Dtype>* net_param = net_params[i].get();
185 history_.push_back(shared_ptr<Blob<Dtype> >(new Blob<Dtype>(
186 net_param->num(), net_param->channels(), net_param->height(),
187 net_param->width())));
188 }
189 }
192 template <typename Dtype>
193 void SGDSolver<Dtype>::ComputeUpdateValue() {
194 vector<shared_ptr<Blob<Dtype> > >& net_params = this->net_->params();
195 vector<float>& net_params_lr = this->net_->params_lr();
196 vector<float>& net_params_weight_decay = this->net_->params_weight_decay();
197 // get the learning rate
198 Dtype rate = GetLearningRate();
199 if (this->param_.display() && this->iter_ % this->param_.display() == 0) {
200 LOG(INFO) << "Iteration " << this->iter_ << ", lr = " << rate;
201 }
202 Dtype momentum = this->param_.momentum();
203 Dtype weight_decay = this->param_.weight_decay();
204 switch (Caffe::mode()) {
205 case Caffe::CPU:
206 for (int param_id = 0; param_id < net_params.size(); ++param_id) {
207 // Compute the value to history, and then copy them to the blob's diff.
208 Dtype local_rate = rate * net_params_lr[param_id];
209 Dtype local_decay = weight_decay * net_params_weight_decay[param_id];
210 caffe_axpby(net_params[param_id]->count(), local_rate,
211 net_params[param_id]->cpu_diff(), momentum,
212 history_[param_id]->mutable_cpu_data());
213 if (local_decay) {
214 // add weight decay
215 caffe_axpy(net_params[param_id]->count(),
216 local_decay * local_rate,
217 net_params[param_id]->cpu_data(),
218 history_[param_id]->mutable_cpu_data());
219 }
220 // copy
221 caffe_copy(net_params[param_id]->count(),
222 history_[param_id]->cpu_data(),
223 net_params[param_id]->mutable_cpu_diff());
224 }
225 break;
226 case Caffe::GPU:
227 for (int param_id = 0; param_id < net_params.size(); ++param_id) {
228 // Compute the value to history, and then copy them to the blob's diff.
229 Dtype local_rate = rate * net_params_lr[param_id];
230 Dtype local_decay = weight_decay * net_params_weight_decay[param_id];
231 caffe_gpu_axpby(net_params[param_id]->count(), local_rate,
232 net_params[param_id]->gpu_diff(), momentum,
233 history_[param_id]->mutable_gpu_data());
234 if (local_decay) {
235 // add weight decay
236 caffe_gpu_axpy(net_params[param_id]->count(),
237 local_decay * local_rate,
238 net_params[param_id]->gpu_data(),
239 history_[param_id]->mutable_gpu_data());
240 }
241 // copy
242 caffe_gpu_copy(net_params[param_id]->count(),
243 history_[param_id]->gpu_data(),
244 net_params[param_id]->mutable_gpu_diff());
245 }
246 break;
247 default:
248 LOG(FATAL) << "Unknown caffe mode: " << Caffe::mode();
249 }
250 }
252 template <typename Dtype>
253 void SGDSolver<Dtype>::SnapshotSolverState(SolverState* state) {
254 state->clear_history();
255 for (int i = 0; i < history_.size(); ++i) {
256 // Add history
257 BlobProto* history_blob = state->add_history();
258 history_[i]->ToProto(history_blob);
259 }
260 }
262 template <typename Dtype>
263 void SGDSolver<Dtype>::RestoreSolverState(const SolverState& state) {
264 CHECK_EQ(state.history_size(), history_.size())
265 << "Incorrect length of history blobs.";
266 LOG(INFO) << "SGDSolver: restoring history";
267 for (int i = 0; i < history_.size(); ++i) {
268 history_[i]->FromProto(state.history(i));
269 }
270 }
272 INSTANTIATE_CLASS(Solver);
273 INSTANTIATE_CLASS(SGDSolver);
275 } // namespace caffe