1 #include <cmath>
2 #include <cstdlib>
3 #include <cstring>
5 #include <device_functions.h>
7 #include "caffeine/common.hpp"
8 #include "caffeine/util/im2col.hpp"
10 namespace caffeine {
12 template <typename Dtype>
13 __global__ void im2col_gpu_kernel(const int n, const Dtype* data_im,
14 const int height, const int width, const int ksize,
15 const int stride, const int height_col, const int width_col, Dtype* data_col) {
16 int index = threadIdx.x + blockIdx.x * blockDim.x;
17 if (index < n) {
18 int w_out = index % width_col;
19 index /= width_col;
20 int h_out = index % height_col;
21 int channel_in = index / height_col;
22 int channel_out = channel_in * ksize * ksize;
23 int h_in = h_out * stride;
24 int w_in = w_out * stride;
25 data_col += (channel_out * height_col + h_out) * width_col + w_out;
26 data_im += (channel_in * height + h_in) * width + w_in;
27 for (int i = 0; i < ksize; ++i) {
28 for (int j = 0; j < ksize; ++j) {
29 *data_col = data_im[i * width + j];
30 data_col += height_col * width_col;
31 }
32 }
33 }
34 }
36 template <typename Dtype>
37 void im2col_gpu(const Dtype* data_im, const int channels,
38 const int height, const int width, const int ksize, const int stride,
39 Dtype* data_col) {
40 // We are going to launch channels * height_col * width_col kernels, each
41 // kernel responsible for copying a single-channel grid.
42 int height_col = (height - ksize) / stride + 1;
43 int width_col = (width - ksize) / stride + 1;
44 int num_kernels = channels * height_col * width_col;
45 im2col_gpu_kernel<Dtype><<<CAFFEINE_GET_BLOCKS(num_kernels), CAFFEINE_CUDA_NUM_THREADS>>>(
46 num_kernels, data_im, height, width, ksize, stride, height_col, width_col,
47 data_col);
48 CUDA_POST_KERNEL_CHECK;
49 }
51 // Explicit instantiation
52 template void im2col_gpu<float>(const float* data_im, const int channels,
53 const int height, const int width, const int ksize, const int stride,
54 float* data_col);
55 template void im2col_gpu<double>(const double* data_im, const int channels,
56 const int height, const int width, const int ksize, const int stride,
57 double* data_col);
59 /*
60 // A bunch of stuff dealing with double atomic add
61 template <typename Dtype>
62 __device__ inline Dtype MyAtomicAdd(Dtype* address, Dtype val);
64 template <>
65 __device__ float MyAtomicAdd<float>(float* address, float val) {
66 return atomicAdd(address, val);
67 }
68 template <>
69 __device__ double MyAtomicAdd<double>(double* address, double val)
70 {
71 unsigned long long int* address_as_ull = (unsigned long long int*)address;
72 unsigned long long int old = *address_as_ull, assumed;
73 do {
74 assumed = old;
75 old = atomicCAS(address_as_ull, assumed,__double_as_longlong(val +
76 __longlong_as_double(assumed)));
77 } while (assumed != old);
78 return __longlong_as_double(old);
79 }
80 */
82 template <typename Dtype>
83 __global__ void col2im_gpu_kernel(const int n, const Dtype* data_col,
84 const int height, const int width, const int channels, const int ksize,
85 const int stride, const int height_col, const int width_col, Dtype* data_im) {
86 int index = threadIdx.x + blockIdx.x * blockDim.x;
87 if (index < n) {
88 int w = index % width;
89 int h = (index / width) % height;
90 int c = index / (width * height);
91 // compute the start and end of the output
92 int w_col_start = (w < ksize) ? 0 : (w - ksize) / stride + 1;
93 int w_col_end = min(w / stride + 1, width_col);
94 int h_col_start = (h < ksize) ? 0 : (h - ksize) / stride + 1;
95 int h_col_end = min(h / stride + 1, height_col);
96 for (int h_col = h_col_start; h_col < h_col_end; ++h_col) {
97 for (int w_col = w_col_start; w_col < w_col_end; ++w_col) {
98 // the col location: [c * width * height + h_out, w_out]
99 int c_col = c * ksize * ksize + (h - h_col * stride) * ksize + (w - w_col * stride);
100 data_im[index] += data_col[(c_col * height_col + h_col) * width_col + w_col];
101 }
102 }
103 }
104 }
106 template <typename Dtype>
107 void col2im_gpu(const Dtype* data_col, const int channels,
108 const int height, const int width, const int ksize, const int stride,
109 Dtype* data_im) {
110 CUDA_CHECK(cudaMemset(data_im, 0, sizeof(Dtype) * height * width * channels));
111 int height_col = (height - ksize) / stride + 1;
112 int width_col = (width - ksize) / stride + 1;
113 int num_kernels = channels * height * width;
114 // To avoid involving atomic operations, we will launch one kernel per
115 // bottom dimension, and then in the kernel add up the top dimensions.
116 col2im_gpu_kernel<Dtype><<<CAFFEINE_GET_BLOCKS(num_kernels), CAFFEINE_CUDA_NUM_THREADS>>>(
117 num_kernels, data_col, height, width, channels, ksize, stride,
118 height_col, width_col, data_im);
119 CUDA_POST_KERNEL_CHECK;
120 }
123 // Explicit instantiation
124 template void col2im_gpu<float>(const float* data_col, const int channels,
125 const int height, const int width, const int psize, const int stride,
126 float* data_im);
127 template void col2im_gpu<double>(const double* data_col, const int channels,
128 const int height, const int width, const int psize, const int stride,
129 double* data_im);
132 } // namespace caffeine