]> Gitweb @ Texas Instruments - Open Source Git Repositories - git.TI.com/gitweb - jacinto-ai/caffe-jacinto.git/commitdiff
working version
authorYangqing Jia <jiayq84@gmail.com>
Wed, 18 Sep 2013 01:35:25 +0000 (18:35 -0700)
committerYangqing Jia <jiayq84@gmail.com>
Wed, 18 Sep 2013 01:35:25 +0000 (18:35 -0700)
src/caffeine/layers/inner_product_layer.cu
src/caffeine/vision_layers.hpp

index baeca0e8729797e308155f81d8994db095fbd1db..88dd41c2c664199854127fca11b244331c9e43e0 100644 (file)
@@ -14,40 +14,45 @@ void InnerProductLayer<Dtype>::SetUp(const vector<Blob<Dtype>*>& bottom,
       vector<Blob<Dtype>*>* top) {
   CHECK_EQ(bottom.size(), 1) << "IP Layer takes a single blob as input.";
   CHECK_EQ(top->size(), 1) << "IP Layer takes a single blob as output.";
       vector<Blob<Dtype>*>* top) {
   CHECK_EQ(bottom.size(), 1) << "IP Layer takes a single blob as input.";
   CHECK_EQ(top->size(), 1) << "IP Layer takes a single blob as output.";
-       const int num_output = this->layer_param_.num_output();
-       const bool gemm_last_dim = this->layer_param_.gemm_last_dim();
-       biasterm_ = this->layer_param_.biasterm();
-       // Figure out the dimensions
-       if (gemm_last_dim) {
-               M_ = bottom[0]->count() / bottom[0]->channels();
-       K_ = bottom[0]->channels();
-       N_ = num_output;
-       (*top)[0]->Reshape(bottom[0]->num(), bottom[0]->height(),
-                       bottom[0]->width(), num_output);
-       } else {
-               M_ = bottom[0]->num();
-               K_ = bottom[0]->count() / bottom[0]->num();
-               N_ = num_output;
-               (*top)[0]->Reshape(bottom[0]->num(), 1, 1, num_output);
-       }
+  const int num_output = this->layer_param_.num_output();
+  const bool gemm_last_dim = this->layer_param_.gemm_last_dim();
+  biasterm_ = this->layer_param_.biasterm();
+  // Figure out the dimensions
+  if (gemm_last_dim) {
+    M_ = bottom[0]->count() / bottom[0]->channels();
+    K_ = bottom[0]->channels();
+    N_ = num_output;
+    (*top)[0]->Reshape(bottom[0]->num(), bottom[0]->height(),
+        bottom[0]->width(), num_output);
+  } else {
+    M_ = bottom[0]->num();
+    K_ = bottom[0]->count() / bottom[0]->num();
+    N_ = num_output;
+    (*top)[0]->Reshape(bottom[0]->num(), 1, 1, num_output);
+  }
   if (biasterm_) {
     this->blobs_.resize(2);
   } else {
     this->blobs_.resize(1);
   }
   if (biasterm_) {
     this->blobs_.resize(2);
   } else {
     this->blobs_.resize(1);
   }
-       // Intialize the weight
+  // Intialize the weight
   this->blobs_[0].Reshape(1, 1, K_, N_);
   this->blobs_[0].Reshape(1, 1, K_, N_);
-       // fill the weights
-       shared_ptr<Filler<Dtype> > weight_filler(
-                       GetFiller<Dtype>(this->layer_param_.weight_filler()));
-       weight_filler->Fill(&this->blobs_[0]);
-       // If necessary, intiialize and fill the bias term
-       if (biasterm_) {
+  // fill the weights
+  shared_ptr<Filler<Dtype> > weight_filler(
+      GetFiller<Dtype>(this->layer_param_.weight_filler()));
+  weight_filler->Fill(&this->blobs_[0]);
+  // If necessary, intiialize and fill the bias term
+  if (biasterm_) {
     this->blobs_[1].Reshape(1, 1, 1, N_);
     this->blobs_[1].Reshape(1, 1, 1, N_);
-               shared_ptr<Filler<Dtype> > bias_filler(
-                               GetFiller<Dtype>(this->layer_param_.bias_filler()));
-               bias_filler->Fill(&this->blobs_[1]);
-       }
+    shared_ptr<Filler<Dtype> > bias_filler(
+        GetFiller<Dtype>(this->layer_param_.bias_filler()));
+    bias_filler->Fill(&this->blobs_[1]);
+    bias_multiplier_.reset(new SyncedMemory(M_ * sizeof(Dtype)));
+    Dtype* bias_multiplier_data = (Dtype*)bias_multiplier_->mutable_cpu_data();
+    for (int i = 0; i < M_; ++i) {
+        bias_multiplier_data[i] = 1.;
+    }
+  }
 };
 
 template <typename Dtype>
 };
 
 template <typename Dtype>
@@ -58,37 +63,34 @@ void InnerProductLayer<Dtype>::Forward_cpu(const vector<Blob<Dtype>*>& bottom,
   const Dtype* weight = this->blobs_[0].cpu_data();
   const Dtype* bias = NULL;
   if (biasterm_) {
   const Dtype* weight = this->blobs_[0].cpu_data();
   const Dtype* bias = NULL;
   if (biasterm_) {
-       bias = this->blobs_[1].cpu_data();
+    bias = this->blobs_[1].cpu_data();
   }
   switch(sizeof(Dtype)) {
   case sizeof(float):
   }
   switch(sizeof(Dtype)) {
   case sizeof(float):
-       // matrix multiply
-       cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasNoTrans, M_, N_, K_,
-                       1., (const float*)bottom_data, K_, (const float*)weight, N_, 0.,
-                       (float*)top_data, N_);
-       if (bias) {
-               // add bias
-               for (int i = 0; i < M_; ++i) {
-                       cblas_saxpy(N_, 1., (const float*)bias, 1,
-                                       (float*)(top_data) + (N_ * i), 1);
-               }
-       }
+    // matrix multiply
+    cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasNoTrans, M_, N_, K_,
+        1., (const float*)bottom_data, K_, (const float*)weight, N_, 0.,
+        (float*)top_data, N_);
+    if (bias) {
+      // add bias
+      cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasNoTrans, M_, N_, 1,
+          1., (const float*)bias_multiplier_->cpu_data(), 1,
+          (const float*)bias, N_, 1., (float*)top_data, N_);
+    }
     break;
   case sizeof(double):
     // matrix multiply
     break;
   case sizeof(double):
     // matrix multiply
-       cblas_dgemm(CblasRowMajor, CblasNoTrans, CblasNoTrans, M_, N_, K_,
-                       1., (const double*)bottom_data, K_, (const double*)weight, N_, 0.,
-                       (double*)top_data, N_);
-       if (bias) {
-               // add bias
-               for (int i = 0; i < M_; ++i) {
-                       cblas_daxpy(N_, 1., (const double*)bias, 1,
-                                       (double*)(top_data) + (N_ * i), 1);
-               }
-       }
+    cblas_dgemm(CblasRowMajor, CblasNoTrans, CblasNoTrans, M_, N_, K_,
+        1., (const double*)bottom_data, K_, (const double*)weight, N_, 0.,
+        (double*)top_data, N_);
+    if (bias) {
+      cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasNoTrans, M_, N_, 1,
+          1., (const float*)bias_multiplier_->cpu_data(), 1,
+          (const float*)bias, N_, 1., (float*)top_data, N_);
+    }
     break;
   default:
     break;
   default:
-       CHECK(false) << "Unknown data type.";
+    CHECK(false) << "Unknown data type.";
   }
 }
 
   }
 }
 
@@ -96,17 +98,20 @@ template <typename Dtype>
 Dtype InnerProductLayer<Dtype>::Backward_cpu(const vector<Blob<Dtype>*>& top,
     const bool propagate_down,
     vector<Blob<Dtype>*>* bottom) {
 Dtype InnerProductLayer<Dtype>::Backward_cpu(const vector<Blob<Dtype>*>& top,
     const bool propagate_down,
     vector<Blob<Dtype>*>* bottom) {
-  CHECK(false);
+  // TODO: gradient w.r.t the params
+  if (propagate_down) {
+    // TODO: gradient w.r.t. the bottom
+  }
   return Dtype(0);
 }
 
 template <typename Dtype>
 __global__ void BroadcastCopy(const int total, const int vec_len,
   return Dtype(0);
 }
 
 template <typename Dtype>
 __global__ void BroadcastCopy(const int total, const int vec_len,
-       const Dtype* in_vec, Dtype* out_matrix) {
+  const Dtype* in_vec, Dtype* out_matrix) {
   int index = threadIdx.x + blockIdx.x * blockDim.x;
   if (index < total) {
   int index = threadIdx.x + blockIdx.x * blockDim.x;
   if (index < total) {
-       int v_index = index % vec_len;
-       out_matrix[index] = in_vec[v_index];
+    int v_index = index % vec_len;
+    out_matrix[index] = in_vec[v_index];
   }
 }
 
   }
 }
 
@@ -119,29 +124,29 @@ void InnerProductLayer<Dtype>::Forward_gpu(const vector<Blob<Dtype>*>& bottom,
   const Dtype* bias = NULL;
   Dtype alpha = 1., beta = 0.;
   if (biasterm_) {
   const Dtype* bias = NULL;
   Dtype alpha = 1., beta = 0.;
   if (biasterm_) {
-       bias = this->blobs_[1].gpu_data();
-       beta = 1.;
-       const int count = (*top)[0]->count();
-       // we pre-copy the bias to the results, and then call gemm.
-       BroadcastCopy<<<CAFFEINE_GET_BLOCKS(count), CAFFEINE_CUDA_NUM_THREADS>>>(
-                       count, N_, bias, top_data);
+    bias = this->blobs_[1].gpu_data();
+    beta = 1.;
+    const int count = (*top)[0]->count();
+    // we pre-copy the bias to the results, and then call gemm.
+    BroadcastCopy<<<CAFFEINE_GET_BLOCKS(count), CAFFEINE_CUDA_NUM_THREADS>>>(
+        count, N_, bias, top_data);
   }
   switch(sizeof(Dtype)) {
   case sizeof(float):
   }
   switch(sizeof(Dtype)) {
   case sizeof(float):
-       // matrix multiply: since cublas uses Fortran major, we actually do
-       // C' = B' A'
-       CUBLAS_CHECK(cublasSgemm(Caffeine::cublas_handle(), CUBLAS_OP_N,
-                       CUBLAS_OP_N, N_, M_, K_, (float*)&alpha, (const float*)weight, N_,
-                       (const float*)bottom_data, K_, (float*)&beta, (float*)top_data, N_));
+    // matrix multiply: since cublas uses Fortran major, we actually do
+    // C' = B' A'
+    CUBLAS_CHECK(cublasSgemm(Caffeine::cublas_handle(), CUBLAS_OP_N,
+        CUBLAS_OP_N, N_, M_, K_, (float*)&alpha, (const float*)weight, N_,
+        (const float*)bottom_data, K_, (float*)&beta, (float*)top_data, N_));
     break;
   case sizeof(double):
     // matrix multiply
     break;
   case sizeof(double):
     // matrix multiply
-       CUBLAS_CHECK(cublasDgemm(Caffeine::cublas_handle(), CUBLAS_OP_N,
-                       CUBLAS_OP_N, N_, M_, K_, (double*)&alpha, (const double*)weight, N_,
-                       (const double*)bottom_data, K_, (double*)&beta, (double*)top_data, N_));
+    CUBLAS_CHECK(cublasDgemm(Caffeine::cublas_handle(), CUBLAS_OP_N,
+        CUBLAS_OP_N, N_, M_, K_, (double*)&alpha, (const double*)weight, N_,
+        (const double*)bottom_data, K_, (double*)&beta, (double*)top_data, N_));
     break;
   default:
     break;
   default:
-       CHECK(false) << "Unknown data type.";
+    CHECK(false) << "Unknown data type.";
   }
 }
 
   }
 }
 
index 8cf361c586b8808e4ab0403a8bcb54efac2fd14a..e324c8ef22b3c8ab42b5dd1086d389e1427dbf34 100644 (file)
@@ -80,6 +80,7 @@ class InnerProductLayer : public Layer<Dtype> {
   int K_;
   int N_;
   bool biasterm_;
   int K_;
   int N_;
   bool biasterm_;
+  shared_ptr<SyncedMemory> bias_multiplier_;
 };
 
 }  // namespace caffeine
 };
 
 }  // namespace caffeine