1 #include <mkl.h>
2 #include <cublas_v2.h>
3 #include "caffe/common.hpp"
4 #include "caffe/util/math_functions.hpp"
6 namespace caffe {
8 template<>
9 void caffe_cpu_gemm<float>(const CBLAS_TRANSPOSE TransA,
10 const CBLAS_TRANSPOSE TransB, const int M, const int N, const int K,
11 const float alpha, const float* A, const float* B, const float beta,
12 float* C) {
13 int lda = (TransA == CblasNoTrans) ? K : M;
14 int ldb = (TransB == CblasNoTrans) ? N : K;
15 cblas_sgemm(CblasRowMajor, TransA, TransB, M, N, K, alpha, A, lda, B,
16 ldb, beta, C, N);
17 }
19 template<>
20 void caffe_cpu_gemm<double>(const CBLAS_TRANSPOSE TransA,
21 const CBLAS_TRANSPOSE TransB, const int M, const int N, const int K,
22 const double alpha, const double* A, const double* B, const double beta,
23 double* C) {
24 int lda = (TransA == CblasNoTrans) ? K : M;
25 int ldb = (TransB == CblasNoTrans) ? N : K;
26 cblas_dgemm(CblasRowMajor, TransA, TransB, M, N, K, alpha, A, lda, B,
27 ldb, beta, C, N);
28 }
30 template <>
31 void caffe_gpu_gemm<float>(const CBLAS_TRANSPOSE TransA,
32 const CBLAS_TRANSPOSE TransB, const int M, const int N, const int K,
33 const float alpha, const float* A, const float* B, const float beta,
34 float* C) {
35 // Note that cublas follows fortran order.
36 int lda = (TransA == CblasNoTrans) ? K : M;
37 int ldb = (TransB == CblasNoTrans) ? N : K;
38 cublasOperation_t cuTransA =
39 (TransA == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T;
40 cublasOperation_t cuTransB =
41 (TransB == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T;
42 CUBLAS_CHECK(cublasSgemm(Caffe::cublas_handle(), cuTransB, cuTransA,
43 N, M, K, &alpha, B, ldb, A, lda, &beta, C, N));
44 }
46 template <>
47 void caffe_gpu_gemm<double>(const CBLAS_TRANSPOSE TransA,
48 const CBLAS_TRANSPOSE TransB, const int M, const int N, const int K,
49 const double alpha, const double* A, const double* B, const double beta,
50 double* C) {
51 // Note that cublas follows fortran order.
52 int lda = (TransA == CblasNoTrans) ? K : M;
53 int ldb = (TransB == CblasNoTrans) ? N : K;
54 cublasOperation_t cuTransA =
55 (TransA == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T;
56 cublasOperation_t cuTransB =
57 (TransB == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T;
58 CUBLAS_CHECK(cublasDgemm(Caffe::cublas_handle(), cuTransB, cuTransA,
59 N, M, K, &alpha, B, ldb, A, lda, &beta, C, N));
60 }
62 template <>
63 void caffe_cpu_gemv<float>(const CBLAS_TRANSPOSE TransA, const int M,
64 const int N, const float alpha, const float* A, const float* x,
65 const float beta, float* y) {
66 cblas_sgemv(CblasRowMajor, TransA, M, N, alpha, A, N, x, 1, beta, y, 1);
67 }
69 template <>
70 void caffe_cpu_gemv<double>(const CBLAS_TRANSPOSE TransA, const int M,
71 const int N, const double alpha, const double* A, const double* x,
72 const double beta, double* y) {
73 cblas_dgemv(CblasRowMajor, TransA, M, N, alpha, A, N, x, 1, beta, y, 1);
74 }
76 template <>
77 void caffe_gpu_gemv<float>(const CBLAS_TRANSPOSE TransA, const int M,
78 const int N, const float alpha, const float* A, const float* x,
79 const float beta, float* y) {
80 cublasOperation_t cuTransA =
81 (TransA == CblasNoTrans) ? CUBLAS_OP_T : CUBLAS_OP_N;
82 CUBLAS_CHECK(cublasSgemv(Caffe::cublas_handle(), cuTransA, N, M, &alpha,
83 A, N, x, 1, &beta, y, 1));
84 }
86 template <>
87 void caffe_gpu_gemv<double>(const CBLAS_TRANSPOSE TransA, const int M,
88 const int N, const double alpha, const double* A, const double* x,
89 const double beta, double* y) {
90 cublasOperation_t cuTransA =
91 (TransA == CblasNoTrans) ? CUBLAS_OP_T : CUBLAS_OP_N;
92 CUBLAS_CHECK(cublasDgemv(Caffe::cublas_handle(), cuTransA, N, M, &alpha,
93 A, N, x, 1, &beta, y, 1));
94 }
96 template <>
97 void caffe_axpy<float>(const int N, const float alpha, const float* X,
98 float* Y) { cblas_saxpy(N, alpha, X, 1, Y, 1); }
100 template <>
101 void caffe_axpy<double>(const int N, const double alpha, const double* X,
102 double* Y) { cblas_daxpy(N, alpha, X, 1, Y, 1); }
104 template <>
105 void caffe_copy<float>(const int N, const float* X, float* Y) {
106 cblas_scopy(N, X, 1, Y, 1);
107 }
109 template <>
110 void caffe_copy<double>(const int N, const double* X, double* Y) {
111 cblas_dcopy(N, X, 1, Y, 1);
112 }
114 template <>
115 void caffe_sqr<float>(const int n, const float* a, float* y){
116 vsSqr(n, a, y);
117 }
119 template <>
120 void caffe_sqr<double>(const int n, const double* a, double* y) {
121 vdSqr(n, a, y);
122 }
124 template <>
125 void caffe_mul<float>(const int n, const float* a, const float* b,
126 float* y) { vsMul(n, a, b, y); }
128 template <>
129 void caffe_mul<double>(const int n, const double* a, const double* b,
130 double* y) { vdMul(n, a, b, y); }
132 template <>
133 void caffe_div<float>(const int n, const float* a, const float* b,
134 float* y) { vsDiv(n, a, b, y); }
136 template <>
137 void caffe_div<double>(const int n, const double* a, const double* b,
138 double* y) { vdDiv(n, a, b, y); }
140 template <>
141 void caffe_powx<float>(const int n, const float* a, const float b,
142 float* y) { vsPowx(n, a, b, y); }
144 template <>
145 void caffe_powx<double>(const int n, const double* a, const double b,
146 double* y) { vdPowx(n, a, b, y); }
148 } // namespace caffe