]> Gitweb @ Texas Instruments - Open Source Git Repositories - git.TI.com/gitweb - processor-sdk/kaldi.git/blobdiff - src/cudamatrix/cu-matrix.cc
[src] Fix CU_SAFE_CALL wrapper so it correctly prints CuBLAS error codes (#1900)
[processor-sdk/kaldi.git] / src / cudamatrix / cu-matrix.cc
index 30a79b5c93c1b7e3a66471a9cc8d9a0ed967dacf..5860c1938ce5444eb86ddd61494cf4b9472a3d6a 100644 (file)
@@ -1161,7 +1161,7 @@ void CuMatrixBase<Real>::AddMatMat(
 #if HAVE_CUDA == 1
   if (CuDevice::Instantiate().Enabled()) {
     CuTimer tim;
-    CU_SAFE_CALL(cublas_gemm(GetCublasHandle(),
+    CUBLAS_SAFE_CALL(cublas_gemm(GetCublasHandle(),
                              (transB==kTrans? CUBLAS_OP_T:CUBLAS_OP_N),
                              (transA==kTrans? CUBLAS_OP_T:CUBLAS_OP_N),
                              m, n, k, alpha, B.data_, B.Stride(),
@@ -1188,8 +1188,8 @@ void CuMatrixBase<Real>::AddVecVec(
 #if HAVE_CUDA == 1
   if (CuDevice::Instantiate().Enabled()) {
     CuTimer tim;
-    CU_SAFE_CALL(cublas_ger(GetCublasHandle(), m, n, alpha,
-                 y.Data(), 1, x.Data(), 1, data_, Stride()));
+    CUBLAS_SAFE_CALL(cublas_ger(GetCublasHandle(), m, n, alpha,
+                     y.Data(), 1, x.Data(), 1, data_, Stride()));
 
     CuDevice::Instantiate().AccuProfile(__func__, tim);
   } else
@@ -1215,9 +1215,10 @@ void CuMatrixBase<Real>::SymAddMat2(
     CuTimer tim;
     cublasOperation_t trans = (transA == kTrans ? CUBLAS_OP_N : CUBLAS_OP_T);
     MatrixIndexT A_other_dim = (transA == kNoTrans ? A.num_cols_ : A.num_rows_);
-    CU_SAFE_CALL(cublas_syrk(GetCublasHandle(), CUBLAS_FILL_MODE_UPPER, trans,
-                             num_rows_, A_other_dim, alpha, A.Data(), A.Stride(),
-                             beta, this->data_, this->stride_));
+    CUBLAS_SAFE_CALL(cublas_syrk(GetCublasHandle(), CUBLAS_FILL_MODE_UPPER,
+                                 trans, num_rows_, A_other_dim, 
+                                 alpha, A.Data(), A.Stride(),
+                                 beta, this->data_, this->stride_));
 
     CuDevice::Instantiate().AccuProfile(__func__, tim);
   } else
@@ -2106,13 +2107,13 @@ void AddMatMatBatched(const Real alpha, std::vector<CuSubMatrix<Real>* > &C,
 
     CU_SAFE_CALL(cudaMemcpy(device_abc_array, host_abc_array, 3*size*sizeof(Real*), cudaMemcpyHostToDevice));
 
-    CU_SAFE_CALL(cublas_gemmBatched(GetCublasHandle(),
-                                    (transB==kTrans? CUBLAS_OP_T:CUBLAS_OP_N),
-                                    (transA==kTrans? CUBLAS_OP_T:CUBLAS_OP_N),
-                                    m, n, k, alpha, device_b_array,
-                                    B[0]->Stride(), device_a_array,
-                                    A[0]->Stride(), beta, device_c_array,
-                                    C[0]->Stride(), size));
+    CUBLAS_SAFE_CALL(cublas_gemmBatched(GetCublasHandle(),
+                                        (transB==kTrans? CUBLAS_OP_T:CUBLAS_OP_N),
+                                        (transA==kTrans? CUBLAS_OP_T:CUBLAS_OP_N),
+                                        m, n, k, alpha, device_b_array,
+                                        B[0]->Stride(), device_a_array,
+                                        A[0]->Stride(), beta, device_c_array,
+                                        C[0]->Stride(), size));
 
     CuDevice::Instantiate().Free(device_abc_array);
     delete[] host_abc_array;