solver restructuring: now all prototxt are specified in the solver protocol buffer
[jacinto-ai/caffe-jacinto.git] / include / caffe / solver.hpp
1 // Copyright Yangqing Jia 2013
3 #ifndef CAFFE_OPTIMIZATION_SOLVER_HPP_
4 #define CAFFE_OPTIMIZATION_SOLVER_HPP_
6 #include <vector>
8 namespace caffe {
10 template <typename Dtype>
11 class Solver {
12  public:
13   explicit Solver(const SolverParameter& param);
14   // The main entry of the solver function. In default, iter will be zero. Pass
15   // in a non-zero iter number to resume training for a pre-trained net.
16   void Solve(const char* resume_file = NULL);
17   inline void Solve(const string resume_file) { Solve(resume_file.c_str()); }
18   virtual ~Solver() {}
20  protected:
21   // PreSolve is run before any solving iteration starts, allowing one to
22   // put up some scaffold.
23   virtual void PreSolve() {}
24   // Get the update value for the current iteration.
25   virtual void ComputeUpdateValue() = 0;
26   // The Solver::Snapshot function implements the basic snapshotting utility
27   // that stores the learned net. You should implement the SnapshotSolverState()
28   // function that produces a SolverState protocol buffer that needs to be
29   // written to disk together with the learned net.
30   void Snapshot();
31   // The test routine
32   void Test();
33   virtual void SnapshotSolverState(SolverState* state) = 0;
34   // The Restore function implements how one should restore the solver to a
35   // previously snapshotted state. You should implement the RestoreSolverState()
36   // function that restores the state from a SolverState protocol buffer.
37   void Restore(const char* resume_file);
38   virtual void RestoreSolverState(const SolverState& state) = 0;
39   SolverParameter param_;
40   int iter_;
41   Net<Dtype>* net_;
42   Net<Dtype>* test_net_;
44   DISABLE_COPY_AND_ASSIGN(Solver);
45 };
48 template <typename Dtype>
49 class SGDSolver : public Solver<Dtype> {
50  public:
51   explicit SGDSolver(const SolverParameter& param)
52       : Solver<Dtype>(param) {}
54  protected:
55   virtual void PreSolve();
56   Dtype GetLearningRate();
57   virtual void ComputeUpdateValue();
58   virtual void SnapshotSolverState(SolverState * state);
59   virtual void RestoreSolverState(const SolverState& state);
60   // history maintains the historical momentum data.
61   vector<shared_ptr<Blob<Dtype> > > history_;
63   DISABLE_COPY_AND_ASSIGN(SGDSolver);
64 };
67 }  // namspace caffe
69 #endif  // CAFFE_OPTIMIZATION_SOLVER_HPP_