summary | shortlog | log | commit | commitdiff | tree
raw | patch | inline | side by side (from parent 1: 1e31fc5)
raw | patch | inline | side by side (from parent 1: 1e31fc5)
author | Yangqing Jia <jiayq84@gmail.com> | |
Thu, 19 Sep 2013 23:26:01 +0000 (16:26 -0700) | ||
committer | Yangqing Jia <jiayq84@gmail.com> | |
Thu, 19 Sep 2013 23:26:01 +0000 (16:26 -0700) |
diff --git a/.gitignore b/.gitignore
index f4489a608a0ffe670978b72fddd6b4395f6bad0c..ca27edc412e61b21d3b08b2e9ded4a51e45c505e 100644 (file)
--- a/.gitignore
+++ b/.gitignore
*.slo
*.lo
*.o
+*.cuo
# Compiled Dynamic libraries
*.so
diff --git a/src/Makefile b/src/Makefile
index cd89c7d1ad817203d1e9f6068b456c3c6e0369a0..2b53501a05841c6821c1fb8c69583042e19cc1a7 100644 (file)
--- a/src/Makefile
+++ b/src/Makefile
PROTO_GEN_CC := ${PROTO_SRCS:.proto=.pb.cc}
PROTO_GEN_PY := ${PROTO_SRCS:.proto=_pb2.py}
CXX_OBJS := ${CXX_SRCS:.cpp=.o}
-CU_OBJS := ${CU_SRCS:.cu=.o}
+CU_OBJS := ${CU_SRCS:.cu=.cuo}
PROTO_OBJS := ${PROTO_SRCS:.proto=.pb.o}
OBJS := $(PROTO_OBJS) $(CXX_OBJS) $(CU_OBJS)
TEST_OBJS := ${TEST_SRCS:.cpp=.o}
$(NAME): $(PROTO_GEN_CC) $(OBJS)
$(LINK) -shared $(OBJS) -o $(NAME)
-$(CU_OBJS): %.o: %.cu
+$(CU_OBJS): %.cuo: %.cu
$(NVCC) -c $< -o $@
$(PROTO_GEN_CC): $(PROTO_SRCS)
index 9818907fc0e3978f5844b71c3c08f768a8d68e0f..2d9e2600afb6bbfb3e49b6c40ae3f19eff41386d 100644 (file)
DropoutForward<Dtype><<<CAFFEINE_GET_BLOCKS(count), CAFFEINE_CUDA_NUM_THREADS>>>(
count, bottom_data, (unsigned int*)rand_vec_->gpu_data(), uint_thres_, scale_,
top_data);
+ CUDA_POST_KERNEL_CHECK;
} else {
CUDA_CHECK(cudaMemcpy(top_data, bottom_data,
count * sizeof(Dtype), cudaMemcpyDeviceToDevice));
const int count = (*bottom)[0]->count();
DropoutBackward<Dtype><<<CAFFEINE_GET_BLOCKS(count), CAFFEINE_CUDA_NUM_THREADS>>>(
count, top_diff, mask, uint_thres_, scale_, bottom_diff);
+ CUDA_POST_KERNEL_CHECK;
}
return Dtype(0);
}
index c6020e378e4b2d24a175ccd48a6f7bfd2a157a2f..ee874932d8fd7e2802aab38d234eccddbda01b34 100644 (file)
#include "caffeine/layer.hpp"
#include "caffeine/util/im2col.hpp"
#include "caffeine/vision_layers.hpp"
+#include "caffeine/common.hpp"
namespace caffeine {
}
}
+template <typename Dtype>
+void Im2colLayer<Dtype>::Forward_gpu(const vector<Blob<Dtype>*>& bottom,
+ vector<Blob<Dtype>*>* top) {
+ const Dtype* bottom_data = bottom[0]->gpu_data();
+ Dtype* top_data = (*top)[0]->mutable_gpu_data();
+ for (int n = 0; n < bottom[0]->num(); ++n) {
+ im2col_gpu(bottom_data + bottom[0]->offset(n), CHANNELS_, HEIGHT_,
+ WIDTH_, KSIZE_, STRIDE_, top_data + (*top)[0]->offset(n));
+ }
+}
+
template <typename Dtype>
Dtype Im2colLayer<Dtype>::Backward_cpu(const vector<Blob<Dtype>*>& top,
const bool propagate_down, vector<Blob<Dtype>*>* bottom) {
index 21714679b39ff3c497c9c630bb284e688d559114..53bd1f50b8eef20cb5ec330c7a8535e0ad455a47 100644 (file)
PaddingBackward<Dtype><<<CAFFEINE_GET_BLOCKS(count), CAFFEINE_CUDA_NUM_THREADS>>>(
count, top_diff, bottom_diff, NUM_, CHANNEL_, HEIGHT_IN_, WIDTH_IN_,
PAD_);
+ CUDA_POST_KERNEL_CHECK;
}
return Dtype(0);
}
index 7456dfb54ef5e62d79337d22c3974e55e06ede8c..e2b28e6e6c6e94f9c0a1cec778905bbbb166b827 100644 (file)
}
}
+TYPED_TEST(Im2colLayerTest, TestGPU) {
+ LayerParameter layer_param;
+ layer_param.set_kernelsize(3);
+ layer_param.set_stride(2);
+ Im2colLayer<TypeParam> layer(layer_param);
+ Caffeine::set_mode(Caffeine::GPU);
+ layer.SetUp(this->blob_bottom_vec_, &(this->blob_top_vec_));
+ layer.Forward(this->blob_bottom_vec_, &(this->blob_top_vec_));
+ // We are lazy and will only check the top left block
+ for (int c = 0; c < 27; ++c) {
+ EXPECT_EQ(this->blob_bottom_->data_at(0, (c / 9), (c / 3) % 3, c % 3),
+ this->blob_top_->data_at(0, c, 0, 0));
+ }
+}
+
TYPED_TEST(Im2colLayerTest, TestCPUGradient) {
LayerParameter layer_param;
layer_param.set_kernelsize(3);
diff --git a/src/caffeine/util/im2col.cu b/src/caffeine/util/im2col.cu
--- /dev/null
@@ -0,0 +1,87 @@
+#include <cmath>
+#include <cstdlib>
+#include <cstring>
+
+#include "caffeine/common.hpp"
+#include "caffeine/util/im2col.hpp"
+
+namespace caffeine {
+
+template <typename Dtype>
+__global__ void im2col_gpu_kernel(const int n, const Dtype* data_im,
+ const int height, const int width, const int ksize,
+ const int stride, const int height_col, const int width_col, Dtype* data_col) {
+ int index = threadIdx.x + blockIdx.x * blockDim.x;
+ if (index < n) {
+ int w_out = index % width_col;
+ index /= width_col;
+ int h_out = index % height_col;
+ int channel_in = index / height_col;
+ int channel_out = channel_in * ksize * ksize;
+ int h_in = h_out * stride;
+ int w_in = w_out * stride;
+ data_col += (channel_out * height_col + h_out) * width_col + w_out;
+ data_im += (channel_in * height + h_in) * width + w_in;
+ for (int i = 0; i < ksize; ++i) {
+ for (int j = 0; j < ksize; ++j) {
+ *data_col = data_im[i * width + j];
+ data_col += height_col * width_col;
+ }
+ }
+ }
+}
+
+template <typename Dtype>
+void im2col_gpu(const Dtype* data_im, const int channels,
+ const int height, const int width, const int ksize, const int stride,
+ Dtype* data_col) {
+ // We are going to launch channels * height_col * width_col kernels, each
+ // kernel responsible for copying a single-channel grid.
+ int height_col = (height - ksize) / stride + 1;
+ int width_col = (width - ksize) / stride + 1;
+ int num_kernels = channels * height_col * width_col;
+ im2col_gpu_kernel<<<CAFFEINE_GET_BLOCKS(num_kernels), CAFFEINE_CUDA_NUM_THREADS>>>(
+ num_kernels, data_im, height, width, ksize, stride, height_col, width_col,
+ data_col);
+ CUDA_POST_KERNEL_CHECK;
+}
+
+// Explicit instantiation
+template void im2col_gpu<float>(const float* data_im, const int channels,
+ const int height, const int width, const int ksize, const int stride,
+ float* data_col);
+template void im2col_gpu<double>(const double* data_im, const int channels,
+ const int height, const int width, const int ksize, const int stride,
+ double* data_col);
+
+/*
+template <typename Dtype>
+void col2im_gpu(const Dtype* data_col, const int channels,
+ const int height, const int width, const int ksize, const int stride,
+ Dtype* data_im) {
+ memset(data_im, 0, sizeof(Dtype) * height * width * channels);
+ int height_col = (height - ksize) / stride + 1;
+ int width_col = (width - ksize) / stride + 1;
+ int channels_col = channels * ksize * ksize;
+ for (int c = 0; c < channels_col; ++c) {
+ int w_offset = c % ksize;
+ int h_offset = (c / ksize) % ksize;
+ int c_im = c / ksize / ksize;
+ for (int h = 0; h < height_col; ++h) {
+ for (int w = 0; w < width_col; ++w) {
+ data_im[(c_im * height + h * stride + h_offset) * width + w * stride
+ + w_offset] += data_col[(c * height_col + h) * width_col + w];
+ }
+ }
+ }
+}
+
+// Explicit instantiation
+template void col2im_gpu<float>(const float* data_col, const int channels,
+ const int height, const int width, const int psize, const int stride,
+ float* data_im);
+template void col2im_gpu<double>(const double* data_col, const int channels,
+ const int height, const int width, const int psize, const int stride,
+ double* data_im);
+*/
+} // namespace caffeine
index 76f401d1727ec024764d4f88d1d8cca4e13769b9..f6349909f7ac15999d52cea4b25eb03412ec6a4e 100644 (file)
const int height, const int width, const int psize, const int stride,
Dtype* data_im);
+template <typename Dtype>
+void im2col_gpu(const Dtype* data_im, const int channels,
+ const int height, const int width, const int ksize, const int stride,
+ Dtype* data_col);
+template <typename Dtype>
+void col2im_gpu(const Dtype* data_col, const int channels,
+ const int height, const int width, const int psize, const int stride,
+ Dtype* data_im);
} // namespace caffeine
index d931bc245907c2173e2fa9055afbabec48fce772..2d24bf874748d3c76d1f8b8098ada833e1866e0c 100644 (file)
protected:
virtual void Forward_cpu(const vector<Blob<Dtype>*>& bottom,
vector<Blob<Dtype>*>* top);
- //virtual void Forward_gpu(const vector<Blob<Dtype>*>& bottom,
- // vector<Blob<Dtype>*>* top);
+ virtual void Forward_gpu(const vector<Blob<Dtype>*>& bottom,
+ vector<Blob<Dtype>*>* top);
virtual Dtype Backward_cpu(const vector<Blob<Dtype>*>& top,
const bool propagate_down, vector<Blob<Dtype>*>* bottom);
//virtual Dtype Backward_gpu(const vector<Blob<Dtype>*>& top,