]> Gitweb @ Texas Instruments - Open Source Git Repositories - git.TI.com/gitweb - dense-linear-algebra-libraries/linalg.git/blob - blis/testsuite/src/test_gemm.c
33ef0c8192fb6dc47aff4544df5003077883c6d7
[dense-linear-algebra-libraries/linalg.git] / blis / testsuite / src / test_gemm.c
1 /*
3    BLIS    
4    An object-based framework for developing high-performance BLAS-like
5    libraries.
7    Copyright (C) 2014, The University of Texas at Austin
9    Redistribution and use in source and binary forms, with or without
10    modification, are permitted provided that the following conditions are
11    met:
12     - Redistributions of source code must retain the above copyright
13       notice, this list of conditions and the following disclaimer.
14     - Redistributions in binary form must reproduce the above copyright
15       notice, this list of conditions and the following disclaimer in the
16       documentation and/or other materials provided with the distribution.
17     - Neither the name of The University of Texas at Austin nor the names
18       of its contributors may be used to endorse or promote products
19       derived from this software without specific prior written permission.
21    THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
22    "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
23    LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
24    A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
25    HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
26    SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
27    LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
28    DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
29    THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
30    (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
31    OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
33 */
35 #include "blis.h"
36 #include "test_libblis.h"
38 // Static variables.
39 static char*     op_str                    = "gemm";
40 static char*     o_types                   = "mmm"; // a b c
41 static char*     p_types                   = "hh";  // transa transb
42 static thresh_t  thresh[BLIS_NUM_FP_TYPES] = { { 1e-04, 1e-05 },   // warn, pass for s
43                                                { 1e-04, 1e-05 },   // warn, pass for c
44                                                { 1e-13, 1e-14 },   // warn, pass for d
45                                                { 1e-13, 1e-14 } }; // warn, pass for z
47 // Local prototypes.
48 void libblis_test_gemm_deps( test_params_t* params,
49                              test_op_t*     op );
51 void libblis_test_gemm_experiment( test_params_t* params,
52                                    test_op_t*     op,
53                                    iface_t        iface,
54                                    num_t          datatype,
55                                    char*          pc_str,
56                                    char*          sc_str,
57                                    unsigned int   p_cur,
58                                    perf_t*        perf,
59                                    double*        resid );
61 void libblis_test_gemm_impl( iface_t   iface,
62                              obj_t*    alpha,
63                              obj_t*    a,
64                              obj_t*    b,
65                              obj_t*    beta,
66                              obj_t*    c );
68 void libblis_test_gemm_check( obj_t*  alpha,
69                               obj_t*  a,
70                               obj_t*  b,
71                               obj_t*  beta,
72                               obj_t*  c,
73                               obj_t*  c_orig,
74                               double* resid );
78 void libblis_test_gemm_deps( test_params_t* params, test_op_t* op )
79 {
80         libblis_test_randv( params, &(op->ops->randv) );
81         libblis_test_randm( params, &(op->ops->randm) );
82         libblis_test_setv( params, &(op->ops->setv) );
83         libblis_test_normfv( params, &(op->ops->normfv) );
84         libblis_test_subv( params, &(op->ops->subv) );
85         libblis_test_scalv( params, &(op->ops->scalv) );
86         libblis_test_copym( params, &(op->ops->copym) );
87         libblis_test_scalm( params, &(op->ops->scalm) );
88         libblis_test_gemv( params, &(op->ops->gemv) );
89 }
93 void libblis_test_gemm( test_params_t* params, test_op_t* op )
94 {
96         // Return early if this test has already been done.
97         if ( op->test_done == TRUE ) return;
99         // Return early if operation is disabled.
100         if ( op->op_switch == DISABLE_ALL ||
101              op->ops->l3_over == DISABLE_ALL ) return;
103         // Call dependencies first.
104         if ( TRUE ) libblis_test_gemm_deps( params, op );
106         // Execute the test driver for each implementation requested.
107         if ( op->front_seq == ENABLE )
108         {
109                 libblis_test_op_driver( params,
110                                         op,
111                                         BLIS_TEST_SEQ_FRONT_END,
112                                         op_str,
113                                         p_types,
114                                         o_types,
115                                         thresh,
116                                         libblis_test_gemm_experiment );
117         }
122 void libblis_test_gemm_experiment( test_params_t* params,
123                                    test_op_t*     op,
124                                    iface_t        iface,
125                                    num_t          datatype,
126                                    char*          pc_str,
127                                    char*          sc_str,
128                                    unsigned int   p_cur,
129                                    perf_t*        perf,
130                                    double*        resid )
132         unsigned int n_repeats = params->n_repeats;
133         unsigned int i,j;
135         double       time_min  = 1e9;
136         double       time;
138         dim_t        m, n, k;
140         trans_t      transa;
141         trans_t      transb;
143         obj_t        kappa;
144         obj_t        alpha, a, b, beta;
146 #ifdef BLIS_ENABLE_MULTITHREAD_TEST
147         dim_t            test_way = bli_read_nway_from_env( "BLIS_LB_NT" );
148         obj_t        c[test_way];
149         obj_t        c_save[test_way];
150         double       resid_local[test_way];
151 #else
152         obj_t        c, c_save;
153 #endif
156         // Map the dimension specifier to actual dimensions.
157         m = libblis_test_get_dim_from_prob_size( op->dim_spec[0], p_cur );
158         n = libblis_test_get_dim_from_prob_size( op->dim_spec[1], p_cur );
159         k = libblis_test_get_dim_from_prob_size( op->dim_spec[2], p_cur );
161         // Map parameter characters to BLIS constants.
162         bli_param_map_char_to_blis_trans( pc_str[0], &transa );
163         bli_param_map_char_to_blis_trans( pc_str[1], &transb );
165         // Create test scalars.
166         bli_obj_scalar_init_detached( datatype, &kappa );
167         bli_obj_scalar_init_detached( datatype, &alpha );
168         bli_obj_scalar_init_detached( datatype, &beta );
170         // Create test operands (vectors and/or matrices).
171         libblis_test_mobj_create( params, datatype, transa,
172                                   sc_str[0], m, k, &a );
173         //printf("Created object a, buffer address is 0x%x.\n", (unsigned int)bli_obj_buffer(a));
174         
175         libblis_test_mobj_create( params, datatype, transb,
176                                   sc_str[1], k, n, &b );
177         //printf("Created object b, buffer address is 0x%x.\n", (unsigned int)bli_obj_buffer(b));
179 #ifdef BLIS_ENABLE_MULTITHREAD_TEST
180         for(i = 0; i < test_way; i++)
181         {
182                 libblis_test_mobj_create( params, datatype, BLIS_NO_TRANSPOSE,
183                                           sc_str[2], m, n, &c[i] );
184                 libblis_test_mobj_create( params, datatype, BLIS_NO_TRANSPOSE,
185                                           sc_str[2], m, n, &c_save[i] );
186         }
187 #else
188         libblis_test_mobj_create( params, datatype, BLIS_NO_TRANSPOSE,
189                                   sc_str[2], m, n, &c );
190         libblis_test_mobj_create( params, datatype, BLIS_NO_TRANSPOSE,
191                                   sc_str[2], m, n, &c_save );
192 #endif
193         //printf("Created object c, buffer address is 0x%x.\n", (unsigned int)bli_obj_buffer(c[0]));
194         //printf("Created object c_save, buffer address is 0x%x.\n", (unsigned int)bli_obj_buffer(c_save[0]));
196         // Set alpha and beta.
197 #ifdef BLIS_ENABLE_MULTITHREAD_TEST
198         if ( bli_obj_is_real( c[0] ) )
199 #else
200         if ( bli_obj_is_real( c ) )
201 #endif
202         {
203                 bli_setsc(  1.2,  0.0, &alpha );
204                 bli_setsc( -1.0,  0.0, &beta );
205         }
206         else
207         {
208                 bli_setsc(  1.2,  0.8, &alpha );
209                 bli_setsc( -1.0,  1.0, &beta );
210         }
212 //      //dnparikh: setting beta & alpha for sgemm tests
213         //bli_setsc(0.0, 0.0, &beta);
214         //bli_setsc(1.0, 0.0, &alpha);
216         // Randomize A, B, and C, and save C.
217         bli_randm( &a );
218         bli_randm( &b );
220 #ifdef BLIS_ENABLE_MULTITHREAD_TEST
221         for(i = 0; i < test_way; i++)
222         {
223                 bli_randm( &c[i] );
224                 bli_copym( &c[i], &c_save[i] );
225         }
226 #else
227         bli_randm( &c );
228         bli_copym( &c, &c_save );
230 #endif
233         // Normalize by k.
234         bli_setsc( 1.0/( double )k, 0.0, &kappa );
235         bli_scalm( &kappa, &a );
236         bli_scalm( &kappa, &b );
238         // Apply the parameters.
239         bli_obj_set_conjtrans( transa, a );
240         bli_obj_set_conjtrans( transb, b );
242         //bli_printm( "c_save = [", &c_save[0], "%f", "];" );
244         // Repeat the experiment n_repeats times and record results. 
245         for ( i = 0; i < n_repeats; ++i )
246         {
247 #ifdef BLIS_ENABLE_MULTITHREAD_TEST
248                 // Need only one call to initialize the CBLAS OpenCL kernel
249                 bli_copym( &c_save[0], &c[0] );
250                 libblis_test_gemm_impl( iface, &alpha, &a, &b, &beta, &c[0] );
252                 //but need to re-initialize C for each of iteration of n_repeats
253                 for(j = 0; j < test_way; j++)
254                 {
255                         bli_copym( &c_save[j], &c[j] );
256                 }
258 #else
259                 bli_copym( &c_save, &c );
260                 libblis_test_gemm_impl( iface, &alpha, &a, &b, &beta, &c );
261                 bli_copym( &c_save, &c );
262 #endif
264                 time = bli_clock();
266 //              bli_printm( "a = [", &a, "%f", "];" );
267 //              bli_printm( "b = [", &b, "%f", "];" );
268 //              bli_printm( "c = [", &c, "%f", "];" );
270 #ifdef BLIS_ENABLE_MULTITHREAD_TEST
271 #pragma omp parallel num_threads(test_way)
272                 {
273                         #pragma omp for
274                         for(j = 0; j < test_way; j++)
275                         {
276                                 libblis_test_gemm_impl( iface, &alpha, &a, &b, &beta, &c[j] );
277                         }
278                 }
279 #else
280                 libblis_test_gemm_impl( iface, &alpha, &a, &b, &beta, &c );
281 #endif
283 //              bli_printm( "a_after = [", &a, "%f", "];" );
284 //              bli_printm( "b_after = [", &b, "%f", "];" );
285 //              bli_printm( "c_after = [", &c, "%f", "];" );
287                 time_min = bli_clock_min_diff( time_min, time );
288         }
290         // Estimate the performance of the best experiment repeat.
291 #ifdef BLIS_ENABLE_MULTITHREAD_TEST
292         perf->gflops = ( 2.0 * m * n * k ) / time_min * test_way / FLOPS_PER_UNIT_PERF;
293         if ( bli_obj_is_complex( c[0] ) ) perf->gflops *= 4.0;
294 #else
295         perf->gflops = ( 2.0 * m * n * k ) / time_min / FLOPS_PER_UNIT_PERF;
296         if ( bli_obj_is_complex( c ) ) perf->gflops *= 4.0;
297 #endif
298     perf->time = time_min;
300         // Perform checks.
301 #ifdef BLIS_ENABLE_MULTITHREAD_TEST
302         // Check output of each thread, and send max residue to main
303         for(i = 0; i < test_way; i++)
304         {
305                 libblis_test_gemm_check( &alpha, &a, &b, &beta, &c[i], &c_save[i], &resid_local[i] );
306                 libblis_test_check_empty_problem( &c[i], perf, &resid_local[i] );
308                 if(i == 0)
309                 {
310                         *resid = resid_local[i];
311                 }
312                 else if (resid_local[i] > *resid)
313                 {
314                         *resid = resid_local[i];
315                 }
316         }
318 #else
319         libblis_test_gemm_check( &alpha, &a, &b, &beta, &c, &c_save, resid );
321         // Zero out performance and residual if output matrix is empty.
322         libblis_test_check_empty_problem( &c, perf, resid );
323 #endif
325         // Free the test objects.
326         bli_obj_free( &a );
327         bli_obj_free( &b );
328 #ifdef BLIS_ENABLE_MULTITHREAD_TEST
329         for(i = 0; i < test_way; i++)
330         {
331                 bli_obj_free( &c[i] );
332                 bli_obj_free( &c_save[i] );
333         }
334 #else
335         bli_obj_free( &c );
336         bli_obj_free( &c_save );
337 #endif
342 void libblis_test_gemm_impl( iface_t   iface,
343                              obj_t*    alpha,
344                              obj_t*    a,
345                              obj_t*    b,
346                              obj_t*    beta,
347                              obj_t*    c )
349         switch ( iface )
350         {
351                 case BLIS_TEST_SEQ_FRONT_END:
353 #ifdef CBLAS
354        enum CBLAS_ORDER order;
355        enum CBLAS_TRANSPOSE transA, transB;
356        int m, n, k;
357        int lda, ldb, ldc;
359        m     = bli_obj_length( *c );
360        k     = bli_obj_width_after_trans( *a );
361        n     = bli_obj_width( *c );
363 //       printf("m %d k %d n %d \n", m, k, n);
364 //       printf("A %d %d\n", bli_obj_length( *a ), bli_obj_width( *a) );
365 //       printf("B %d %d\n", bli_obj_length( *b ), bli_obj_width( *b) );
366 //       printf("C %d %d\n", bli_obj_length( *c ), bli_obj_width( *c) );
367 //       printf("A %d %d\n", bli_obj_row_stride( *a ), bli_obj_col_stride( *a) );
368 //       printf("B %d %d\n", bli_obj_row_stride( *b ), bli_obj_col_stride( *b) );
369 //       printf("C %d %d\n", bli_obj_row_stride( *c ), bli_obj_col_stride( *c) );
371        if(bli_obj_is_row_stored( *c ))
372        {
373            order = CblasRowMajor;
374            lda = bli_obj_row_stride(*a);
375            ldb = bli_obj_row_stride(*b);
376            ldc = bli_obj_row_stride(*c);
377        }
378        else if(bli_obj_is_col_stored( *c ))
379        {
380            order = CblasColMajor;
381            lda = bli_obj_col_stride(*a);
382            ldb = bli_obj_col_stride(*b);
383            ldc = bli_obj_col_stride(*c);
384        }
385        else
386        {
387            bli_check_error_code( BLIS_INVALID_DIM_STRIDE_COMBINATION );
388            return;
389        }
391        if(bli_obj_has_notrans( *a ) && bli_obj_has_noconj( *a ))
392        {
393            transA = CblasNoTrans;
394        }
395        else if(bli_obj_has_trans( *a ) && bli_obj_has_noconj( *a ))
396        {
397            transA = CblasTrans;
398        }
399        else if(bli_obj_has_trans( *a ) && bli_obj_has_conj( *a ))
400        {
401            transA = CblasConjTrans;
402        }
403        else
404        {
405            bli_check_error_code( BLIS_INVALID_TRANS );
406            return;
407        }
409        if(bli_obj_has_notrans( *b ) && bli_obj_has_noconj( *b ))
410        {
411            transB = CblasNoTrans;
412        }
413        else if(bli_obj_has_trans( *b ) && bli_obj_has_noconj( *b ))
414        {
415            transB = CblasTrans;
416        }
417        else if(bli_obj_has_trans( *b ) && bli_obj_has_conj( *b ))
418        {
419            transB = CblasConjTrans;
420        }
421        else
422        {
423            bli_check_error_code( BLIS_INVALID_TRANS );
424            return;
425        }
427        if (bli_obj_is_float( *a ))
428        {
429            float *cblas_alpha, *cblas_beta;
430            float *cblas_a, *cblas_b, *cblas_c;
432            //obj_t *c_save;
435            cblas_alpha = (float *) bli_obj_buffer( *alpha );
436            cblas_beta  = (float *) bli_obj_buffer( *beta );
437            cblas_a     = (float *) bli_obj_buffer( *a );
438            cblas_b     = (float *) bli_obj_buffer( *b );
439            cblas_c     = (float *) bli_obj_buffer( *c );
441            //printf("test_gemm %d %d %d %d %d, 0x%x, 0x%x, 0x%x\n", order, transA, transB, lda, ldb, (unsigned int)cblas_a,(unsigned int)cblas_b,(unsigned int)cblas_c);
442            cblas_sgemm(order, transA, transB, m, n, k, *cblas_alpha, cblas_a, lda, cblas_b, ldb, *cblas_beta, cblas_c, ldc);
444        }
445        else if (bli_obj_is_double( *a ))
446        {
447            double *cblas_alpha, *cblas_beta;
448            double *cblas_a, *cblas_b, *cblas_c;
450            cblas_alpha = (double *) bli_obj_buffer( *alpha );
451            cblas_beta  = (double *) bli_obj_buffer( *beta );
452            cblas_a     = (double *) bli_obj_buffer( *a );
453            cblas_b     = (double *) bli_obj_buffer( *b );
454            cblas_c     = (double *) bli_obj_buffer( *c );
456            //printf("test_gemm %d %d %d %d %d, 0x%x, 0x%x, 0x%x\n", order, transA, transB, lda, ldb, (unsigned int)cblas_a,(unsigned int)cblas_b,(unsigned int)cblas_c);
457            cblas_dgemm(order, transA, transB, m, n, k, *cblas_alpha, cblas_a, lda, cblas_b, ldb, *cblas_beta, cblas_c, ldc);
459        }
460        else if (bli_obj_is_scomplex( *a ))
461        {
462            void *cblas_alpha, *cblas_beta;
463            void *cblas_a, *cblas_b, *cblas_c;
465            cblas_alpha = bli_obj_buffer( *alpha );
466            cblas_beta  = bli_obj_buffer( *beta );
467            cblas_a     = bli_obj_buffer( *a );
468            cblas_b     = bli_obj_buffer( *b );
469            cblas_c     = bli_obj_buffer( *c );
471            cblas_cgemm(order, transA, transB, m, n, k, cblas_alpha, cblas_a, lda, cblas_b, ldb, cblas_beta, cblas_c, ldc);
472        }
473        else if (bli_obj_is_dcomplex( *a ))
474        {
475            void *cblas_alpha, *cblas_beta;
476            void *cblas_a, *cblas_b, *cblas_c;
478            cblas_alpha = bli_obj_buffer( *alpha );
479            cblas_beta  = bli_obj_buffer( *beta );
480            cblas_a     = bli_obj_buffer( *a );
481            cblas_b     = bli_obj_buffer( *b );
482            cblas_c     = bli_obj_buffer( *c );
484            cblas_zgemm(order, transA, transB, m, n, k, cblas_alpha, cblas_a, lda, cblas_b, ldb, cblas_beta, cblas_c, ldc);
485        }
486 #else
487                 bli_gemm( alpha, a, b, beta, c );
488                 //bli_gemm4m( alpha, a, b, beta, c );
489                 //bli_gemm3m( alpha, a, b, beta, c );
490 #endif
491                 break;
493                 default:
494                 libblis_test_printf_error( "Invalid interface type.\n" );
495         }
500 void libblis_test_gemm_check( obj_t*  alpha,
501                               obj_t*  a,
502                               obj_t*  b,
503                               obj_t*  beta,
504                               obj_t*  c,
505                               obj_t*  c_orig,
506                               double* resid )
508         num_t  dt      = bli_obj_datatype( *c );
509         num_t  dt_real = bli_obj_datatype_proj_to_real( *c );
511         dim_t  m       = bli_obj_length( *c );
512         dim_t  n       = bli_obj_width( *c );
513         dim_t  k       = bli_obj_width_after_trans( *a );
515         obj_t  kappa, norm;
516         obj_t  t, v, w, z;
518         double junk;
520         //
521         // Pre-conditions:
522         // - a is randomized.
523         // - b is randomized.
524         // - c_orig is randomized.
525         // Note:
526         // - alpha and beta should have non-zero imaginary components in the
527         //   complex cases in order to more fully exercise the implementation.
528         //
529         // Under these conditions, we assume that the implementation for
530         //
531         //   C := beta * C_orig + alpha * transa(A) * transb(B)
532         //
533         // is functioning correctly if
534         //
535         //   normf( v - z )
536         //
537         // is negligible, where
538         //
539         //   v = C * t
540         //   z = ( beta * C_orig + alpha * transa(A) * transb(B) ) * t
541         //     = beta * C_orig * t + alpha * transa(A) * transb(B) * t
542         //     = beta * C_orig * t + alpha * transa(A) * w
543         //     = beta * C_orig * t + z
544         //
546         bli_obj_scalar_init_detached( dt,      &kappa );
547         bli_obj_scalar_init_detached( dt_real, &norm );
549         bli_obj_create( dt, n, 1, 0, 0, &t );
550         bli_obj_create( dt, m, 1, 0, 0, &v );
551         bli_obj_create( dt, k, 1, 0, 0, &w );
552         bli_obj_create( dt, m, 1, 0, 0, &z );
554         bli_randv( &t );
555         bli_setsc( 1.0/( double )n, 0.0, &kappa );
556         bli_scalv( &kappa, &t );
558         bli_gemv( &BLIS_ONE, c, &t, &BLIS_ZERO, &v );
560         bli_gemv( &BLIS_ONE, b, &t, &BLIS_ZERO, &w );
561         bli_gemv( alpha, a, &w, &BLIS_ZERO, &z );
562         bli_gemv( beta, c_orig, &t, &BLIS_ONE, &z );
564         bli_subv( &z, &v );
565         bli_normfv( &v, &norm );
566         bli_getsc( &norm, resid, &junk );
568         bli_obj_free( &t );
569         bli_obj_free( &v );
570         bli_obj_free( &w );
571         bli_obj_free( &z );