1949a703d75a9e71c91182dff0a7c6ae24689ef8
1 // Copyright 2013 Yangqing Jia
3 #include <mkl.h>
4 #include <cublas_v2.h>
5 #include "caffe/common.hpp"
6 #include "caffe/util/math_functions.hpp"
8 namespace caffe {
10 template<>
11 void caffe_cpu_gemm<float>(const CBLAS_TRANSPOSE TransA,
12 const CBLAS_TRANSPOSE TransB, const int M, const int N, const int K,
13 const float alpha, const float* A, const float* B, const float beta,
14 float* C) {
15 int lda = (TransA == CblasNoTrans) ? K : M;
16 int ldb = (TransB == CblasNoTrans) ? N : K;
17 cblas_sgemm(CblasRowMajor, TransA, TransB, M, N, K, alpha, A, lda, B,
18 ldb, beta, C, N);
19 }
21 template<>
22 void caffe_cpu_gemm<double>(const CBLAS_TRANSPOSE TransA,
23 const CBLAS_TRANSPOSE TransB, const int M, const int N, const int K,
24 const double alpha, const double* A, const double* B, const double beta,
25 double* C) {
26 int lda = (TransA == CblasNoTrans) ? K : M;
27 int ldb = (TransB == CblasNoTrans) ? N : K;
28 cblas_dgemm(CblasRowMajor, TransA, TransB, M, N, K, alpha, A, lda, B,
29 ldb, beta, C, N);
30 }
32 template <>
33 void caffe_gpu_gemm<float>(const CBLAS_TRANSPOSE TransA,
34 const CBLAS_TRANSPOSE TransB, const int M, const int N, const int K,
35 const float alpha, const float* A, const float* B, const float beta,
36 float* C) {
37 // Note that cublas follows fortran order.
38 int lda = (TransA == CblasNoTrans) ? K : M;
39 int ldb = (TransB == CblasNoTrans) ? N : K;
40 cublasOperation_t cuTransA =
41 (TransA == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T;
42 cublasOperation_t cuTransB =
43 (TransB == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T;
44 CUBLAS_CHECK(cublasSgemm(Caffe::cublas_handle(), cuTransB, cuTransA,
45 N, M, K, &alpha, B, ldb, A, lda, &beta, C, N));
46 }
48 template <>
49 void caffe_gpu_gemm<double>(const CBLAS_TRANSPOSE TransA,
50 const CBLAS_TRANSPOSE TransB, const int M, const int N, const int K,
51 const double alpha, const double* A, const double* B, const double beta,
52 double* C) {
53 // Note that cublas follows fortran order.
54 int lda = (TransA == CblasNoTrans) ? K : M;
55 int ldb = (TransB == CblasNoTrans) ? N : K;
56 cublasOperation_t cuTransA =
57 (TransA == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T;
58 cublasOperation_t cuTransB =
59 (TransB == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T;
60 CUBLAS_CHECK(cublasDgemm(Caffe::cublas_handle(), cuTransB, cuTransA,
61 N, M, K, &alpha, B, ldb, A, lda, &beta, C, N));
62 }
64 template <>
65 void caffe_cpu_gemv<float>(const CBLAS_TRANSPOSE TransA, const int M,
66 const int N, const float alpha, const float* A, const float* x,
67 const float beta, float* y) {
68 cblas_sgemv(CblasRowMajor, TransA, M, N, alpha, A, N, x, 1, beta, y, 1);
69 }
71 template <>
72 void caffe_cpu_gemv<double>(const CBLAS_TRANSPOSE TransA, const int M,
73 const int N, const double alpha, const double* A, const double* x,
74 const double beta, double* y) {
75 cblas_dgemv(CblasRowMajor, TransA, M, N, alpha, A, N, x, 1, beta, y, 1);
76 }
78 template <>
79 void caffe_gpu_gemv<float>(const CBLAS_TRANSPOSE TransA, const int M,
80 const int N, const float alpha, const float* A, const float* x,
81 const float beta, float* y) {
82 cublasOperation_t cuTransA =
83 (TransA == CblasNoTrans) ? CUBLAS_OP_T : CUBLAS_OP_N;
84 CUBLAS_CHECK(cublasSgemv(Caffe::cublas_handle(), cuTransA, N, M, &alpha,
85 A, N, x, 1, &beta, y, 1));
86 }
88 template <>
89 void caffe_gpu_gemv<double>(const CBLAS_TRANSPOSE TransA, const int M,
90 const int N, const double alpha, const double* A, const double* x,
91 const double beta, double* y) {
92 cublasOperation_t cuTransA =
93 (TransA == CblasNoTrans) ? CUBLAS_OP_T : CUBLAS_OP_N;
94 CUBLAS_CHECK(cublasDgemv(Caffe::cublas_handle(), cuTransA, N, M, &alpha,
95 A, N, x, 1, &beta, y, 1));
96 }
98 template <>
99 void caffe_axpy<float>(const int N, const float alpha, const float* X,
100 float* Y) { cblas_saxpy(N, alpha, X, 1, Y, 1); }
102 template <>
103 void caffe_axpy<double>(const int N, const double alpha, const double* X,
104 double* Y) { cblas_daxpy(N, alpha, X, 1, Y, 1); }
106 template <>
107 void caffe_copy<float>(const int N, const float* X, float* Y) {
108 cblas_scopy(N, X, 1, Y, 1);
109 }
111 template <>
112 void caffe_copy<double>(const int N, const double* X, double* Y) {
113 cblas_dcopy(N, X, 1, Y, 1);
114 }
116 template <>
117 void caffe_scal<float>(const int N, const float alpha, float *X) {
118 cblas_sscal(N, alpha, X, 1);
119 }
121 template <>
122 void caffe_scal<double>(const int N, const double alpha, double *X) {
123 cblas_dscal(N, alpha, X, 1);
124 }
126 template <>
127 void caffe_sqr<float>(const int n, const float* a, float* y) {
128 vsSqr(n, a, y);
129 }
131 template <>
132 void caffe_sqr<double>(const int n, const double* a, double* y) {
133 vdSqr(n, a, y);
134 }
136 template <>
137 void caffe_mul<float>(const int n, const float* a, const float* b,
138 float* y) { vsMul(n, a, b, y); }
140 template <>
141 void caffe_mul<double>(const int n, const double* a, const double* b,
142 double* y) { vdMul(n, a, b, y); }
144 template <>
145 void caffe_div<float>(const int n, const float* a, const float* b,
146 float* y) { vsDiv(n, a, b, y); }
148 template <>
149 void caffe_div<double>(const int n, const double* a, const double* b,
150 double* y) { vdDiv(n, a, b, y); }
152 template <>
153 void caffe_powx<float>(const int n, const float* a, const float b,
154 float* y) { vsPowx(n, a, b, y); }
156 template <>
157 void caffe_powx<double>(const int n, const double* a, const double b,
158 double* y) { vdPowx(n, a, b, y); }
160 template <>
161 void caffe_vRngUniform<float>(const int n, float* r,
162 const float a, const float b) {
163 VSL_CHECK(vsRngUniform(VSL_RNG_METHOD_UNIFORM_STD, Caffe::vsl_stream(),
164 n, r, a, b));
165 }
167 template <>
168 void caffe_vRngUniform<double>(const int n, double* r,
169 const double a, const double b) {
170 VSL_CHECK(vdRngUniform(VSL_RNG_METHOD_UNIFORM_STD, Caffe::vsl_stream(),
171 n, r, a, b));
172 }
174 template <>
175 void caffe_vRngGaussian<float>(const int n, float* r, const float a,
176 const float sigma) {
177 VSL_CHECK(vsRngGaussian(VSL_RNG_METHOD_GAUSSIAN_BOXMULLER,
178 Caffe::vsl_stream(), n, r, a, sigma));
179 }
182 template <>
183 void caffe_vRngGaussian<double>(const int n, double* r, const double a,
184 const double sigma) {
185 VSL_CHECK(vdRngGaussian(VSL_RNG_METHOD_GAUSSIAN_BOXMULLER,
186 Caffe::vsl_stream(), n, r, a, sigma));
187 }
189 template <>
190 void caffe_exp<float>(const int n, const float* a, float* y) {
191 vsExp(n, a, y);
192 }
194 template <>
195 void caffe_exp<double>(const int n, const double* a, double* y) {
196 vdExp(n, a, y);
197 }
199 template <>
200 float caffe_cpu_dot<float>(const int n, const float* x, const float* y) {
201 return cblas_sdot(n, x, 1, y, 1);
202 }
204 template <>
205 double caffe_cpu_dot<double>(const int n, const double* x, const double* y) {
206 return cblas_ddot(n, x, 1, y, 1);
207 }
209 } // namespace caffe