allow setting custom weight decay
authorYangqing Jia <jiayq84@gmail.com>
Tue, 22 Oct 2013 20:18:50 +0000 (13:18 -0700)
committerYangqing Jia <jiayq84@gmail.com>
Tue, 22 Oct 2013 20:18:50 +0000 (13:18 -0700)
examples/imagenet.prototxt
include/caffe/net.hpp
src/caffe/net.cpp
src/caffe/proto/caffe.proto
src/caffe/solver.cpp

index 9f24cfa81bac3b1388c628739ee6a1517159fd4a..a821f438cf4749e5b5c220ffaf09c7abf91263a8 100644 (file)
@@ -5,7 +5,7 @@ layers {
     type: "data"
     source: "/home/jiayq/caffe-train-leveldb"
     meanfile: "/home/jiayq/ilsvrc2012_mean.binaryproto"
-    batchsize: 64
+    batchsize: 256
     cropsize: 227
     mirror: true
   }
@@ -25,10 +25,12 @@ layers {
     }
     bias_filler {
       type: "constant"
-      value: 0.1
+      value: 0.
     }
     blobs_lr: 1.
     blobs_lr: 2.
+    weight_decay: 1.
+    weight_decay: 0.
   }
   bottom: "data"
   top: "conv1"
@@ -85,10 +87,12 @@ layers {
     }
     bias_filler {
       type: "constant"
-      value: 0.1
+      value: 1.
     }
     blobs_lr: 1.
     blobs_lr: 2.
+    weight_decay: 1.
+    weight_decay: 0.
   }
   bottom: "pad2"
   top: "conv2"
@@ -144,10 +148,12 @@ layers {
     }
     bias_filler {
       type: "constant"
-      value: 0.1
+      value: 0.
     }
     blobs_lr: 1.
     blobs_lr: 2.
+    weight_decay: 1.
+    weight_decay: 0.
   }
   bottom: "pad3"
   top: "conv3"
@@ -182,10 +188,12 @@ layers {
     }
     bias_filler {
       type: "constant"
-      value: 0.1
+      value: 1.
     }
     blobs_lr: 1.
     blobs_lr: 2.
+    weight_decay: 1.
+    weight_decay: 0.
   }
   bottom: "pad4"
   top: "conv4"
@@ -220,10 +228,12 @@ layers {
     }
     bias_filler {
       type: "constant"
-      value: 0.1
+      value: 1.
     }
     blobs_lr: 1.
     blobs_lr: 2.
+    weight_decay: 1.
+    weight_decay: 0.
   }
   bottom: "pad5"
   top: "conv5"
@@ -258,10 +268,12 @@ layers {
     }
     bias_filler {
       type: "constant"
-      value: 0.1
+      value: 1.
     }
     blobs_lr: 1.
     blobs_lr: 2.
+    weight_decay: 1.
+    weight_decay: 0.
   }
   bottom: "pool5"
   top: "fc6"
@@ -294,10 +306,12 @@ layers {
     }
     bias_filler {
       type: "constant"
-      value: 0.1
+      value: 1.
     }
     blobs_lr: 1.
     blobs_lr: 2.
+    weight_decay: 1.
+    weight_decay: 0.
   }
   bottom: "fc6"
   top: "fc7"
@@ -334,6 +348,8 @@ layers {
     }
     blobs_lr: 1.
     blobs_lr: 2.
+    weight_decay: 1.
+    weight_decay: 0.
   }
   bottom: "fc7"
   top: "fc8"
index c27442b8243da29d3f2adadf0211d1b7651a6350..f0a5ebb93c018177294ef1cb45a8f9cc989713d4 100644 (file)
@@ -60,6 +60,7 @@ class Net {
   inline vector<shared_ptr<Blob<Dtype> > >& params() { return params_; }
   // returns the parameter learning rate multipliers
   inline vector<float>& params_lr() {return params_lr_; }
+  inline vector<float>& params_weight_decay() { return params_weight_decay_; }
   // Updates the network
   void Update();
 
@@ -86,6 +87,8 @@ class Net {
   vector<shared_ptr<Blob<Dtype> > > params_;
   // the learning rate multipliers
   vector<float> params_lr_;
+  // the weight decay multipliers
+  vector<float> params_weight_decay_;
   DISABLE_COPY_AND_ASSIGN(Net);
 };
 
index e1442ecb184859614641ef1e1f2d5e59e7a1828e..165869d4b3353781cf22b4ee546b3f02a0322097 100644 (file)
@@ -119,6 +119,20 @@ Net<Dtype>::Net(const NetParameter& param,
         params_lr_.push_back(1.);
       }
     }
+    // push the weight decay multipliers
+    if (layers_[i]->layer_param().weight_decay_size()) {
+      CHECK_EQ(layers_[i]->layer_param().weight_decay_size(),
+          layer_blobs.size());
+      for (int j = 0; j < layer_blobs.size(); ++j) {
+        float local_decay = layers_[i]->layer_param().weight_decay(j);
+        CHECK_GT(local_decay, 0.);
+        params_weight_decay_.push_back(local_decay);
+      }
+    } else {
+      for (int j = 0; j < layer_blobs.size(); ++j) {
+        params_weight_decay_.push_back(1.);
+      }
+    }
     for (int topid = 0; topid < top_vecs_[i].size(); ++topid) {
       LOG(INFO) << "Top shape: " << top_vecs_[i][topid]->channels() << " "
           << top_vecs_[i][topid]->height() << " "
index 0aa90fd2af3c1b3afb0a9d787b882668bfd00009..9d50c36d90970f629ce61277863a9a02bf11b0a2 100644 (file)
@@ -76,6 +76,8 @@ message LayerParameter {
   // The ratio that is multiplied on the global learning rate. If you want to set
   // the learning ratio for one blob, you need to set it for all blobs.
   repeated float blobs_lr = 51;
+  // The weight decay that is multiplied on the global weight decay.
+  repeated float weight_decay = 52;
 }
 
 message LayerConnection {
index 425bd421e082e4e62fde685d207c339b366f9e5b..6fe2ce91257d7601ce44093cbf0a5312445ed8af 100644 (file)
@@ -132,6 +132,7 @@ template <typename Dtype>
 void SGDSolver<Dtype>::ComputeUpdateValue() {
   vector<shared_ptr<Blob<Dtype> > >& net_params = this->net_->params();
   vector<float>& net_params_lr = this->net_->params_lr();
+  vector<float>& net_params_weight_decay = this->net_->params_weight_decay();
   // get the learning rate
   Dtype rate = GetLearningRate();
   if (this->param_.display() && this->iter_ % this->param_.display() == 0) {
@@ -139,20 +140,19 @@ void SGDSolver<Dtype>::ComputeUpdateValue() {
   }
   Dtype momentum = this->param_.momentum();
   Dtype weight_decay = this->param_.weight_decay();
-  // LOG(ERROR) << "rate:" << rate << " momentum:" << momentum
-  //    << " weight_decay:" << weight_decay;
   switch (Caffe::mode()) {
   case Caffe::CPU:
     for (int param_id = 0; param_id < net_params.size(); ++param_id) {
       // Compute the value to history, and then copy them to the blob's diff.
       Dtype local_rate = rate * net_params_lr[param_id];
+      Dtype local_decay = weight_decay * net_params_weight_decay[param_id];
       caffe_axpby(net_params[param_id]->count(), local_rate,
           net_params[param_id]->cpu_diff(), momentum,
           history_[param_id]->mutable_cpu_data());
-      if (weight_decay) {
+      if (local_decay) {
         // add weight decay
         caffe_axpy(net_params[param_id]->count(),
-            weight_decay * local_rate,
+            local_decay * local_rate,
             net_params[param_id]->cpu_data(),
             history_[param_id]->mutable_cpu_data());
       }
@@ -166,13 +166,14 @@ void SGDSolver<Dtype>::ComputeUpdateValue() {
     for (int param_id = 0; param_id < net_params.size(); ++param_id) {
       // Compute the value to history, and then copy them to the blob's diff.
       Dtype local_rate = rate * net_params_lr[param_id];
+      Dtype local_decay = weight_decay * net_params_weight_decay[param_id];
       caffe_gpu_axpby(net_params[param_id]->count(), local_rate,
           net_params[param_id]->gpu_diff(), momentum,
           history_[param_id]->mutable_gpu_data());
-      if (weight_decay) {
+      if (local_decay) {
         // add weight decay
         caffe_gpu_axpy(net_params[param_id]->count(),
-            weight_decay * local_rate,
+            local_decay * local_rate,
             net_params[param_id]->gpu_data(),
             history_[param_id]->mutable_gpu_data());
       }