1 // Copyright Yangqing Jia 2013
3 #include <fstream>
4 #include <string>
6 #include "caffe/proto/caffe.pb.h"
7 #include "caffe/net.hpp"
8 #include "caffe/optimization/solver.hpp"
10 using std::stringstream;
11 using std::ofstream;
13 namespace caffe {
15 template <typename Dtype>
16 void Solver<Dtype>::Solve(Net<Dtype>* net) {
17 net_ = net;
18 LOG(INFO) << "Solving net " << net_->name();
19 iter_ = 0;
20 // For a network that is trained by the solver, no bottom or top vecs
21 // should be given, and we will just provide dummy vecs.
22 vector<Blob<Dtype>*> bottom_vec;
23 vector<Blob<Dtype>*> top_vec;
24 while (iter_++ < param_.max_iter()) {
25 Dtype loss = net_->ForwardBackWard(bottom_vec, &top_vec);
26 ComputeUpdateValue();
27 net->Update();
29 // Check if we need to do snapshot
30 if (iter_ % param_.snapshot()) {
31 // TODO(Yangqing): snapshot
32 }
33 LOG(INFO) << "Iteration" << iter_ << ", loss=" << loss;
34 }
35 LOG(INFO) << "Optimization Done.";
36 }
38 template <typename Dtype>
39 void Solver<Dtype>::Snapshot(bool is_final) {
40 NetParameter net_param;
41 net_->ToProto(&net_param);
42 stringstream ss;
43 ss << param_.snapshot_prefix();
44 if (is_final) {
45 ss << "_final";
46 } else {
47 ss << "_iter_" << iter_;
48 }
49 ofstream output_file;
50 output_file.open(ss.str().c_str());
51 CHECK(net_param.SerializeToOstream(&output_file));
52 output_file.close();
53 }
55 template <typename Dtype>
56 Dtype SGDSolver<Dtype>::GetLearningRate() {
57 Dtype rate;
58 const string& lr_policy = this->param_.lr_policy();
59 if (lr_policy == "fixed") {
60 rate = this->param_.base_lr();
61 } else if (lr_policy == "exp") {
62 rate = this->param_.base_lr() * pow(this->param_.gamma(), this->iter_);
63 } else if (lr_policy == "inv") {
64 rate = this->param_.base_lr() *
65 pow(Dtype(1) + this->param_.gamma() * this->iter_,
66 this->param_.power());
67 } else {
68 LOG(FATAL) << "Unknown learning rate policy: " << lr_policy;
69 }
70 rate = min(max(rate, this->param_.min_pr()), this->param_.max_lr());
71 return rate;
72 }
74 template <typename Dtype>
75 void SGDSolver<Dtype>::ComputeUpdateValue() {
76 // First of all, see if we need to initialize the history
77 vector<shared_ptr<Blob<Dtype> > >& net_params = this->net_.params();
78 if (this->iter_ == 1 && this->param_.momentum() > 0) {
79 LOG(INFO) << "Using momentum " << this->param_.momentum();
80 for (int i = 0; i < net_params.size(); ++i) {
81 const Blob<Dtype>* net_param = net_params[i].get();
82 history_.push_back(shared_ptr<Blob<Dtype> >(new Blob<Dtype>(
83 net_param->num(), net_param->channels(), net_param->height(),
84 net_param->width())));
85 }
86 }
87 // get the learning rate
88 Dtype rate = GetLearningRate();
89 if (this->param_.momentum == 0) {
90 for (int i = 0; i < net_params.size(); ++i) {
91 switch (Caffe::mode()) {
92 case Caffe::CPU:
93 caffe_scal(net_params[i]->count(), rate,
94 net_params[i]->mutable_cpu_data());
95 break;
96 case Caffe::GPU:
97 caffe_gpu_scal(net_params[i]->count(), rate,
98 net_params[i]->mutable_gpu_data());
99 break;
100 default:
101 LOG(FATAL) << "Unknown caffe mode: " << Caffe::mode();
102 }
103 }
104 } else {
105 NOT_IMPLEMENTED;
106 }
107 }
111 INSTANTIATE_CLASS(Solver);
113 } // namespace caffe