]> Gitweb @ Texas Instruments - Open Source Git Repositories - git.TI.com/gitweb - jacinto-ai/caffe-jacinto.git/commitdiff
lrn layer gpu forward
authorYangqing Jia <jiayq84@gmail.com>
Sat, 21 Sep 2013 04:28:39 +0000 (21:28 -0700)
committerYangqing Jia <jiayq84@gmail.com>
Sat, 21 Sep 2013 04:28:39 +0000 (21:28 -0700)
src/caffeine/layers/lrn_layer.cpp [new file with mode: 0644]
src/caffeine/layers/lrn_layer.cu
src/caffeine/test/test_lrn_layer.cpp
src/caffeine/vision_layers.hpp

diff --git a/src/caffeine/layers/lrn_layer.cpp b/src/caffeine/layers/lrn_layer.cpp
new file mode 100644 (file)
index 0000000..2c62136
--- /dev/null
@@ -0,0 +1,131 @@
+#include "caffeine/layer.hpp"
+#include "caffeine/vision_layers.hpp"
+#include "caffeine/util/math_functions.hpp"
+
+namespace caffeine {
+
+template <typename Dtype>
+void LRNLayer<Dtype>::SetUp(const vector<Blob<Dtype>*>& bottom,
+      vector<Blob<Dtype>*>* top) {
+  CHECK_EQ(bottom.size(), 1) <<
+      "Local Response Normalization Layer takes a single blob as input.";
+  CHECK_EQ(top->size(), 1) << 
+      "Local Response Normalization Layer takes a single blob as output.";
+  num_ = bottom[0]->num();
+  channels_ = bottom[0]->channels();
+  height_ = bottom[0]->height();
+  width_ = bottom[0]->width();
+  (*top)[0]->Reshape(num_, channels_, height_, width_);
+  scale_.Reshape(num_, channels_, height_, width_);
+  size_ = this->layer_param_.local_size();
+  pre_pad_ = (size_ - 1) / 2;
+  alpha_ = this->layer_param_.alpha();
+  beta_ = this->layer_param_.beta();
+};
+
+template <typename Dtype>
+void LRNLayer<Dtype>::Forward_cpu(const vector<Blob<Dtype>*>& bottom,
+    vector<Blob<Dtype>*>* top) {
+  const Dtype* bottom_data = bottom[0]->cpu_data();
+  Dtype* top_data = (*top)[0]->mutable_cpu_data();
+  Dtype* scale_data = scale_.mutable_cpu_data();
+  // start with the constant value
+  for (int i = 0; i < scale_.count(); ++i) {
+    scale_data[i] = 1.;
+  }
+  Blob<Dtype> padded_square(1, channels_ + size_ - 1, height_, width_);
+  Dtype* padded_square_data = padded_square.mutable_cpu_data();
+  memset(padded_square_data, 0, sizeof(Dtype) * padded_square.count());
+  Dtype alpha_over_size = alpha_ / size_;
+  // go through the images
+  for (int n = 0; n < num_; ++n) {
+    // compute the padded square
+    caffeine_sqr(channels_ * height_ * width_,
+        bottom_data + bottom[0]->offset(n),
+        padded_square_data + padded_square.offset(0, pre_pad_));
+    // Create the first channel scale
+    for (int c = 0; c < size_; ++c) {
+      caffeine_axpy<Dtype>(height_ * width_, alpha_over_size,
+          padded_square_data + padded_square.offset(0, c),
+          scale_data + scale_.offset(n, 0));
+    }
+    for (int c = 1; c < channels_; ++c) {
+      // copy previous scale
+      caffeine_copy<Dtype>(height_ * width_,
+          scale_data + scale_.offset(n, c - 1),
+          scale_data + scale_.offset(n, c));
+      // add head
+      caffeine_axpy<Dtype>(height_ * width_, alpha_over_size,
+          padded_square_data + padded_square.offset(0, c + size_ - 1),
+          scale_data + scale_.offset(n, c));
+      // subtract tail
+      caffeine_axpy<Dtype>(height_ * width_, -alpha_over_size,
+          padded_square_data + padded_square.offset(0, c - 1),
+          scale_data + scale_.offset(n, c));
+    }
+  }
+
+  // In the end, compute output
+  caffeine_powx<Dtype>(scale_.count(), scale_data, -beta_, top_data);
+  caffeine_mul<Dtype>(scale_.count(), top_data, bottom_data, top_data);
+}
+
+template <typename Dtype>
+Dtype LRNLayer<Dtype>::Backward_cpu(const vector<Blob<Dtype>*>& top,
+    const bool propagate_down, vector<Blob<Dtype>*>* bottom) {
+  const Dtype* top_diff = top[0]->cpu_diff();
+  const Dtype* top_data = top[0]->cpu_data();
+  const Dtype* bottom_data = (*bottom)[0]->cpu_data();
+  const Dtype* scale_data = scale_.cpu_data();
+  Dtype* bottom_diff = (*bottom)[0]->mutable_cpu_diff();
+  Blob<Dtype> padded_ratio(1, channels_ + size_ - 1, height_, width_);
+  Blob<Dtype> accum_ratio(1, 1, height_, width_);
+  Dtype* padded_ratio_data = padded_ratio.mutable_cpu_data();
+  Dtype* accum_ratio_data = accum_ratio.mutable_cpu_data();
+  // We hack a little bit by using the diff() to store an additional result
+  Dtype* accum_ratio_times_bottom = accum_ratio.mutable_cpu_diff();
+  memset(padded_ratio_data, 0, sizeof(Dtype) * padded_ratio.count());
+  Dtype cache_ratio_value = 2. * alpha_ * beta_ / size_;
+
+  caffeine_powx<Dtype>(scale_.count(), scale_data, -beta_, bottom_diff);
+  caffeine_mul<Dtype>(scale_.count(), top_diff, bottom_diff, bottom_diff);
+
+  // go through individual data
+  int inverse_pre_pad = size_ - (size_ + 1) / 2;
+  for (int n = 0; n < num_; ++n) {
+    int block_offset = scale_.offset(n);
+    // first, compute diff_i * y_i / s_i
+    caffeine_mul<Dtype>(channels_ * height_ * width_,
+        top_diff + block_offset, top_data + block_offset,
+        padded_ratio_data + padded_ratio.offset(0, inverse_pre_pad));
+    caffeine_div<Dtype>(channels_ * height_ * width_,
+        padded_ratio_data + padded_ratio.offset(0, inverse_pre_pad),
+        scale_data + block_offset,
+        padded_ratio_data + padded_ratio.offset(0, inverse_pre_pad));
+    // Now, compute the accumulated ratios and the bottom diff
+    memset(accum_ratio_data, 0, sizeof(Dtype) * accum_ratio.count());
+    for (int c = 0; c < size_ - 1; ++c) {
+      caffeine_axpy<Dtype>(height_ * width_, 1.,
+          padded_ratio_data + padded_ratio.offset(0, c), accum_ratio_data);
+    }
+    for (int c = 0; c < channels_; ++c) {
+      caffeine_axpy<Dtype>(height_ * width_, 1.,
+          padded_ratio_data + padded_ratio.offset(0, c + size_ - 1),
+          accum_ratio_data);
+      // compute bottom diff
+      caffeine_mul<Dtype>(height_ * width_,
+          bottom_data + top[0]->offset(n, c),
+          accum_ratio_data, accum_ratio_times_bottom);
+      caffeine_axpy<Dtype>(height_ * width_, -cache_ratio_value,
+          accum_ratio_times_bottom, bottom_diff + top[0]->offset(n,c));
+      caffeine_axpy<Dtype>(height_ * width_, -1.,
+          padded_ratio_data + padded_ratio.offset(0, c), accum_ratio_data);
+    }
+  }
+  return Dtype(0.);
+}
+
+INSTANTIATE_CLASS(LRNLayer);
+
+
+}  // namespace caffeine
index 2c62136a8a7112c1682c8e2d5a946dddd2258d31..5eb7efaa739e1f77d2650a02d97549497649d96a 100644 (file)
 namespace caffeine {
 
 template <typename Dtype>
-void LRNLayer<Dtype>::SetUp(const vector<Blob<Dtype>*>& bottom,
-      vector<Blob<Dtype>*>* top) {
-  CHECK_EQ(bottom.size(), 1) <<
-      "Local Response Normalization Layer takes a single blob as input.";
-  CHECK_EQ(top->size(), 1) << 
-      "Local Response Normalization Layer takes a single blob as output.";
-  num_ = bottom[0]->num();
-  channels_ = bottom[0]->channels();
-  height_ = bottom[0]->height();
-  width_ = bottom[0]->width();
-  (*top)[0]->Reshape(num_, channels_, height_, width_);
-  scale_.Reshape(num_, channels_, height_, width_);
-  size_ = this->layer_param_.local_size();
-  pre_pad_ = (size_ - 1) / 2;
-  alpha_ = this->layer_param_.alpha();
-  beta_ = this->layer_param_.beta();
-};
-
-template <typename Dtype>
-void LRNLayer<Dtype>::Forward_cpu(const vector<Blob<Dtype>*>& bottom,
-    vector<Blob<Dtype>*>* top) {
-  const Dtype* bottom_data = bottom[0]->cpu_data();
-  Dtype* top_data = (*top)[0]->mutable_cpu_data();
-  Dtype* scale_data = scale_.mutable_cpu_data();
-  // start with the constant value
-  for (int i = 0; i < scale_.count(); ++i) {
-    scale_data[i] = 1.;
-  }
-  Blob<Dtype> padded_square(1, channels_ + size_ - 1, height_, width_);
-  Dtype* padded_square_data = padded_square.mutable_cpu_data();
-  memset(padded_square_data, 0, sizeof(Dtype) * padded_square.count());
-  Dtype alpha_over_size = alpha_ / size_;
-  // go through the images
-  for (int n = 0; n < num_; ++n) {
-    // compute the padded square
-    caffeine_sqr(channels_ * height_ * width_,
-        bottom_data + bottom[0]->offset(n),
-        padded_square_data + padded_square.offset(0, pre_pad_));
-    // Create the first channel scale
-    for (int c = 0; c < size_; ++c) {
-      caffeine_axpy<Dtype>(height_ * width_, alpha_over_size,
-          padded_square_data + padded_square.offset(0, c),
-          scale_data + scale_.offset(n, 0));
+__global__ void LRNFillScale(const int nthreads, const Dtype* in,
+    const int num, const int channels, const int height,
+    const int width, const int size, const Dtype alpha_over_size,
+    Dtype* scale) {
+  int index = threadIdx.x + blockIdx.x * blockDim.x;
+  if (index < nthreads) {
+    // find out the local offset
+    int w = index % width;
+    int h = (index / width) % height;
+    int n = index / width / height;
+    int offset = (n * channels * height + h) * width + w;
+    int step = height * width;
+    in += offset;
+    scale += offset;
+    int head = 0;
+    int pre_pad = (size - 1) / 2;
+    int post_pad = size - pre_pad - 1;
+    Dtype accum_scale = 0;
+    // fill the scale at [n, :, h, w]
+    // accumulate values 
+    while (head < post_pad) {
+      accum_scale += in[head * step] * in[head * step];
+      ++head;
+    }
+    // until we reach size, nothing needs to be subtracted
+    while (head < size) {
+      accum_scale += in[head * step] * in[head * step];
+      scale[(head - post_pad) * step] = 1. + accum_scale * alpha_over_size;
+      ++head;
     }
-    for (int c = 1; c < channels_; ++c) {
-      // copy previous scale
-      caffeine_copy<Dtype>(height_ * width_,
-          scale_data + scale_.offset(n, c - 1),
-          scale_data + scale_.offset(n, c));
-      // add head
-      caffeine_axpy<Dtype>(height_ * width_, alpha_over_size,
-          padded_square_data + padded_square.offset(0, c + size_ - 1),
-          scale_data + scale_.offset(n, c));
-      // subtract tail
-      caffeine_axpy<Dtype>(height_ * width_, -alpha_over_size,
-          padded_square_data + padded_square.offset(0, c - 1),
-          scale_data + scale_.offset(n, c));
+    // both add and subtract
+    while (head < channels) {
+      accum_scale += in[head * step] * in[head * step];
+      accum_scale -= in[(head - size) * step] * in[(head - size) * step];
+      scale[(head - post_pad) * step] = 1. + accum_scale * alpha_over_size;
+      ++head;
+    }
+    // subtract only
+    while (head < size + post_pad) {
+      accum_scale -= in[(head - size) * step] * in[(head - size) * step];
+      scale[(head - post_pad) * step] = 1. + accum_scale * alpha_over_size;
+      ++head;
     }
   }
-
-  // In the end, compute output
-  caffeine_powx<Dtype>(scale_.count(), scale_data, -beta_, top_data);
-  caffeine_mul<Dtype>(scale_.count(), top_data, bottom_data, top_data);
 }
 
 template <typename Dtype>
-Dtype LRNLayer<Dtype>::Backward_cpu(const vector<Blob<Dtype>*>& top,
-    const bool propagate_down, vector<Blob<Dtype>*>* bottom) {
-  const Dtype* top_diff = top[0]->cpu_diff();
-  const Dtype* top_data = top[0]->cpu_data();
-  const Dtype* bottom_data = (*bottom)[0]->cpu_data();
-  const Dtype* scale_data = scale_.cpu_data();
-  Dtype* bottom_diff = (*bottom)[0]->mutable_cpu_diff();
-  Blob<Dtype> padded_ratio(1, channels_ + size_ - 1, height_, width_);
-  Blob<Dtype> accum_ratio(1, 1, height_, width_);
-  Dtype* padded_ratio_data = padded_ratio.mutable_cpu_data();
-  Dtype* accum_ratio_data = accum_ratio.mutable_cpu_data();
-  // We hack a little bit by using the diff() to store an additional result
-  Dtype* accum_ratio_times_bottom = accum_ratio.mutable_cpu_diff();
-  memset(padded_ratio_data, 0, sizeof(Dtype) * padded_ratio.count());
-  Dtype cache_ratio_value = 2. * alpha_ * beta_ / size_;
+__global__ void LRNComputeOutput(const int nthreads, const Dtype* in,
+    const Dtype* scale, const Dtype negative_beta, Dtype* out) {
+  int index = threadIdx.x + blockIdx.x * blockDim.x;
+  if (index < nthreads) {
+    out[index] = in[index] * pow(scale[index], negative_beta);
+  }
+}
 
-  caffeine_powx<Dtype>(scale_.count(), scale_data, -beta_, bottom_diff);
-  caffeine_mul<Dtype>(scale_.count(), top_diff, bottom_diff, bottom_diff);
+template <typename Dtype>
+void LRNLayer<Dtype>::Forward_gpu(const vector<Blob<Dtype>*>& bottom,
+    vector<Blob<Dtype>*>* top) {
+  // First, compute scale
+  const Dtype* bottom_data = bottom[0]->gpu_data();
+  Dtype* top_data = (*top)[0]->mutable_gpu_data();
+  Dtype* scale_data = scale_.mutable_gpu_data();
+  // We will launch one kernel for each pixel location, and have the kernel
+  // go through all the channels.
+  int n_threads = num_ * height_ * width_;
+  LRNFillScale<<<CAFFEINE_GET_BLOCKS(n_threads), CAFFEINE_CUDA_NUM_THREADS>>>(
+      n_threads, bottom_data, num_, channels_, height_, width_, size_,
+      alpha_ / size_, scale_data);
+  CUDA_POST_KERNEL_CHECK;
+  n_threads = bottom[0]->count();
+  LRNComputeOutput<<<CAFFEINE_GET_BLOCKS(n_threads), CAFFEINE_CUDA_NUM_THREADS>>>(
+      n_threads, bottom_data, scale_data, -beta_, top_data);
+  CUDA_POST_KERNEL_CHECK;
+}
 
-  // go through individual data
-  int inverse_pre_pad = size_ - (size_ + 1) / 2;
-  for (int n = 0; n < num_; ++n) {
-    int block_offset = scale_.offset(n);
-    // first, compute diff_i * y_i / s_i
-    caffeine_mul<Dtype>(channels_ * height_ * width_,
-        top_diff + block_offset, top_data + block_offset,
-        padded_ratio_data + padded_ratio.offset(0, inverse_pre_pad));
-    caffeine_div<Dtype>(channels_ * height_ * width_,
-        padded_ratio_data + padded_ratio.offset(0, inverse_pre_pad),
-        scale_data + block_offset,
-        padded_ratio_data + padded_ratio.offset(0, inverse_pre_pad));
-    // Now, compute the accumulated ratios and the bottom diff
-    memset(accum_ratio_data, 0, sizeof(Dtype) * accum_ratio.count());
-    for (int c = 0; c < size_ - 1; ++c) {
-      caffeine_axpy<Dtype>(height_ * width_, 1.,
-          padded_ratio_data + padded_ratio.offset(0, c), accum_ratio_data);
-    }
-    for (int c = 0; c < channels_; ++c) {
-      caffeine_axpy<Dtype>(height_ * width_, 1.,
-          padded_ratio_data + padded_ratio.offset(0, c + size_ - 1),
-          accum_ratio_data);
-      // compute bottom diff
-      caffeine_mul<Dtype>(height_ * width_,
-          bottom_data + top[0]->offset(n, c),
-          accum_ratio_data, accum_ratio_times_bottom);
-      caffeine_axpy<Dtype>(height_ * width_, -cache_ratio_value,
-          accum_ratio_times_bottom, bottom_diff + top[0]->offset(n,c));
-      caffeine_axpy<Dtype>(height_ * width_, -1.,
-          padded_ratio_data + padded_ratio.offset(0, c), accum_ratio_data);
-    }
-  }
+template <typename Dtype>
+Dtype LRNLayer<Dtype>::Backward_gpu(const vector<Blob<Dtype>*>& top,
+    const bool propagate_down, vector<Blob<Dtype>*>* bottom) {
+  NOT_IMPLEMENTED;
   return Dtype(0.);
 }
 
 INSTANTIATE_CLASS(LRNLayer);
 
-
 }  // namespace caffeine
index c850eef52074a6e6ccd85427c1c6eee980a56dac..54daecb1c20d3ac9f5efca1c8d54d6e476d634b1 100644 (file)
@@ -102,6 +102,15 @@ TYPED_TEST(LRNLayerTest, TestCPU) {
     EXPECT_LE(this->blob_top_->cpu_data()[i],
         top_reference.cpu_data()[i] + 1e-5);
   }
+
+  Caffeine::set_mode(Caffeine::GPU);
+  layer.Forward(this->blob_bottom_vec_, &(this->blob_top_vec_));
+  for (int i = 0; i < this->blob_bottom_->count(); ++i) {
+    EXPECT_GE(this->blob_top_->cpu_data()[i],
+        top_reference.cpu_data()[i] - 1e-5);
+    EXPECT_LE(this->blob_top_->cpu_data()[i],
+        top_reference.cpu_data()[i] + 1e-5);
+  }
 }
 
 TYPED_TEST(LRNLayerTest, TestCPUGradient) {
index d5f939a210d4e4e97ab3237b9558d3ce3f39a7ad..e609f0c61790a88f74edbe5952b3f0fe78e185e4 100644 (file)
@@ -118,12 +118,12 @@ class LRNLayer : public Layer<Dtype> {
  protected:
   virtual void Forward_cpu(const vector<Blob<Dtype>*>& bottom,
       vector<Blob<Dtype>*>* top);
-  //virtual void Forward_gpu(const vector<Blob<Dtype>*>& bottom,
-  //    vector<Blob<Dtype>*>* top);
+  virtual void Forward_gpu(const vector<Blob<Dtype>*>& bottom,
+      vector<Blob<Dtype>*>* top);
   virtual Dtype Backward_cpu(const vector<Blob<Dtype>*>& top,
       const bool propagate_down, vector<Blob<Dtype>*>* bottom);
-  //virtual Dtype Backward_gpu(const vector<Blob<Dtype>*>& top,
-  //    const bool propagate_down, vector<Blob<Dtype>*>* bottom);
+  virtual Dtype Backward_gpu(const vector<Blob<Dtype>*>& top,
+      const bool propagate_down, vector<Blob<Dtype>*>* bottom);
   // scale_ stores the intermediate summing results
   Blob<Dtype> scale_;
   int size_;