]> Gitweb @ Texas Instruments - Open Source Git Repositories - git.TI.com/gitweb - jacinto-ai/caffe-jacinto.git/commitdiff
working halfway into dropout, machine down, changing machine
authorYangqing Jia <jiayq84@gmail.com>
Mon, 16 Sep 2013 20:38:35 +0000 (13:38 -0700)
committerYangqing Jia <jiayq84@gmail.com>
Mon, 16 Sep 2013 20:38:35 +0000 (13:38 -0700)
src/Makefile
src/caffeine/dropout_layer.cu [new file with mode: 0644]
src/caffeine/neuron_layer.cpp [new file with mode: 0644]
src/caffeine/proto/layer_param.proto
src/caffeine/relu_layer.cu [moved from src/caffeine/neuron_layer.cu with 85% similarity]
src/caffeine/vision_layers.hpp

index f3c83917219a97f6c64f0f8285208b98ab0d433b..9ab43e5b2a2dd5a6f7d65eb9bb0197bec43d977a 100644 (file)
@@ -34,7 +34,7 @@ LIBRARY_DIRS := . /usr/local/lib $(CUDA_LIB_DIR) $(MKL_LIB_DIR)
 LIBRARIES := cuda cudart cublas protobuf glog mkl_rt mkl_intel_thread
 WARNINGS := -Wall
 
-CXXFLAGS += -fPIC $(foreach includedir,$(INCLUDE_DIRS),-I$(includedir))
+CXXFLAGS += -fPIC -O2 $(foreach includedir,$(INCLUDE_DIRS),-I$(includedir))
 LDFLAGS += $(foreach librarydir,$(LIBRARY_DIRS),-L$(librarydir))
 LDFLAGS += $(foreach library,$(LIBRARIES),-l$(library))
 
@@ -53,8 +53,8 @@ $(TEST_NAME): $(OBJS) $(TEST_OBJS)
 $(NAME): $(PROTO_GEN_CC) $(OBJS)
        $(LINK) -shared $(OBJS) -o $(NAME)
 
-$(CU_OBJS): $(CU_SRCS)
-       $(NVCC) -c -o $(CU_OBJS) $(CU_SRCS)
+$(CU_OBJS): %.o: %.cu
+       $(NVCC) -c $< -o $@
 
 $(PROTO_GEN_CC): $(PROTO_SRCS)
        protoc $(PROTO_SRCS) --cpp_out=. --python_out=.
diff --git a/src/caffeine/dropout_layer.cu b/src/caffeine/dropout_layer.cu
new file mode 100644 (file)
index 0000000..23999fb
--- /dev/null
@@ -0,0 +1,101 @@
+#include "caffeine/layer.hpp"
+#include "caffeine/vision_layers.hpp"
+#include <algorithm>
+
+using std::max;
+
+namespace caffeine {
+
+template <typename Dtype>
+void DropoutLayer<Dtype>::SetUp(const vector<Blob<Dtype>*>& bottom,
+      vector<Blob<Dtype>*>* top) {
+  NeuronLayer<Dtype>::SetUp(bottom, top);
+  // Set up the cache for random number generation
+  rand_mat_.reset(new Blob<float>(bottom.num(), bottom.channels(),
+      bottom.height(), bottom.width());
+  filler_.reset(new UniformFiller<float>(FillerParameter()));
+};
+
+template <typename Dtype>
+void DropoutLayer<Dtype>::Forward_cpu(const vector<Blob<Dtype>*>& bottom,
+    vector<Blob<Dtype>*>* top) {
+  // First, create the random matrix
+  filler_->Fill(rand_mat_.get()); 
+  const Dtype* bottom_data = bottom[0]->cpu_data();
+  const Dtype* rand_vals = rand_mat_->cpu_data();
+  Dtype* top_data = (*top)[0]->mutable_cpu_data();
+  float threshold = layer_param_->dropout_ratio();
+  float scale = layer_param_->dropo
+  const int count = bottom[0]->count();
+  for (int i = 0; i < count; ++i) {
+    top_data[i] = rand_mat_ > ;
+  }
+}
+
+template <typename Dtype>
+Dtype DropoutLayer<Dtype>::Backward_cpu(const vector<Blob<Dtype>*>& top,
+    const bool propagate_down,
+    vector<Blob<Dtype>*>* bottom) {
+  if (propagate_down) {
+    const Dtype* bottom_data = (*bottom)[0]->cpu_data();
+    const Dtype* top_diff = top[0]->cpu_diff();
+    Dtype* bottom_diff = (*bottom)[0]->mutable_cpu_diff();
+    const int count = (*bottom)[0]->count();
+    for (int i = 0; i < count; ++i) {
+      bottom_diff[i] = top_diff[i] * (bottom_data[i] >= 0);
+    }
+  }
+  return Dtype(0);
+}
+
+template <typename Dtype>
+__global__ void DropoutForward(const int n, const Dtype* in, Dtype* out) {
+  int index = threadIdx.x + blockIdx.x * blockDim.x;
+  if (index < n) {
+    out[index] = max(in[index], Dtype(0.));
+  }
+}
+
+template <typename Dtype>
+void DropoutLayer<Dtype>::Forward_gpu(const vector<Blob<Dtype>*>& bottom,
+    vector<Blob<Dtype>*>* top) {
+  const Dtype* bottom_data = bottom[0]->gpu_data();
+  Dtype* top_data = (*top)[0]->mutable_gpu_data();
+  const int count = bottom[0]->count();
+  const int blocks = (count + CAFFEINE_CUDA_NUM_THREADS - 1) /
+      CAFFEINE_CUDA_NUM_THREADS;
+  DropoutForward<<<blocks, CAFFEINE_CUDA_NUM_THREADS>>>(count, bottom_data,
+      top_data);
+}
+
+template <typename Dtype>
+__global__ void DropoutBackward(const int n, const Dtype* in_diff,
+    const Dtype* in_data, Dtype* out_diff) {
+  int index = threadIdx.x + blockIdx.x * blockDim.x;
+  if (index < n) {
+    out_diff[index] = in_diff[index] * (in_data[index] >= 0);
+  }
+}
+
+template <typename Dtype>
+Dtype DropoutLayer<Dtype>::Backward_gpu(const vector<Blob<Dtype>*>& top,
+    const bool propagate_down,
+    vector<Blob<Dtype>*>* bottom) {
+  if (propagate_down) {
+    const Dtype* bottom_data = (*bottom)[0]->gpu_data();
+    const Dtype* top_diff = top[0]->gpu_diff();
+    Dtype* bottom_diff = (*bottom)[0]->mutable_gpu_diff();
+    const int count = (*bottom)[0]->count();
+    const int blocks = (count + CAFFEINE_CUDA_NUM_THREADS - 1) /
+        CAFFEINE_CUDA_NUM_THREADS;
+    DropoutBackward<<<blocks, CAFFEINE_CUDA_NUM_THREADS>>>(count, top_diff,
+        bottom_data, bottom_diff);
+  }
+  return Dtype(0);
+}
+
+template class DropoutLayer<float>;
+template class DropoutLayer<double>;
+
+
+}  // namespace caffeine
diff --git a/src/caffeine/neuron_layer.cpp b/src/caffeine/neuron_layer.cpp
new file mode 100644 (file)
index 0000000..050c690
--- /dev/null
@@ -0,0 +1,18 @@
+#include "caffeine/layer.hpp"
+#include "caffeine/vision_layers.hpp"
+
+namespace caffeine {
+
+template <typename Dtype>
+void NeuronLayer<Dtype>::SetUp(const vector<Blob<Dtype>*>& bottom,
+      vector<Blob<Dtype>*>* top) {
+  CHECK_EQ(bottom.size(), 1) << "Neuron Layer takes a single blob as input.";
+  CHECK_EQ(top->size(), 1) << "Neuron Layer takes a single blob as output.";
+  (*top)[0]->Reshape(bottom[0]->num(), bottom[0]->channels(),
+      bottom[0]->height(), bottom[0]->width());
+};
+
+template class NeuronLayer<float>;
+template class NeuronLayer<double>;
+
+}  // namespace caffeine
index 1ad42d277e925d358b030b8e7898344035b37a30..7bb37089b08043e835335429c104f96ea97580f3 100644 (file)
@@ -1,8 +1,24 @@
 package caffeine;
 
 message LayerParameter {
-  required string name = 1;
-  required string type = 2;
+  required string name = 1; // the layer name
+  required string type = 2; // the string to specify the layer type
+
+  // Parameters to specify layers with inner products.
+  optional int32 num_output = 3; // The number of outputs for the layer
+  optional bool biasterm = 4 [default = true]; // whether to have bias terms
+  optional FillerParameter weight_filler = 5; // The filler for the weight
+  optional FillerParameter bias_filler = 6; // The filler for the bias
+
+  optional uint32 pad = 7 [default = 0]; // The padding size
+  optional uint32 kernelsize = 8; // The kernel size
+  optional uint32 group = 9 [default = 1]; // The group size for group conv
+  optional uint32 stride = 10 [default = 1]; // The stride
+  optional string pool = 11 [default = 'max']; // The pooling method
+  optional float dropout_ratio = 12 [default = 0.5]; // dropout ratio
+
+  optional float alpha = 13 [default = 1.]; // for local response norm
+  optional float beta = 14 [default = 0.75]; // for local response norm
 }
 
 message FillerParameter {
@@ -21,4 +37,4 @@ message BlobProto {
   optional int32 channels = 4 [default = 0];
   repeated float data = 5;
   repeated float diff = 6;
-}
\ No newline at end of file
+}
similarity index 85%
rename from src/caffeine/neuron_layer.cu
rename to src/caffeine/relu_layer.cu
index 2801248d1d3e59c40a8b1886917469ddd3cce55c..158131a0cd307607954fba2b2eb4382b1b0f9d14 100644 (file)
@@ -6,18 +6,6 @@ using std::max;
 
 namespace caffeine {
 
-template <typename Dtype>
-void NeuronLayer<Dtype>::SetUp(const vector<Blob<Dtype>*>& bottom,
-      vector<Blob<Dtype>*>* top) {
-  CHECK_EQ(bottom.size(), 1) << "Neuron Layer takes a single blob as input.";
-  CHECK_EQ(top->size(), 1) << "Neuron Layer takes a single blob as output.";
-  (*top)[0]->Reshape(bottom[0]->num(), bottom[0]->channels(),
-      bottom[0]->height(), bottom[0]->width());
-};
-
-template class NeuronLayer<float>;
-template class NeuronLayer<double>;
-
 template <typename Dtype>
 void ReLULayer<Dtype>::Forward_cpu(const vector<Blob<Dtype>*>& bottom,
     vector<Blob<Dtype>*>* top) {
index b2f492690f1d7eac5b3252968a245177fb74425e..08561bcea8e840f265be5362d448365ca0775213 100644 (file)
@@ -31,6 +31,32 @@ class ReLULayer : public NeuronLayer<Dtype> {
       const bool propagate_down, vector<Blob<Dtype>*>* bottom);
 };
 
+template <typename Dtype>
+class DropoutLayer : public NeuronLayer<Dtype> {
+ public:
+  explicit DropoutLayer(const LayerParameter& param)
+      : NeuronLayer<Dtype>(param) {};
+  virtual void SetUp(const vector<Blob<Dtype>*>& bottom,
+      vector<Blob<Dtype>*>* top);
+ protected:
+  virtual void Forward_cpu(const vector<Blob<Dtype>*>& bottom,
+      vector<Blob<Dtype>*>* top);
+  virtual void Forward_gpu(const vector<Blob<Dtype>*>& bottom,
+      vector<Blob<Dtype>*>* top);
+
+  virtual Dtype Backward_cpu(const vector<Blob<Dtype>*>& top,
+      const bool propagate_down, vector<Blob<Dtype>*>* bottom);
+  virtual Dtype Backward_gpu(const vector<Blob<Dtype>*>& top,
+      const bool propagate_down, vector<Blob<Dtype>*>* bottom);
+ private:
+  shared_ptr<Blob<float> > rand_mat_;
+  shared_ptr<UniformFiller<float> > filler_;
+};
+
+
+
+
+
 }  // namespace caffeine
 
 #endif  // CAFFEINE_VISION_LAYERS_HPP_