summary | shortlog | log | commit | commitdiff | tree
raw | patch | inline | side by side (parent: 6329b44)
raw | patch | inline | side by side (parent: 6329b44)
author | Yangqing Jia <jiayq84@gmail.com> | |
Fri, 20 Sep 2013 20:45:10 +0000 (13:45 -0700) | ||
committer | Yangqing Jia <jiayq84@gmail.com> | |
Fri, 20 Sep 2013 20:45:10 +0000 (13:45 -0700) |
diff --git a/src/caffeine/layer.cpp b/src/caffeine/layer.cpp
index 9d61fe9cd9daf59e27af17378bd158cde00286c2..8e3ac8cdee9060fc750479a73d31995873c261cd 100644 (file)
--- a/src/caffeine/layer.cpp
+++ b/src/caffeine/layer.cpp
namespace caffeine {
-// Forward, backward and predict wrappers. You should implement the cpu and
-// gpu specific implementations instead, and should not change these
-// functions.
-template <typename Dtype>
-inline void Layer<Dtype>::Forward(const vector<Blob<Dtype>*>& bottom,
- vector<Blob<Dtype>*>* top) {
- switch(Caffeine::mode()) {
- case Caffeine::CPU:
- Forward_cpu(bottom, top);
- break;
- case Caffeine::GPU:
- Forward_gpu(bottom, top);
- break;
- default:
- LOG(FATAL) << "Unknown caffeine mode.";
- }
-};
-
-template <typename Dtype>
-inline Dtype Layer<Dtype>::Backward(const vector<Blob<Dtype>*>& top,
- const bool propagate_down,
- vector<Blob<Dtype>*>* bottom) {
- switch(Caffeine::mode()) {
- case Caffeine::CPU:
- return Backward_cpu(top, propagate_down, bottom);
- break;
- case Caffeine::GPU:
- return Backward_gpu(top, propagate_down, bottom);
- break;
- default:
- LOG(FATAL) << "Unknown caffeine mode.";
- }
-};
-
INSTANTIATE_CLASS(Layer);
} // namespace caffeine
diff --git a/src/caffeine/layer.hpp b/src/caffeine/layer.hpp
index 4a9247d3cd34c37769eb8711db7355addcf3578c..f81535c671df11ef6e4059e5eb2f2c02e6c5b126 100644 (file)
--- a/src/caffeine/layer.hpp
+++ b/src/caffeine/layer.hpp
};
}; // class Layer
+// Forward and backward wrappers. You should implement the cpu and
+// gpu specific implementations instead, and should not change these
+// functions.
+template <typename Dtype>
+inline void Layer<Dtype>::Forward(const vector<Blob<Dtype>*>& bottom,
+ vector<Blob<Dtype>*>* top) {
+ switch(Caffeine::mode()) {
+ case Caffeine::CPU:
+ Forward_cpu(bottom, top);
+ break;
+ case Caffeine::GPU:
+ Forward_gpu(bottom, top);
+ break;
+ default:
+ LOG(FATAL) << "Unknown caffeine mode.";
+ }
+};
+
+template <typename Dtype>
+inline Dtype Layer<Dtype>::Backward(const vector<Blob<Dtype>*>& top,
+ const bool propagate_down,
+ vector<Blob<Dtype>*>* bottom) {
+ switch(Caffeine::mode()) {
+ case Caffeine::CPU:
+ return Backward_cpu(top, propagate_down, bottom);
+ case Caffeine::GPU:
+ return Backward_gpu(top, propagate_down, bottom);
+ default:
+ LOG(FATAL) << "Unknown caffeine mode.";
+ }
+};
+
} // namespace caffeine
#endif // CAFFEINE_LAYER_H_
index 77108af409decd6cd4ecfaef15bb36edcd7a8699..9fc62966ecd675b08f14de154e8fa589fd10d1c9 100644 (file)
template <typename Dtype>
Dtype Im2colLayer<Dtype>::Backward_gpu(const vector<Blob<Dtype>*>& top,
const bool propagate_down, vector<Blob<Dtype>*>* bottom) {
- LOG(ERROR) << "Warning: still CPU version";
- return Backward_cpu(top, propagate_down, bottom);
- /*
const Dtype* top_diff = top[0]->gpu_diff();
Dtype* bottom_diff = (*bottom)[0]->mutable_gpu_diff();
for (int n = 0; n < top[0]->num(); ++n) {
WIDTH_, KSIZE_, STRIDE_, bottom_diff + (*bottom)[0]->offset(n));
}
return Dtype(0.);
- */
}
INSTANTIATE_CLASS(Im2colLayer);
index 4d835a235f2c3f61785084a0c58195311dd910d2..0dd5572dd15fe3ab2cc085ebf8b58b73712e178d 100644 (file)
} while (assumed != old);
return __longlong_as_double(old);
}
+*/
template <typename Dtype>
__global__ void col2im_gpu_kernel(const int n, const Dtype* data_col,
- const int height, const int width, const int ksize,
+ const int height, const int width, const int channels, const int ksize,
const int stride, const int height_col, const int width_col, Dtype* data_im) {
int index = threadIdx.x + blockIdx.x * blockDim.x;
if (index < n) {
- int w_out = index % width_col;
- index /= width_col;
- int h_out = index % height_col;
- int channel_out = index / height_col;
- int w_in = w_out * stride + channel_out % ksize;
- int h_in = h_out * stride + (channel_out / ksize) % ksize;
- int channel_in = channel_out / ksize / ksize;
- MyAtomicAdd(data_im + (channel_in * height + h_in) * width + w_in,
- data_col[(channel_out* height_col + h_out) * width_col + w_out]);
+ int w = index % width;
+ int h = (index / width) % height;
+ int c = index / (width * height);
+ // compute the start and end of the output
+ int w_col_start = (w < ksize) ? 0 : (w - ksize) / stride + 1;
+ int w_col_end = min(w / stride + 1, width_col);
+ int h_col_start = (h < ksize) ? 0 : (h - ksize) / stride + 1;
+ int h_col_end = min(h / stride + 1, height_col);
+ for (int h_col = h_col_start; h_col < h_col_end; ++h_col) {
+ for (int w_col = w_col_start; w_col < w_col_end; ++w_col) {
+ // the col location: [c * width * height + h_out, w_out]
+ int c_col = c * ksize * ksize + (h - h_col * stride) * ksize + (w - w_col * stride);
+ data_im[index] += data_col[(c_col * height_col + h_col) * width_col + w_col];
+ }
+ }
}
}
CUDA_CHECK(cudaMemset(data_im, 0, sizeof(Dtype) * height * width * channels));
int height_col = (height - ksize) / stride + 1;
int width_col = (width - ksize) / stride + 1;
- int channels_col = channels * ksize * ksize;
- int num_kernels = channels_col * height_col * width_col;
+ int num_kernels = channels * height * width;
+ // To avoid involving atomic operations, we will launch one kernel per
+ // bottom dimension, and then in the kernel add up the top dimensions.
col2im_gpu_kernel<Dtype><<<CAFFEINE_GET_BLOCKS(num_kernels), CAFFEINE_CUDA_NUM_THREADS>>>(
- num_kernels, data_col, height, width, ksize, stride, height_col, width_col,
- data_im);
+ num_kernels, data_col, height, width, channels, ksize, stride,
+ height_col, width_col, data_im);
CUDA_POST_KERNEL_CHECK;
}
template void col2im_gpu<double>(const double* data_col, const int channels,
const int height, const int width, const int psize, const int stride,
double* data_im);
-*/
+
} // namespace caffeine
index 23e571e29a139cb531b8fc2c9d7258bcd479949d..f6349909f7ac15999d52cea4b25eb03412ec6a4e 100644 (file)
const int height, const int width, const int ksize, const int stride,
Dtype* data_col);
-/*
template <typename Dtype>
void col2im_gpu(const Dtype* data_col, const int channels,
const int height, const int width, const int psize, const int stride,
Dtype* data_im);
-*/
+
} // namespace caffeine
#endif // CAFFEINE_UTIL_IM2COL_HPP_