]> Gitweb @ Texas Instruments - Open Source Git Repositories - git.TI.com/gitweb - jacinto-ai/caffe-jacinto.git/blob - src/caffe/layers/lrn_layer.cu
solver restructuring: now all prototxt are specified in the solver protocol buffer
[jacinto-ai/caffe-jacinto.git] / src / caffe / layers / lrn_layer.cu
1 // Copyright 2013 Yangqing Jia
3 #include "caffe/layer.hpp"
4 #include "caffe/vision_layers.hpp"
5 #include "caffe/util/math_functions.hpp"
7 namespace caffe {
9 template <typename Dtype>
10 __global__ void LRNFillScale(const int nthreads, const Dtype* in,
11     const int num, const int channels, const int height,
12     const int width, const int size, const Dtype alpha_over_size,
13     Dtype* scale) {
14   int index = threadIdx.x + blockIdx.x * blockDim.x;
15   if (index < nthreads) {
16     // find out the local offset
17     int w = index % width;
18     int h = (index / width) % height;
19     int n = index / width / height;
20     int offset = (n * channels * height + h) * width + w;
21     int step = height * width;
22     in += offset;
23     scale += offset;
24     int head = 0;
25     int pre_pad = (size - 1) / 2;
26     int post_pad = size - pre_pad - 1;
27     Dtype accum_scale = 0;
28     // fill the scale at [n, :, h, w]
29     // accumulate values
30     while (head < post_pad) {
31       accum_scale += in[head * step] * in[head * step];
32       ++head;
33     }
34     // until we reach size, nothing needs to be subtracted
35     while (head < size) {
36       accum_scale += in[head * step] * in[head * step];
37       scale[(head - post_pad) * step] = 1. + accum_scale * alpha_over_size;
38       ++head;
39     }
40     // both add and subtract
41     while (head < channels) {
42       accum_scale += in[head * step] * in[head * step];
43       accum_scale -= in[(head - size) * step] * in[(head - size) * step];
44       scale[(head - post_pad) * step] = 1. + accum_scale * alpha_over_size;
45       ++head;
46     }
47     // subtract only
48     while (head < channels + post_pad) {
49       accum_scale -= in[(head - size) * step] * in[(head - size) * step];
50       scale[(head - post_pad) * step] = 1. + accum_scale * alpha_over_size;
51       ++head;
52     }
53   }
54 }
57 // TODO: check if it would be faster to just put it into the previous kernel.
58 template <typename Dtype>
59 __global__ void LRNComputeOutput(const int nthreads, const Dtype* in,
60     const Dtype* scale, const Dtype negative_beta, Dtype* out) {
61   int index = threadIdx.x + blockIdx.x * blockDim.x;
62   if (index < nthreads) {
63     out[index] = in[index] * pow(scale[index], negative_beta);
64   }
65 }
67 template <typename Dtype>
68 void LRNLayer<Dtype>::Forward_gpu(const vector<Blob<Dtype>*>& bottom,
69     vector<Blob<Dtype>*>* top) {
70   // First, compute scale
71   const Dtype* bottom_data = bottom[0]->gpu_data();
72   Dtype* top_data = (*top)[0]->mutable_gpu_data();
73   Dtype* scale_data = scale_.mutable_gpu_data();
74   // We will launch one kernel for each pixel location, and have the kernel
75   // go through all the channels.
76   int n_threads = num_ * height_ * width_;
77   LRNFillScale<<<CAFFE_GET_BLOCKS(n_threads), CAFFE_CUDA_NUM_THREADS>>>(
78       n_threads, bottom_data, num_, channels_, height_, width_, size_,
79       alpha_ / size_, scale_data);
80   CUDA_POST_KERNEL_CHECK;
81   n_threads = bottom[0]->count();
82   LRNComputeOutput<<<CAFFE_GET_BLOCKS(n_threads), CAFFE_CUDA_NUM_THREADS>>>(
83       n_threads, bottom_data, scale_data, -beta_, top_data);
84   CUDA_POST_KERNEL_CHECK;
85 }
88 template <typename Dtype>
89 __global__ void LRNComputeDiff(const int nthreads, const Dtype* bottom_data,
90     const Dtype* top_data, const Dtype* scale, const Dtype* top_diff,
91     const int num, const int channels, const int height,
92     const int width, const int size, const Dtype negative_beta,
93     const Dtype cache_ratio,
94     Dtype* bottom_diff) {
95   int index = threadIdx.x + blockIdx.x * blockDim.x;
96   if (index < nthreads) {
97     // find out the local offset
98     int w = index % width;
99     int h = (index / width) % height;
100     int n = index / width / height;
101     int offset = (n * channels * height + h) * width + w;
102     int step = height * width;
103     bottom_data += offset;
104     top_data += offset;
105     scale += offset;
106     top_diff += offset;
107     bottom_diff += offset;
108     int head = 0;
109     int pre_pad = size - (size + 1) / 2;
110     int post_pad = size - pre_pad - 1;
111     Dtype accum_ratio = 0;
112     // accumulate values
113     while (head < post_pad) {
114       accum_ratio += top_diff[head * step] * top_data[head * step] /
115           scale[head * step];
116       ++head;
117     }
118     // until we reach size, nothing needs to be subtracted
119     while (head < size) {
120       accum_ratio += top_diff[head * step] * top_data[head * step] /
121           scale[head * step];
122       bottom_diff[(head - post_pad) * step] = top_diff[(head - post_pad) * step]
123           * pow(scale[(head - post_pad) * step], negative_beta) - cache_ratio *
124           bottom_data[(head - post_pad) * step] * accum_ratio;
125       ++head;
126     }
127     // both add and subtract
128     while (head < channels) {
129       accum_ratio += top_diff[head * step] * top_data[head * step] /
130           scale[head * step];
131       accum_ratio -= top_diff[(head - size) * step] *
132           top_data[(head - size) * step] / scale[(head - size) * step];
133       bottom_diff[(head - post_pad) * step] = top_diff[(head - post_pad) * step]
134           * pow(scale[(head - post_pad) * step], negative_beta) - cache_ratio *
135           bottom_data[(head - post_pad) * step] * accum_ratio;
136       ++head;
137     }
138     // subtract only
139     while (head < channels + post_pad) {
140       accum_ratio -= top_diff[(head - size) * step] *
141           top_data[(head - size) * step] / scale[(head - size) * step];
142       bottom_diff[(head - post_pad) * step] = top_diff[(head - post_pad) * step]
143           * pow(scale[(head - post_pad) * step], negative_beta) - cache_ratio *
144           bottom_data[(head - post_pad) * step] * accum_ratio;
145       ++head;
146     }
147   }
150 template <typename Dtype>
151 Dtype LRNLayer<Dtype>::Backward_gpu(const vector<Blob<Dtype>*>& top,
152     const bool propagate_down, vector<Blob<Dtype>*>* bottom) {
153   int n_threads = num_ * height_ * width_;
154   LRNComputeDiff<<<CAFFE_GET_BLOCKS(n_threads), CAFFE_CUDA_NUM_THREADS>>>(
155       n_threads, (*bottom)[0]->gpu_data(), top[0]->gpu_data(),
156       scale_.gpu_data(), top[0]->gpu_diff(), num_, channels_, height_, width_,
157       size_, -beta_, Dtype(2. * alpha_ * beta_ / size_),
158       (*bottom)[0]->mutable_gpu_diff());
159   return Dtype(0.);
163 INSTANTIATE_CLASS(LRNLayer);
165 }  // namespace caffe