// Copyright Yangqing Jia 2013 #ifndef CAFFE_OPTIMIZATION_SOLVER_HPP_ #define CAFFE_OPTIMIZATION_SOLVER_HPP_ #include namespace caffe { template class Solver { public: explicit Solver(const SolverParameter& param); // The main entry of the solver function. In default, iter will be zero. Pass // in a non-zero iter number to resume training for a pre-trained net. void Solve(const char* resume_file = NULL); inline void Solve(const string resume_file) { Solve(resume_file.c_str()); } virtual ~Solver() {} protected: // PreSolve is run before any solving iteration starts, allowing one to // put up some scaffold. virtual void PreSolve() {} // Get the update value for the current iteration. virtual void ComputeUpdateValue() = 0; // The Solver::Snapshot function implements the basic snapshotting utility // that stores the learned net. You should implement the SnapshotSolverState() // function that produces a SolverState protocol buffer that needs to be // written to disk together with the learned net. void Snapshot(); // The test routine void Test(); virtual void SnapshotSolverState(SolverState* state) = 0; // The Restore function implements how one should restore the solver to a // previously snapshotted state. You should implement the RestoreSolverState() // function that restores the state from a SolverState protocol buffer. void Restore(const char* resume_file); virtual void RestoreSolverState(const SolverState& state) = 0; SolverParameter param_; int iter_; Net* net_; Net* test_net_; DISABLE_COPY_AND_ASSIGN(Solver); }; template class SGDSolver : public Solver { public: explicit SGDSolver(const SolverParameter& param) : Solver(param) {} protected: virtual void PreSolve(); Dtype GetLearningRate(); virtual void ComputeUpdateValue(); virtual void SnapshotSolverState(SolverState * state); virtual void RestoreSolverState(const SolverState& state); // history maintains the historical momentum data. vector > > history_; DISABLE_COPY_AND_ASSIGN(SGDSolver); }; } // namspace caffe #endif // CAFFE_OPTIMIZATION_SOLVER_HPP_