solver restructuring: now all prototxt are specified in the solver protocol buffer
[jacinto-ai/caffe-jacinto.git] / src / caffe / solver.cpp
index 87c346f7036bd32a29d6364cc619efcd3ea2d7e2..07b46a42e971de8093471cc8a42b1774182e2d23 100644 (file)
@@ -18,8 +18,31 @@ using std::min;
 namespace caffe {
 
 template <typename Dtype>
-void Solver<Dtype>::Solve(Net<Dtype>* net, char* resume_file) {
-  net_ = net;
+Solver<Dtype>::Solver(const SolverParameter& param)
+    : param_(param), net_(NULL), test_net_(NULL) {
+  // Scaffolding code
+  NetParameter train_net_param;
+  ReadProtoFromTextFile(param_.train_net(), &train_net_param);
+  // For the training network, there should be no input - so we simply create
+  // a dummy bottom_vec instance to initialize the networks.
+  vector<Blob<Dtype>*> bottom_vec;
+  LOG(INFO) << "Creating training net.";
+  net_ = new Net<Dtype>(train_net_param, bottom_vec);
+  if (param_.has_test_net()) {
+    LOG(INFO) << "Creating testing net.";
+    NetParameter test_net_param;
+    ReadProtoFromTextFile(param_.test_net(), &test_net_param);
+    test_net_ = new Net<Dtype>(test_net_param, bottom_vec);
+    CHECK_GT(param_.test_iter(), 0);
+    CHECK_GT(param_.test_interval(), 0);
+  }
+  LOG(INFO) << "Solver scaffolding done.";
+}
+
+
+template <typename Dtype>
+void Solver<Dtype>::Solve(const char* resume_file) {
+  Caffe::set_phase(Caffe::TRAIN);
   LOG(INFO) << "Solving " << net_->name();
   PreSolve();
 
@@ -42,13 +65,54 @@ void Solver<Dtype>::Solve(Net<Dtype>* net, char* resume_file) {
       Snapshot();
     }
     if (param_.display() && iter_ % param_.display() == 0) {
-      LOG(ERROR) << "Iteration " << iter_ << ", loss = " << loss;
+      LOG(INFO) << "Iteration " << iter_ << ", loss = " << loss;
+    }
+    if (param_.test_interval() && iter_ % param_.test_interval() == 0) {
+      // We need to set phase to test before running.
+      Caffe::set_phase(Caffe::TEST);
+      Test();
+      Caffe::set_phase(Caffe::TRAIN);
     }
   }
   LOG(INFO) << "Optimization Done.";
 }
 
 
+template <typename Dtype>
+void Solver<Dtype>::Test() {
+  LOG(INFO) << "Testing net";
+  NetParameter net_param;
+  net_->ToProto(&net_param);
+  CHECK_NOTNULL(test_net_)->CopyTrainedLayersFrom(net_param);
+  vector<Dtype> test_score;
+  vector<Blob<Dtype>*> bottom_vec;
+  for (int i = 0; i < param_.test_iter(); ++i) {
+    const vector<Blob<Dtype>*>& result =
+        test_net_->Forward(bottom_vec);
+    if (i == 0) {
+      for (int j = 0; j < result.size(); ++j) {
+        const Dtype* result_vec = result[j]->cpu_data();
+        for (int k = 0; k < result[j]->count(); ++k) {
+          test_score.push_back(result_vec[k]);
+        }
+      }
+    } else {
+      int idx = 0;
+      for (int j = 0; j < result.size(); ++j) {
+        const Dtype* result_vec = result[j]->cpu_data();
+        for (int k = 0; k < result[j]->count(); ++k) {
+          test_score[idx++] += result_vec[k];
+        }
+      }
+    }
+  }
+  for (int i = 0; i < test_score.size(); ++i) {
+    LOG(INFO) << "Test score #" << i << ": "
+        << test_score[i] / param_.test_iter();
+  }
+}
+
+
 template <typename Dtype>
 void Solver<Dtype>::Snapshot() {
   NetParameter net_param;
@@ -70,7 +134,7 @@ void Solver<Dtype>::Snapshot() {
 }
 
 template <typename Dtype>
-void Solver<Dtype>::Restore(char* state_file) {
+void Solver<Dtype>::Restore(const char* state_file) {
   SolverState state;
   NetParameter net_param;
   ReadProtoFromBinaryFile(state_file, &state);
@@ -108,15 +172,13 @@ Dtype SGDSolver<Dtype>::GetLearningRate() {
   } else {
     LOG(FATAL) << "Unknown learning rate policy: " << lr_policy;
   }
-  rate = min(max(rate, Dtype(this->param_.min_lr())),
-      Dtype(this->param_.max_lr()));
   return rate;
 }
 
 
 template <typename Dtype>
 void SGDSolver<Dtype>::PreSolve() {
-  // First of all, see if we need to initialize the history
+  // Initialize the history
   vector<shared_ptr<Blob<Dtype> > >& net_params = this->net_->params();
   history_.clear();
   for (int i = 0; i < net_params.size(); ++i) {
@@ -136,7 +198,7 @@ void SGDSolver<Dtype>::ComputeUpdateValue() {
   // get the learning rate
   Dtype rate = GetLearningRate();
   if (this->param_.display() && this->iter_ % this->param_.display() == 0) {
-    LOG(ERROR) << "Iteration " << this->iter_ << ", lr = " << rate;
+    LOG(INFO) << "Iteration " << this->iter_ << ", lr = " << rate;
   }
   Dtype momentum = this->param_.momentum();
   Dtype weight_decay = this->param_.weight_decay();