]> Gitweb @ Texas Instruments - Open Source Git Repositories - git.TI.com/gitweb - jacinto-ai/caffe-jacinto.git/blob - src/caffeine/util/gemm.cpp
gemm util
[jacinto-ai/caffe-jacinto.git] / src / caffeine / util / gemm.cpp
1 #include <mkl.h>
2 #include <cublas_v2.h>
3 #include "caffeine/common.hpp"
4 #include "caffeine/util/gemm.hpp"
6 namespace caffeine {
8 template<>
9 void decaf_cpu_gemm<float>(const CBLAS_TRANSPOSE TransA,
10     const CBLAS_TRANSPOSE TransB, const int M, const int N, const int K,
11     const float alpha, const float* A, const float* B, const float beta,
12     float* C) {
13   int lda = (TransA == CblasNoTrans) ? K : M;
14   int ldb = (TransB == CblasNoTrans) ? N : K;
15   cblas_sgemm(CblasRowMajor, TransA, TransB, M, N, K, alpha, A, lda, B,
16       ldb, beta, C, N);
17 }
19 template<>
20 void decaf_cpu_gemm<double>(const CBLAS_TRANSPOSE TransA,
21     const CBLAS_TRANSPOSE TransB, const int M, const int N, const int K,
22     const double alpha, const double* A, const double* B, const double beta,
23     double* C) {
24   int lda = (TransA == CblasNoTrans) ? K : M;
25   int ldb = (TransB == CblasNoTrans) ? N : K;
26   cblas_dgemm(CblasRowMajor, TransA, TransB, M, N, K, alpha, A, lda, B,
27       ldb, beta, C, N);
28 }
30 template <>
31 void decaf_gpu_gemm<float>(const CBLAS_TRANSPOSE TransA,
32     const CBLAS_TRANSPOSE TransB, const int M, const int N, const int K,
33     const float alpha, const float* A, const float* B, const float beta,
34     float* C) {
35   // Note that cublas follows fortran order.
36   int lda = (TransA == CblasNoTrans) ? K : M;
37   int ldb = (TransB == CblasNoTrans) ? N : K;
38   cublasOperation_t cuTransA = 
39       (TransA == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T;
40   cublasOperation_t cuTransB =
41       (TransB == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T;
42   CUBLAS_CHECK(cublasSgemm(Caffeine::cublas_handle(), cuTransB, cuTransA,
43       N, M, K, &alpha, B, ldb, A, lda, &beta, C, N));  
44 }
46 template <>
47 void decaf_gpu_gemm<double>(const CBLAS_TRANSPOSE TransA,
48     const CBLAS_TRANSPOSE TransB, const int M, const int N, const int K,
49     const double alpha, const double* A, const double* B, const double beta,
50     double* C) {
51   // Note that cublas follows fortran order.
52   int lda = (TransA == CblasNoTrans) ? K : M;
53   int ldb = (TransB == CblasNoTrans) ? N : K;
54   cublasOperation_t cuTransA = 
55       (TransA == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T;
56   cublasOperation_t cuTransB =
57       (TransB == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T;
58   CUBLAS_CHECK(cublasDgemm(Caffeine::cublas_handle(), cuTransA, cuTransB,
59       N, M, K, &alpha, B, ldb, A, lda, &beta, C, N));  
60 }
63 }  // namespace caffeine