]> Gitweb @ Texas Instruments - Open Source Git Repositories - git.TI.com/gitweb - jacinto-ai/caffe-jacinto.git/commitdiff
misc update
authorYangqing Jia <jiayq84@gmail.com>
Mon, 16 Sep 2013 18:25:43 +0000 (11:25 -0700)
committerYangqing Jia <jiayq84@gmail.com>
Mon, 16 Sep 2013 18:25:43 +0000 (11:25 -0700)
15 files changed:
src/Makefile
src/caffeine/common.cpp
src/caffeine/common.hpp
src/caffeine/layer.cpp
src/caffeine/layer.hpp
src/caffeine/neuron_layer.cpp [deleted file]
src/caffeine/neuron_layer.cu [new file with mode: 0644]
src/caffeine/test/test_blob.cpp [moved from src/caffeine/test_blob.cpp with 100% similarity]
src/caffeine/test/test_caffeine_main.cpp [moved from src/caffeine/test_caffeine_main.cpp with 100% similarity]
src/caffeine/test/test_common.cpp [moved from src/caffeine/test_common.cpp with 77% similarity]
src/caffeine/test/test_filler.cpp [moved from src/caffeine/test_filler.cpp with 88% similarity]
src/caffeine/test/test_neuron_layer.cpp [new file with mode: 0644]
src/caffeine/test/test_syncedmem.cpp [moved from src/caffeine/test_syncedmem.cpp with 100% similarity]
src/caffeine/test_neuron_layer.cpp [deleted file]
src/caffeine/vision_layers.hpp

index 4c807de8488f4141958262bbfe4cc9d7456f33c1..3d2fca234e278fb60fc4531d4a93ed916933d0f2 100644 (file)
@@ -9,16 +9,19 @@ PROJECT := caffeine
 NAME := lib$(PROJECT).so
 TEST_NAME := test_$(PROJECT)
 CXX_SRCS := $(shell find caffeine ! -name "test_*.cpp" -name "*.cpp")
+CU_SRCS := $(shell find caffeine -name "*.cu")
 TEST_SRCS := $(shell find caffeine -name "test_*.cpp") gtest/gtest-all.cpp
 PROTO_SRCS := $(wildcard caffeine/proto/*.proto)
 PROTO_GEN_HEADER := ${PROTO_SRCS:.proto=.pb.h}
 PROTO_GEN_CC := ${PROTO_SRCS:.proto=.pb.cc}
 CXX_OBJS := ${CXX_SRCS:.cpp=.o}
+CU_OBJS := ${CU_SRCS:.cu=.o}
 PROTO_OBJS := ${PROTO_SRCS:.proto=.pb.o}
-OBJS := $(PROTO_OBJS) $(CXX_OBJS)
+OBJS := $(PROTO_OBJS) $(CXX_OBJS) $(CU_OBJS)
 TEST_OBJS := ${TEST_SRCS:.cpp=.o}
 
 CUDA_DIR = /usr/local/cuda
+CUDA_ARCH = -arch=sm_20
 MKL_DIR = /opt/intel/mkl
 
 CUDA_INCLUDE_DIR = $(CUDA_DIR)/include
@@ -31,11 +34,12 @@ LIBRARY_DIRS := . /usr/local/lib $(CUDA_LIB_DIR) $(MKL_LIB_DIR)
 LIBRARIES := cuda cudart cublas protobuf glog mkl_rt mkl_intel_thread
 WARNINGS := -Wall
 
-CXXFLAGS += -fPIC $(foreach includedir,$(INCLUDE_DIRS),-I$(includedir))
+CXXFLAGS += $(foreach includedir,$(INCLUDE_DIRS),-I$(includedir))
 LDFLAGS += $(foreach librarydir,$(LIBRARY_DIRS),-L$(librarydir))
 LDFLAGS += $(foreach library,$(LIBRARIES),-l$(library))
 
-LINK = $(CXX) $(CXXFLAGS) $(CPPFLAGS) $(LDFLAGS) $(WARNINGS)
+LINK = $(CXX) -fPIC $(CXXFLAGS) $(CPPFLAGS) $(LDFLAGS) $(WARNINGS)
+NVCC = nvcc $(CXXFLAGS) $(CPPFLAGS) $(CUDA_ARCH)
 
 .PHONY: all test clean distclean
 
@@ -49,6 +53,9 @@ $(TEST_NAME): $(OBJS) $(TEST_OBJS)
 $(NAME): $(PROTO_GEN_CC) $(OBJS)
        $(LINK) -shared $(OBJS) -o $(NAME)
 
+$(CU_OBJS): $(CU_SRCS)
+       $(NVCC) -c -o $(CU_OBJS) $(CU_SRCS)
+
 $(PROTO_GEN_CC): $(PROTO_SRCS)
        protoc $(PROTO_SRCS) --cpp_out=. --python_out=.
 
index f8446dada2b6f2388b520e8cfc440cfde2c63128..7ac0eada131fc6d84b15aa7a773238fb51518f15 100644 (file)
@@ -5,7 +5,7 @@ namespace caffeine {
 shared_ptr<Caffeine> Caffeine::singleton_;
 
 Caffeine::Caffeine()
-    : mode_(Caffeine::CPU) {
+    : mode_(Caffeine::CPU), phase_(Caffeine::TRAIN) {
   CUBLAS_CHECK(cublasCreate(&cublas_handle_));
   VSL_CHECK(vslNewStream(&vsl_stream_, VSL_BRNG_MT19937, 1701));
 }
@@ -34,10 +34,17 @@ Caffeine::Brew Caffeine::mode() {
   return Get().mode_;
 }
 
-
 void Caffeine::set_mode(Caffeine::Brew mode) {
   Get().mode_ = mode;
 }
 
+Caffeine::Phase Caffeine::phase() {
+  return Get().phase_;
+}
+
+void Caffeine::set_phase(Caffeine::Phase phase) {
+  Get().phase_ = phase;
+}
+
 }  // namespace caffeine
 
index aab179b18ce1e3d179b13eb308ded076be8ad07b..060d1f7d51fe7fbb17cf097419f73b90a09e69ba 100644 (file)
@@ -8,15 +8,18 @@
 
 #include "driver_types.h"
 
+#define CUDA_CHECK(condition) CHECK_EQ((condition), cudaSuccess)
+#define CUBLAS_CHECK(condition) CHECK_EQ((condition), CUBLAS_STATUS_SUCCESS)
+#define VSL_CHECK(condition) CHECK_EQ((condition), VSL_STATUS_OK)
+
 namespace caffeine {
 
 // We will use the boost shared_ptr instead of the new C++11 one mainly
 // because cuda does not work (at least now) well with C++11 features.
 using boost::shared_ptr;
 
-#define CUDA_CHECK(condition) CHECK_EQ((condition), cudaSuccess)
-#define CUBLAS_CHECK(condition) CHECK_EQ((condition), CUBLAS_STATUS_SUCCESS)
-#define VSL_CHECK(condition) CHECK_EQ((condition), VSL_STATUS_OK)
+// For backward compatibility we will just use 512 threads per block
+const int CAFFEINE_CUDA_NUM_THREADS = 512;
 
 // A singleton class to hold common caffeine stuff, such as the handler that
 // caffeine is going to use for cublas.
@@ -25,19 +28,23 @@ class Caffeine {
   ~Caffeine();
   static Caffeine& Get();
   enum Brew { CPU, GPU };
+  enum Phase { TRAIN, TEST};
 
   // The getters for the variables. 
   static cublasHandle_t cublas_handle();
   static VSLStreamStatePtr vsl_stream();
   static Brew mode();
+  static Phase phase();
   // The setters for the variables
   static void set_mode(Brew mode);
+  static void set_phase(Phase phase);
  private:
   Caffeine();
   static shared_ptr<Caffeine> singleton_;
   cublasHandle_t cublas_handle_;
   VSLStreamStatePtr vsl_stream_;
   Brew mode_;
+  Phase phase_;
 };
 
 }  // namespace caffeine
index f832b546d7c387da831c291fd2c59bf9cb5c64a4..c33746c4484764a3aa112523d0831b5f07b541df 100644 (file)
@@ -36,19 +36,7 @@ inline Dtype Layer<Dtype>::Backward(const vector<Blob<Dtype>*>& top,
   }
 };
 
-template <typename Dtype>
-inline void Layer<Dtype>::Predict(const vector<Blob<Dtype>*>& bottom,
-    vector<Blob<Dtype>*>* top) {
-  switch(Caffeine::mode()) {
-  case Caffeine::CPU:
-    Predict_cpu(bottom, top);
-    break;
-  case Caffeine::GPU:
-    Predict_gpu(bottom, top);
-    break;
-  default:
-    LOG(FATAL) << "Unknown caffeine mode.";
-  }
-};
+template class Layer<float>;
+template class Layer<double>;
 
 }  // namespace caffeine
index b7fbe29a4b2895922855df24a9cd25501a09233a..554507616dae41d2f33cf554d5131d0e1a1146b7 100644 (file)
@@ -18,12 +18,12 @@ class Layer {
    // layer.
   explicit Layer(const LayerParameter& param)
     : layer_param_(param) {};
-  virtual ~Layer();
+  virtual ~Layer() {};
   // SetUp: your function should implement this.
-  virtual void SetUp(vector<const Blob<Dtype>*>& bottom,
+  virtual void SetUp(const vector<Blob<Dtype>*>& bottom,
       vector<Blob<Dtype>*>* top) = 0;
 
-  // Forward, backward and predict wrappers. You should implement the cpu and
+  // Forward and backward wrappers. You should implement the cpu and
   // gpu specific implementations instead, and should not change these
   // functions.
   inline void Forward(const vector<Blob<Dtype>*>& bottom,
@@ -31,8 +31,6 @@ class Layer {
   inline Dtype Backward(const vector<Blob<Dtype>*>& top,
       const bool propagate_down,
       vector<Blob<Dtype>*>* bottom);
-  inline void Predict(const vector<Blob<Dtype>*>& bottom,
-      vector<Blob<Dtype>*>* top);
 
  protected:
   // The protobuf that stores the layer parameters
@@ -62,17 +60,6 @@ class Layer {
     LOG(WARNING) << "Using CPU code as backup.";
     return Backward_cpu(top, propagate_down, bottom);
   };
-
-  // Prediction functions: could be overridden, but the default behavior is to
-  // simply call the forward functions.
-  virtual void Predict_cpu(const vector<Blob<Dtype>*>& bottom,
-      vector<Blob<Dtype>*>* top) { Forward_cpu(bottom, top); };
-  // For prediction, if there is no Predict_gpu, then there are two options:
-  // to use predict_cpu as a backup, or to use forward_gpu (e.g. maybe the
-  // author forgot to write what backup s/he wants?). Thus, we will require
-  // the author to explicitly specify which fallback s/he wants.
-  virtual void Predict_gpu(const vector<Blob<Dtype>*>& bottom,
-      vector<Blob<Dtype>*>* top) = 0;
 };  // class Layer
 
 }  // namespace caffeine
diff --git a/src/caffeine/neuron_layer.cpp b/src/caffeine/neuron_layer.cpp
deleted file mode 100644 (file)
index 8f454c3..0000000
+++ /dev/null
@@ -1,53 +0,0 @@
-#include "caffeine/vision_layers.hpp"
-#include <algorithm>
-
-using std::max;
-
-namespace caffeine {
-
-template <typename Dtype>
-void NeuronLayer<Dtype>::SetUp(vector<const Blob<Dtype>*>& bottom,
-      vector<Blob<Dtype>*>* top) {
-  CHECK_EQ(bottom.size(), 1) << "Neuron Layer takes a single blob as input.";
-  CHECK_EQ(top->size(), 1) << "Neuron Layer takes a single blob as output.";
-  for (int i = 0; i < bottom.size(); ++i) {
-    (*top)[i].Reshape(bottom.num(), bottom.channels(), bottom.height(),
-                      bottom.width());
-  }
-};
-
-template <typename Dtype>
-void ReLULayer<Dtype>::Forward_cpu(const vector<Blob<Dtype>*>& bottom,
-    vector<Blob<Dtype>*>* top) {
-  const Dtype* bottom_data = bottom[0].cpu_data();
-  Dtype* top_data = (*top)[0].mutable_cpu_data();
-  const int count = bottom[0].count();
-  for (int i = 0; i < count; ++i) {
-    top_data[i] = max(bottom_data[i], Dtype(0));
-  }
-}
-
-template <typename Dtype>
-Dtype ReLULayer<Dtype>::Backward_cpu(const vector<Blob<Dtype>*>& top,
-    const bool propagate_down,
-    vector<Blob<Dtype>*>* bottom) {
-  if (propagate_down) {
-    const Dtype* bottom_data = (*bottom)[0].cpu_data();
-    const Dtype* top_diff = top[0].cpu_diff();
-    Dtype* bottom_diff = (*bottom)[0].mutable_cpu_diff();
-    const int count = (*bottom)[0].count();
-    for (int i = 0; i < count; ++i) {
-      bottom_diff[i] = top_diff[i] * (bottom_data[i] >= 0);
-    }
-  }
-  return Dtype(0);
-}
-
-template <typename Dtype>
-inline void ReLULayer<Dtype>::Predict_cpu(const vector<Blob<Dtype>*>& bottom,
-    vector<Blob<Dtype>*>* top) {
-  Forward_cpu(bottom, top);
-}
-
-
-}  // namespace caffeine
diff --git a/src/caffeine/neuron_layer.cu b/src/caffeine/neuron_layer.cu
new file mode 100644 (file)
index 0000000..2801248
--- /dev/null
@@ -0,0 +1,98 @@
+#include "caffeine/layer.hpp"
+#include "caffeine/vision_layers.hpp"
+#include <algorithm>
+
+using std::max;
+
+namespace caffeine {
+
+template <typename Dtype>
+void NeuronLayer<Dtype>::SetUp(const vector<Blob<Dtype>*>& bottom,
+      vector<Blob<Dtype>*>* top) {
+  CHECK_EQ(bottom.size(), 1) << "Neuron Layer takes a single blob as input.";
+  CHECK_EQ(top->size(), 1) << "Neuron Layer takes a single blob as output.";
+  (*top)[0]->Reshape(bottom[0]->num(), bottom[0]->channels(),
+      bottom[0]->height(), bottom[0]->width());
+};
+
+template class NeuronLayer<float>;
+template class NeuronLayer<double>;
+
+template <typename Dtype>
+void ReLULayer<Dtype>::Forward_cpu(const vector<Blob<Dtype>*>& bottom,
+    vector<Blob<Dtype>*>* top) {
+  const Dtype* bottom_data = bottom[0]->cpu_data();
+  Dtype* top_data = (*top)[0]->mutable_cpu_data();
+  const int count = bottom[0]->count();
+  for (int i = 0; i < count; ++i) {
+    top_data[i] = max(bottom_data[i], Dtype(0));
+  }
+}
+
+template <typename Dtype>
+Dtype ReLULayer<Dtype>::Backward_cpu(const vector<Blob<Dtype>*>& top,
+    const bool propagate_down,
+    vector<Blob<Dtype>*>* bottom) {
+  if (propagate_down) {
+    const Dtype* bottom_data = (*bottom)[0]->cpu_data();
+    const Dtype* top_diff = top[0]->cpu_diff();
+    Dtype* bottom_diff = (*bottom)[0]->mutable_cpu_diff();
+    const int count = (*bottom)[0]->count();
+    for (int i = 0; i < count; ++i) {
+      bottom_diff[i] = top_diff[i] * (bottom_data[i] >= 0);
+    }
+  }
+  return Dtype(0);
+}
+
+template <typename Dtype>
+__global__ void ReLUForward(const int n, const Dtype* in, Dtype* out) {
+  int index = threadIdx.x + blockIdx.x * blockDim.x;
+  if (index < n) {
+    out[index] = max(in[index], Dtype(0.));
+  }
+}
+
+template <typename Dtype>
+void ReLULayer<Dtype>::Forward_gpu(const vector<Blob<Dtype>*>& bottom,
+    vector<Blob<Dtype>*>* top) {
+  const Dtype* bottom_data = bottom[0]->gpu_data();
+  Dtype* top_data = (*top)[0]->mutable_gpu_data();
+  const int count = bottom[0]->count();
+  const int blocks = (count + CAFFEINE_CUDA_NUM_THREADS - 1) /
+      CAFFEINE_CUDA_NUM_THREADS;
+  ReLUForward<<<blocks, CAFFEINE_CUDA_NUM_THREADS>>>(count, bottom_data,
+      top_data);
+}
+
+template <typename Dtype>
+__global__ void ReLUBackward(const int n, const Dtype* in_diff,
+    const Dtype* in_data, Dtype* out_diff) {
+  int index = threadIdx.x + blockIdx.x * blockDim.x;
+  if (index < n) {
+    out_diff[index] = in_diff[index] * (in_data[index] >= 0);
+  }
+}
+
+template <typename Dtype>
+Dtype ReLULayer<Dtype>::Backward_gpu(const vector<Blob<Dtype>*>& top,
+    const bool propagate_down,
+    vector<Blob<Dtype>*>* bottom) {
+  if (propagate_down) {
+    const Dtype* bottom_data = (*bottom)[0]->gpu_data();
+    const Dtype* top_diff = top[0]->gpu_diff();
+    Dtype* bottom_diff = (*bottom)[0]->mutable_gpu_diff();
+    const int count = (*bottom)[0]->count();
+    const int blocks = (count + CAFFEINE_CUDA_NUM_THREADS - 1) /
+        CAFFEINE_CUDA_NUM_THREADS;
+    ReLUBackward<<<blocks, CAFFEINE_CUDA_NUM_THREADS>>>(count, top_diff,
+        bottom_data, bottom_diff);
+  }
+  return Dtype(0);
+}
+
+template class ReLULayer<float>;
+template class ReLULayer<double>;
+
+
+}  // namespace caffeine
similarity index 77%
rename from src/caffeine/test_common.cpp
rename to src/caffeine/test/test_common.cpp
index 1fed2c0dd2ebe9d6a6a76f013ffb9ef8b46e5dad..acf898183eae40f9140eba39352623a3e8989183 100644 (file)
@@ -24,4 +24,10 @@ TEST_F(CommonTest, TestBrewMode) {
  EXPECT_EQ(Caffeine::mode(), Caffeine::GPU);
 }
 
+TEST_F(CommonTest, TestPhase) {
+ EXPECT_EQ(Caffeine::phase(), Caffeine::TRAIN);
+ Caffeine::set_phase(Caffeine::TEST);
+ EXPECT_EQ(Caffeine::phase(), Caffeine::TEST);
+}
+
 }
similarity index 88%
rename from src/caffeine/test_filler.cpp
rename to src/caffeine/test/test_filler.cpp
index 1d446b35c687ff31b6d0f79afa91fa6b8600974c..61dfdfe1b76c8c28750ca442b3bb74793d3b05fa 100644 (file)
@@ -89,12 +89,20 @@ TYPED_TEST(GaussianFillerTest, TestFill) {
   const int count = this->blob_->count();
   const TypeParam* data = this->blob_->cpu_data();
   TypeParam mean = 0.;
+  TypeParam var = 0.;
   for (int i = 0; i < count; ++i) {
     mean += data[i];
+    var += (data[i] - this->filler_param_.mean()) * 
+        (data[i] - this->filler_param_.mean());
   }
   mean /= count;
-  EXPECT_GE(mean, this->filler_param_.mean() - this->filler_param_.std() * 10);
-  EXPECT_LE(mean, this->filler_param_.mean() + this->filler_param_.std() * 10);
+  var /= count;
+  // Very loose test.
+  EXPECT_GE(mean, this->filler_param_.mean() - this->filler_param_.std() * 5);
+  EXPECT_LE(mean, this->filler_param_.mean() + this->filler_param_.std() * 5);
+  TypeParam target_var = this->filler_param_.std() * this->filler_param_.std();
+  EXPECT_GE(var, target_var / 5.);
+  EXPECT_LE(var, target_var * 5.);
 }
 
 }
diff --git a/src/caffeine/test/test_neuron_layer.cpp b/src/caffeine/test/test_neuron_layer.cpp
new file mode 100644 (file)
index 0000000..db33f7b
--- /dev/null
@@ -0,0 +1,48 @@
+#include <cstring>
+#include <cuda_runtime.h>
+
+#include "gtest/gtest.h"
+#include "caffeine/blob.hpp"
+#include "caffeine/common.hpp"
+#include "caffeine/filler.hpp"
+#include "caffeine/vision_layers.hpp"
+
+namespace caffeine {
+  
+template <typename Dtype>
+class NeuronLayerTest : public ::testing::Test {
+ protected:
+  NeuronLayerTest()
+      : blob_bottom_(new Blob<Dtype>(2, 3, 4, 5)),
+        blob_top_(new Blob<Dtype>(2, 3, 4, 5)) {
+    // fill the values
+    FillerParameter filler_param;
+    GaussianFiller<Dtype> filler(filler_param);
+    filler.Fill(this->blob_bottom_);
+    blob_bottom_vec_.push_back(blob_bottom_);
+    blob_top_vec_.push_back(blob_top_);
+  };
+  virtual ~NeuronLayerTest() { delete blob_bottom_; delete blob_top_; }
+  Blob<Dtype>* const blob_bottom_;
+  Blob<Dtype>* const blob_top_;
+  vector<Blob<Dtype>*> blob_bottom_vec_;
+  vector<Blob<Dtype>*> blob_top_vec_;
+};
+
+typedef ::testing::Types<float, double> Dtypes;
+TYPED_TEST_CASE(NeuronLayerTest, Dtypes);
+
+TYPED_TEST(NeuronLayerTest, TestReLU) {
+  LayerParameter layer_param;
+  ReLULayer<TypeParam> layer(layer_param);
+  layer.Forward(this->blob_bottom_vec_, &(this->blob_top_vec_));
+  // Now, check values
+  const TypeParam* bottom_data = this->blob_bottom_->cpu_data();
+  const TypeParam* top_data = this->blob_top_->cpu_data();
+  for (int i = 0; i < this->blob_bottom_->count(); ++i) {
+    EXPECT_GE(top_data[i], 0.);
+    EXPECT_TRUE(top_data[i] == 0 || top_data[i] == bottom_data[i]);
+  }
+}
+
+}
diff --git a/src/caffeine/test_neuron_layer.cpp b/src/caffeine/test_neuron_layer.cpp
deleted file mode 100644 (file)
index e38037f..0000000
+++ /dev/null
@@ -1,27 +0,0 @@
-#include <cstring>
-#include <cuda_runtime.h>
-
-#include "gtest/gtest.h"
-#include "caffeine/common.hpp"
-#include "caffeine/blob.hpp"
-
-namespace caffeine {
-  
-template <typename Dtype>
-class NeuronLayerTest : public ::testing::Test {
- protected:
-  NeuronLayerTest()
-      : blob_bottom_(new Blob<Dtype>(2, 3, 4, 5)),
-        blob_top_(new Blob<Dtype>(2, 3, 4, 5)) {
-    // fill the values
-    
-  };
-  virtual ~NeuronLayerTest() { delete blob_bottom_; delete blob_top_; }
-  Blob<Dtype>* const blob_bottom_;
-  Blob<Dtype>* const blob_top_;
-};
-
-typedef ::testing::Types<float, double> Dtypes;
-TYPED_TEST_CASE(NeuronLayerTest, Dtypes);
-
-}
index 19548890c533aa3680fdaca5b8dd6e74cea5f132..b2f492690f1d7eac5b3252968a245177fb74425e 100644 (file)
@@ -8,12 +8,17 @@ namespace caffeine {
 template <typename Dtype>
 class NeuronLayer : public Layer<Dtype> {
  public:
-  virtual void SetUp(vector<const Blob<Dtype>*>& bottom,
+  explicit NeuronLayer(const LayerParameter& param)
+     : Layer<Dtype>(param) {};
+  virtual void SetUp(const vector<Blob<Dtype>*>& bottom,
       vector<Blob<Dtype>*>* top);
 };
 
 template <typename Dtype>
 class ReLULayer : public NeuronLayer<Dtype> {
+ public:
+  explicit ReLULayer(const LayerParameter& param)
+      : NeuronLayer<Dtype>(param) {};
  protected:
   virtual void Forward_cpu(const vector<Blob<Dtype>*>& bottom,
       vector<Blob<Dtype>*>* top);
@@ -24,11 +29,6 @@ class ReLULayer : public NeuronLayer<Dtype> {
       const bool propagate_down, vector<Blob<Dtype>*>* bottom);
   virtual Dtype Backward_gpu(const vector<Blob<Dtype>*>& top,
       const bool propagate_down, vector<Blob<Dtype>*>* bottom);
-
-  virtual void Predict_cpu(const vector<Blob<Dtype>*>& bottom,
-    vector<Blob<Dtype>*>* top);
-  virtual void Predict_gpu(const vector<Blob<Dtype>*>& bottom,
-      vector<Blob<Dtype>*>* top);
 };
 
 }  // namespace caffeine