1 #include <mkl.h>
2 #include <cublas_v2.h>
3 #include "caffeine/common.hpp"
4 #include "caffeine/util/gemm.hpp"
6 namespace caffeine {
8 template<>
9 void decaf_cpu_gemm<float>(const CBLAS_TRANSPOSE TransA,
10 const CBLAS_TRANSPOSE TransB, const int M, const int N, const int K,
11 const float alpha, const float* A, const float* B, const float beta,
12 float* C) {
13 int lda = (TransA == CblasNoTrans) ? K : M;
14 int ldb = (TransB == CblasNoTrans) ? N : K;
15 cblas_sgemm(CblasRowMajor, TransA, TransB, M, N, K, alpha, A, lda, B,
16 ldb, beta, C, N);
17 }
19 template<>
20 void decaf_cpu_gemm<double>(const CBLAS_TRANSPOSE TransA,
21 const CBLAS_TRANSPOSE TransB, const int M, const int N, const int K,
22 const double alpha, const double* A, const double* B, const double beta,
23 double* C) {
24 int lda = (TransA == CblasNoTrans) ? K : M;
25 int ldb = (TransB == CblasNoTrans) ? N : K;
26 cblas_dgemm(CblasRowMajor, TransA, TransB, M, N, K, alpha, A, lda, B,
27 ldb, beta, C, N);
28 }
30 template <>
31 void decaf_gpu_gemm<float>(const CBLAS_TRANSPOSE TransA,
32 const CBLAS_TRANSPOSE TransB, const int M, const int N, const int K,
33 const float alpha, const float* A, const float* B, const float beta,
34 float* C) {
35 // Note that cublas follows fortran order.
36 int lda = (TransA == CblasNoTrans) ? K : M;
37 int ldb = (TransB == CblasNoTrans) ? N : K;
38 cublasOperation_t cuTransA =
39 (TransA == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T;
40 cublasOperation_t cuTransB =
41 (TransB == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T;
42 CUBLAS_CHECK(cublasSgemm(Caffeine::cublas_handle(), cuTransB, cuTransA,
43 N, M, K, &alpha, B, ldb, A, lda, &beta, C, N));
44 }
46 template <>
47 void decaf_gpu_gemm<double>(const CBLAS_TRANSPOSE TransA,
48 const CBLAS_TRANSPOSE TransB, const int M, const int N, const int K,
49 const double alpha, const double* A, const double* B, const double beta,
50 double* C) {
51 // Note that cublas follows fortran order.
52 int lda = (TransA == CblasNoTrans) ? K : M;
53 int ldb = (TransB == CblasNoTrans) ? N : K;
54 cublasOperation_t cuTransA =
55 (TransA == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T;
56 cublasOperation_t cuTransB =
57 (TransB == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T;
58 CUBLAS_CHECK(cublasDgemm(Caffeine::cublas_handle(), cuTransA, cuTransB,
59 N, M, K, &alpha, B, ldb, A, lda, &beta, C, N));
60 }
63 } // namespace caffeine