1 // Copyright Yangqing Jia 2013
3 #include <fstream>
4 #include <string>
6 #include "caffe/proto/caffe.pb.h"
7 #include "caffe/net.hpp"
8 #include "caffe/optimization/solver.hpp"
10 using std::stringstream;
11 using std::ofstream;
13 namespace caffe {
15 template <typename Dtype>
16 void Solver<Dtype>::Solve(Net<Dtype>* net) {
17 net_ = net;
18 LOG(INFO) << "Solving net " << net_->name();
19 iter_ = 0;
20 // For a network that is trained by the solver, no bottom or top vecs
21 // should be given, and we will just provide dummy vecs.
22 vector<Blob<Dtype>*> bottom_vec;
23 while (iter_++ < param_.max_iter()) {
24 Dtype loss = net_->ForwardBackWard(bottom_vec);
25 ComputeUpdateValue();
26 net->Update();
28 // Check if we need to do snapshot
29 if (iter_ % param_.snapshot()) {
30 // TODO(Yangqing): snapshot
31 }
32 LOG(INFO) << "Iteration" << iter_ << ", loss=" << loss;
33 }
34 LOG(INFO) << "Optimization Done.";
35 }
37 template <typename Dtype>
38 void Solver<Dtype>::Snapshot(bool is_final) {
39 NetParameter net_param;
40 net_->ToProto(&net_param);
41 stringstream ss;
42 ss << param_.snapshot_prefix();
43 if (is_final) {
44 ss << "_final";
45 } else {
46 ss << "_iter_" << iter_;
47 }
48 ofstream output_file;
49 output_file.open(ss.str().c_str());
50 CHECK(net_param.SerializeToOstream(&output_file));
51 output_file.close();
52 }
54 template <typename Dtype>
55 Dtype SGDSolver<Dtype>::GetLearningRate() {
56 Dtype rate;
57 const string& lr_policy = this->param_.lr_policy();
58 if (lr_policy == "fixed") {
59 rate = this->param_.base_lr();
60 } else if (lr_policy == "exp") {
61 rate = this->param_.base_lr() * pow(this->param_.gamma(), this->iter_);
62 } else if (lr_policy == "inv") {
63 rate = this->param_.base_lr() *
64 pow(Dtype(1) + this->param_.gamma() * this->iter_,
65 this->param_.power());
66 } else {
67 LOG(FATAL) << "Unknown learning rate policy: " << lr_policy;
68 }
69 rate = min(max(rate, this->param_.min_pr()), this->param_.max_lr());
70 return rate;
71 }
73 template <typename Dtype>
74 void SGDSolver<Dtype>::ComputeUpdateValue() {
75 // First of all, see if we need to initialize the history
76 vector<shared_ptr<Blob<Dtype> > >& net_params = this->net_.params();
77 if (this->iter_ == 1 && this->param_.momentum() > 0) {
78 LOG(INFO) << "Using momentum " << this->param_.momentum();
79 for (int i = 0; i < net_params.size(); ++i) {
80 const Blob<Dtype>* net_param = net_params[i].get();
81 history_.push_back(shared_ptr<Blob<Dtype> >(new Blob<Dtype>(
82 net_param->num(), net_param->channels(), net_param->height(),
83 net_param->width())));
84 }
85 }
86 // get the learning rate
87 Dtype rate = GetLearningRate();
88 if (this->param_.momentum == 0) {
89 for (int i = 0; i < net_params.size(); ++i) {
90 switch (Caffe::mode()) {
91 case Caffe::CPU:
92 caffe_scal(net_params[i]->count(), rate,
93 net_params[i]->mutable_cpu_data());
94 break;
95 case Caffe::GPU:
96 caffe_gpu_scal(net_params[i]->count(), rate,
97 net_params[i]->mutable_gpu_data());
98 break;
99 default:
100 LOG(FATAL) << "Unknown caffe mode: " << Caffe::mode();
101 }
102 }
103 } else {
104 NOT_IMPLEMENTED;
105 }
106 }
110 INSTANTIATE_CLASS(Solver);
112 } // namespace caffe