2055ab6b24c5e42bf1d8d1360d6d53ab18038bed
1 // Copyright Yangqing Jia 2013
3 #ifndef CAFFE_OPTIMIZATION_SOLVER_HPP_
4 #define CAFFE_OPTIMIZATION_SOLVER_HPP_
6 #include <vector>
8 namespace caffe {
10 template <typename Dtype>
11 class Solver {
12 public:
13 explicit Solver(const SolverParameter& param)
14 : param_(param) {}
15 // The main entry of the solver function. In default, iter will be zero. Pass
16 // in a non-zero iter number to resume training for a pre-trained net.
17 void Solve(Net<Dtype>* net, char* state_file = NULL);
18 virtual ~Solver() {}
20 protected:
21 // PreSolve is run before any solving iteration starts, allowing one to
22 // put up some scaffold.
23 virtual void PreSolve() {}
24 // Get the update value for the current iteration.
25 virtual void ComputeUpdateValue() = 0;
26 // The Solver::Snapshot function implements the basic snapshotting utility
27 // that stores the learned net. You should implement the SnapshotSolverState()
28 // function that produces a SolverState protocol buffer that needs to be
29 // written to disk together with the learned net.
30 void Snapshot();
31 virtual void SnapshotSolverState(SolverState* state) = 0;
32 // The Restore function implements how one should restore the solver to a
33 // previously snapshotted state. You should implement the RestoreSolverState()
34 // function that restores the state from a SolverState protocol buffer.
35 void Restore(char* state_file);
36 virtual void RestoreSolverState(const SolverState& state) = 0;
37 SolverParameter param_;
38 int iter_;
39 Net<Dtype>* net_;
41 DISABLE_COPY_AND_ASSIGN(Solver);
42 };
45 template <typename Dtype>
46 class SGDSolver : public Solver<Dtype> {
47 public:
48 explicit SGDSolver(const SolverParameter& param)
49 : Solver<Dtype>(param) {}
51 protected:
52 virtual void PreSolve();
53 Dtype GetLearningRate();
54 virtual void ComputeUpdateValue();
55 virtual void SnapshotSolverState(SolverState * state);
56 virtual void RestoreSolverState(const SolverState& state);
57 // history maintains the historical momentum data.
58 vector<shared_ptr<Blob<Dtype> > > history_;
60 DISABLE_COPY_AND_ASSIGN(SGDSolver);
61 };
64 } // namspace caffe
66 #endif // CAFFE_OPTIMIZATION_SOLVER_HPP_