]> Gitweb @ Texas Instruments - Open Source Git Repositories - git.TI.com/gitweb - jacinto-ai/caffe-jacinto.git/blobdiff - src/caffe/layers/conv_layer.cpp
naming. I might regret it someday.
[jacinto-ai/caffe-jacinto.git] / src / caffe / layers / conv_layer.cpp
similarity index 90%
rename from src/caffeine/layers/conv_layer.cpp
rename to src/caffe/layers/conv_layer.cpp
index 8670d81f1c60b606c15e0755bd62de28fe6367b0..c9dc2f62c453ff14458a904dfd0d9cf86f75c506 100644 (file)
@@ -1,10 +1,10 @@
-#include "caffeine/layer.hpp"
-#include "caffeine/vision_layers.hpp"
-#include "caffeine/util/im2col.hpp"
-#include "caffeine/filler.hpp"
-#include "caffeine/util/math_functions.hpp"
+#include "caffe/layer.hpp"
+#include "caffe/vision_layers.hpp"
+#include "caffe/util/im2col.hpp"
+#include "caffe/filler.hpp"
+#include "caffe/util/math_functions.hpp"
 
-namespace caffeine {
+namespace caffe {
 
 template <typename Dtype>
 void ConvolutionLayer<Dtype>::SetUp(const vector<Blob<Dtype>*>& bottom,
@@ -76,13 +76,13 @@ void ConvolutionLayer<Dtype>::Forward_cpu(const vector<Blob<Dtype>*>& bottom,
         WIDTH_, KSIZE_, STRIDE_, col_data);
     // Second, innerproduct with groups
     for (int g = 0; g < GROUP_; ++g) {
-      caffeine_cpu_gemm<Dtype>(CblasNoTrans, CblasNoTrans, M_, N_, K_,
+      caffe_cpu_gemm<Dtype>(CblasNoTrans, CblasNoTrans, M_, N_, K_,
         (Dtype)1., weight + weight_offset * g, col_data + col_offset * g,
         (Dtype)0., top_data + (*top)[0]->offset(n) + top_offset * g);
     }
     // third, add bias
     if (biasterm_) {
-      caffeine_cpu_gemm<Dtype>(CblasNoTrans, CblasNoTrans, NUM_OUTPUT_,
+      caffe_cpu_gemm<Dtype>(CblasNoTrans, CblasNoTrans, NUM_OUTPUT_,
           N_, 1, (Dtype)1., this->blobs_[1].cpu_data(),
           (Dtype*)bias_multiplier_->cpu_data(), (Dtype)1.,
           top_data + (*top)[0]->offset(n));
@@ -106,13 +106,13 @@ void ConvolutionLayer<Dtype>::Forward_gpu(const vector<Blob<Dtype>*>& bottom,
         WIDTH_, KSIZE_, STRIDE_, col_data);
     // Second, innerproduct with groups
     for (int g = 0; g < GROUP_; ++g) {
-      caffeine_gpu_gemm<Dtype>(CblasNoTrans, CblasNoTrans, M_, N_, K_,
+      caffe_gpu_gemm<Dtype>(CblasNoTrans, CblasNoTrans, M_, N_, K_,
         (Dtype)1., weight + weight_offset * g, col_data + col_offset * g,
         (Dtype)0., top_data + (*top)[0]->offset(n) + top_offset * g);
     }
     // third, add bias
     if (biasterm_) {
-      caffeine_gpu_gemm<Dtype>(CblasNoTrans, CblasNoTrans, NUM_OUTPUT_,
+      caffe_gpu_gemm<Dtype>(CblasNoTrans, CblasNoTrans, NUM_OUTPUT_,
           N_, 1, (Dtype)1., this->blobs_[1].gpu_data(),
           (Dtype*)bias_multiplier_->gpu_data(), (Dtype)1.,
           top_data + (*top)[0]->offset(n));
@@ -137,7 +137,7 @@ Dtype ConvolutionLayer<Dtype>::Backward_cpu(const vector<Blob<Dtype>*>& top,
     bias_diff = this->blobs_[1].mutable_cpu_diff();
     memset(bias_diff, 0., sizeof(Dtype) * this->blobs_[1].count());
     for (int n = 0; n < NUM_; ++n) {
-      caffeine_cpu_gemv<Dtype>(CblasNoTrans, NUM_OUTPUT_, N_,
+      caffe_cpu_gemv<Dtype>(CblasNoTrans, NUM_OUTPUT_, N_,
         1., top_diff + top[0]->offset(n),
         (Dtype*)bias_multiplier_->cpu_data(), 1., bias_diff);
     }
@@ -154,7 +154,7 @@ Dtype ConvolutionLayer<Dtype>::Backward_cpu(const vector<Blob<Dtype>*>& top,
         WIDTH_, KSIZE_, STRIDE_, col_data);
     // gradient w.r.t. weight. Note that we will accumulate diffs.
     for (int g = 0; g < GROUP_; ++g) {
-      caffeine_cpu_gemm<Dtype>(CblasNoTrans, CblasTrans, M_, K_, N_,
+      caffe_cpu_gemm<Dtype>(CblasNoTrans, CblasTrans, M_, K_, N_,
         (Dtype)1., top_diff + top[0]->offset(n) + top_offset * g,
         col_data + col_offset * g, (Dtype)1.,
         weight_diff + weight_offset * g);
@@ -162,7 +162,7 @@ Dtype ConvolutionLayer<Dtype>::Backward_cpu(const vector<Blob<Dtype>*>& top,
     // gradient w.r.t. bottom data, if necessary
     if (propagate_down) {
       for (int g = 0; g < GROUP_; ++g) {
-        caffeine_cpu_gemm<Dtype>(CblasTrans, CblasNoTrans, K_, N_, M_,
+        caffe_cpu_gemm<Dtype>(CblasTrans, CblasNoTrans, K_, N_, M_,
           (Dtype)1., weight + weight_offset * g,
           top_diff + top[0]->offset(n) + top_offset * g,
           (Dtype)0., col_diff + col_offset * g);
@@ -193,7 +193,7 @@ Dtype ConvolutionLayer<Dtype>::Backward_gpu(const vector<Blob<Dtype>*>& top,
     CUDA_CHECK(cudaMemset(bias_diff, 0.,
         sizeof(Dtype) * this->blobs_[1].count()));
     for (int n = 0; n < NUM_; ++n) {
-      caffeine_gpu_gemv<Dtype>(CblasNoTrans, NUM_OUTPUT_, N_,
+      caffe_gpu_gemv<Dtype>(CblasNoTrans, NUM_OUTPUT_, N_,
         1., top_diff + top[0]->offset(n),
         (Dtype*)bias_multiplier_->gpu_data(), 1., bias_diff);
     }
@@ -211,7 +211,7 @@ Dtype ConvolutionLayer<Dtype>::Backward_gpu(const vector<Blob<Dtype>*>& top,
         WIDTH_, KSIZE_, STRIDE_, col_data);
     // gradient w.r.t. weight. Note that we will accumulate diffs.
     for (int g = 0; g < GROUP_; ++g) {
-      caffeine_gpu_gemm<Dtype>(CblasNoTrans, CblasTrans, M_, K_, N_,
+      caffe_gpu_gemm<Dtype>(CblasNoTrans, CblasTrans, M_, K_, N_,
         (Dtype)1., top_diff + top[0]->offset(n) + top_offset * g,
         col_data + col_offset * g, (Dtype)1.,
         weight_diff + weight_offset * g);
@@ -219,7 +219,7 @@ Dtype ConvolutionLayer<Dtype>::Backward_gpu(const vector<Blob<Dtype>*>& top,
     // gradient w.r.t. bottom data, if necessary
     if (propagate_down) {
       for (int g = 0; g < GROUP_; ++g) {
-        caffeine_gpu_gemm<Dtype>(CblasTrans, CblasNoTrans, K_, N_, M_,
+        caffe_gpu_gemm<Dtype>(CblasTrans, CblasNoTrans, K_, N_, M_,
           (Dtype)1., weight + weight_offset * g,
           top_diff + top[0]->offset(n) + top_offset * g,
           (Dtype)0., col_diff + col_offset * g);
@@ -234,4 +234,4 @@ Dtype ConvolutionLayer<Dtype>::Backward_gpu(const vector<Blob<Dtype>*>& top,
 
 INSTANTIATE_CLASS(ConvolutionLayer);
 
-}  // namespace caffeine
+}  // namespace caffe