1 #include <mkl.h>
2 #include <cublas_v2.h>
4 #include "caffeine/blob.hpp"
5 #include "caffeine/common.hpp"
6 #include "caffeine/filler.hpp"
7 #include "caffeine/layer.hpp"
8 #include "caffeine/vision_layers.hpp"
9 #include "caffeine/util/blas.hpp"
11 namespace caffeine {
13 template <typename Dtype>
14 void InnerProductLayer<Dtype>::SetUp(const vector<Blob<Dtype>*>& bottom,
15 vector<Blob<Dtype>*>* top) {
16 CHECK_EQ(bottom.size(), 1) << "IP Layer takes a single blob as input.";
17 CHECK_EQ(top->size(), 1) << "IP Layer takes a single blob as output.";
18 const int num_output = this->layer_param_.num_output();
19 biasterm_ = this->layer_param_.biasterm();
20 // Figure out the dimensions
21 M_ = bottom[0]->num();
22 K_ = bottom[0]->count() / bottom[0]->num();
23 N_ = num_output;
24 (*top)[0]->Reshape(bottom[0]->num(), num_output, 1, 1);
25 if (biasterm_) {
26 this->blobs_.resize(2);
27 } else {
28 this->blobs_.resize(1);
29 }
30 // Intialize the weight
31 this->blobs_[0].Reshape(1, 1, K_, N_);
32 // fill the weights
33 shared_ptr<Filler<Dtype> > weight_filler(
34 GetFiller<Dtype>(this->layer_param_.weight_filler()));
35 weight_filler->Fill(&this->blobs_[0]);
36 // If necessary, intiialize and fill the bias term
37 if (biasterm_) {
38 this->blobs_[1].Reshape(1, 1, 1, N_);
39 shared_ptr<Filler<Dtype> > bias_filler(
40 GetFiller<Dtype>(this->layer_param_.bias_filler()));
41 bias_filler->Fill(&this->blobs_[1]);
42 bias_multiplier_.reset(new SyncedMemory(M_ * sizeof(Dtype)));
43 Dtype* bias_multiplier_data = (Dtype*)bias_multiplier_->mutable_cpu_data();
44 for (int i = 0; i < M_; ++i) {
45 bias_multiplier_data[i] = 1.;
46 }
47 }
48 };
50 template <typename Dtype>
51 void InnerProductLayer<Dtype>::Forward_cpu(const vector<Blob<Dtype>*>& bottom,
52 vector<Blob<Dtype>*>* top) {
53 const Dtype* bottom_data = bottom[0]->cpu_data();
54 Dtype* top_data = (*top)[0]->mutable_cpu_data();
55 const Dtype* weight = this->blobs_[0].cpu_data();
56 caffeine_cpu_gemm<Dtype>(CblasNoTrans, CblasNoTrans, M_, N_, K_, (Dtype)1.,
57 bottom_data, weight, (Dtype)0., top_data);
58 if (biasterm_) {
59 caffeine_cpu_gemm<Dtype>(CblasNoTrans, CblasNoTrans, M_, N_, 1, (Dtype)1.,
60 (Dtype*)bias_multiplier_->cpu_data(), this->blobs_[1].cpu_data(),
61 (Dtype)1., top_data);
62 }
63 }
65 template <typename Dtype>
66 Dtype InnerProductLayer<Dtype>::Backward_cpu(const vector<Blob<Dtype>*>& top,
67 const bool propagate_down,
68 vector<Blob<Dtype>*>* bottom) {
69 const Dtype* top_diff = top[0]->cpu_diff();
70 const Dtype* bottom_data = (*bottom)[0]->cpu_data();
71 // Gradient with respect to weight
72 caffeine_cpu_gemm<Dtype>(CblasTrans, CblasNoTrans, K_, N_, M_, (Dtype)1.,
73 bottom_data, top_diff, (Dtype)0., this->blobs_[0].mutable_cpu_diff());
74 if (biasterm_) {
75 // Gradient with respect to bias
76 caffeine_cpu_gemv<Dtype>(CblasTrans, M_, N_, (Dtype)1., top_diff,
77 (Dtype*)bias_multiplier_->cpu_data(), (Dtype)0.,
78 this->blobs_[1].mutable_cpu_diff());
79 }
80 if (propagate_down) {
81 // Gradient with respect to bottom data
82 caffeine_cpu_gemm<Dtype>(CblasNoTrans, CblasTrans, M_, K_, N_, (Dtype)1.,
83 top_diff, this->blobs_[0].cpu_data(), (Dtype)0.,
84 (*bottom)[0]->mutable_cpu_diff());
85 }
86 return Dtype(0);
87 }
89 template <typename Dtype>
90 __global__ void BroadcastRow(const int total, const int vec_len,
91 const Dtype* in_vec, Dtype* out_matrix) {
92 int index = threadIdx.x + blockIdx.x * blockDim.x;
93 if (index < total) {
94 int v_index = index % vec_len;
95 out_matrix[index] = in_vec[v_index];
96 }
97 }
101 template <typename Dtype>
102 void InnerProductLayer<Dtype>::Forward_gpu(const vector<Blob<Dtype>*>& bottom,
103 vector<Blob<Dtype>*>* top) {
104 const Dtype* bottom_data = bottom[0]->gpu_data();
105 Dtype* top_data = (*top)[0]->mutable_gpu_data();
106 const Dtype* weight = this->blobs_[0].gpu_data();
107 caffeine_gpu_gemm<Dtype>(CblasNoTrans, CblasNoTrans, M_, N_, K_, (Dtype)1.,
108 bottom_data, weight, (Dtype)0., top_data);
109 if (biasterm_) {
110 caffeine_gpu_gemm<Dtype>(CblasNoTrans, CblasNoTrans, M_, N_, 1, (Dtype)1.,
111 (Dtype*)bias_multiplier_->gpu_data(), this->blobs_[1].gpu_data(),
112 (Dtype)1., top_data);
113 }
114 }
116 template <typename Dtype>
117 Dtype InnerProductLayer<Dtype>::Backward_gpu(const vector<Blob<Dtype>*>& top,
118 const bool propagate_down,
119 vector<Blob<Dtype>*>* bottom) {
120 const Dtype* top_diff = top[0]->gpu_diff();
121 const Dtype* bottom_data = (*bottom)[0]->gpu_data();
122 // Gradient with respect to weight
123 caffeine_gpu_gemm<Dtype>(CblasTrans, CblasNoTrans, K_, N_, M_, (Dtype)1.,
124 bottom_data, top_diff, (Dtype)0., this->blobs_[0].mutable_gpu_diff());
125 if (biasterm_) {
126 // Gradient with respect to bias
127 caffeine_gpu_gemv<Dtype>(CblasTrans, M_, N_, (Dtype)1., top_diff,
128 (Dtype*)bias_multiplier_->gpu_data(), (Dtype)0.,
129 this->blobs_[1].mutable_gpu_diff());
130 }
131 if (propagate_down) {
132 // Gradient with respect to bottom data
133 caffeine_gpu_gemm<Dtype>(CblasNoTrans, CblasTrans, M_, K_, N_, (Dtype)1.,
134 top_diff, this->blobs_[0].gpu_data(), (Dtype)0.,
135 (*bottom)[0]->mutable_gpu_diff());
136 }
137 return Dtype(0);
138 }
140 INSTANTIATE_CLASS(InnerProductLayer);
142 } // namespace caffeine