1 // Copyright Yangqing Jia 2013
3 #include <cstdio>
5 #include <algorithm>
6 #include <string>
7 #include <vector>
9 #include "caffe/net.hpp"
10 #include "caffe/proto/caffe.pb.h"
11 #include "caffe/solver.hpp"
12 #include "caffe/util/io.hpp"
13 #include "caffe/util/math_functions.hpp"
15 using std::max;
16 using std::min;
18 namespace caffe {
20 template <typename Dtype>
21 void Solver<Dtype>::Solve(Net<Dtype>* net, char* resume_file) {
22 net_ = net;
23 LOG(INFO) << "Solving " << net_->name();
24 PreSolve();
26 iter_ = 0;
27 if (resume_file) {
28 LOG(INFO) << "Restoring previous solver status from " << resume_file;
29 Restore(resume_file);
30 }
32 // For a network that is trained by the solver, no bottom or top vecs
33 // should be given, and we will just provide dummy vecs.
34 vector<Blob<Dtype>*> bottom_vec;
35 while (iter_++ < param_.max_iter()) {
36 Dtype loss = net_->ForwardBackward(bottom_vec);
37 ComputeUpdateValue();
38 net_->Update();
40 // Check if we need to do snapshot
41 if (param_.snapshot() && iter_ % param_.snapshot() == 0) {
42 Snapshot();
43 }
44 if (param_.display() && iter_ % param_.display() == 0) {
45 LOG(ERROR) << "Iteration " << iter_ << ", loss = " << loss;
46 }
47 }
48 LOG(INFO) << "Optimization Done.";
49 }
52 template <typename Dtype>
53 void Solver<Dtype>::Snapshot() {
54 NetParameter net_param;
55 // For intermediate results, we will also dump the gradient values.
56 net_->ToProto(&net_param, param_.snapshot_diff());
57 string filename(param_.snapshot_prefix());
58 char iter_str_buffer[20];
59 sprintf(iter_str_buffer, "_iter_%d", iter_);
60 filename += iter_str_buffer;
61 LOG(INFO) << "Snapshotting to " << filename;
62 WriteProtoToBinaryFile(net_param, filename.c_str());
63 SolverState state;
64 SnapshotSolverState(&state);
65 state.set_iter(iter_);
66 state.set_learned_net(filename);
67 filename += ".solverstate";
68 LOG(INFO) << "Snapshotting solver state to " << filename;
69 WriteProtoToBinaryFile(state, filename.c_str());
70 }
72 template <typename Dtype>
73 void Solver<Dtype>::Restore(char* state_file) {
74 SolverState state;
75 NetParameter net_param;
76 ReadProtoFromBinaryFile(state_file, &state);
77 ReadProtoFromBinaryFile(state.learned_net().c_str(), &net_param);
78 net_->CopyTrainedLayersFrom(net_param);
79 iter_ = state.iter();
80 RestoreSolverState(state);
81 }
84 // Return the current learning rate. The currently implemented learning rate
85 // policies are as follows:
86 // - fixed: always return base_lr.
87 // - step: return base_lr * gamma ^ (floor(iter / step))
88 // - exp: return base_lr * gamma ^ iter
89 // - inv: return base_lr * (1 + gamma * iter) ^ (- power)
90 // where base_lr, gamma, step and power are defined in the solver parameter
91 // protocol buffer, and iter is the current iteration.
92 template <typename Dtype>
93 Dtype SGDSolver<Dtype>::GetLearningRate() {
94 Dtype rate;
95 const string& lr_policy = this->param_.lr_policy();
96 if (lr_policy == "fixed") {
97 rate = this->param_.base_lr();
98 } else if (lr_policy == "step") {
99 int current_step = this->iter_ / this->param_.stepsize();
100 rate = this->param_.base_lr() *
101 pow(this->param_.gamma(), current_step);
102 } else if (lr_policy == "exp") {
103 rate = this->param_.base_lr() * pow(this->param_.gamma(), this->iter_);
104 } else if (lr_policy == "inv") {
105 rate = this->param_.base_lr() *
106 pow(Dtype(1) + this->param_.gamma() * this->iter_,
107 - this->param_.power());
108 } else {
109 LOG(FATAL) << "Unknown learning rate policy: " << lr_policy;
110 }
111 rate = min(max(rate, Dtype(this->param_.min_lr())),
112 Dtype(this->param_.max_lr()));
113 return rate;
114 }
117 template <typename Dtype>
118 void SGDSolver<Dtype>::PreSolve() {
119 // First of all, see if we need to initialize the history
120 vector<shared_ptr<Blob<Dtype> > >& net_params = this->net_->params();
121 history_.clear();
122 for (int i = 0; i < net_params.size(); ++i) {
123 const Blob<Dtype>* net_param = net_params[i].get();
124 history_.push_back(shared_ptr<Blob<Dtype> >(new Blob<Dtype>(
125 net_param->num(), net_param->channels(), net_param->height(),
126 net_param->width())));
127 }
128 }
131 template <typename Dtype>
132 void SGDSolver<Dtype>::ComputeUpdateValue() {
133 vector<shared_ptr<Blob<Dtype> > >& net_params = this->net_->params();
134 vector<float>& net_params_lr = this->net_->params_lr();
135 vector<float>& net_params_weight_decay = this->net_->params_weight_decay();
136 // get the learning rate
137 Dtype rate = GetLearningRate();
138 if (this->param_.display() && this->iter_ % this->param_.display() == 0) {
139 LOG(ERROR) << "Iteration " << this->iter_ << ", lr = " << rate;
140 }
141 Dtype momentum = this->param_.momentum();
142 Dtype weight_decay = this->param_.weight_decay();
143 switch (Caffe::mode()) {
144 case Caffe::CPU:
145 for (int param_id = 0; param_id < net_params.size(); ++param_id) {
146 // Compute the value to history, and then copy them to the blob's diff.
147 Dtype local_rate = rate * net_params_lr[param_id];
148 Dtype local_decay = weight_decay * net_params_weight_decay[param_id];
149 caffe_axpby(net_params[param_id]->count(), local_rate,
150 net_params[param_id]->cpu_diff(), momentum,
151 history_[param_id]->mutable_cpu_data());
152 if (local_decay) {
153 // add weight decay
154 caffe_axpy(net_params[param_id]->count(),
155 local_decay * local_rate,
156 net_params[param_id]->cpu_data(),
157 history_[param_id]->mutable_cpu_data());
158 }
159 // copy
160 caffe_copy(net_params[param_id]->count(),
161 history_[param_id]->cpu_data(),
162 net_params[param_id]->mutable_cpu_diff());
163 }
164 break;
165 case Caffe::GPU:
166 for (int param_id = 0; param_id < net_params.size(); ++param_id) {
167 // Compute the value to history, and then copy them to the blob's diff.
168 Dtype local_rate = rate * net_params_lr[param_id];
169 Dtype local_decay = weight_decay * net_params_weight_decay[param_id];
170 caffe_gpu_axpby(net_params[param_id]->count(), local_rate,
171 net_params[param_id]->gpu_diff(), momentum,
172 history_[param_id]->mutable_gpu_data());
173 if (local_decay) {
174 // add weight decay
175 caffe_gpu_axpy(net_params[param_id]->count(),
176 local_decay * local_rate,
177 net_params[param_id]->gpu_data(),
178 history_[param_id]->mutable_gpu_data());
179 }
180 // copy
181 caffe_gpu_copy(net_params[param_id]->count(),
182 history_[param_id]->gpu_data(),
183 net_params[param_id]->mutable_gpu_diff());
184 }
185 break;
186 default:
187 LOG(FATAL) << "Unknown caffe mode: " << Caffe::mode();
188 }
189 }
191 template <typename Dtype>
192 void SGDSolver<Dtype>::SnapshotSolverState(SolverState* state) {
193 state->clear_history();
194 for (int i = 0; i < history_.size(); ++i) {
195 // Add history
196 BlobProto* history_blob = state->add_history();
197 history_[i]->ToProto(history_blob);
198 }
199 }
201 template <typename Dtype>
202 void SGDSolver<Dtype>::RestoreSolverState(const SolverState& state) {
203 CHECK_EQ(state.history_size(), history_.size())
204 << "Incorrect length of history blobs.";
205 for (int i = 0; i < history_.size(); ++i) {
206 history_[i]->FromProto(state.history(i));
207 }
208 }
210 INSTANTIATE_CLASS(Solver);
211 INSTANTIATE_CLASS(SGDSolver);
213 } // namespace caffe