0a78d88000178387b26d5e84f2498a37a3a9c8ed
[jacinto-ai/caffe-jacinto.git] / src / caffe / optimization / solver.hpp
1 #ifndef CAFFE_OPTIMIZATION_SOLVER_HPP_
2 #define CAFFE_OPTIMIZATION_SOLVER_HPP_
4 namespace caffe {
6 template <typename Dtype>
7 class Solver {
8  public:
9   explicit Solver(const SolverParameter& param)
10       : param_(param) {}
11   // The main entry of the solver function.
12   void Solve(Net<Dtype>* net);
14  protected:
15   // Get the update value for the current iteration.
16   virtual void ComputeUpdateValue() = 0;
17   void Snapshot(bool is_final = false);
18   SolverParameter param_;
19   int iter_;
20   Net<Dtype>* net_;
22   DISABLE_COPY_AND_ASSIGN(Solver);
23 };
25 template <typename Dtype>
26 class SGDSolver : public Solver<Dtype> {
27  public:
28   explicit SGDSolver(const SolverParameter& param)
29       : Solver<Dtype>(param) {}
31  protected:
32   Dtype GetLearningRate();
33   virtual void ComputeUpdateValue();
34   // history maintains the historical momentum data.
35   vector<shared_ptr<Blob<Dtype> > > history_;
36 };
39 }  // namspace caffe
41 #endif  // CAFFE_OPTIMIZATION_SOLVER_HPP_