b9055d26eace5f50a1a8c05865c0ed64413f04ab
[jacinto-ai/caffe-jacinto.git] / src / caffe / optimization / solver.cpp
1 // Copyright Yangqing Jia 2013
3 #include <fstream>
4 #include <string>
6 #include "caffe/proto/caffe.pb.h"
7 #include "caffe/net.hpp"
8 #include "caffe/optimization/solver.hpp"
10 using std::stringstream;
11 using std::ofstream;
13 namespace caffe {
15 template <typename Dtype>
16 void Solver<Dtype>::Solve(Net<Dtype>* net) {
17   net_ = net;
18   LOG(INFO) << "Solving net " << net_->name();
19   iter_ = 0;
20   // For a network that is trained by the solver, no bottom or top vecs
21   // should be given, and we will just provide dummy vecs.
22   vector<Blob<Dtype>*> bottom_vec;
23   vector<Blob<Dtype>*> top_vec;
24   while (iter_++ < param_.max_iter()) {
25     Dtype loss = net_->ForwardBackWard(bottom_vec, &top_vec);
26     ComputeUpdateValue();
27     net->Update();
29     // Check if we need to do snapshot
30     if (iter_ % param_.snapshot()) {
31       // TODO(Yangqing): snapshot
32     }
33     LOG(INFO) << "Iteration" << iter_ << ", loss=" << loss;
34   }
35   LOG(INFO) << "Optimization Done.";
36 }
38 template <typename Dtype>
39 void Solver<Dtype>::Snapshot(bool is_final) {
40   NetParameter net_param;
41   net_->ToProto(&net_param);
42   stringstream ss;
43   ss << param_.snapshot_prefix();
44   if (is_final) {
45     ss << "_final";
46   } else {
47     ss << "_iter_" << iter_;
48   }
49   ofstream output_file;
50   output_file.open(ss.str().c_str());
51   CHECK(net_param.SerializeToOstream(&output_file));
52   output_file.close();
53 }
55 template <typename Dtype>
56 Dtype SGDSolver<Dtype>::GetLearningRate() {
57   Dtype rate;
58   const string& lr_policy = this->param_.lr_policy();
59   if (lr_policy == "fixed") {
60     rate = this->param_.base_lr();
61   } else if (lr_policy == "exp") {
62     rate = this->param_.base_lr() * pow(this->param_.gamma(), this->iter_);
63   } else if (lr_policy == "inv") {
64     rate = this->param_.base_lr() *
65         pow(Dtype(1) + this->param_.gamma() * this->iter_,
66             this->param_.power());
67   } else {
68     LOG(FATAL) << "Unknown learning rate policy: " << lr_policy;
69   }
70   rate = min(max(rate, this->param_.min_pr()), this->param_.max_lr());
71   return rate;
72 }
74 template <typename Dtype>
75 void SGDSolver<Dtype>::ComputeUpdateValue() {
76   // First of all, see if we need to initialize the history
77   vector<shared_ptr<Blob<Dtype> > >& net_params = this->net_.params();
78   if (this->iter_ == 1 && this->param_.momentum() > 0) {
79     LOG(INFO) << "Using momentum " << this->param_.momentum();
80     for (int i = 0; i < net_params.size(); ++i) {
81       const Blob<Dtype>* net_param = net_params[i].get();
82       history_.push_back(shared_ptr<Blob<Dtype> >(new Blob<Dtype>(
83           net_param->num(), net_param->channels(), net_param->height(),
84           net_param->width())));
85     }
86   }
87   // get the learning rate
88   Dtype rate = GetLearningRate();
89   if (this->param_.momentum == 0) {
90     for (int i = 0; i < net_params.size(); ++i) {
91       switch (Caffe::mode()) {
92       case Caffe::CPU:
93         caffe_scal(net_params[i]->count(), rate,
94             net_params[i]->mutable_cpu_data());
95         break;
96       case Caffe::GPU:
97         caffe_gpu_scal(net_params[i]->count(), rate,
98             net_params[i]->mutable_gpu_data());
99         break;
100       default:
101         LOG(FATAL) << "Unknown caffe mode: " << Caffe::mode();
102       }
103     }
104   } else {
105     NOT_IMPLEMENTED;
106   }
111 INSTANTIATE_CLASS(Solver);
113 }  // namespace caffe