]> Gitweb @ Texas Instruments - Open Source Git Repositories - git.TI.com/gitweb - dense-linear-algebra-libraries/linalg.git/blob - examples/dsponly/dgemm_test/dgemm_test.c
DSP-only BILS test suite works on C6678 EVM.
[dense-linear-algebra-libraries/linalg.git] / examples / dsponly / dgemm_test / dgemm_test.c
1 /******************************************************************************
2  * Copyright (c) 2015, Texas Instruments Incorporated - http://www.ti.com
3  *   All rights reserved.
4  *
5  *   Redistribution and use in source and binary forms, with or without
6  *   modification, are permitted provided that the following conditions are met:
7  *       * Redistributions of source code must retain the above copyright
8  *         notice, this list of conditions and the following disclaimer.
9  *       * Redistributions in binary form must reproduce the above copyright
10  *         notice, this list of conditions and the following disclaimer in the
11  *         documentation and/or other materials provided with the distribution.
12  *       * Neither the name of Texas Instruments Incorporated nor the
13  *         names of its contributors may be used to endorse or promote products
14  *         derived from this software without specific prior written permission.
15  *
16  *   THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
17  *   AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
18  *   IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
19  *   ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE
20  *   LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
21  *   CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
22  *   SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
23  *   INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
24  *   CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
25  *   ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF
26  *   THE POSSIBILITY OF SUCH DAMAGE.
27  *****************************************************************************/
28 /******************************************************************************
29 * FILE: dgemm_test.c 
30 ******************************************************************************/
31 #include <omp.h>
32 #include <string.h>
33 #include <stdio.h>
34 #include <libarch.h>
35 #include <ticblas.h>
36 #include <cblas.h>
38 #define FLOPS_PER_UNIT_PERF          1e9
40 extern void cleanup_after_ticblas();
41 extern void prepare_for_ticblas();
42 extern double omp_get_wtime(void);
44 void matrix_gen(double *A, double *B, double *C, int m, int k, int n);
45 void mat_mpy(const double * A, const double * B, double * C, int mat_N, 
46              int mat_K, int mat_M, double alpha, double beta);
47 double dotprod(const double * A, const double * B, int n);
48 void print_matrix(double *mat, int m, int n);
49 double diff_matrix(double *mat1, double * mat2, int m, int n);
51 int main (int argc, char *argv[]) 
52 {
53   double *A, *B, *C, *C_copy;
54   int m, n, k;
55   double alpha, beta, precision_diff, time, time_diff, gflops;
57   int nthreads, tid;
59   /* Verify OpenMP working properly */
60   #pragma omp parallel private(nthreads, tid)
61   {
62     tid = omp_get_thread_num(); /* Obtain thread number */
63     printf("Hello World from thread = %d\n", tid);
65     /* Only master thread does this */
66     if (tid == 0) {
67       nthreads = omp_get_num_threads();
68       printf("Number of threads = %d\n", nthreads);
69     }
71   }  /* All threads join master thread and disband */
73   /* hard code dgemm parameters */
74   m = k = n = 1000;
75   alpha = 0.7; 
76   beta  = 1.3; 
78   /* Allocate memory for matrices */
79   A = (double *)malloc( m*k*sizeof( double ) );
80   B = (double *)malloc( k*n*sizeof( double ) );
81   C = (double *)malloc( m*n*sizeof( double ) );
82   C_copy = (double *)malloc( m*n*sizeof( double ) );
83   if (A == NULL || B == NULL || C == NULL || C_copy == NULL) {
84     printf( "\nERROR: Can't allocate memory for matrices. Aborting... \n\n");
85     free(A);
86     free(B);
87     free(C);
88     return 1;
89   }   
91   /* Initialize random number generator */    
92   srand(123456789);
94   /* Configure memory and initialize TI CBLAS */
95   prepare_for_ticblas();
97   /* Generate matrices */
98   matrix_gen(A, B, C, m, k, n);
99   memcpy(C_copy, C, m*n*sizeof(double));
101   /* Call standard CBLAS API for dgemm */
102   time = omp_get_wtime();
103   cblas_dgemm(CblasRowMajor, CblasNoTrans, CblasNoTrans, m, n, k, alpha, A, k, B, n, beta, C, n);
104   time_diff = omp_get_wtime() - time;
105   gflops = ( 2.0 * m * n * k ) / time_diff / FLOPS_PER_UNIT_PERF;
106   printf("DGEMM time for (m,n,k) = (%d,%d,%d) is %e, GFLOPS is %e.\n", m,n,k, time_diff, gflops);
108   /* Straightforward matrix multiplication as reference */
109   mat_mpy(A, B, C_copy, m, n, k, alpha, beta);
111   /* Find the difference between dgemm and reference */
112   precision_diff = diff_matrix(C, C_copy, m, k);
113   printf("Precision error is %e.\n", precision_diff);
114  
115   /* Finalize TI CBLAS and reconfigure memory */
116   cleanup_after_ticblas();
118   return 0;
121 /*==============================================================================
122  * This function generates matrices of random data
123  *============================================================================*/
124 void matrix_gen(double *A, double *B, double *C, int m, int k, int n)
127     int i;
128     for (i = 0; i < (m*k); i++) {
129         A[i] = (double)rand()/RAND_MAX - 0.5;
130     }
132     for (i = 0; i < (k*n); i++) {
133         B[i] = (double)rand()/RAND_MAX - 0.5;
134     }
136     for (i = 0; i < (m*n); i++) {
137         C[i] = (double)rand()/RAND_MAX - 0.5;
138     }
139     
143 /******************************************************************************
144 * Straightforward implementation of matrix multiplication with row-major
145 ******************************************************************************/
146 void mat_mpy(const double * A, const double * B, double * C, int mat_N, 
147              int mat_K, int mat_M, double alpha, double beta)
149     int col, row;
150     double b_col[mat_K];
152     for (col = 0; col < mat_M; ++col)
153     {
154         for (row = 0; row < mat_K; ++row)
155             b_col[row] = B[row*mat_M+col];
157         for (row = 0; row < mat_N; ++row)
158             C[row*mat_M+col] =  alpha*dotprod(A + (row * mat_K), b_col, mat_K)
159                               + beta*C[row*mat_M+col];
160     }
163 /******************************************************************************
164 * dot product for matrix multiplication
165 ******************************************************************************/
166 double dotprod(const double * A, const double * B, int n)
168     int i;
169     
170     float result = 0;
171     for (i = 0; i < n; ++i) result += A[i] * B[i];
172     
173     return result;
176 /******************************************************************************
177 * Print a row-major matrix
178 ******************************************************************************/
179 void print_matrix(double *mat, int m, int n) 
181    int i, j;
182     
183    for(i=0; i<m; i++) {
184       for(j=0; j<n; j++) {
185          printf( " %10.5f ", mat[i*n+j]);
186       }
187       printf( "\n" );
188   }
191 /******************************************************************************
192 * Find the maximum absolute difference of two matrices
193 ******************************************************************************/
194 double diff_matrix(double *mat1, double * mat2, int m, int n)
196     int i, j;
197     double abs_max_err, err;
198     
199     abs_max_err = 0.0f;
200     for(i=0; i<m; i++) 
201     {
202        for(j=0; j<n; j++) 
203        {
204            err = fabs(mat1[i*n+j] - mat2[i*n+j]);
205            if(abs_max_err < err) {
206                abs_max_err = err;
207            }
208        }
209     }
210     
211     return (abs_max_err);
214 /* Nothing past this point */