1 // Copyright 2013 Yangqing Jia
3 #include "caffe/layer.hpp"
4 #include "caffe/vision_layers.hpp"
5 #include <algorithm>
7 using std::max;
9 namespace caffe {
11 template <typename Dtype>
12 void ReLULayer<Dtype>::Forward_cpu(const vector<Blob<Dtype>*>& bottom,
13 vector<Blob<Dtype>*>* top) {
14 const Dtype* bottom_data = bottom[0]->cpu_data();
15 Dtype* top_data = (*top)[0]->mutable_cpu_data();
16 const int count = bottom[0]->count();
17 for (int i = 0; i < count; ++i) {
18 top_data[i] = max(bottom_data[i], Dtype(0));
19 }
20 }
22 template <typename Dtype>
23 Dtype ReLULayer<Dtype>::Backward_cpu(const vector<Blob<Dtype>*>& top,
24 const bool propagate_down,
25 vector<Blob<Dtype>*>* bottom) {
26 if (propagate_down) {
27 const Dtype* bottom_data = (*bottom)[0]->cpu_data();
28 const Dtype* top_diff = top[0]->cpu_diff();
29 Dtype* bottom_diff = (*bottom)[0]->mutable_cpu_diff();
30 const int count = (*bottom)[0]->count();
31 for (int i = 0; i < count; ++i) {
32 bottom_diff[i] = top_diff[i] * (bottom_data[i] >= 0);
33 }
34 }
35 return Dtype(0);
36 }
38 template <typename Dtype>
39 __global__ void ReLUForward(const int n, const Dtype* in, Dtype* out) {
40 int index = threadIdx.x + blockIdx.x * blockDim.x;
41 if (index < n) {
42 out[index] = max(in[index], Dtype(0.));
43 }
44 }
46 template <typename Dtype>
47 void ReLULayer<Dtype>::Forward_gpu(const vector<Blob<Dtype>*>& bottom,
48 vector<Blob<Dtype>*>* top) {
49 const Dtype* bottom_data = bottom[0]->gpu_data();
50 Dtype* top_data = (*top)[0]->mutable_gpu_data();
51 const int count = bottom[0]->count();
52 ReLUForward<Dtype><<<CAFFE_GET_BLOCKS(count), CAFFE_CUDA_NUM_THREADS>>>(
53 count, bottom_data, top_data);
54 }
56 template <typename Dtype>
57 __global__ void ReLUBackward(const int n, const Dtype* in_diff,
58 const Dtype* in_data, Dtype* out_diff) {
59 int index = threadIdx.x + blockIdx.x * blockDim.x;
60 if (index < n) {
61 out_diff[index] = in_diff[index] * (in_data[index] >= 0);
62 }
63 }
65 template <typename Dtype>
66 Dtype ReLULayer<Dtype>::Backward_gpu(const vector<Blob<Dtype>*>& top,
67 const bool propagate_down,
68 vector<Blob<Dtype>*>* bottom) {
69 if (propagate_down) {
70 const Dtype* bottom_data = (*bottom)[0]->gpu_data();
71 const Dtype* top_diff = top[0]->gpu_diff();
72 Dtype* bottom_diff = (*bottom)[0]->mutable_gpu_diff();
73 const int count = (*bottom)[0]->count();
74 ReLUBackward<Dtype><<<CAFFE_GET_BLOCKS(count), CAFFE_CUDA_NUM_THREADS>>>(
75 count, top_diff, bottom_data, bottom_diff);
76 }
77 return Dtype(0);
78 }
80 INSTANTIATE_CLASS(ReLULayer);
83 } // namespace caffe