diff --git a/src/caffe/solver.cpp b/src/caffe/solver.cpp
index 425bd421e082e4e62fde685d207c339b366f9e5b..07b46a42e971de8093471cc8a42b1774182e2d23 100644 (file)
--- a/src/caffe/solver.cpp
+++ b/src/caffe/solver.cpp
namespace caffe {
template <typename Dtype>
namespace caffe {
template <typename Dtype>
-void Solver<Dtype>::Solve(Net<Dtype>* net, char* resume_file) {
- net_ = net;
+Solver<Dtype>::Solver(const SolverParameter& param)
+ : param_(param), net_(NULL), test_net_(NULL) {
+ // Scaffolding code
+ NetParameter train_net_param;
+ ReadProtoFromTextFile(param_.train_net(), &train_net_param);
+ // For the training network, there should be no input - so we simply create
+ // a dummy bottom_vec instance to initialize the networks.
+ vector<Blob<Dtype>*> bottom_vec;
+ LOG(INFO) << "Creating training net.";
+ net_ = new Net<Dtype>(train_net_param, bottom_vec);
+ if (param_.has_test_net()) {
+ LOG(INFO) << "Creating testing net.";
+ NetParameter test_net_param;
+ ReadProtoFromTextFile(param_.test_net(), &test_net_param);
+ test_net_ = new Net<Dtype>(test_net_param, bottom_vec);
+ CHECK_GT(param_.test_iter(), 0);
+ CHECK_GT(param_.test_interval(), 0);
+ }
+ LOG(INFO) << "Solver scaffolding done.";
+}
+
+
+template <typename Dtype>
+void Solver<Dtype>::Solve(const char* resume_file) {
+ Caffe::set_phase(Caffe::TRAIN);
LOG(INFO) << "Solving " << net_->name();
PreSolve();
LOG(INFO) << "Solving " << net_->name();
PreSolve();
Snapshot();
}
if (param_.display() && iter_ % param_.display() == 0) {
Snapshot();
}
if (param_.display() && iter_ % param_.display() == 0) {
- LOG(ERROR) << "Iteration " << iter_ << ", loss = " << loss;
+ LOG(INFO) << "Iteration " << iter_ << ", loss = " << loss;
+ }
+ if (param_.test_interval() && iter_ % param_.test_interval() == 0) {
+ // We need to set phase to test before running.
+ Caffe::set_phase(Caffe::TEST);
+ Test();
+ Caffe::set_phase(Caffe::TRAIN);
}
}
LOG(INFO) << "Optimization Done.";
}
}
}
LOG(INFO) << "Optimization Done.";
}
+template <typename Dtype>
+void Solver<Dtype>::Test() {
+ LOG(INFO) << "Testing net";
+ NetParameter net_param;
+ net_->ToProto(&net_param);
+ CHECK_NOTNULL(test_net_)->CopyTrainedLayersFrom(net_param);
+ vector<Dtype> test_score;
+ vector<Blob<Dtype>*> bottom_vec;
+ for (int i = 0; i < param_.test_iter(); ++i) {
+ const vector<Blob<Dtype>*>& result =
+ test_net_->Forward(bottom_vec);
+ if (i == 0) {
+ for (int j = 0; j < result.size(); ++j) {
+ const Dtype* result_vec = result[j]->cpu_data();
+ for (int k = 0; k < result[j]->count(); ++k) {
+ test_score.push_back(result_vec[k]);
+ }
+ }
+ } else {
+ int idx = 0;
+ for (int j = 0; j < result.size(); ++j) {
+ const Dtype* result_vec = result[j]->cpu_data();
+ for (int k = 0; k < result[j]->count(); ++k) {
+ test_score[idx++] += result_vec[k];
+ }
+ }
+ }
+ }
+ for (int i = 0; i < test_score.size(); ++i) {
+ LOG(INFO) << "Test score #" << i << ": "
+ << test_score[i] / param_.test_iter();
+ }
+}
+
+
template <typename Dtype>
void Solver<Dtype>::Snapshot() {
NetParameter net_param;
template <typename Dtype>
void Solver<Dtype>::Snapshot() {
NetParameter net_param;
}
template <typename Dtype>
}
template <typename Dtype>
-void Solver<Dtype>::Restore(char* state_file) {
+void Solver<Dtype>::Restore(const char* state_file) {
SolverState state;
NetParameter net_param;
ReadProtoFromBinaryFile(state_file, &state);
SolverState state;
NetParameter net_param;
ReadProtoFromBinaryFile(state_file, &state);
} else {
LOG(FATAL) << "Unknown learning rate policy: " << lr_policy;
}
} else {
LOG(FATAL) << "Unknown learning rate policy: " << lr_policy;
}
- rate = min(max(rate, Dtype(this->param_.min_lr())),
- Dtype(this->param_.max_lr()));
return rate;
}
template <typename Dtype>
void SGDSolver<Dtype>::PreSolve() {
return rate;
}
template <typename Dtype>
void SGDSolver<Dtype>::PreSolve() {
- // First of all, see if we need to initialize the history
+ // Initialize the history
vector<shared_ptr<Blob<Dtype> > >& net_params = this->net_->params();
history_.clear();
for (int i = 0; i < net_params.size(); ++i) {
vector<shared_ptr<Blob<Dtype> > >& net_params = this->net_->params();
history_.clear();
for (int i = 0; i < net_params.size(); ++i) {
void SGDSolver<Dtype>::ComputeUpdateValue() {
vector<shared_ptr<Blob<Dtype> > >& net_params = this->net_->params();
vector<float>& net_params_lr = this->net_->params_lr();
void SGDSolver<Dtype>::ComputeUpdateValue() {
vector<shared_ptr<Blob<Dtype> > >& net_params = this->net_->params();
vector<float>& net_params_lr = this->net_->params_lr();
+ vector<float>& net_params_weight_decay = this->net_->params_weight_decay();
// get the learning rate
Dtype rate = GetLearningRate();
if (this->param_.display() && this->iter_ % this->param_.display() == 0) {
// get the learning rate
Dtype rate = GetLearningRate();
if (this->param_.display() && this->iter_ % this->param_.display() == 0) {
- LOG(ERROR) << "Iteration " << this->iter_ << ", lr = " << rate;
+ LOG(INFO) << "Iteration " << this->iter_ << ", lr = " << rate;
}
Dtype momentum = this->param_.momentum();
Dtype weight_decay = this->param_.weight_decay();
}
Dtype momentum = this->param_.momentum();
Dtype weight_decay = this->param_.weight_decay();
- // LOG(ERROR) << "rate:" << rate << " momentum:" << momentum
- // << " weight_decay:" << weight_decay;
switch (Caffe::mode()) {
case Caffe::CPU:
for (int param_id = 0; param_id < net_params.size(); ++param_id) {
// Compute the value to history, and then copy them to the blob's diff.
Dtype local_rate = rate * net_params_lr[param_id];
switch (Caffe::mode()) {
case Caffe::CPU:
for (int param_id = 0; param_id < net_params.size(); ++param_id) {
// Compute the value to history, and then copy them to the blob's diff.
Dtype local_rate = rate * net_params_lr[param_id];
+ Dtype local_decay = weight_decay * net_params_weight_decay[param_id];
caffe_axpby(net_params[param_id]->count(), local_rate,
net_params[param_id]->cpu_diff(), momentum,
history_[param_id]->mutable_cpu_data());
caffe_axpby(net_params[param_id]->count(), local_rate,
net_params[param_id]->cpu_diff(), momentum,
history_[param_id]->mutable_cpu_data());
- if (weight_decay) {
+ if (local_decay) {
// add weight decay
caffe_axpy(net_params[param_id]->count(),
// add weight decay
caffe_axpy(net_params[param_id]->count(),
- weight_decay * local_rate,
+ local_decay * local_rate,
net_params[param_id]->cpu_data(),
history_[param_id]->mutable_cpu_data());
}
net_params[param_id]->cpu_data(),
history_[param_id]->mutable_cpu_data());
}
for (int param_id = 0; param_id < net_params.size(); ++param_id) {
// Compute the value to history, and then copy them to the blob's diff.
Dtype local_rate = rate * net_params_lr[param_id];
for (int param_id = 0; param_id < net_params.size(); ++param_id) {
// Compute the value to history, and then copy them to the blob's diff.
Dtype local_rate = rate * net_params_lr[param_id];
+ Dtype local_decay = weight_decay * net_params_weight_decay[param_id];
caffe_gpu_axpby(net_params[param_id]->count(), local_rate,
net_params[param_id]->gpu_diff(), momentum,
history_[param_id]->mutable_gpu_data());
caffe_gpu_axpby(net_params[param_id]->count(), local_rate,
net_params[param_id]->gpu_diff(), momentum,
history_[param_id]->mutable_gpu_data());
- if (weight_decay) {
+ if (local_decay) {
// add weight decay
caffe_gpu_axpy(net_params[param_id]->count(),
// add weight decay
caffe_gpu_axpy(net_params[param_id]->count(),
- weight_decay * local_rate,
+ local_decay * local_rate,
net_params[param_id]->gpu_data(),
history_[param_id]->mutable_gpu_data());
}
net_params[param_id]->gpu_data(),
history_[param_id]->mutable_gpu_data());
}
void SGDSolver<Dtype>::RestoreSolverState(const SolverState& state) {
CHECK_EQ(state.history_size(), history_.size())
<< "Incorrect length of history blobs.";
void SGDSolver<Dtype>::RestoreSolverState(const SolverState& state) {
CHECK_EQ(state.history_size(), history_.size())
<< "Incorrect length of history blobs.";
+ LOG(INFO) << "SGDSolver: restoring history";
for (int i = 0; i < history_.size(); ++i) {
history_[i]->FromProto(state.history(i));
}
for (int i = 0; i < history_.size(); ++i) {
history_[i]->FromProto(state.history(i));
}