1 #include <mkl.h>
2 #include <cublas_v2.h>
3 #include "caffeine/common.hpp"
4 #include "caffeine/util/blas.hpp"
6 namespace caffeine {
8 template<>
9 void caffeine_cpu_gemm<float>(const CBLAS_TRANSPOSE TransA,
10 const CBLAS_TRANSPOSE TransB, const int M, const int N, const int K,
11 const float alpha, const float* A, const float* B, const float beta,
12 float* C) {
13 int lda = (TransA == CblasNoTrans) ? K : M;
14 int ldb = (TransB == CblasNoTrans) ? N : K;
15 cblas_sgemm(CblasRowMajor, TransA, TransB, M, N, K, alpha, A, lda, B,
16 ldb, beta, C, N);
17 }
19 template<>
20 void caffeine_cpu_gemm<double>(const CBLAS_TRANSPOSE TransA,
21 const CBLAS_TRANSPOSE TransB, const int M, const int N, const int K,
22 const double alpha, const double* A, const double* B, const double beta,
23 double* C) {
24 int lda = (TransA == CblasNoTrans) ? K : M;
25 int ldb = (TransB == CblasNoTrans) ? N : K;
26 cblas_dgemm(CblasRowMajor, TransA, TransB, M, N, K, alpha, A, lda, B,
27 ldb, beta, C, N);
28 }
30 template <>
31 void caffeine_gpu_gemm<float>(const CBLAS_TRANSPOSE TransA,
32 const CBLAS_TRANSPOSE TransB, const int M, const int N, const int K,
33 const float alpha, const float* A, const float* B, const float beta,
34 float* C) {
35 // Note that cublas follows fortran order.
36 int lda = (TransA == CblasNoTrans) ? K : M;
37 int ldb = (TransB == CblasNoTrans) ? N : K;
38 cublasOperation_t cuTransA =
39 (TransA == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T;
40 cublasOperation_t cuTransB =
41 (TransB == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T;
42 CUBLAS_CHECK(cublasSgemm(Caffeine::cublas_handle(), cuTransB, cuTransA,
43 N, M, K, &alpha, B, ldb, A, lda, &beta, C, N));
44 }
46 template <>
47 void caffeine_gpu_gemm<double>(const CBLAS_TRANSPOSE TransA,
48 const CBLAS_TRANSPOSE TransB, const int M, const int N, const int K,
49 const double alpha, const double* A, const double* B, const double beta,
50 double* C) {
51 // Note that cublas follows fortran order.
52 int lda = (TransA == CblasNoTrans) ? K : M;
53 int ldb = (TransB == CblasNoTrans) ? N : K;
54 cublasOperation_t cuTransA =
55 (TransA == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T;
56 cublasOperation_t cuTransB =
57 (TransB == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T;
58 CUBLAS_CHECK(cublasDgemm(Caffeine::cublas_handle(), cuTransB, cuTransA,
59 N, M, K, &alpha, B, ldb, A, lda, &beta, C, N));
60 }
62 template <>
63 void caffeine_cpu_gemv<float>(const CBLAS_TRANSPOSE TransA, const int M,
64 const int N, const float alpha, const float* A, const float* x,
65 const float beta, float* y) {
66 cblas_sgemv(CblasRowMajor, TransA, M, N, alpha, A, N, x, 1, beta, y, 1);
67 }
69 template <>
70 void caffeine_cpu_gemv<double>(const CBLAS_TRANSPOSE TransA, const int M,
71 const int N, const double alpha, const double* A, const double* x,
72 const double beta, double* y) {
73 cblas_dgemv(CblasRowMajor, TransA, M, N, alpha, A, N, x, 1, beta, y, 1);
74 }
76 template <>
77 void caffeine_gpu_gemv<float>(const CBLAS_TRANSPOSE TransA, const int M,
78 const int N, const float alpha, const float* A, const float* x,
79 const float beta, float* y) {
80 cublasOperation_t cuTransA =
81 (TransA == CblasNoTrans) ? CUBLAS_OP_T : CUBLAS_OP_N;
82 CUBLAS_CHECK(cublasSgemv(Caffeine::cublas_handle(), cuTransA, N, M, &alpha,
83 A, N, x, 1, &beta, y, 1));
84 }
86 template <>
87 void caffeine_gpu_gemv<double>(const CBLAS_TRANSPOSE TransA, const int M,
88 const int N, const double alpha, const double* A, const double* x,
89 const double beta, double* y) {
90 cublasOperation_t cuTransA =
91 (TransA == CblasNoTrans) ? CUBLAS_OP_T : CUBLAS_OP_N;
92 CUBLAS_CHECK(cublasDgemv(Caffeine::cublas_handle(), cuTransA, N, M, &alpha,
93 A, N, x, 1, &beta, y, 1));
94 }
96 } // namespace caffeine