solver restructuring: now all prototxt are specified in the solver protocol buffer
authorYangqing Jia <jiayq84@gmail.com>
Thu, 31 Oct 2013 23:52:22 +0000 (16:52 -0700)
committerYangqing Jia <jiayq84@gmail.com>
Thu, 31 Oct 2013 23:52:22 +0000 (16:52 -0700)
Makefile
data/lenet_solver.prototxt [new file with mode: 0644]
examples/demo_mnist.cpp [deleted file]
examples/train_net.cpp
include/caffe/common.hpp
include/caffe/solver.hpp
src/caffe/common.cpp
src/caffe/layers/data_layer.cpp
src/caffe/net.cpp
src/caffe/proto/caffe.proto
src/caffe/solver.cpp

index 6bdab19f1baa8ecb46806b90c2a76dff1fac2510..a74d8b51bcdf44f7585d9a38a8f3d3d9f5f1b07b 100644 (file)
--- a/Makefile
+++ b/Makefile
@@ -47,7 +47,7 @@ LIBRARIES := cuda cudart cublas curand protobuf opencv_core opencv_highgui \
        glog mkl_rt mkl_intel_thread leveldb snappy pthread
 WARNINGS := -Wall
 
-COMMON_FLAGS := $(foreach includedir,$(INCLUDE_DIRS),-I$(includedir))
+COMMON_FLAGS := -DNDEBUG $(foreach includedir,$(INCLUDE_DIRS),-I$(includedir))
 CXXFLAGS += -pthread -fPIC -O2 $(COMMON_FLAGS)
 NVCCFLAGS := -Xcompiler -fPIC -O2 $(COMMON_FLAGS)
 LDFLAGS += $(foreach librarydir,$(LIBRARY_DIRS),-L$(librarydir)) \
diff --git a/data/lenet_solver.prototxt b/data/lenet_solver.prototxt
new file mode 100644 (file)
index 0000000..d58255b
--- /dev/null
@@ -0,0 +1,12 @@
+train_net: "data/lenet.prototxt"
+test_net: "data/lenet_test.prototxt"
+base_lr: 0.01
+lr_policy: "inv"
+gamma: 0.0001
+power: 0.75
+display: 100
+max_iter: 5000
+momentum: 0.9
+weight_decay: 0.0005
+test_iter: 100
+test_interval: 500
\ No newline at end of file
diff --git a/examples/demo_mnist.cpp b/examples/demo_mnist.cpp
deleted file mode 100644 (file)
index 11d3fc5..0000000
+++ /dev/null
@@ -1,96 +0,0 @@
-// Copyright 2013 Yangqing Jia
-// This example shows how to run a modified version of LeNet using Caffe.
-
-#include <cuda_runtime.h>
-#include <fcntl.h>
-#include <google/protobuf/text_format.h>
-
-#include <cstring>
-#include <iostream>
-
-#include "caffe/blob.hpp"
-#include "caffe/common.hpp"
-#include "caffe/net.hpp"
-#include "caffe/filler.hpp"
-#include "caffe/proto/caffe.pb.h"
-#include "caffe/util/io.hpp"
-#include "caffe/solver.hpp"
-
-using namespace caffe;
-
-int main(int argc, char** argv) {
-  if (argc < 3) {
-    std::cout << "Usage:" << std::endl;
-    std::cout << "demo_mnist.bin train_file test_file [CPU/GPU]" << std::endl;
-    return 0;
-  }
-  google::InitGoogleLogging(argv[0]);
-  Caffe::DeviceQuery();
-
-  if (argc == 4 && strcmp(argv[3], "GPU") == 0) {
-    LOG(ERROR) << "Using GPU";
-    Caffe::set_mode(Caffe::GPU);
-  } else {
-    LOG(ERROR) << "Using CPU";
-    Caffe::set_mode(Caffe::CPU);
-  }
-
-  // Start training
-  Caffe::set_phase(Caffe::TRAIN);
-
-  NetParameter net_param;
-  ReadProtoFromTextFile(argv[1],
-      &net_param);
-  vector<Blob<float>*> bottom_vec;
-  Net<float> caffe_net(net_param, bottom_vec);
-
-  // Run the network without training.
-  LOG(ERROR) << "Performing Forward";
-  caffe_net.Forward(bottom_vec);
-  LOG(ERROR) << "Performing Backward";
-  LOG(ERROR) << "Initial loss: " << caffe_net.Backward();
-
-  SolverParameter solver_param;
-  // Solver Parameters are hard-coded in this case, but you can write a
-  // SolverParameter protocol buffer to specify all these values.
-  solver_param.set_base_lr(0.01);
-  solver_param.set_display(100);
-  solver_param.set_max_iter(5000);
-  solver_param.set_lr_policy("inv");
-  solver_param.set_gamma(0.0001);
-  solver_param.set_power(0.75);
-  solver_param.set_momentum(0.9);
-  solver_param.set_weight_decay(0.0005);
-
-  LOG(ERROR) << "Starting Optimization";
-  SGDSolver<float> solver(solver_param);
-  solver.Solve(&caffe_net);
-  LOG(ERROR) << "Optimization Done.";
-
-  // Write the trained network to a NetParameter protobuf. If you are training
-  // the model and saving it for later, this is what you want to serialize and
-  // store.
-  NetParameter trained_net_param;
-  caffe_net.ToProto(&trained_net_param);
-
-  // Now, let's starting doing testing.
-  Caffe::set_phase(Caffe::TEST);
-
-  // Using the testing data to test the accuracy.
-  NetParameter test_net_param;
-  ReadProtoFromTextFile(argv[2], &test_net_param);
-  Net<float> caffe_test_net(test_net_param, bottom_vec);
-  caffe_test_net.CopyTrainedLayersFrom(trained_net_param);
-
-  double test_accuracy = 0;
-  int batch_size = test_net_param.layers(0).layer().batchsize();
-  for (int i = 0; i < 10000 / batch_size; ++i) {
-    const vector<Blob<float>*>& result =
-        caffe_test_net.Forward(bottom_vec);
-    test_accuracy += result[0]->cpu_data()[0];
-  }
-  test_accuracy /= 10000 / batch_size;
-  LOG(ERROR) << "Test accuracy:" << test_accuracy;
-
-  return 0;
-}
index 3abb1c3955557486d54534666ea8be2e0b78861d..d84619b2d44c10028fa5fdd215a020c03b677b9a 100644 (file)
 using namespace caffe;
 
 int main(int argc, char** argv) {
-  if (argc < 3) {
-    LOG(ERROR) << "Usage: train_net net_proto_file solver_proto_file "
-               << "[resume_point_file]";
+  ::google::InitGoogleLogging(argv[0]);
+  if (argc < 2) {
+    LOG(ERROR) << "Usage: train_net solver_proto_file [resume_point_file]";
     return 0;
   }
 
-  cudaSetDevice(0);
+  Caffe::SetDevice(0);
   Caffe::set_mode(Caffe::GPU);
-  Caffe::set_phase(Caffe::TRAIN);
-
-  NetParameter net_param;
-  ReadProtoFromTextFile(argv[1], &net_param);
-  vector<Blob<float>*> bottom_vec;
-  Net<float> caffe_net(net_param, bottom_vec);
 
   SolverParameter solver_param;
-  ReadProtoFromTextFile(argv[2], &solver_param);
+  ReadProtoFromTextFile(argv[1], &solver_param);
 
-  LOG(ERROR) << "Starting Optimization";
+  LOG(INFO) << "Starting Optimization";
   SGDSolver<float> solver(solver_param);
-  if (argc == 4) {
-    LOG(ERROR) << "Resuming from " << argv[3];
-    solver.Solve(&caffe_net, argv[3]);
+  if (argc == 3) {
+    LOG(INFO) << "Resuming from " << argv[2];
+    solver.Solve(argv[2]);
   } else {
-    solver.Solve(&caffe_net);
+    solver.Solve();
   }
-  LOG(ERROR) << "Optimization Done.";
+  LOG(INFO) << "Optimization Done.";
 
   return 0;
 }
index 485a86947f5f9834b056593dd60ca6507b3f935e..af42772bac43951a2861892e1dfd5d4b57ce355e 100644 (file)
@@ -94,6 +94,9 @@ class Caffe {
   inline static void set_phase(Phase phase) { Get().phase_ = phase; }
   // Sets the random seed of both MKL and curand
   static void set_random_seed(const unsigned int seed);
+  // Sets the device. Since we have cublas and curand stuff, set device also
+  // requires us to reset those values.
+  static void SetDevice(const int device_id);
   // Prints the current GPU status.
   static void DeviceQuery();
 
index 2055ab6b24c5e42bf1d8d1360d6d53ab18038bed..168f4b43f0090f2787b618f22b35e4158b492637 100644 (file)
@@ -10,11 +10,11 @@ namespace caffe {
 template <typename Dtype>
 class Solver {
  public:
-  explicit Solver(const SolverParameter& param)
-      : param_(param) {}
+  explicit Solver(const SolverParameter& param);
   // The main entry of the solver function. In default, iter will be zero. Pass
   // in a non-zero iter number to resume training for a pre-trained net.
-  void Solve(Net<Dtype>* net, char* state_file = NULL);
+  void Solve(const char* resume_file = NULL);
+  inline void Solve(const string resume_file) { Solve(resume_file.c_str()); }
   virtual ~Solver() {}
 
  protected:
@@ -28,15 +28,18 @@ class Solver {
   // function that produces a SolverState protocol buffer that needs to be
   // written to disk together with the learned net.
   void Snapshot();
+  // The test routine
+  void Test();
   virtual void SnapshotSolverState(SolverState* state) = 0;
   // The Restore function implements how one should restore the solver to a
   // previously snapshotted state. You should implement the RestoreSolverState()
   // function that restores the state from a SolverState protocol buffer.
-  void Restore(char* state_file);
+  void Restore(const char* resume_file);
   virtual void RestoreSolverState(const SolverState& state) = 0;
   SolverParameter param_;
   int iter_;
   Net<Dtype>* net_;
+  Net<Dtype>* test_net_;
 
   DISABLE_COPY_AND_ASSIGN(Solver);
 };
index aecdc6e12666345606eb3e011cf6547f639ec385..a70a28087d391a1c750303ce4d5a98afd3662d47 100644 (file)
@@ -74,6 +74,24 @@ void Caffe::set_random_seed(const unsigned int seed) {
   VSL_CHECK(vslNewStream(&(Get().vsl_stream_), VSL_BRNG_MT19937, seed));
 }
 
+void Caffe::SetDevice(const int device_id) {
+  int current_device;
+  CUDA_CHECK(cudaGetDevice(&current_device));
+  if (current_device == device_id) {
+    return;
+  }
+  if (Get().cublas_handle_) CUBLAS_CHECK(cublasDestroy(Get().cublas_handle_));
+  if (Get().curand_generator_) {
+    CURAND_CHECK(curandDestroyGenerator(Get().curand_generator_));
+  }
+  CUDA_CHECK(cudaSetDevice(device_id));
+  CUBLAS_CHECK(cublasCreate(&Get().cublas_handle_));
+  CURAND_CHECK(curandCreateGenerator(&Get().curand_generator_,
+      CURAND_RNG_PSEUDO_DEFAULT));
+  CURAND_CHECK(curandSetPseudoRandomGeneratorSeed(Get().curand_generator_,
+      time(NULL)));
+}
+
 void Caffe::DeviceQuery() {
   cudaDeviceProp prop;
   int device;
index 8ce6d09a129fddde0ce6844c3afe3fa2b923e29a..12fd6d94625e68e82188951e208c8af728ae85c3 100644 (file)
@@ -101,7 +101,7 @@ void* DataLayerPrefetch(void* layer_pointer) {
     layer->iter_->Next();
     if (!layer->iter_->Valid()) {
       // We have reached the end. Restart from the first.
-      LOG(INFO) << "Restarting data prefetching from start.";
+      DLOG(INFO) << "Restarting data prefetching from start.";
       layer->iter_->SeekToFirst();
     }
   }
@@ -180,10 +180,10 @@ void DataLayer<Dtype>::SetUp(const vector<Blob<Dtype>*>& bottom,
   prefetch_data_->mutable_cpu_data();
   prefetch_label_->mutable_cpu_data();
   data_mean_.cpu_data();
-  // LOG(INFO) << "Initializing prefetch";
+  DLOG(INFO) << "Initializing prefetch";
   CHECK(!pthread_create(&thread_, NULL, DataLayerPrefetch<Dtype>,
       reinterpret_cast<void*>(this))) << "Pthread execution failed.";
-  // LOG(INFO) << "Prefetch initialized.";
+  DLOG(INFO) << "Prefetch initialized.";
 }
 
 template <typename Dtype>
index 38a806df17cd50c4d65f934a68f2df8e13d3aa9b..0c344faeb1cd3216230f64461285e9447214f81f 100644 (file)
@@ -121,7 +121,7 @@ Net<Dtype>::Net(const NetParameter& param,
   // In the end, all remaining blobs are considered output blobs.
   for (set<string>::iterator it = available_blobs.begin();
       it != available_blobs.end(); ++it) {
-    LOG(ERROR) << "This network produces output " << *it;
+    LOG(INFO) << "This network produces output " << *it;
     net_output_blob_indices_.push_back(blob_name_to_idx[*it]);
     net_output_blobs_.push_back(blobs_[blob_name_to_idx[*it]].get());
   }
@@ -207,10 +207,10 @@ void Net<Dtype>::CopyTrainedLayersFrom(const NetParameter& param) {
       ++target_layer_id;
     }
     if (target_layer_id == layer_names_.size()) {
-      LOG(INFO) << "Ignoring source layer " << source_layer_name;
+      DLOG(INFO) << "Ignoring source layer " << source_layer_name;
       continue;
     }
-    LOG(INFO) << "Loading source layer " << source_layer_name;
+    DLOG(INFO) << "Loading source layer " << source_layer_name;
     vector<shared_ptr<Blob<Dtype> > >& target_blobs =
         layers_[target_layer_id]->blobs();
     CHECK_EQ(target_blobs.size(), source_layer.blobs_size())
@@ -233,7 +233,7 @@ void Net<Dtype>::ToProto(NetParameter* param, bool write_diff) {
   for (int i = 0; i < net_input_blob_indices_.size(); ++i) {
     param->add_input(blob_names_[net_input_blob_indices_[i]]);
   }
-  LOG(INFO) << "Serializing " << layers_.size() << " layers";
+  DLOG(INFO) << "Serializing " << layers_.size() << " layers";
   for (int i = 0; i < layers_.size(); ++i) {
     LayerConnection* layer_connection = param->add_layers();
     for (int j = 0; j < bottom_id_vecs_[i].size(); ++j) {
index 16bb352e964c0accfb9b385f65c2efcaa02c9b71..c3140708d93f84bd42ee91f3ad5968ff4274fc5f 100644 (file)
@@ -94,25 +94,28 @@ message NetParameter {
 }
 
 message SolverParameter {
-  optional float base_lr = 1; // The base learning rate
+  optional string train_net = 1; // The proto file for the training net.
+  optional string test_net = 2; // The proto file for the testing net.
+  // The number of iterations for each testing phase.
+  optional int32 test_iter = 3 [ default = 0 ];
+  // The number of iterations between two testing phases.
+  optional int32 test_interval = 4 [ default = 0 ];
+  optional float base_lr = 5; // The base learning rate
   // the number of iterations between displaying info. If display = 0, no info
   // will be displayed.
-  optional int32 display = 2;
-  optional int32 max_iter = 3; // the maximum number of iterations
-  optional int32 snapshot = 4 [default = 0]; // The snapshot interval
-  optional string lr_policy = 5; // The learning rate decay policy.
-  optional float min_lr = 6 [default = 0]; // The mininum learning rate
-  optional float max_lr = 7 [default = 1e10]; // The maximum learning rate
-  optional float gamma = 8; // The parameter to compute the learning rate.
-  optional float power = 9; // The parameter to compute the learning rate.
-  optional float momentum = 10; // The momentum value.
-  optional float weight_decay = 11; // The weight decay.
-  optional int32 stepsize = 12; // the stepsize for learning rate policy "step"
-
-  optional string snapshot_prefix = 13; // The prefix for the snapshot.
+  optional int32 display = 6;
+  optional int32 max_iter = 7; // the maximum number of iterations
+  optional string lr_policy = 8; // The learning rate decay policy.
+  optional float gamma = 9; // The parameter to compute the learning rate.
+  optional float power = 10; // The parameter to compute the learning rate.
+  optional float momentum = 11; // The momentum value.
+  optional float weight_decay = 12; // The weight decay.
+  optional int32 stepsize = 13; // the stepsize for learning rate policy "step"
+  optional int32 snapshot = 14 [default = 0]; // The snapshot interval
+  optional string snapshot_prefix = 15; // The prefix for the snapshot.
   // whether to snapshot diff in the results or not. Snapshotting diff will help
   // debugging but the final protocol buffer size will be much larger.
-  optional bool snapshot_diff = 14 [ default = false];
+  optional bool snapshot_diff = 16 [ default = false];
 }
 
 // A message that stores the solver snapshots
index bf4bdbc620a7aa96a7adfef2e0d27075a575bdd6..07b46a42e971de8093471cc8a42b1774182e2d23 100644 (file)
@@ -18,8 +18,31 @@ using std::min;
 namespace caffe {
 
 template <typename Dtype>
-void Solver<Dtype>::Solve(Net<Dtype>* net, char* resume_file) {
-  net_ = net;
+Solver<Dtype>::Solver(const SolverParameter& param)
+    : param_(param), net_(NULL), test_net_(NULL) {
+  // Scaffolding code
+  NetParameter train_net_param;
+  ReadProtoFromTextFile(param_.train_net(), &train_net_param);
+  // For the training network, there should be no input - so we simply create
+  // a dummy bottom_vec instance to initialize the networks.
+  vector<Blob<Dtype>*> bottom_vec;
+  LOG(INFO) << "Creating training net.";
+  net_ = new Net<Dtype>(train_net_param, bottom_vec);
+  if (param_.has_test_net()) {
+    LOG(INFO) << "Creating testing net.";
+    NetParameter test_net_param;
+    ReadProtoFromTextFile(param_.test_net(), &test_net_param);
+    test_net_ = new Net<Dtype>(test_net_param, bottom_vec);
+    CHECK_GT(param_.test_iter(), 0);
+    CHECK_GT(param_.test_interval(), 0);
+  }
+  LOG(INFO) << "Solver scaffolding done.";
+}
+
+
+template <typename Dtype>
+void Solver<Dtype>::Solve(const char* resume_file) {
+  Caffe::set_phase(Caffe::TRAIN);
   LOG(INFO) << "Solving " << net_->name();
   PreSolve();
 
@@ -42,13 +65,54 @@ void Solver<Dtype>::Solve(Net<Dtype>* net, char* resume_file) {
       Snapshot();
     }
     if (param_.display() && iter_ % param_.display() == 0) {
-      LOG(ERROR) << "Iteration " << iter_ << ", loss = " << loss;
+      LOG(INFO) << "Iteration " << iter_ << ", loss = " << loss;
+    }
+    if (param_.test_interval() && iter_ % param_.test_interval() == 0) {
+      // We need to set phase to test before running.
+      Caffe::set_phase(Caffe::TEST);
+      Test();
+      Caffe::set_phase(Caffe::TRAIN);
     }
   }
   LOG(INFO) << "Optimization Done.";
 }
 
 
+template <typename Dtype>
+void Solver<Dtype>::Test() {
+  LOG(INFO) << "Testing net";
+  NetParameter net_param;
+  net_->ToProto(&net_param);
+  CHECK_NOTNULL(test_net_)->CopyTrainedLayersFrom(net_param);
+  vector<Dtype> test_score;
+  vector<Blob<Dtype>*> bottom_vec;
+  for (int i = 0; i < param_.test_iter(); ++i) {
+    const vector<Blob<Dtype>*>& result =
+        test_net_->Forward(bottom_vec);
+    if (i == 0) {
+      for (int j = 0; j < result.size(); ++j) {
+        const Dtype* result_vec = result[j]->cpu_data();
+        for (int k = 0; k < result[j]->count(); ++k) {
+          test_score.push_back(result_vec[k]);
+        }
+      }
+    } else {
+      int idx = 0;
+      for (int j = 0; j < result.size(); ++j) {
+        const Dtype* result_vec = result[j]->cpu_data();
+        for (int k = 0; k < result[j]->count(); ++k) {
+          test_score[idx++] += result_vec[k];
+        }
+      }
+    }
+  }
+  for (int i = 0; i < test_score.size(); ++i) {
+    LOG(INFO) << "Test score #" << i << ": "
+        << test_score[i] / param_.test_iter();
+  }
+}
+
+
 template <typename Dtype>
 void Solver<Dtype>::Snapshot() {
   NetParameter net_param;
@@ -70,7 +134,7 @@ void Solver<Dtype>::Snapshot() {
 }
 
 template <typename Dtype>
-void Solver<Dtype>::Restore(char* state_file) {
+void Solver<Dtype>::Restore(const char* state_file) {
   SolverState state;
   NetParameter net_param;
   ReadProtoFromBinaryFile(state_file, &state);
@@ -108,8 +172,6 @@ Dtype SGDSolver<Dtype>::GetLearningRate() {
   } else {
     LOG(FATAL) << "Unknown learning rate policy: " << lr_policy;
   }
-  rate = min(max(rate, Dtype(this->param_.min_lr())),
-      Dtype(this->param_.max_lr()));
   return rate;
 }
 
@@ -136,7 +198,7 @@ void SGDSolver<Dtype>::ComputeUpdateValue() {
   // get the learning rate
   Dtype rate = GetLearningRate();
   if (this->param_.display() && this->iter_ % this->param_.display() == 0) {
-    LOG(ERROR) << "Iteration " << this->iter_ << ", lr = " << rate;
+    LOG(INFO) << "Iteration " << this->iter_ << ", lr = " << rate;
   }
   Dtype momentum = this->param_.momentum();
   Dtype weight_decay = this->param_.weight_decay();