started writing solver
authorYangqing Jia <jiayq84@gmail.com>
Fri, 27 Sep 2013 21:56:06 +0000 (14:56 -0700)
committerYangqing Jia <jiayq84@gmail.com>
Fri, 27 Sep 2013 21:56:06 +0000 (14:56 -0700)
src/caffe/net.cpp
src/caffe/net.hpp
src/caffe/optimization/solver.hpp [new file with mode: 0644]
src/caffe/proto/caffe.proto

index ac9f6a9238a24c5cdad56ae061c7b2b491673eae..c6dfce19c038239b841b00b24c18c36bee8af025 100644 (file)
@@ -95,7 +95,12 @@ Net<Dtype>::Net(const NetParameter& param,
   for (int i = 0; i < layers_.size(); ++i) {
     LOG(INFO) << "Setting up " << layer_names_[i];
     layers_[i]->SetUp(bottom_vecs_[i], &top_vecs_[i]);
   for (int i = 0; i < layers_.size(); ++i) {
     LOG(INFO) << "Setting up " << layer_names_[i];
     layers_[i]->SetUp(bottom_vecs_[i], &top_vecs_[i]);
+    vector<shared_ptr<Blob<Dtype> > >& layer_params = layers_[i].params();
+    for (int j = 0; j < layer_params.size(); ++j) {
+      params_.push_back(layer_params[j]);
+    }
   }
   }
+
   LOG(INFO) << "Network initialization done.";
 }
 
   LOG(INFO) << "Network initialization done.";
 }
 
index e91081b9c82e9a781599b50056908137c6497d96..719267c6402c57c7f30b4582d176410ad79773c0 100644 (file)
@@ -45,8 +45,10 @@ class Net {
   inline const vector<string>& blob_names() { return blob_names_; }
   // returns the blobs
   inline const vector<shared_ptr<Blob<Dtype> > >& blobs() { return blobs_; }
   inline const vector<string>& blob_names() { return blob_names_; }
   // returns the blobs
   inline const vector<shared_ptr<Blob<Dtype> > >& blobs() { return blobs_; }
-  // rethrns the layers
+  // returns the layers
   inline const vector<shared_ptr<Layer<Dtype> > >& layers() { return layers_; }
   inline const vector<shared_ptr<Layer<Dtype> > >& layers() { return layers_; }
+  // returns the parameters
+  vector<shared_ptr<Blob<Dtype> > >& params() { return params_; };
 
  protected:
   // Individual layers in the net
 
  protected:
   // Individual layers in the net
@@ -66,6 +68,8 @@ class Net {
   vector<int> net_input_blob_indices_;
   vector<int> net_output_blob_indices_;
   string name_;
   vector<int> net_input_blob_indices_;
   vector<int> net_output_blob_indices_;
   string name_;
+  // The parameters in the network.
+  vector<shared_ptr<Blob<Dtype> > > params_;
 
   DISABLE_COPY_AND_ASSIGN(Net);
 };
 
   DISABLE_COPY_AND_ASSIGN(Net);
 };
diff --git a/src/caffe/optimization/solver.hpp b/src/caffe/optimization/solver.hpp
new file mode 100644 (file)
index 0000000..0c680e3
--- /dev/null
@@ -0,0 +1,18 @@
+#ifndef CAFFE_OPTIMIZATION_SOLVER_HPP_
+#define CAFFE_OPTIMIZATION_SOLVER_HPP_
+
+namespace caffe {
+
+class Solver {
+ public:
+  explicit Solver(const SolverParameter& param)
+      : param_(param) {}
+  void Solve(Net* net);
+
+ protected:
+  SolverParameter param_;
+};
+
+}  // namspace caffe
+
+#endif  // CAFFE_OPTIMIZATION_SOLVER_HPP_
\ No newline at end of file
index 8f7e0c3328622a41c0d625ddb21863c1b65379ec..732c2eecfda153dbf111d0c573de4e2aa7544078 100644 (file)
@@ -77,3 +77,16 @@ message NetParameter {
   repeated string bottom = 3; // The input to the network
   repeated string top = 4; // The output of the network.
 }
   repeated string bottom = 3; // The input to the network
   repeated string top = 4; // The output of the network.
 }
+
+message SolverParameter {
+  optional float base_lr = 1; // The base learning rate
+  optional int32 display = 2; // display options. 0 = no display
+  optional int32 max_iter = 3; // the maximum number of iterations
+  optional int32 snapshot = 4; // The snapshot interval
+  optional string lr_policy = 5; // The learning rate decay policy.
+  optional float min_lr = 6 [default = 0]; // The mininum learning rate
+  optional float max_lr = 7 [default = 1e10]; // The maximum learning rate
+  optional float gamma = 8; // The parameter to compute the learning rate.
+  optional float power = 9; // The parameter to compute the learning rate.
+  optional float momentum = 10; // The momentum value.
+}
\ No newline at end of file