diff --git a/src/caffeine/layers/inner_product_layer.cu b/src/caffeine/layers/inner_product_layer.cu
index 26a7c4bc93c2cb35f715fe0d2722ef62fe8f9f5c..5b1124c3db41218e7463e3f4722de84ec44caae3 100644 (file)
const bool propagate_down,
vector<Blob<Dtype>*>* bottom) {
const Dtype* top_diff = top[0]->cpu_diff();
- CHECK(false);
+ const Dtype* bottom_data = (*bottom)[0]->cpu_data();
+ // Gradient with respect to weight
+ caffeine_cpu_gemm<Dtype>(CblasTrans, CblasNoTrans, K_, N_, M_, (Dtype)1.,
+ bottom_data, top_diff, (Dtype)0., this->blobs_[0].mutable_cpu_diff());
+ if (biasterm_) {
+ // Gradient with respect to bias
+ caffeine_cpu_gemv<Dtype>(CblasTrans, M_, N_, (Dtype)1., top_diff,
+ (Dtype*)bias_multiplier_->cpu_data(), (Dtype)0.,
+ this->blobs_[1].mutable_cpu_diff());
+ }
+ if (propagate_down) {
+ // Gradient with respect to bottom data
+ caffeine_cpu_gemm<Dtype>(CblasNoTrans, CblasTrans, M_, K_, N_, (Dtype)1.,
+ top_diff, this->blobs_[0].cpu_data(), (Dtype)0.,
+ (*bottom)[0]->mutable_cpu_diff());
+ }
return Dtype(0);
}
const Dtype* bottom_data = bottom[0]->gpu_data();
Dtype* top_data = (*top)[0]->mutable_gpu_data();
const Dtype* weight = this->blobs_[0].gpu_data();
- const Dtype* bias = NULL;
- Dtype alpha = 1., beta = 0.;
+ caffeine_gpu_gemm<Dtype>(CblasNoTrans, CblasNoTrans, M_, N_, K_, (Dtype)1.,
+ bottom_data, weight, (Dtype)0., top_data);
if (biasterm_) {
- bias = this->blobs_[1].gpu_data();
- beta = 1.;
- const int count = (*top)[0]->count();
- // we pre-copy the bias to the results, and then call gemm.
- BroadcastRow<<<CAFFEINE_GET_BLOCKS(count), CAFFEINE_CUDA_NUM_THREADS>>>(
- count, N_, bias, top_data);
- }
- switch(sizeof(Dtype)) {
- case sizeof(float):
- // matrix multiply: since cublas uses Fortran major, we actually do
- // C' = B' A'
- CUBLAS_CHECK(cublasSgemm(Caffeine::cublas_handle(), CUBLAS_OP_N,
- CUBLAS_OP_N, N_, M_, K_, (float*)&alpha, (const float*)weight, N_,
- (const float*)bottom_data, K_, (float*)&beta, (float*)top_data, N_));
- break;
- case sizeof(double):
- // matrix multiply
- CUBLAS_CHECK(cublasDgemm(Caffeine::cublas_handle(), CUBLAS_OP_N,
- CUBLAS_OP_N, N_, M_, K_, (double*)&alpha, (const double*)weight, N_,
- (const double*)bottom_data, K_, (double*)&beta, (double*)top_data, N_));
- break;
- default:
- CHECK(false) << "Unknown data type.";
+ caffeine_gpu_gemm<Dtype>(CblasNoTrans, CblasNoTrans, M_, N_, 1, (Dtype)1.,
+ (Dtype*)bias_multiplier_->gpu_data(), this->blobs_[1].gpu_data(),
+ (Dtype)1., top_data);
}
}
Dtype InnerProductLayer<Dtype>::Backward_gpu(const vector<Blob<Dtype>*>& top,
const bool propagate_down,
vector<Blob<Dtype>*>* bottom) {
- CHECK(false);
- return Dtype(0.);
+ const Dtype* top_diff = top[0]->gpu_diff();
+ const Dtype* bottom_data = (*bottom)[0]->gpu_data();
+ // Gradient with respect to weight
+ caffeine_gpu_gemm<Dtype>(CblasTrans, CblasNoTrans, K_, N_, M_, (Dtype)1.,
+ bottom_data, top_diff, (Dtype)0., this->blobs_[0].mutable_gpu_diff());
+ if (biasterm_) {
+ // Gradient with respect to bias
+ caffeine_gpu_gemv<Dtype>(CblasTrans, M_, N_, (Dtype)1., top_diff,
+ (Dtype*)bias_multiplier_->gpu_data(), (Dtype)0.,
+ this->blobs_[1].mutable_gpu_diff());
+ }
+ if (propagate_down) {
+ // Gradient with respect to bottom data
+ caffeine_gpu_gemm<Dtype>(CblasNoTrans, CblasTrans, M_, K_, N_, (Dtype)1.,
+ top_diff, this->blobs_[0].gpu_data(), (Dtype)0.,
+ (*bottom)[0]->mutable_gpu_diff());
+ }
+ return Dtype(0);
}
INSTANTIATE_CLASS(InnerProductLayer);