]> Gitweb @ Texas Instruments - Open Source Git Repositories - git.TI.com/gitweb - jacinto-ai/caffe-jacinto.git/blob - src/caffeine/layers/inner_product_layer.cu
inner product forward backward
[jacinto-ai/caffe-jacinto.git] / src / caffeine / layers / inner_product_layer.cu
1 #include <mkl.h>
2 #include <cublas_v2.h>
4 #include "caffeine/blob.hpp"
5 #include "caffeine/common.hpp"
6 #include "caffeine/filler.hpp"
7 #include "caffeine/layer.hpp"
8 #include "caffeine/vision_layers.hpp"
9 #include "caffeine/util/blas.hpp"
11 namespace caffeine {
13 template <typename Dtype>
14 void InnerProductLayer<Dtype>::SetUp(const vector<Blob<Dtype>*>& bottom,
15       vector<Blob<Dtype>*>* top) {
16   CHECK_EQ(bottom.size(), 1) << "IP Layer takes a single blob as input.";
17   CHECK_EQ(top->size(), 1) << "IP Layer takes a single blob as output.";
18   const int num_output = this->layer_param_.num_output();
19   biasterm_ = this->layer_param_.biasterm();
20   // Figure out the dimensions
21   M_ = bottom[0]->num();
22   K_ = bottom[0]->count() / bottom[0]->num();
23   N_ = num_output;
24   (*top)[0]->Reshape(bottom[0]->num(), num_output, 1, 1);
25   if (biasterm_) {
26     this->blobs_.resize(2);
27   } else {
28     this->blobs_.resize(1);
29   }
30   // Intialize the weight
31   this->blobs_[0].Reshape(1, 1, K_, N_);
32   // fill the weights
33   shared_ptr<Filler<Dtype> > weight_filler(
34       GetFiller<Dtype>(this->layer_param_.weight_filler()));
35   weight_filler->Fill(&this->blobs_[0]);
36   // If necessary, intiialize and fill the bias term
37   if (biasterm_) {
38     this->blobs_[1].Reshape(1, 1, 1, N_);
39     shared_ptr<Filler<Dtype> > bias_filler(
40         GetFiller<Dtype>(this->layer_param_.bias_filler()));
41     bias_filler->Fill(&this->blobs_[1]);
42     bias_multiplier_.reset(new SyncedMemory(M_ * sizeof(Dtype)));
43     Dtype* bias_multiplier_data = (Dtype*)bias_multiplier_->mutable_cpu_data();
44     for (int i = 0; i < M_; ++i) {
45         bias_multiplier_data[i] = 1.;
46     }
47   }
48 };
50 template <typename Dtype>
51 void InnerProductLayer<Dtype>::Forward_cpu(const vector<Blob<Dtype>*>& bottom,
52     vector<Blob<Dtype>*>* top) {
53   const Dtype* bottom_data = bottom[0]->cpu_data();
54   Dtype* top_data = (*top)[0]->mutable_cpu_data();
55   const Dtype* weight = this->blobs_[0].cpu_data();
56   caffeine_cpu_gemm<Dtype>(CblasNoTrans, CblasNoTrans, M_, N_, K_, (Dtype)1.,
57       bottom_data, weight, (Dtype)0., top_data);
58   if (biasterm_) {
59     caffeine_cpu_gemm<Dtype>(CblasNoTrans, CblasNoTrans, M_, N_, 1, (Dtype)1.,
60         (Dtype*)bias_multiplier_->cpu_data(), this->blobs_[1].cpu_data(),
61         (Dtype)1., top_data);
62   }
63 }
65 template <typename Dtype>
66 Dtype InnerProductLayer<Dtype>::Backward_cpu(const vector<Blob<Dtype>*>& top,
67     const bool propagate_down,
68     vector<Blob<Dtype>*>* bottom) {
69   const Dtype* top_diff = top[0]->cpu_diff();
70   const Dtype* bottom_data = (*bottom)[0]->cpu_data();
71   // Gradient with respect to weight
72   caffeine_cpu_gemm<Dtype>(CblasTrans, CblasNoTrans, K_, N_, M_, (Dtype)1.,
73       bottom_data, top_diff, (Dtype)0., this->blobs_[0].mutable_cpu_diff());
74   if (biasterm_) {
75     // Gradient with respect to bias
76     caffeine_cpu_gemv<Dtype>(CblasTrans, M_, N_, (Dtype)1., top_diff,
77         (Dtype*)bias_multiplier_->cpu_data(), (Dtype)0.,
78         this->blobs_[1].mutable_cpu_diff());
79   }
80   if (propagate_down) {
81     // Gradient with respect to bottom data
82     caffeine_cpu_gemm<Dtype>(CblasNoTrans, CblasTrans, M_, K_, N_, (Dtype)1.,
83         top_diff, this->blobs_[0].cpu_data(), (Dtype)0.,
84         (*bottom)[0]->mutable_cpu_diff());
85   }
86   return Dtype(0);
87 }
89 template <typename Dtype>
90 __global__ void BroadcastRow(const int total, const int vec_len,
91         const Dtype* in_vec, Dtype* out_matrix) {
92   int index = threadIdx.x + blockIdx.x * blockDim.x;
93   if (index < total) {
94     int v_index = index % vec_len;
95     out_matrix[index] = in_vec[v_index];
96   }
97 }
101 template <typename Dtype>
102 void InnerProductLayer<Dtype>::Forward_gpu(const vector<Blob<Dtype>*>& bottom,
103     vector<Blob<Dtype>*>* top) {
104   const Dtype* bottom_data = bottom[0]->gpu_data();
105   Dtype* top_data = (*top)[0]->mutable_gpu_data();
106   const Dtype* weight = this->blobs_[0].gpu_data();
107   caffeine_gpu_gemm<Dtype>(CblasNoTrans, CblasNoTrans, M_, N_, K_, (Dtype)1.,
108       bottom_data, weight, (Dtype)0., top_data);
109   if (biasterm_) {
110     caffeine_gpu_gemm<Dtype>(CblasNoTrans, CblasNoTrans, M_, N_, 1, (Dtype)1.,
111         (Dtype*)bias_multiplier_->gpu_data(), this->blobs_[1].gpu_data(),
112         (Dtype)1., top_data);
113   }
116 template <typename Dtype>
117 Dtype InnerProductLayer<Dtype>::Backward_gpu(const vector<Blob<Dtype>*>& top,
118     const bool propagate_down,
119     vector<Blob<Dtype>*>* bottom) {
120   const Dtype* top_diff = top[0]->gpu_diff();
121   const Dtype* bottom_data = (*bottom)[0]->gpu_data();
122   // Gradient with respect to weight
123   caffeine_gpu_gemm<Dtype>(CblasTrans, CblasNoTrans, K_, N_, M_, (Dtype)1.,
124       bottom_data, top_diff, (Dtype)0., this->blobs_[0].mutable_gpu_diff());
125   if (biasterm_) {
126     // Gradient with respect to bias
127     caffeine_gpu_gemv<Dtype>(CblasTrans, M_, N_, (Dtype)1., top_diff,
128         (Dtype*)bias_multiplier_->gpu_data(), (Dtype)0.,
129         this->blobs_[1].mutable_gpu_diff());
130   }
131   if (propagate_down) {
132     // Gradient with respect to bottom data
133     caffeine_gpu_gemm<Dtype>(CblasNoTrans, CblasTrans, M_, K_, N_, (Dtype)1.,
134         top_diff, this->blobs_[0].gpu_data(), (Dtype)0.,
135         (*bottom)[0]->mutable_gpu_diff());
136   }
137   return Dtype(0);
140 INSTANTIATE_CLASS(InnerProductLayer);
142 }  // namespace caffeine