diff --git a/src/caffe/solver.cpp b/src/caffe/solver.cpp
index bf4bdbc620a7aa96a7adfef2e0d27075a575bdd6..07b46a42e971de8093471cc8a42b1774182e2d23 100644 (file)
--- a/src/caffe/solver.cpp
+++ b/src/caffe/solver.cpp
namespace caffe {
template <typename Dtype>
namespace caffe {
template <typename Dtype>
-void Solver<Dtype>::Solve(Net<Dtype>* net, char* resume_file) {
- net_ = net;
+Solver<Dtype>::Solver(const SolverParameter& param)
+ : param_(param), net_(NULL), test_net_(NULL) {
+ // Scaffolding code
+ NetParameter train_net_param;
+ ReadProtoFromTextFile(param_.train_net(), &train_net_param);
+ // For the training network, there should be no input - so we simply create
+ // a dummy bottom_vec instance to initialize the networks.
+ vector<Blob<Dtype>*> bottom_vec;
+ LOG(INFO) << "Creating training net.";
+ net_ = new Net<Dtype>(train_net_param, bottom_vec);
+ if (param_.has_test_net()) {
+ LOG(INFO) << "Creating testing net.";
+ NetParameter test_net_param;
+ ReadProtoFromTextFile(param_.test_net(), &test_net_param);
+ test_net_ = new Net<Dtype>(test_net_param, bottom_vec);
+ CHECK_GT(param_.test_iter(), 0);
+ CHECK_GT(param_.test_interval(), 0);
+ }
+ LOG(INFO) << "Solver scaffolding done.";
+}
+
+
+template <typename Dtype>
+void Solver<Dtype>::Solve(const char* resume_file) {
+ Caffe::set_phase(Caffe::TRAIN);
LOG(INFO) << "Solving " << net_->name();
PreSolve();
LOG(INFO) << "Solving " << net_->name();
PreSolve();
Snapshot();
}
if (param_.display() && iter_ % param_.display() == 0) {
Snapshot();
}
if (param_.display() && iter_ % param_.display() == 0) {
- LOG(ERROR) << "Iteration " << iter_ << ", loss = " << loss;
+ LOG(INFO) << "Iteration " << iter_ << ", loss = " << loss;
+ }
+ if (param_.test_interval() && iter_ % param_.test_interval() == 0) {
+ // We need to set phase to test before running.
+ Caffe::set_phase(Caffe::TEST);
+ Test();
+ Caffe::set_phase(Caffe::TRAIN);
}
}
LOG(INFO) << "Optimization Done.";
}
}
}
LOG(INFO) << "Optimization Done.";
}
+template <typename Dtype>
+void Solver<Dtype>::Test() {
+ LOG(INFO) << "Testing net";
+ NetParameter net_param;
+ net_->ToProto(&net_param);
+ CHECK_NOTNULL(test_net_)->CopyTrainedLayersFrom(net_param);
+ vector<Dtype> test_score;
+ vector<Blob<Dtype>*> bottom_vec;
+ for (int i = 0; i < param_.test_iter(); ++i) {
+ const vector<Blob<Dtype>*>& result =
+ test_net_->Forward(bottom_vec);
+ if (i == 0) {
+ for (int j = 0; j < result.size(); ++j) {
+ const Dtype* result_vec = result[j]->cpu_data();
+ for (int k = 0; k < result[j]->count(); ++k) {
+ test_score.push_back(result_vec[k]);
+ }
+ }
+ } else {
+ int idx = 0;
+ for (int j = 0; j < result.size(); ++j) {
+ const Dtype* result_vec = result[j]->cpu_data();
+ for (int k = 0; k < result[j]->count(); ++k) {
+ test_score[idx++] += result_vec[k];
+ }
+ }
+ }
+ }
+ for (int i = 0; i < test_score.size(); ++i) {
+ LOG(INFO) << "Test score #" << i << ": "
+ << test_score[i] / param_.test_iter();
+ }
+}
+
+
template <typename Dtype>
void Solver<Dtype>::Snapshot() {
NetParameter net_param;
template <typename Dtype>
void Solver<Dtype>::Snapshot() {
NetParameter net_param;
}
template <typename Dtype>
}
template <typename Dtype>
-void Solver<Dtype>::Restore(char* state_file) {
+void Solver<Dtype>::Restore(const char* state_file) {
SolverState state;
NetParameter net_param;
ReadProtoFromBinaryFile(state_file, &state);
SolverState state;
NetParameter net_param;
ReadProtoFromBinaryFile(state_file, &state);
} else {
LOG(FATAL) << "Unknown learning rate policy: " << lr_policy;
}
} else {
LOG(FATAL) << "Unknown learning rate policy: " << lr_policy;
}
- rate = min(max(rate, Dtype(this->param_.min_lr())),
- Dtype(this->param_.max_lr()));
return rate;
}
return rate;
}
// get the learning rate
Dtype rate = GetLearningRate();
if (this->param_.display() && this->iter_ % this->param_.display() == 0) {
// get the learning rate
Dtype rate = GetLearningRate();
if (this->param_.display() && this->iter_ % this->param_.display() == 0) {
- LOG(ERROR) << "Iteration " << this->iter_ << ", lr = " << rate;
+ LOG(INFO) << "Iteration " << this->iter_ << ", lr = " << rate;
}
Dtype momentum = this->param_.momentum();
Dtype weight_decay = this->param_.weight_decay();
}
Dtype momentum = this->param_.momentum();
Dtype weight_decay = this->param_.weight_decay();