1 #include "caffeine/layer.hpp"
2 #include "caffeine/vision_layers.hpp"
3 #include "caffeine/util/math_functions.hpp"
5 namespace caffeine {
7 template <typename Dtype>
8 __global__ void LRNFillScale(const int nthreads, const Dtype* in,
9 const int num, const int channels, const int height,
10 const int width, const int size, const Dtype alpha_over_size,
11 Dtype* scale) {
12 int index = threadIdx.x + blockIdx.x * blockDim.x;
13 if (index < nthreads) {
14 // find out the local offset
15 int w = index % width;
16 int h = (index / width) % height;
17 int n = index / width / height;
18 int offset = (n * channels * height + h) * width + w;
19 int step = height * width;
20 in += offset;
21 scale += offset;
22 int head = 0;
23 int pre_pad = (size - 1) / 2;
24 int post_pad = size - pre_pad - 1;
25 Dtype accum_scale = 0;
26 // fill the scale at [n, :, h, w]
27 // accumulate values
28 while (head < post_pad) {
29 accum_scale += in[head * step] * in[head * step];
30 ++head;
31 }
32 // until we reach size, nothing needs to be subtracted
33 while (head < size) {
34 accum_scale += in[head * step] * in[head * step];
35 scale[(head - post_pad) * step] = 1. + accum_scale * alpha_over_size;
36 ++head;
37 }
38 // both add and subtract
39 while (head < channels) {
40 accum_scale += in[head * step] * in[head * step];
41 accum_scale -= in[(head - size) * step] * in[(head - size) * step];
42 scale[(head - post_pad) * step] = 1. + accum_scale * alpha_over_size;
43 ++head;
44 }
45 // subtract only
46 while (head < channels + post_pad) {
47 accum_scale -= in[(head - size) * step] * in[(head - size) * step];
48 scale[(head - post_pad) * step] = 1. + accum_scale * alpha_over_size;
49 ++head;
50 }
51 }
52 }
55 // TODO: check if it would be faster to just put it into the previous kernel.
56 template <typename Dtype>
57 __global__ void LRNComputeOutput(const int nthreads, const Dtype* in,
58 const Dtype* scale, const Dtype negative_beta, Dtype* out) {
59 int index = threadIdx.x + blockIdx.x * blockDim.x;
60 if (index < nthreads) {
61 out[index] = in[index] * pow(scale[index], negative_beta);
62 }
63 }
65 template <typename Dtype>
66 void LRNLayer<Dtype>::Forward_gpu(const vector<Blob<Dtype>*>& bottom,
67 vector<Blob<Dtype>*>* top) {
68 // First, compute scale
69 const Dtype* bottom_data = bottom[0]->gpu_data();
70 Dtype* top_data = (*top)[0]->mutable_gpu_data();
71 Dtype* scale_data = scale_.mutable_gpu_data();
72 // We will launch one kernel for each pixel location, and have the kernel
73 // go through all the channels.
74 int n_threads = num_ * height_ * width_;
75 LRNFillScale<<<CAFFEINE_GET_BLOCKS(n_threads), CAFFEINE_CUDA_NUM_THREADS>>>(
76 n_threads, bottom_data, num_, channels_, height_, width_, size_,
77 alpha_ / size_, scale_data);
78 CUDA_POST_KERNEL_CHECK;
79 n_threads = bottom[0]->count();
80 LRNComputeOutput<<<CAFFEINE_GET_BLOCKS(n_threads), CAFFEINE_CUDA_NUM_THREADS>>>(
81 n_threads, bottom_data, scale_data, -beta_, top_data);
82 CUDA_POST_KERNEL_CHECK;
83 }
86 template <typename Dtype>
87 __global__ void LRNComputeDiff(const int nthreads, const Dtype* bottom_data,
88 const Dtype* top_data, const Dtype* scale, const Dtype* top_diff,
89 const int num, const int channels, const int height,
90 const int width, const int size, const Dtype negative_beta,
91 const Dtype cache_ratio,
92 Dtype* bottom_diff) {
93 int index = threadIdx.x + blockIdx.x * blockDim.x;
94 if (index < nthreads) {
95 // find out the local offset
96 int w = index % width;
97 int h = (index / width) % height;
98 int n = index / width / height;
99 int offset = (n * channels * height + h) * width + w;
100 int step = height * width;
101 bottom_data += offset;
102 top_data += offset;
103 scale += offset;
104 top_diff += offset;
105 bottom_diff += offset;
106 int head = 0;
107 int pre_pad = size - (size + 1) / 2;
108 int post_pad = size - pre_pad - 1;
109 Dtype accum_ratio = 0;
110 // accumulate values
111 while (head < post_pad) {
112 accum_ratio += top_diff[head * step] * top_data[head * step] /
113 scale[head * step];
114 ++head;
115 }
116 // until we reach size, nothing needs to be subtracted
117 while (head < size) {
118 accum_ratio += top_diff[head * step] * top_data[head * step] /
119 scale[head * step];
120 bottom_diff[(head - post_pad) * step] = top_diff[(head - post_pad) * step]
121 * pow(scale[(head - post_pad) * step], negative_beta) - cache_ratio *
122 bottom_data[(head - post_pad) * step] * accum_ratio;
123 ++head;
124 }
125 // both add and subtract
126 while (head < channels) {
127 accum_ratio += top_diff[head * step] * top_data[head * step] /
128 scale[head * step];
129 accum_ratio -= top_diff[(head - size) * step] *
130 top_data[(head - size) * step] / scale[(head - size) * step];
131 bottom_diff[(head - post_pad) * step] = top_diff[(head - post_pad) * step]
132 * pow(scale[(head - post_pad) * step], negative_beta) - cache_ratio *
133 bottom_data[(head - post_pad) * step] * accum_ratio;
134 ++head;
135 }
136 // subtract only
137 while (head < channels + post_pad) {
138 accum_ratio -= top_diff[(head - size) * step] *
139 top_data[(head - size) * step] / scale[(head - size) * step];
140 bottom_diff[(head - post_pad) * step] = top_diff[(head - post_pad) * step]
141 * pow(scale[(head - post_pad) * step], negative_beta) - cache_ratio *
142 bottom_data[(head - post_pad) * step] * accum_ratio;
143 ++head;
144 }
145 }
146 }
148 template <typename Dtype>
149 Dtype LRNLayer<Dtype>::Backward_gpu(const vector<Blob<Dtype>*>& top,
150 const bool propagate_down, vector<Blob<Dtype>*>* bottom) {
151 int n_threads = num_ * height_ * width_;
152 LRNComputeDiff<<<CAFFEINE_GET_BLOCKS(n_threads), CAFFEINE_CUDA_NUM_THREADS>>>(
153 n_threads, (*bottom)[0]->gpu_data(), top[0]->gpu_data(),
154 scale_.gpu_data(), top[0]->gpu_diff(), num_, channels_, height_, width_,
155 size_, -beta_, Dtype(2. * alpha_ * beta_ / size_),
156 (*bottom)[0]->mutable_gpu_diff());
157 return Dtype(0.);
158 }
161 INSTANTIATE_CLASS(LRNLayer);
163 } // namespace caffeine