1 // Copyright Yangqing Jia 2013
3 #ifndef CAFFE_OPTIMIZATION_SOLVER_HPP_
4 #define CAFFE_OPTIMIZATION_SOLVER_HPP_
6 #include <vector>
8 namespace caffe {
10 template <typename Dtype>
11 class Solver {
12 public:
13 explicit Solver(const SolverParameter& param);
14 // The main entry of the solver function. In default, iter will be zero. Pass
15 // in a non-zero iter number to resume training for a pre-trained net.
16 virtual void Solve(const char* resume_file = NULL);
17 inline void Solve(const string resume_file) { Solve(resume_file.c_str()); }
18 virtual ~Solver() {}
19 inline Net<Dtype>* net() { return net_.get(); }
21 protected:
22 // PreSolve is run before any solving iteration starts, allowing one to
23 // put up some scaffold.
24 virtual void PreSolve() {}
25 // Get the update value for the current iteration.
26 virtual void ComputeUpdateValue() = 0;
27 // The Solver::Snapshot function implements the basic snapshotting utility
28 // that stores the learned net. You should implement the SnapshotSolverState()
29 // function that produces a SolverState protocol buffer that needs to be
30 // written to disk together with the learned net.
31 void Snapshot();
32 // The test routine
33 void Test();
34 virtual void SnapshotSolverState(SolverState* state) = 0;
35 // The Restore function implements how one should restore the solver to a
36 // previously snapshotted state. You should implement the RestoreSolverState()
37 // function that restores the state from a SolverState protocol buffer.
38 void Restore(const char* resume_file);
39 virtual void RestoreSolverState(const SolverState& state) = 0;
41 SolverParameter param_;
42 int iter_;
43 shared_ptr<Net<Dtype> > net_;
44 shared_ptr<Net<Dtype> > test_net_;
46 DISABLE_COPY_AND_ASSIGN(Solver);
47 };
50 template <typename Dtype>
51 class SGDSolver : public Solver<Dtype> {
52 public:
53 explicit SGDSolver(const SolverParameter& param)
54 : Solver<Dtype>(param) {}
56 protected:
57 virtual void PreSolve();
58 Dtype GetLearningRate();
59 virtual void ComputeUpdateValue();
60 virtual void SnapshotSolverState(SolverState * state);
61 virtual void RestoreSolverState(const SolverState& state);
62 // history maintains the historical momentum data.
63 vector<shared_ptr<Blob<Dtype> > > history_;
65 DISABLE_COPY_AND_ASSIGN(SGDSolver);
66 };
69 } // namspace caffe
71 #endif // CAFFE_OPTIMIZATION_SOLVER_HPP_