// Copyright Yangqing Jia 2013 #include #include #include #include #include "caffe/net.hpp" #include "caffe/proto/caffe.pb.h" #include "caffe/solver.hpp" #include "caffe/util/io.hpp" #include "caffe/util/math_functions.hpp" using std::max; using std::min; namespace caffe { template void Solver::Solve(Net* net, char* resume_file) { net_ = net; LOG(INFO) << "Solving " << net_->name(); PreSolve(); iter_ = 0; if (resume_file) { LOG(INFO) << "Restoring previous solver status from " << resume_file; Restore(resume_file); } // For a network that is trained by the solver, no bottom or top vecs // should be given, and we will just provide dummy vecs. vector*> bottom_vec; while (iter_++ < param_.max_iter()) { Dtype loss = net_->ForwardBackward(bottom_vec); ComputeUpdateValue(); net_->Update(); // Check if we need to do snapshot if (param_.snapshot() && iter_ % param_.snapshot() == 0) { Snapshot(); } if (param_.display() && iter_ % param_.display() == 0) { LOG(ERROR) << "Iteration " << iter_ << ", loss = " << loss; } } LOG(INFO) << "Optimization Done."; } template void Solver::Snapshot() { NetParameter net_param; // For intermediate results, we will also dump the gradient values. net_->ToProto(&net_param, param_.snapshot_diff()); string filename(param_.snapshot_prefix()); char iter_str_buffer[20]; sprintf(iter_str_buffer, "_iter_%d", iter_); filename += iter_str_buffer; LOG(INFO) << "Snapshotting to " << filename; WriteProtoToBinaryFile(net_param, filename.c_str()); SolverState state; SnapshotSolverState(&state); state.set_iter(iter_); state.set_learned_net(filename); filename += ".solverstate"; LOG(INFO) << "Snapshotting solver state to " << filename; WriteProtoToBinaryFile(state, filename.c_str()); } template void Solver::Restore(char* state_file) { SolverState state; NetParameter net_param; ReadProtoFromBinaryFile(state_file, &state); ReadProtoFromBinaryFile(state.learned_net().c_str(), &net_param); net_->CopyTrainedLayersFrom(net_param); iter_ = state.iter(); RestoreSolverState(state); } // Return the current learning rate. The currently implemented learning rate // policies are as follows: // - fixed: always return base_lr. // - step: return base_lr * gamma ^ (floor(iter / step)) // - exp: return base_lr * gamma ^ iter // - inv: return base_lr * (1 + gamma * iter) ^ (- power) // where base_lr, gamma, step and power are defined in the solver parameter // protocol buffer, and iter is the current iteration. template Dtype SGDSolver::GetLearningRate() { Dtype rate; const string& lr_policy = this->param_.lr_policy(); if (lr_policy == "fixed") { rate = this->param_.base_lr(); } else if (lr_policy == "step") { int current_step = this->iter_ / this->param_.stepsize(); rate = this->param_.base_lr() * pow(this->param_.gamma(), current_step); } else if (lr_policy == "exp") { rate = this->param_.base_lr() * pow(this->param_.gamma(), this->iter_); } else if (lr_policy == "inv") { rate = this->param_.base_lr() * pow(Dtype(1) + this->param_.gamma() * this->iter_, - this->param_.power()); } else { LOG(FATAL) << "Unknown learning rate policy: " << lr_policy; } rate = min(max(rate, Dtype(this->param_.min_lr())), Dtype(this->param_.max_lr())); return rate; } template void SGDSolver::PreSolve() { // Initialize the history vector > >& net_params = this->net_->params(); history_.clear(); for (int i = 0; i < net_params.size(); ++i) { const Blob* net_param = net_params[i].get(); history_.push_back(shared_ptr >(new Blob( net_param->num(), net_param->channels(), net_param->height(), net_param->width()))); } } template void SGDSolver::ComputeUpdateValue() { vector > >& net_params = this->net_->params(); vector& net_params_lr = this->net_->params_lr(); vector& net_params_weight_decay = this->net_->params_weight_decay(); // get the learning rate Dtype rate = GetLearningRate(); if (this->param_.display() && this->iter_ % this->param_.display() == 0) { LOG(ERROR) << "Iteration " << this->iter_ << ", lr = " << rate; } Dtype momentum = this->param_.momentum(); Dtype weight_decay = this->param_.weight_decay(); switch (Caffe::mode()) { case Caffe::CPU: for (int param_id = 0; param_id < net_params.size(); ++param_id) { // Compute the value to history, and then copy them to the blob's diff. Dtype local_rate = rate * net_params_lr[param_id]; Dtype local_decay = weight_decay * net_params_weight_decay[param_id]; caffe_axpby(net_params[param_id]->count(), local_rate, net_params[param_id]->cpu_diff(), momentum, history_[param_id]->mutable_cpu_data()); if (local_decay) { // add weight decay caffe_axpy(net_params[param_id]->count(), local_decay * local_rate, net_params[param_id]->cpu_data(), history_[param_id]->mutable_cpu_data()); } // copy caffe_copy(net_params[param_id]->count(), history_[param_id]->cpu_data(), net_params[param_id]->mutable_cpu_diff()); } break; case Caffe::GPU: for (int param_id = 0; param_id < net_params.size(); ++param_id) { // Compute the value to history, and then copy them to the blob's diff. Dtype local_rate = rate * net_params_lr[param_id]; Dtype local_decay = weight_decay * net_params_weight_decay[param_id]; caffe_gpu_axpby(net_params[param_id]->count(), local_rate, net_params[param_id]->gpu_diff(), momentum, history_[param_id]->mutable_gpu_data()); if (local_decay) { // add weight decay caffe_gpu_axpy(net_params[param_id]->count(), local_decay * local_rate, net_params[param_id]->gpu_data(), history_[param_id]->mutable_gpu_data()); } // copy caffe_gpu_copy(net_params[param_id]->count(), history_[param_id]->gpu_data(), net_params[param_id]->mutable_gpu_diff()); } break; default: LOG(FATAL) << "Unknown caffe mode: " << Caffe::mode(); } } template void SGDSolver::SnapshotSolverState(SolverState* state) { state->clear_history(); for (int i = 0; i < history_.size(); ++i) { // Add history BlobProto* history_blob = state->add_history(); history_[i]->ToProto(history_blob); } } template void SGDSolver::RestoreSolverState(const SolverState& state) { CHECK_EQ(state.history_size(), history_.size()) << "Incorrect length of history blobs."; LOG(INFO) << "SGDSolver: restoring history"; for (int i = 0; i < history_.size(); ++i) { history_[i]->FromProto(state.history(i)); } } INSTANTIATE_CLASS(Solver); INSTANTIATE_CLASS(SGDSolver); } // namespace caffe