X-Git-Url: https://git.ti.com/gitweb?p=jacinto-ai%2Fcaffe-jacinto.git;a=blobdiff_plain;f=src%2Fcaffe%2Fsolver.cpp;h=07b46a42e971de8093471cc8a42b1774182e2d23;hp=6fe2ce91257d7601ce44093cbf0a5312445ed8af;hb=82b912be849a7bec9bfee92a8e5d81182f4130f2;hpb=62089dd8da5bb78e758d3a7fe84095f75a4120f1 diff --git a/src/caffe/solver.cpp b/src/caffe/solver.cpp index 6fe2ce91..07b46a42 100644 --- a/src/caffe/solver.cpp +++ b/src/caffe/solver.cpp @@ -18,8 +18,31 @@ using std::min; namespace caffe { template -void Solver::Solve(Net* net, char* resume_file) { - net_ = net; +Solver::Solver(const SolverParameter& param) + : param_(param), net_(NULL), test_net_(NULL) { + // Scaffolding code + NetParameter train_net_param; + ReadProtoFromTextFile(param_.train_net(), &train_net_param); + // For the training network, there should be no input - so we simply create + // a dummy bottom_vec instance to initialize the networks. + vector*> bottom_vec; + LOG(INFO) << "Creating training net."; + net_ = new Net(train_net_param, bottom_vec); + if (param_.has_test_net()) { + LOG(INFO) << "Creating testing net."; + NetParameter test_net_param; + ReadProtoFromTextFile(param_.test_net(), &test_net_param); + test_net_ = new Net(test_net_param, bottom_vec); + CHECK_GT(param_.test_iter(), 0); + CHECK_GT(param_.test_interval(), 0); + } + LOG(INFO) << "Solver scaffolding done."; +} + + +template +void Solver::Solve(const char* resume_file) { + Caffe::set_phase(Caffe::TRAIN); LOG(INFO) << "Solving " << net_->name(); PreSolve(); @@ -42,13 +65,54 @@ void Solver::Solve(Net* net, char* resume_file) { Snapshot(); } if (param_.display() && iter_ % param_.display() == 0) { - LOG(ERROR) << "Iteration " << iter_ << ", loss = " << loss; + LOG(INFO) << "Iteration " << iter_ << ", loss = " << loss; + } + if (param_.test_interval() && iter_ % param_.test_interval() == 0) { + // We need to set phase to test before running. + Caffe::set_phase(Caffe::TEST); + Test(); + Caffe::set_phase(Caffe::TRAIN); } } LOG(INFO) << "Optimization Done."; } +template +void Solver::Test() { + LOG(INFO) << "Testing net"; + NetParameter net_param; + net_->ToProto(&net_param); + CHECK_NOTNULL(test_net_)->CopyTrainedLayersFrom(net_param); + vector test_score; + vector*> bottom_vec; + for (int i = 0; i < param_.test_iter(); ++i) { + const vector*>& result = + test_net_->Forward(bottom_vec); + if (i == 0) { + for (int j = 0; j < result.size(); ++j) { + const Dtype* result_vec = result[j]->cpu_data(); + for (int k = 0; k < result[j]->count(); ++k) { + test_score.push_back(result_vec[k]); + } + } + } else { + int idx = 0; + for (int j = 0; j < result.size(); ++j) { + const Dtype* result_vec = result[j]->cpu_data(); + for (int k = 0; k < result[j]->count(); ++k) { + test_score[idx++] += result_vec[k]; + } + } + } + } + for (int i = 0; i < test_score.size(); ++i) { + LOG(INFO) << "Test score #" << i << ": " + << test_score[i] / param_.test_iter(); + } +} + + template void Solver::Snapshot() { NetParameter net_param; @@ -70,7 +134,7 @@ void Solver::Snapshot() { } template -void Solver::Restore(char* state_file) { +void Solver::Restore(const char* state_file) { SolverState state; NetParameter net_param; ReadProtoFromBinaryFile(state_file, &state); @@ -108,15 +172,13 @@ Dtype SGDSolver::GetLearningRate() { } else { LOG(FATAL) << "Unknown learning rate policy: " << lr_policy; } - rate = min(max(rate, Dtype(this->param_.min_lr())), - Dtype(this->param_.max_lr())); return rate; } template void SGDSolver::PreSolve() { - // First of all, see if we need to initialize the history + // Initialize the history vector > >& net_params = this->net_->params(); history_.clear(); for (int i = 0; i < net_params.size(); ++i) { @@ -136,7 +198,7 @@ void SGDSolver::ComputeUpdateValue() { // get the learning rate Dtype rate = GetLearningRate(); if (this->param_.display() && this->iter_ % this->param_.display() == 0) { - LOG(ERROR) << "Iteration " << this->iter_ << ", lr = " << rate; + LOG(INFO) << "Iteration " << this->iter_ << ", lr = " << rate; } Dtype momentum = this->param_.momentum(); Dtype weight_decay = this->param_.weight_decay(); @@ -202,6 +264,7 @@ template void SGDSolver::RestoreSolverState(const SolverState& state) { CHECK_EQ(state.history_size(), history_.size()) << "Incorrect length of history blobs."; + LOG(INFO) << "SGDSolver: restoring history"; for (int i = 0; i < history_.size(); ++i) { history_[i]->FromProto(state.history(i)); }