summary | shortlog | log | commit | commitdiff | tree
raw | patch | inline | side by side (parent: 582fa14)
raw | patch | inline | side by side (parent: 582fa14)
author | Yangqing Jia <jiayq84@gmail.com> | |
Mon, 16 Sep 2013 18:25:43 +0000 (11:25 -0700) | ||
committer | Yangqing Jia <jiayq84@gmail.com> | |
Mon, 16 Sep 2013 18:25:43 +0000 (11:25 -0700) |
15 files changed:
diff --git a/src/Makefile b/src/Makefile
index 4c807de8488f4141958262bbfe4cc9d7456f33c1..3d2fca234e278fb60fc4531d4a93ed916933d0f2 100644 (file)
--- a/src/Makefile
+++ b/src/Makefile
NAME := lib$(PROJECT).so
TEST_NAME := test_$(PROJECT)
CXX_SRCS := $(shell find caffeine ! -name "test_*.cpp" -name "*.cpp")
+CU_SRCS := $(shell find caffeine -name "*.cu")
TEST_SRCS := $(shell find caffeine -name "test_*.cpp") gtest/gtest-all.cpp
PROTO_SRCS := $(wildcard caffeine/proto/*.proto)
PROTO_GEN_HEADER := ${PROTO_SRCS:.proto=.pb.h}
PROTO_GEN_CC := ${PROTO_SRCS:.proto=.pb.cc}
CXX_OBJS := ${CXX_SRCS:.cpp=.o}
+CU_OBJS := ${CU_SRCS:.cu=.o}
PROTO_OBJS := ${PROTO_SRCS:.proto=.pb.o}
-OBJS := $(PROTO_OBJS) $(CXX_OBJS)
+OBJS := $(PROTO_OBJS) $(CXX_OBJS) $(CU_OBJS)
TEST_OBJS := ${TEST_SRCS:.cpp=.o}
CUDA_DIR = /usr/local/cuda
+CUDA_ARCH = -arch=sm_20
MKL_DIR = /opt/intel/mkl
CUDA_INCLUDE_DIR = $(CUDA_DIR)/include
LIBRARIES := cuda cudart cublas protobuf glog mkl_rt mkl_intel_thread
WARNINGS := -Wall
-CXXFLAGS += -fPIC $(foreach includedir,$(INCLUDE_DIRS),-I$(includedir))
+CXXFLAGS += $(foreach includedir,$(INCLUDE_DIRS),-I$(includedir))
LDFLAGS += $(foreach librarydir,$(LIBRARY_DIRS),-L$(librarydir))
LDFLAGS += $(foreach library,$(LIBRARIES),-l$(library))
-LINK = $(CXX) $(CXXFLAGS) $(CPPFLAGS) $(LDFLAGS) $(WARNINGS)
+LINK = $(CXX) -fPIC $(CXXFLAGS) $(CPPFLAGS) $(LDFLAGS) $(WARNINGS)
+NVCC = nvcc $(CXXFLAGS) $(CPPFLAGS) $(CUDA_ARCH)
.PHONY: all test clean distclean
$(NAME): $(PROTO_GEN_CC) $(OBJS)
$(LINK) -shared $(OBJS) -o $(NAME)
+$(CU_OBJS): $(CU_SRCS)
+ $(NVCC) -c -o $(CU_OBJS) $(CU_SRCS)
+
$(PROTO_GEN_CC): $(PROTO_SRCS)
protoc $(PROTO_SRCS) --cpp_out=. --python_out=.
index f8446dada2b6f2388b520e8cfc440cfde2c63128..7ac0eada131fc6d84b15aa7a773238fb51518f15 100644 (file)
--- a/src/caffeine/common.cpp
+++ b/src/caffeine/common.cpp
shared_ptr<Caffeine> Caffeine::singleton_;
Caffeine::Caffeine()
- : mode_(Caffeine::CPU) {
+ : mode_(Caffeine::CPU), phase_(Caffeine::TRAIN) {
CUBLAS_CHECK(cublasCreate(&cublas_handle_));
VSL_CHECK(vslNewStream(&vsl_stream_, VSL_BRNG_MT19937, 1701));
}
return Get().mode_;
}
-
void Caffeine::set_mode(Caffeine::Brew mode) {
Get().mode_ = mode;
}
+Caffeine::Phase Caffeine::phase() {
+ return Get().phase_;
+}
+
+void Caffeine::set_phase(Caffeine::Phase phase) {
+ Get().phase_ = phase;
+}
+
} // namespace caffeine
index aab179b18ce1e3d179b13eb308ded076be8ad07b..060d1f7d51fe7fbb17cf097419f73b90a09e69ba 100644 (file)
--- a/src/caffeine/common.hpp
+++ b/src/caffeine/common.hpp
#include "driver_types.h"
+#define CUDA_CHECK(condition) CHECK_EQ((condition), cudaSuccess)
+#define CUBLAS_CHECK(condition) CHECK_EQ((condition), CUBLAS_STATUS_SUCCESS)
+#define VSL_CHECK(condition) CHECK_EQ((condition), VSL_STATUS_OK)
+
namespace caffeine {
// We will use the boost shared_ptr instead of the new C++11 one mainly
// because cuda does not work (at least now) well with C++11 features.
using boost::shared_ptr;
-#define CUDA_CHECK(condition) CHECK_EQ((condition), cudaSuccess)
-#define CUBLAS_CHECK(condition) CHECK_EQ((condition), CUBLAS_STATUS_SUCCESS)
-#define VSL_CHECK(condition) CHECK_EQ((condition), VSL_STATUS_OK)
+// For backward compatibility we will just use 512 threads per block
+const int CAFFEINE_CUDA_NUM_THREADS = 512;
// A singleton class to hold common caffeine stuff, such as the handler that
// caffeine is going to use for cublas.
~Caffeine();
static Caffeine& Get();
enum Brew { CPU, GPU };
+ enum Phase { TRAIN, TEST};
// The getters for the variables.
static cublasHandle_t cublas_handle();
static VSLStreamStatePtr vsl_stream();
static Brew mode();
+ static Phase phase();
// The setters for the variables
static void set_mode(Brew mode);
+ static void set_phase(Phase phase);
private:
Caffeine();
static shared_ptr<Caffeine> singleton_;
cublasHandle_t cublas_handle_;
VSLStreamStatePtr vsl_stream_;
Brew mode_;
+ Phase phase_;
};
} // namespace caffeine
diff --git a/src/caffeine/layer.cpp b/src/caffeine/layer.cpp
index f832b546d7c387da831c291fd2c59bf9cb5c64a4..c33746c4484764a3aa112523d0831b5f07b541df 100644 (file)
--- a/src/caffeine/layer.cpp
+++ b/src/caffeine/layer.cpp
}
};
-template <typename Dtype>
-inline void Layer<Dtype>::Predict(const vector<Blob<Dtype>*>& bottom,
- vector<Blob<Dtype>*>* top) {
- switch(Caffeine::mode()) {
- case Caffeine::CPU:
- Predict_cpu(bottom, top);
- break;
- case Caffeine::GPU:
- Predict_gpu(bottom, top);
- break;
- default:
- LOG(FATAL) << "Unknown caffeine mode.";
- }
-};
+template class Layer<float>;
+template class Layer<double>;
} // namespace caffeine
diff --git a/src/caffeine/layer.hpp b/src/caffeine/layer.hpp
index b7fbe29a4b2895922855df24a9cd25501a09233a..554507616dae41d2f33cf554d5131d0e1a1146b7 100644 (file)
--- a/src/caffeine/layer.hpp
+++ b/src/caffeine/layer.hpp
// layer.
explicit Layer(const LayerParameter& param)
: layer_param_(param) {};
- virtual ~Layer();
+ virtual ~Layer() {};
// SetUp: your function should implement this.
- virtual void SetUp(vector<const Blob<Dtype>*>& bottom,
+ virtual void SetUp(const vector<Blob<Dtype>*>& bottom,
vector<Blob<Dtype>*>* top) = 0;
- // Forward, backward and predict wrappers. You should implement the cpu and
+ // Forward and backward wrappers. You should implement the cpu and
// gpu specific implementations instead, and should not change these
// functions.
inline void Forward(const vector<Blob<Dtype>*>& bottom,
inline Dtype Backward(const vector<Blob<Dtype>*>& top,
const bool propagate_down,
vector<Blob<Dtype>*>* bottom);
- inline void Predict(const vector<Blob<Dtype>*>& bottom,
- vector<Blob<Dtype>*>* top);
protected:
// The protobuf that stores the layer parameters
LOG(WARNING) << "Using CPU code as backup.";
return Backward_cpu(top, propagate_down, bottom);
};
-
- // Prediction functions: could be overridden, but the default behavior is to
- // simply call the forward functions.
- virtual void Predict_cpu(const vector<Blob<Dtype>*>& bottom,
- vector<Blob<Dtype>*>* top) { Forward_cpu(bottom, top); };
- // For prediction, if there is no Predict_gpu, then there are two options:
- // to use predict_cpu as a backup, or to use forward_gpu (e.g. maybe the
- // author forgot to write what backup s/he wants?). Thus, we will require
- // the author to explicitly specify which fallback s/he wants.
- virtual void Predict_gpu(const vector<Blob<Dtype>*>& bottom,
- vector<Blob<Dtype>*>* top) = 0;
}; // class Layer
} // namespace caffeine
diff --git a/src/caffeine/neuron_layer.cpp b/src/caffeine/neuron_layer.cpp
+++ /dev/null
@@ -1,53 +0,0 @@
-#include "caffeine/vision_layers.hpp"
-#include <algorithm>
-
-using std::max;
-
-namespace caffeine {
-
-template <typename Dtype>
-void NeuronLayer<Dtype>::SetUp(vector<const Blob<Dtype>*>& bottom,
- vector<Blob<Dtype>*>* top) {
- CHECK_EQ(bottom.size(), 1) << "Neuron Layer takes a single blob as input.";
- CHECK_EQ(top->size(), 1) << "Neuron Layer takes a single blob as output.";
- for (int i = 0; i < bottom.size(); ++i) {
- (*top)[i].Reshape(bottom.num(), bottom.channels(), bottom.height(),
- bottom.width());
- }
-};
-
-template <typename Dtype>
-void ReLULayer<Dtype>::Forward_cpu(const vector<Blob<Dtype>*>& bottom,
- vector<Blob<Dtype>*>* top) {
- const Dtype* bottom_data = bottom[0].cpu_data();
- Dtype* top_data = (*top)[0].mutable_cpu_data();
- const int count = bottom[0].count();
- for (int i = 0; i < count; ++i) {
- top_data[i] = max(bottom_data[i], Dtype(0));
- }
-}
-
-template <typename Dtype>
-Dtype ReLULayer<Dtype>::Backward_cpu(const vector<Blob<Dtype>*>& top,
- const bool propagate_down,
- vector<Blob<Dtype>*>* bottom) {
- if (propagate_down) {
- const Dtype* bottom_data = (*bottom)[0].cpu_data();
- const Dtype* top_diff = top[0].cpu_diff();
- Dtype* bottom_diff = (*bottom)[0].mutable_cpu_diff();
- const int count = (*bottom)[0].count();
- for (int i = 0; i < count; ++i) {
- bottom_diff[i] = top_diff[i] * (bottom_data[i] >= 0);
- }
- }
- return Dtype(0);
-}
-
-template <typename Dtype>
-inline void ReLULayer<Dtype>::Predict_cpu(const vector<Blob<Dtype>*>& bottom,
- vector<Blob<Dtype>*>* top) {
- Forward_cpu(bottom, top);
-}
-
-
-} // namespace caffeine
diff --git a/src/caffeine/neuron_layer.cu b/src/caffeine/neuron_layer.cu
--- /dev/null
@@ -0,0 +1,98 @@
+#include "caffeine/layer.hpp"
+#include "caffeine/vision_layers.hpp"
+#include <algorithm>
+
+using std::max;
+
+namespace caffeine {
+
+template <typename Dtype>
+void NeuronLayer<Dtype>::SetUp(const vector<Blob<Dtype>*>& bottom,
+ vector<Blob<Dtype>*>* top) {
+ CHECK_EQ(bottom.size(), 1) << "Neuron Layer takes a single blob as input.";
+ CHECK_EQ(top->size(), 1) << "Neuron Layer takes a single blob as output.";
+ (*top)[0]->Reshape(bottom[0]->num(), bottom[0]->channels(),
+ bottom[0]->height(), bottom[0]->width());
+};
+
+template class NeuronLayer<float>;
+template class NeuronLayer<double>;
+
+template <typename Dtype>
+void ReLULayer<Dtype>::Forward_cpu(const vector<Blob<Dtype>*>& bottom,
+ vector<Blob<Dtype>*>* top) {
+ const Dtype* bottom_data = bottom[0]->cpu_data();
+ Dtype* top_data = (*top)[0]->mutable_cpu_data();
+ const int count = bottom[0]->count();
+ for (int i = 0; i < count; ++i) {
+ top_data[i] = max(bottom_data[i], Dtype(0));
+ }
+}
+
+template <typename Dtype>
+Dtype ReLULayer<Dtype>::Backward_cpu(const vector<Blob<Dtype>*>& top,
+ const bool propagate_down,
+ vector<Blob<Dtype>*>* bottom) {
+ if (propagate_down) {
+ const Dtype* bottom_data = (*bottom)[0]->cpu_data();
+ const Dtype* top_diff = top[0]->cpu_diff();
+ Dtype* bottom_diff = (*bottom)[0]->mutable_cpu_diff();
+ const int count = (*bottom)[0]->count();
+ for (int i = 0; i < count; ++i) {
+ bottom_diff[i] = top_diff[i] * (bottom_data[i] >= 0);
+ }
+ }
+ return Dtype(0);
+}
+
+template <typename Dtype>
+__global__ void ReLUForward(const int n, const Dtype* in, Dtype* out) {
+ int index = threadIdx.x + blockIdx.x * blockDim.x;
+ if (index < n) {
+ out[index] = max(in[index], Dtype(0.));
+ }
+}
+
+template <typename Dtype>
+void ReLULayer<Dtype>::Forward_gpu(const vector<Blob<Dtype>*>& bottom,
+ vector<Blob<Dtype>*>* top) {
+ const Dtype* bottom_data = bottom[0]->gpu_data();
+ Dtype* top_data = (*top)[0]->mutable_gpu_data();
+ const int count = bottom[0]->count();
+ const int blocks = (count + CAFFEINE_CUDA_NUM_THREADS - 1) /
+ CAFFEINE_CUDA_NUM_THREADS;
+ ReLUForward<<<blocks, CAFFEINE_CUDA_NUM_THREADS>>>(count, bottom_data,
+ top_data);
+}
+
+template <typename Dtype>
+__global__ void ReLUBackward(const int n, const Dtype* in_diff,
+ const Dtype* in_data, Dtype* out_diff) {
+ int index = threadIdx.x + blockIdx.x * blockDim.x;
+ if (index < n) {
+ out_diff[index] = in_diff[index] * (in_data[index] >= 0);
+ }
+}
+
+template <typename Dtype>
+Dtype ReLULayer<Dtype>::Backward_gpu(const vector<Blob<Dtype>*>& top,
+ const bool propagate_down,
+ vector<Blob<Dtype>*>* bottom) {
+ if (propagate_down) {
+ const Dtype* bottom_data = (*bottom)[0]->gpu_data();
+ const Dtype* top_diff = top[0]->gpu_diff();
+ Dtype* bottom_diff = (*bottom)[0]->mutable_gpu_diff();
+ const int count = (*bottom)[0]->count();
+ const int blocks = (count + CAFFEINE_CUDA_NUM_THREADS - 1) /
+ CAFFEINE_CUDA_NUM_THREADS;
+ ReLUBackward<<<blocks, CAFFEINE_CUDA_NUM_THREADS>>>(count, top_diff,
+ bottom_data, bottom_diff);
+ }
+ return Dtype(0);
+}
+
+template class ReLULayer<float>;
+template class ReLULayer<double>;
+
+
+} // namespace caffeine
similarity index 100%
rename from src/caffeine/test_blob.cpp
rename to src/caffeine/test/test_blob.cpp
rename from src/caffeine/test_blob.cpp
rename to src/caffeine/test/test_blob.cpp
similarity index 100%
rename from src/caffeine/test_caffeine_main.cpp
rename to src/caffeine/test/test_caffeine_main.cpp
rename from src/caffeine/test_caffeine_main.cpp
rename to src/caffeine/test/test_caffeine_main.cpp
similarity index 77%
rename from src/caffeine/test_common.cpp
rename to src/caffeine/test/test_common.cpp
index 1fed2c0dd2ebe9d6a6a76f013ffb9ef8b46e5dad..acf898183eae40f9140eba39352623a3e8989183 100644 (file)
rename from src/caffeine/test_common.cpp
rename to src/caffeine/test/test_common.cpp
index 1fed2c0dd2ebe9d6a6a76f013ffb9ef8b46e5dad..acf898183eae40f9140eba39352623a3e8989183 100644 (file)
EXPECT_EQ(Caffeine::mode(), Caffeine::GPU);
}
+TEST_F(CommonTest, TestPhase) {
+ EXPECT_EQ(Caffeine::phase(), Caffeine::TRAIN);
+ Caffeine::set_phase(Caffeine::TEST);
+ EXPECT_EQ(Caffeine::phase(), Caffeine::TEST);
+}
+
}
similarity index 88%
rename from src/caffeine/test_filler.cpp
rename to src/caffeine/test/test_filler.cpp
index 1d446b35c687ff31b6d0f79afa91fa6b8600974c..61dfdfe1b76c8c28750ca442b3bb74793d3b05fa 100644 (file)
rename from src/caffeine/test_filler.cpp
rename to src/caffeine/test/test_filler.cpp
index 1d446b35c687ff31b6d0f79afa91fa6b8600974c..61dfdfe1b76c8c28750ca442b3bb74793d3b05fa 100644 (file)
const int count = this->blob_->count();
const TypeParam* data = this->blob_->cpu_data();
TypeParam mean = 0.;
+ TypeParam var = 0.;
for (int i = 0; i < count; ++i) {
mean += data[i];
+ var += (data[i] - this->filler_param_.mean()) *
+ (data[i] - this->filler_param_.mean());
}
mean /= count;
- EXPECT_GE(mean, this->filler_param_.mean() - this->filler_param_.std() * 10);
- EXPECT_LE(mean, this->filler_param_.mean() + this->filler_param_.std() * 10);
+ var /= count;
+ // Very loose test.
+ EXPECT_GE(mean, this->filler_param_.mean() - this->filler_param_.std() * 5);
+ EXPECT_LE(mean, this->filler_param_.mean() + this->filler_param_.std() * 5);
+ TypeParam target_var = this->filler_param_.std() * this->filler_param_.std();
+ EXPECT_GE(var, target_var / 5.);
+ EXPECT_LE(var, target_var * 5.);
}
}
diff --git a/src/caffeine/test/test_neuron_layer.cpp b/src/caffeine/test/test_neuron_layer.cpp
--- /dev/null
@@ -0,0 +1,48 @@
+#include <cstring>
+#include <cuda_runtime.h>
+
+#include "gtest/gtest.h"
+#include "caffeine/blob.hpp"
+#include "caffeine/common.hpp"
+#include "caffeine/filler.hpp"
+#include "caffeine/vision_layers.hpp"
+
+namespace caffeine {
+
+template <typename Dtype>
+class NeuronLayerTest : public ::testing::Test {
+ protected:
+ NeuronLayerTest()
+ : blob_bottom_(new Blob<Dtype>(2, 3, 4, 5)),
+ blob_top_(new Blob<Dtype>(2, 3, 4, 5)) {
+ // fill the values
+ FillerParameter filler_param;
+ GaussianFiller<Dtype> filler(filler_param);
+ filler.Fill(this->blob_bottom_);
+ blob_bottom_vec_.push_back(blob_bottom_);
+ blob_top_vec_.push_back(blob_top_);
+ };
+ virtual ~NeuronLayerTest() { delete blob_bottom_; delete blob_top_; }
+ Blob<Dtype>* const blob_bottom_;
+ Blob<Dtype>* const blob_top_;
+ vector<Blob<Dtype>*> blob_bottom_vec_;
+ vector<Blob<Dtype>*> blob_top_vec_;
+};
+
+typedef ::testing::Types<float, double> Dtypes;
+TYPED_TEST_CASE(NeuronLayerTest, Dtypes);
+
+TYPED_TEST(NeuronLayerTest, TestReLU) {
+ LayerParameter layer_param;
+ ReLULayer<TypeParam> layer(layer_param);
+ layer.Forward(this->blob_bottom_vec_, &(this->blob_top_vec_));
+ // Now, check values
+ const TypeParam* bottom_data = this->blob_bottom_->cpu_data();
+ const TypeParam* top_data = this->blob_top_->cpu_data();
+ for (int i = 0; i < this->blob_bottom_->count(); ++i) {
+ EXPECT_GE(top_data[i], 0.);
+ EXPECT_TRUE(top_data[i] == 0 || top_data[i] == bottom_data[i]);
+ }
+}
+
+}
similarity index 100%
rename from src/caffeine/test_syncedmem.cpp
rename to src/caffeine/test/test_syncedmem.cpp
rename from src/caffeine/test_syncedmem.cpp
rename to src/caffeine/test/test_syncedmem.cpp
diff --git a/src/caffeine/test_neuron_layer.cpp b/src/caffeine/test_neuron_layer.cpp
+++ /dev/null
@@ -1,27 +0,0 @@
-#include <cstring>
-#include <cuda_runtime.h>
-
-#include "gtest/gtest.h"
-#include "caffeine/common.hpp"
-#include "caffeine/blob.hpp"
-
-namespace caffeine {
-
-template <typename Dtype>
-class NeuronLayerTest : public ::testing::Test {
- protected:
- NeuronLayerTest()
- : blob_bottom_(new Blob<Dtype>(2, 3, 4, 5)),
- blob_top_(new Blob<Dtype>(2, 3, 4, 5)) {
- // fill the values
-
- };
- virtual ~NeuronLayerTest() { delete blob_bottom_; delete blob_top_; }
- Blob<Dtype>* const blob_bottom_;
- Blob<Dtype>* const blob_top_;
-};
-
-typedef ::testing::Types<float, double> Dtypes;
-TYPED_TEST_CASE(NeuronLayerTest, Dtypes);
-
-}
index 19548890c533aa3680fdaca5b8dd6e74cea5f132..b2f492690f1d7eac5b3252968a245177fb74425e 100644 (file)
template <typename Dtype>
class NeuronLayer : public Layer<Dtype> {
public:
- virtual void SetUp(vector<const Blob<Dtype>*>& bottom,
+ explicit NeuronLayer(const LayerParameter& param)
+ : Layer<Dtype>(param) {};
+ virtual void SetUp(const vector<Blob<Dtype>*>& bottom,
vector<Blob<Dtype>*>* top);
};
template <typename Dtype>
class ReLULayer : public NeuronLayer<Dtype> {
+ public:
+ explicit ReLULayer(const LayerParameter& param)
+ : NeuronLayer<Dtype>(param) {};
protected:
virtual void Forward_cpu(const vector<Blob<Dtype>*>& bottom,
vector<Blob<Dtype>*>* top);
const bool propagate_down, vector<Blob<Dtype>*>* bottom);
virtual Dtype Backward_gpu(const vector<Blob<Dtype>*>& top,
const bool propagate_down, vector<Blob<Dtype>*>* bottom);
-
- virtual void Predict_cpu(const vector<Blob<Dtype>*>& bottom,
- vector<Blob<Dtype>*>* top);
- virtual void Predict_gpu(const vector<Blob<Dtype>*>& bottom,
- vector<Blob<Dtype>*>* top);
};
} // namespace caffeine