a bunch of updates.
[jacinto-ai/caffe-jacinto.git] / src / caffe / optimization / solver.cpp
1 // Copyright Yangqing Jia 2013
3 #include <fstream>
4 #include <string>
6 #include "caffe/proto/caffe.pb.h"
7 #include "caffe/net.hpp"
8 #include "caffe/optimization/solver.hpp"
10 using std::stringstream;
11 using std::ofstream;
13 namespace caffe {
15 template <typename Dtype>
16 void Solver<Dtype>::Solve(Net<Dtype>* net) {
17   net_ = net;
18   LOG(INFO) << "Solving net " << net_->name();
19   iter_ = 0;
20   // For a network that is trained by the solver, no bottom or top vecs
21   // should be given, and we will just provide dummy vecs.
22   vector<Blob<Dtype>*> bottom_vec;
23   while (iter_++ < param_.max_iter()) {
24     Dtype loss = net_->ForwardBackWard(bottom_vec);
25     ComputeUpdateValue();
26     net->Update();
28     // Check if we need to do snapshot
29     if (iter_ % param_.snapshot()) {
30       // TODO(Yangqing): snapshot
31     }
32     LOG(INFO) << "Iteration" << iter_ << ", loss=" << loss;
33   }
34   LOG(INFO) << "Optimization Done.";
35 }
37 template <typename Dtype>
38 void Solver<Dtype>::Snapshot(bool is_final) {
39   NetParameter net_param;
40   net_->ToProto(&net_param);
41   stringstream ss;
42   ss << param_.snapshot_prefix();
43   if (is_final) {
44     ss << "_final";
45   } else {
46     ss << "_iter_" << iter_;
47   }
48   ofstream output_file;
49   output_file.open(ss.str().c_str());
50   CHECK(net_param.SerializeToOstream(&output_file));
51   output_file.close();
52 }
54 template <typename Dtype>
55 Dtype SGDSolver<Dtype>::GetLearningRate() {
56   Dtype rate;
57   const string& lr_policy = this->param_.lr_policy();
58   if (lr_policy == "fixed") {
59     rate = this->param_.base_lr();
60   } else if (lr_policy == "exp") {
61     rate = this->param_.base_lr() * pow(this->param_.gamma(), this->iter_);
62   } else if (lr_policy == "inv") {
63     rate = this->param_.base_lr() *
64         pow(Dtype(1) + this->param_.gamma() * this->iter_,
65             this->param_.power());
66   } else {
67     LOG(FATAL) << "Unknown learning rate policy: " << lr_policy;
68   }
69   rate = min(max(rate, this->param_.min_pr()), this->param_.max_lr());
70   return rate;
71 }
73 template <typename Dtype>
74 void SGDSolver<Dtype>::ComputeUpdateValue() {
75   // First of all, see if we need to initialize the history
76   vector<shared_ptr<Blob<Dtype> > >& net_params = this->net_.params();
77   if (this->iter_ == 1 && this->param_.momentum() > 0) {
78     LOG(INFO) << "Using momentum " << this->param_.momentum();
79     for (int i = 0; i < net_params.size(); ++i) {
80       const Blob<Dtype>* net_param = net_params[i].get();
81       history_.push_back(shared_ptr<Blob<Dtype> >(new Blob<Dtype>(
82           net_param->num(), net_param->channels(), net_param->height(),
83           net_param->width())));
84     }
85   }
86   // get the learning rate
87   Dtype rate = GetLearningRate();
88   if (this->param_.momentum == 0) {
89     for (int i = 0; i < net_params.size(); ++i) {
90       switch (Caffe::mode()) {
91       case Caffe::CPU:
92         caffe_scal(net_params[i]->count(), rate,
93             net_params[i]->mutable_cpu_data());
94         break;
95       case Caffe::GPU:
96         caffe_gpu_scal(net_params[i]->count(), rate,
97             net_params[i]->mutable_gpu_data());
98         break;
99       default:
100         LOG(FATAL) << "Unknown caffe mode: " << Caffe::mode();
101       }
102     }
103   } else {
104     NOT_IMPLEMENTED;
105   }
110 INSTANTIATE_CLASS(Solver);
112 }  // namespace caffe