]> Gitweb @ Texas Instruments - Open Source Git Repositories - git.TI.com/gitweb - jacinto-ai/caffe-jacinto.git/blobdiff - src/caffe/solver.cpp
Several changes:
[jacinto-ai/caffe-jacinto.git] / src / caffe / solver.cpp
index 425bd421e082e4e62fde685d207c339b366f9e5b..e02b72f31f657f2bc7bf84c72853d2e64ddb00f7 100644 (file)
@@ -18,8 +18,29 @@ using std::min;
 namespace caffe {
 
 template <typename Dtype>
-void Solver<Dtype>::Solve(Net<Dtype>* net, char* resume_file) {
-  net_ = net;
+Solver<Dtype>::Solver(const SolverParameter& param)
+    : param_(param), net_(), test_net_() {
+  // Scaffolding code
+  NetParameter train_net_param;
+  ReadProtoFromTextFile(param_.train_net(), &train_net_param);
+  LOG(INFO) << "Creating training net.";
+  net_.reset(new Net<Dtype>(train_net_param));
+  if (param_.has_test_net()) {
+    LOG(INFO) << "Creating testing net.";
+    NetParameter test_net_param;
+    ReadProtoFromTextFile(param_.test_net(), &test_net_param);
+    test_net_.reset(new Net<Dtype>(test_net_param));
+    CHECK_GT(param_.test_iter(), 0);
+    CHECK_GT(param_.test_interval(), 0);
+  }
+  LOG(INFO) << "Solver scaffolding done.";
+}
+
+
+template <typename Dtype>
+void Solver<Dtype>::Solve(const char* resume_file) {
+  Caffe::set_mode(Caffe::Brew(param_.solver_mode()));
+  Caffe::set_phase(Caffe::TRAIN);
   LOG(INFO) << "Solving " << net_->name();
   PreSolve();
 
@@ -42,13 +63,57 @@ void Solver<Dtype>::Solve(Net<Dtype>* net, char* resume_file) {
       Snapshot();
     }
     if (param_.display() && iter_ % param_.display() == 0) {
-      LOG(ERROR) << "Iteration " << iter_ << ", loss = " << loss;
+      LOG(INFO) << "Iteration " << iter_ << ", loss = " << loss;
+    }
+    if (param_.test_interval() && iter_ % param_.test_interval() == 0) {
+      // We need to set phase to test before running.
+      Caffe::set_phase(Caffe::TEST);
+      Test();
+      Caffe::set_phase(Caffe::TRAIN);
     }
   }
+  // After the optimization is done, always do a snapshot.
+  iter_--;
+  Snapshot();
   LOG(INFO) << "Optimization Done.";
 }
 
 
+template <typename Dtype>
+void Solver<Dtype>::Test() {
+  LOG(INFO) << "Testing net";
+  NetParameter net_param;
+  net_->ToProto(&net_param);
+  CHECK_NOTNULL(test_net_.get())->CopyTrainedLayersFrom(net_param);
+  vector<Dtype> test_score;
+  vector<Blob<Dtype>*> bottom_vec;
+  for (int i = 0; i < param_.test_iter(); ++i) {
+    const vector<Blob<Dtype>*>& result =
+        test_net_->Forward(bottom_vec);
+    if (i == 0) {
+      for (int j = 0; j < result.size(); ++j) {
+        const Dtype* result_vec = result[j]->cpu_data();
+        for (int k = 0; k < result[j]->count(); ++k) {
+          test_score.push_back(result_vec[k]);
+        }
+      }
+    } else {
+      int idx = 0;
+      for (int j = 0; j < result.size(); ++j) {
+        const Dtype* result_vec = result[j]->cpu_data();
+        for (int k = 0; k < result[j]->count(); ++k) {
+          test_score[idx++] += result_vec[k];
+        }
+      }
+    }
+  }
+  for (int i = 0; i < test_score.size(); ++i) {
+    LOG(INFO) << "Test score #" << i << ": "
+        << test_score[i] / param_.test_iter();
+  }
+}
+
+
 template <typename Dtype>
 void Solver<Dtype>::Snapshot() {
   NetParameter net_param;
@@ -70,12 +135,14 @@ void Solver<Dtype>::Snapshot() {
 }
 
 template <typename Dtype>
-void Solver<Dtype>::Restore(char* state_file) {
+void Solver<Dtype>::Restore(const char* state_file) {
   SolverState state;
   NetParameter net_param;
   ReadProtoFromBinaryFile(state_file, &state);
-  ReadProtoFromBinaryFile(state.learned_net().c_str(), &net_param);
-  net_->CopyTrainedLayersFrom(net_param);
+  if (state.has_learned_net()) {
+    ReadProtoFromBinaryFile(state.learned_net().c_str(), &net_param);
+    net_->CopyTrainedLayersFrom(net_param);
+  }
   iter_ = state.iter();
   RestoreSolverState(state);
 }
@@ -108,15 +175,13 @@ Dtype SGDSolver<Dtype>::GetLearningRate() {
   } else {
     LOG(FATAL) << "Unknown learning rate policy: " << lr_policy;
   }
-  rate = min(max(rate, Dtype(this->param_.min_lr())),
-      Dtype(this->param_.max_lr()));
   return rate;
 }
 
 
 template <typename Dtype>
 void SGDSolver<Dtype>::PreSolve() {
-  // First of all, see if we need to initialize the history
+  // Initialize the history
   vector<shared_ptr<Blob<Dtype> > >& net_params = this->net_->params();
   history_.clear();
   for (int i = 0; i < net_params.size(); ++i) {
@@ -132,27 +197,27 @@ template <typename Dtype>
 void SGDSolver<Dtype>::ComputeUpdateValue() {
   vector<shared_ptr<Blob<Dtype> > >& net_params = this->net_->params();
   vector<float>& net_params_lr = this->net_->params_lr();
+  vector<float>& net_params_weight_decay = this->net_->params_weight_decay();
   // get the learning rate
   Dtype rate = GetLearningRate();
   if (this->param_.display() && this->iter_ % this->param_.display() == 0) {
-    LOG(ERROR) << "Iteration " << this->iter_ << ", lr = " << rate;
+    LOG(INFO) << "Iteration " << this->iter_ << ", lr = " << rate;
   }
   Dtype momentum = this->param_.momentum();
   Dtype weight_decay = this->param_.weight_decay();
-  // LOG(ERROR) << "rate:" << rate << " momentum:" << momentum
-  //    << " weight_decay:" << weight_decay;
   switch (Caffe::mode()) {
   case Caffe::CPU:
     for (int param_id = 0; param_id < net_params.size(); ++param_id) {
       // Compute the value to history, and then copy them to the blob's diff.
       Dtype local_rate = rate * net_params_lr[param_id];
+      Dtype local_decay = weight_decay * net_params_weight_decay[param_id];
       caffe_axpby(net_params[param_id]->count(), local_rate,
           net_params[param_id]->cpu_diff(), momentum,
           history_[param_id]->mutable_cpu_data());
-      if (weight_decay) {
+      if (local_decay) {
         // add weight decay
         caffe_axpy(net_params[param_id]->count(),
-            weight_decay * local_rate,
+            local_decay * local_rate,
             net_params[param_id]->cpu_data(),
             history_[param_id]->mutable_cpu_data());
       }
@@ -166,13 +231,14 @@ void SGDSolver<Dtype>::ComputeUpdateValue() {
     for (int param_id = 0; param_id < net_params.size(); ++param_id) {
       // Compute the value to history, and then copy them to the blob's diff.
       Dtype local_rate = rate * net_params_lr[param_id];
+      Dtype local_decay = weight_decay * net_params_weight_decay[param_id];
       caffe_gpu_axpby(net_params[param_id]->count(), local_rate,
           net_params[param_id]->gpu_diff(), momentum,
           history_[param_id]->mutable_gpu_data());
-      if (weight_decay) {
+      if (local_decay) {
         // add weight decay
         caffe_gpu_axpy(net_params[param_id]->count(),
-            weight_decay * local_rate,
+            local_decay * local_rate,
             net_params[param_id]->gpu_data(),
             history_[param_id]->mutable_gpu_data());
       }
@@ -201,6 +267,7 @@ template <typename Dtype>
 void SGDSolver<Dtype>::RestoreSolverState(const SolverState& state) {
   CHECK_EQ(state.history_size(), history_.size())
       << "Incorrect length of history blobs.";
+  LOG(INFO) << "SGDSolver: restoring history";
   for (int i = 0; i < history_.size(); ++i) {
     history_[i]->FromProto(state.history(i));
   }