]> Gitweb @ Texas Instruments - Open Source Git Repositories - git.TI.com/gitweb - jacinto-ai/caffe-jacinto.git/blob - src/caffe/solver.cpp
Several changes:
[jacinto-ai/caffe-jacinto.git] / src / caffe / solver.cpp
1 // Copyright Yangqing Jia 2013
3 #include <cstdio>
5 #include <algorithm>
6 #include <string>
7 #include <vector>
9 #include "caffe/net.hpp"
10 #include "caffe/proto/caffe.pb.h"
11 #include "caffe/solver.hpp"
12 #include "caffe/util/io.hpp"
13 #include "caffe/util/math_functions.hpp"
15 using std::max;
16 using std::min;
18 namespace caffe {
20 template <typename Dtype>
21 Solver<Dtype>::Solver(const SolverParameter& param)
22     : param_(param), net_(), test_net_() {
23   // Scaffolding code
24   NetParameter train_net_param;
25   ReadProtoFromTextFile(param_.train_net(), &train_net_param);
26   LOG(INFO) << "Creating training net.";
27   net_.reset(new Net<Dtype>(train_net_param));
28   if (param_.has_test_net()) {
29     LOG(INFO) << "Creating testing net.";
30     NetParameter test_net_param;
31     ReadProtoFromTextFile(param_.test_net(), &test_net_param);
32     test_net_.reset(new Net<Dtype>(test_net_param));
33     CHECK_GT(param_.test_iter(), 0);
34     CHECK_GT(param_.test_interval(), 0);
35   }
36   LOG(INFO) << "Solver scaffolding done.";
37 }
40 template <typename Dtype>
41 void Solver<Dtype>::Solve(const char* resume_file) {
42   Caffe::set_mode(Caffe::Brew(param_.solver_mode()));
43   Caffe::set_phase(Caffe::TRAIN);
44   LOG(INFO) << "Solving " << net_->name();
45   PreSolve();
47   iter_ = 0;
48   if (resume_file) {
49     LOG(INFO) << "Restoring previous solver status from " << resume_file;
50     Restore(resume_file);
51   }
53   // For a network that is trained by the solver, no bottom or top vecs
54   // should be given, and we will just provide dummy vecs.
55   vector<Blob<Dtype>*> bottom_vec;
56   while (iter_++ < param_.max_iter()) {
57     Dtype loss = net_->ForwardBackward(bottom_vec);
58     ComputeUpdateValue();
59     net_->Update();
61     // Check if we need to do snapshot
62     if (param_.snapshot() && iter_ % param_.snapshot() == 0) {
63       Snapshot();
64     }
65     if (param_.display() && iter_ % param_.display() == 0) {
66       LOG(INFO) << "Iteration " << iter_ << ", loss = " << loss;
67     }
68     if (param_.test_interval() && iter_ % param_.test_interval() == 0) {
69       // We need to set phase to test before running.
70       Caffe::set_phase(Caffe::TEST);
71       Test();
72       Caffe::set_phase(Caffe::TRAIN);
73     }
74   }
75   // After the optimization is done, always do a snapshot.
76   iter_--;
77   Snapshot();
78   LOG(INFO) << "Optimization Done.";
79 }
82 template <typename Dtype>
83 void Solver<Dtype>::Test() {
84   LOG(INFO) << "Testing net";
85   NetParameter net_param;
86   net_->ToProto(&net_param);
87   CHECK_NOTNULL(test_net_.get())->CopyTrainedLayersFrom(net_param);
88   vector<Dtype> test_score;
89   vector<Blob<Dtype>*> bottom_vec;
90   for (int i = 0; i < param_.test_iter(); ++i) {
91     const vector<Blob<Dtype>*>& result =
92         test_net_->Forward(bottom_vec);
93     if (i == 0) {
94       for (int j = 0; j < result.size(); ++j) {
95         const Dtype* result_vec = result[j]->cpu_data();
96         for (int k = 0; k < result[j]->count(); ++k) {
97           test_score.push_back(result_vec[k]);
98         }
99       }
100     } else {
101       int idx = 0;
102       for (int j = 0; j < result.size(); ++j) {
103         const Dtype* result_vec = result[j]->cpu_data();
104         for (int k = 0; k < result[j]->count(); ++k) {
105           test_score[idx++] += result_vec[k];
106         }
107       }
108     }
109   }
110   for (int i = 0; i < test_score.size(); ++i) {
111     LOG(INFO) << "Test score #" << i << ": "
112         << test_score[i] / param_.test_iter();
113   }
117 template <typename Dtype>
118 void Solver<Dtype>::Snapshot() {
119   NetParameter net_param;
120   // For intermediate results, we will also dump the gradient values.
121   net_->ToProto(&net_param, param_.snapshot_diff());
122   string filename(param_.snapshot_prefix());
123   char iter_str_buffer[20];
124   sprintf(iter_str_buffer, "_iter_%d", iter_);
125   filename += iter_str_buffer;
126   LOG(INFO) << "Snapshotting to " << filename;
127   WriteProtoToBinaryFile(net_param, filename.c_str());
128   SolverState state;
129   SnapshotSolverState(&state);
130   state.set_iter(iter_);
131   state.set_learned_net(filename);
132   filename += ".solverstate";
133   LOG(INFO) << "Snapshotting solver state to " << filename;
134   WriteProtoToBinaryFile(state, filename.c_str());
137 template <typename Dtype>
138 void Solver<Dtype>::Restore(const char* state_file) {
139   SolverState state;
140   NetParameter net_param;
141   ReadProtoFromBinaryFile(state_file, &state);
142   if (state.has_learned_net()) {
143     ReadProtoFromBinaryFile(state.learned_net().c_str(), &net_param);
144     net_->CopyTrainedLayersFrom(net_param);
145   }
146   iter_ = state.iter();
147   RestoreSolverState(state);
151 // Return the current learning rate. The currently implemented learning rate
152 // policies are as follows:
153 //    - fixed: always return base_lr.
154 //    - step: return base_lr * gamma ^ (floor(iter / step))
155 //    - exp: return base_lr * gamma ^ iter
156 //    - inv: return base_lr * (1 + gamma * iter) ^ (- power)
157 // where base_lr, gamma, step and power are defined in the solver parameter
158 // protocol buffer, and iter is the current iteration.
159 template <typename Dtype>
160 Dtype SGDSolver<Dtype>::GetLearningRate() {
161   Dtype rate;
162   const string& lr_policy = this->param_.lr_policy();
163   if (lr_policy == "fixed") {
164     rate = this->param_.base_lr();
165   } else if (lr_policy == "step") {
166     int current_step = this->iter_ / this->param_.stepsize();
167     rate = this->param_.base_lr() *
168         pow(this->param_.gamma(), current_step);
169   } else if (lr_policy == "exp") {
170     rate = this->param_.base_lr() * pow(this->param_.gamma(), this->iter_);
171   } else if (lr_policy == "inv") {
172     rate = this->param_.base_lr() *
173         pow(Dtype(1) + this->param_.gamma() * this->iter_,
174             - this->param_.power());
175   } else {
176     LOG(FATAL) << "Unknown learning rate policy: " << lr_policy;
177   }
178   return rate;
182 template <typename Dtype>
183 void SGDSolver<Dtype>::PreSolve() {
184   // Initialize the history
185   vector<shared_ptr<Blob<Dtype> > >& net_params = this->net_->params();
186   history_.clear();
187   for (int i = 0; i < net_params.size(); ++i) {
188     const Blob<Dtype>* net_param = net_params[i].get();
189     history_.push_back(shared_ptr<Blob<Dtype> >(new Blob<Dtype>(
190         net_param->num(), net_param->channels(), net_param->height(),
191         net_param->width())));
192   }
196 template <typename Dtype>
197 void SGDSolver<Dtype>::ComputeUpdateValue() {
198   vector<shared_ptr<Blob<Dtype> > >& net_params = this->net_->params();
199   vector<float>& net_params_lr = this->net_->params_lr();
200   vector<float>& net_params_weight_decay = this->net_->params_weight_decay();
201   // get the learning rate
202   Dtype rate = GetLearningRate();
203   if (this->param_.display() && this->iter_ % this->param_.display() == 0) {
204     LOG(INFO) << "Iteration " << this->iter_ << ", lr = " << rate;
205   }
206   Dtype momentum = this->param_.momentum();
207   Dtype weight_decay = this->param_.weight_decay();
208   switch (Caffe::mode()) {
209   case Caffe::CPU:
210     for (int param_id = 0; param_id < net_params.size(); ++param_id) {
211       // Compute the value to history, and then copy them to the blob's diff.
212       Dtype local_rate = rate * net_params_lr[param_id];
213       Dtype local_decay = weight_decay * net_params_weight_decay[param_id];
214       caffe_axpby(net_params[param_id]->count(), local_rate,
215           net_params[param_id]->cpu_diff(), momentum,
216           history_[param_id]->mutable_cpu_data());
217       if (local_decay) {
218         // add weight decay
219         caffe_axpy(net_params[param_id]->count(),
220             local_decay * local_rate,
221             net_params[param_id]->cpu_data(),
222             history_[param_id]->mutable_cpu_data());
223       }
224       // copy
225       caffe_copy(net_params[param_id]->count(),
226           history_[param_id]->cpu_data(),
227           net_params[param_id]->mutable_cpu_diff());
228     }
229     break;
230   case Caffe::GPU:
231     for (int param_id = 0; param_id < net_params.size(); ++param_id) {
232       // Compute the value to history, and then copy them to the blob's diff.
233       Dtype local_rate = rate * net_params_lr[param_id];
234       Dtype local_decay = weight_decay * net_params_weight_decay[param_id];
235       caffe_gpu_axpby(net_params[param_id]->count(), local_rate,
236           net_params[param_id]->gpu_diff(), momentum,
237           history_[param_id]->mutable_gpu_data());
238       if (local_decay) {
239         // add weight decay
240         caffe_gpu_axpy(net_params[param_id]->count(),
241             local_decay * local_rate,
242             net_params[param_id]->gpu_data(),
243             history_[param_id]->mutable_gpu_data());
244       }
245       // copy
246       caffe_gpu_copy(net_params[param_id]->count(),
247           history_[param_id]->gpu_data(),
248           net_params[param_id]->mutable_gpu_diff());
249     }
250     break;
251   default:
252     LOG(FATAL) << "Unknown caffe mode: " << Caffe::mode();
253   }
256 template <typename Dtype>
257 void SGDSolver<Dtype>::SnapshotSolverState(SolverState* state) {
258   state->clear_history();
259   for (int i = 0; i < history_.size(); ++i) {
260     // Add history
261     BlobProto* history_blob = state->add_history();
262     history_[i]->ToProto(history_blob);
263   }
266 template <typename Dtype>
267 void SGDSolver<Dtype>::RestoreSolverState(const SolverState& state) {
268   CHECK_EQ(state.history_size(), history_.size())
269       << "Incorrect length of history blobs.";
270   LOG(INFO) << "SGDSolver: restoring history";
271   for (int i = 0; i < history_.size(); ++i) {
272     history_[i]->FromProto(state.history(i));
273   }
276 INSTANTIATE_CLASS(Solver);
277 INSTANTIATE_CLASS(SGDSolver);
279 }  // namespace caffe