1 /*
3 BLIS
4 An object-based framework for developing high-performance BLAS-like
5 libraries.
7 Copyright (C) 2014, The University of Texas at Austin
9 Redistribution and use in source and binary forms, with or without
10 modification, are permitted provided that the following conditions are
11 met:
12 - Redistributions of source code must retain the above copyright
13 notice, this list of conditions and the following disclaimer.
14 - Redistributions in binary form must reproduce the above copyright
15 notice, this list of conditions and the following disclaimer in the
16 documentation and/or other materials provided with the distribution.
17 - Neither the name of The University of Texas at Austin nor the names
18 of its contributors may be used to endorse or promote products
19 derived from this software without specific prior written permission.
21 THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
22 "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
23 LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
24 A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
25 HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
26 SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
27 LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
28 DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
29 THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
30 (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
31 OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
33 */
35 #include "blis.h"
36 #include "test_libblis.h"
38 // Static variables.
39 static char* op_str = "gemm";
40 static char* o_types = "mmm"; // a b c
41 static char* p_types = "hh"; // transa transb
42 static thresh_t thresh[BLIS_NUM_FP_TYPES] = { { 1e-04, 1e-05 }, // warn, pass for s
43 { 1e-04, 1e-05 }, // warn, pass for c
44 { 1e-13, 1e-14 }, // warn, pass for d
45 { 1e-13, 1e-14 } }; // warn, pass for z
47 // Local prototypes.
48 void libblis_test_gemm_deps( test_params_t* params,
49 test_op_t* op );
51 void libblis_test_gemm_experiment( test_params_t* params,
52 test_op_t* op,
53 iface_t iface,
54 num_t datatype,
55 char* pc_str,
56 char* sc_str,
57 unsigned int p_cur,
58 perf_t* perf,
59 double* resid );
61 void libblis_test_gemm_impl( iface_t iface,
62 obj_t* alpha,
63 obj_t* a,
64 obj_t* b,
65 obj_t* beta,
66 obj_t* c );
68 void libblis_test_gemm_check( obj_t* alpha,
69 obj_t* a,
70 obj_t* b,
71 obj_t* beta,
72 obj_t* c,
73 obj_t* c_orig,
74 double* resid );
78 void libblis_test_gemm_deps( test_params_t* params, test_op_t* op )
79 {
80 libblis_test_randv( params, &(op->ops->randv) );
81 libblis_test_randm( params, &(op->ops->randm) );
82 libblis_test_setv( params, &(op->ops->setv) );
83 libblis_test_normfv( params, &(op->ops->normfv) );
84 libblis_test_subv( params, &(op->ops->subv) );
85 libblis_test_scalv( params, &(op->ops->scalv) );
86 libblis_test_copym( params, &(op->ops->copym) );
87 libblis_test_scalm( params, &(op->ops->scalm) );
88 libblis_test_gemv( params, &(op->ops->gemv) );
89 }
93 void libblis_test_gemm( test_params_t* params, test_op_t* op )
94 {
96 // Return early if this test has already been done.
97 if ( op->test_done == TRUE ) return;
99 // Return early if operation is disabled.
100 if ( op->op_switch == DISABLE_ALL ||
101 op->ops->l3_over == DISABLE_ALL ) return;
103 // Call dependencies first.
104 if ( TRUE ) libblis_test_gemm_deps( params, op );
106 // Execute the test driver for each implementation requested.
107 if ( op->front_seq == ENABLE )
108 {
109 libblis_test_op_driver( params,
110 op,
111 BLIS_TEST_SEQ_FRONT_END,
112 op_str,
113 p_types,
114 o_types,
115 thresh,
116 libblis_test_gemm_experiment );
117 }
118 }
122 void libblis_test_gemm_experiment( test_params_t* params,
123 test_op_t* op,
124 iface_t iface,
125 num_t datatype,
126 char* pc_str,
127 char* sc_str,
128 unsigned int p_cur,
129 perf_t* perf,
130 double* resid )
131 {
132 unsigned int n_repeats = params->n_repeats;
133 unsigned int i,j;
135 double time_min = 1e9;
136 double time;
138 dim_t m, n, k;
140 trans_t transa;
141 trans_t transb;
143 obj_t kappa;
144 obj_t alpha, a, b, beta;
146 #ifdef BLIS_ENABLE_MULTITHREAD_TEST
147 dim_t test_way = bli_read_nway_from_env( "BLIS_LB_NT" );
148 obj_t c[test_way];
149 obj_t c_save[test_way];
150 double resid_local[test_way];
151 #else
152 obj_t c, c_save;
153 #endif
156 // Map the dimension specifier to actual dimensions.
157 m = libblis_test_get_dim_from_prob_size( op->dim_spec[0], p_cur );
158 n = libblis_test_get_dim_from_prob_size( op->dim_spec[1], p_cur );
159 k = libblis_test_get_dim_from_prob_size( op->dim_spec[2], p_cur );
161 // Map parameter characters to BLIS constants.
162 bli_param_map_char_to_blis_trans( pc_str[0], &transa );
163 bli_param_map_char_to_blis_trans( pc_str[1], &transb );
165 // Create test scalars.
166 bli_obj_scalar_init_detached( datatype, &kappa );
167 bli_obj_scalar_init_detached( datatype, &alpha );
168 bli_obj_scalar_init_detached( datatype, &beta );
170 // Create test operands (vectors and/or matrices).
171 libblis_test_mobj_create( params, datatype, transa,
172 sc_str[0], m, k, &a );
173 //printf("Created object a, buffer address is 0x%x.\n", (unsigned int)bli_obj_buffer(a));
175 libblis_test_mobj_create( params, datatype, transb,
176 sc_str[1], k, n, &b );
177 //printf("Created object b, buffer address is 0x%x.\n", (unsigned int)bli_obj_buffer(b));
179 #ifdef BLIS_ENABLE_MULTITHREAD_TEST
180 for(i = 0; i < test_way; i++)
181 {
182 libblis_test_mobj_create( params, datatype, BLIS_NO_TRANSPOSE,
183 sc_str[2], m, n, &c[i] );
184 libblis_test_mobj_create( params, datatype, BLIS_NO_TRANSPOSE,
185 sc_str[2], m, n, &c_save[i] );
186 }
187 #else
188 libblis_test_mobj_create( params, datatype, BLIS_NO_TRANSPOSE,
189 sc_str[2], m, n, &c );
190 libblis_test_mobj_create( params, datatype, BLIS_NO_TRANSPOSE,
191 sc_str[2], m, n, &c_save );
192 #endif
193 //printf("Created object c, buffer address is 0x%x.\n", (unsigned int)bli_obj_buffer(c[0]));
194 //printf("Created object c_save, buffer address is 0x%x.\n", (unsigned int)bli_obj_buffer(c_save[0]));
196 // Set alpha and beta.
197 #ifdef BLIS_ENABLE_MULTITHREAD_TEST
198 if ( bli_obj_is_real( c[0] ) )
199 #else
200 if ( bli_obj_is_real( c ) )
201 #endif
202 {
203 bli_setsc( 1.2, 0.0, &alpha );
204 bli_setsc( -1.0, 0.0, &beta );
205 }
206 else
207 {
208 bli_setsc( 1.2, 0.8, &alpha );
209 bli_setsc( -1.0, 1.0, &beta );
210 }
212 // //dnparikh: setting beta & alpha for sgemm tests
213 //bli_setsc(0.0, 0.0, &beta);
214 //bli_setsc(1.0, 0.0, &alpha);
216 // Randomize A, B, and C, and save C.
217 bli_randm( &a );
218 bli_randm( &b );
220 #ifdef BLIS_ENABLE_MULTITHREAD_TEST
221 for(i = 0; i < test_way; i++)
222 {
223 bli_randm( &c[i] );
224 bli_copym( &c[i], &c_save[i] );
225 }
226 #else
227 bli_randm( &c );
228 bli_copym( &c, &c_save );
230 #endif
233 // Normalize by k.
234 bli_setsc( 1.0/( double )k, 0.0, &kappa );
235 bli_scalm( &kappa, &a );
236 bli_scalm( &kappa, &b );
238 // Apply the parameters.
239 bli_obj_set_conjtrans( transa, a );
240 bli_obj_set_conjtrans( transb, b );
242 //bli_printm( "c_save = [", &c_save[0], "%f", "];" );
244 // Repeat the experiment n_repeats times and record results.
245 for ( i = 0; i < n_repeats; ++i )
246 {
247 #ifdef BLIS_ENABLE_MULTITHREAD_TEST
248 // Need only one call to initialize the CBLAS OpenCL kernel
249 bli_copym( &c_save[0], &c[0] );
250 libblis_test_gemm_impl( iface, &alpha, &a, &b, &beta, &c[0] );
252 //but need to re-initialize C for each of iteration of n_repeats
253 for(j = 0; j < test_way; j++)
254 {
255 bli_copym( &c_save[j], &c[j] );
256 }
258 #else
259 bli_copym( &c_save, &c );
260 libblis_test_gemm_impl( iface, &alpha, &a, &b, &beta, &c );
261 bli_copym( &c_save, &c );
262 #endif
264 time = bli_clock();
266 // bli_printm( "a = [", &a, "%f", "];" );
267 // bli_printm( "b = [", &b, "%f", "];" );
268 // bli_printm( "c = [", &c, "%f", "];" );
270 #ifdef BLIS_ENABLE_MULTITHREAD_TEST
271 #pragma omp parallel num_threads(test_way)
272 {
273 #pragma omp for
274 for(j = 0; j < test_way; j++)
275 {
276 libblis_test_gemm_impl( iface, &alpha, &a, &b, &beta, &c[j] );
277 }
278 }
279 #else
280 libblis_test_gemm_impl( iface, &alpha, &a, &b, &beta, &c );
281 #endif
283 // bli_printm( "a_after = [", &a, "%f", "];" );
284 // bli_printm( "b_after = [", &b, "%f", "];" );
285 // bli_printm( "c_after = [", &c, "%f", "];" );
287 time_min = bli_clock_min_diff( time_min, time );
288 }
290 // Estimate the performance of the best experiment repeat.
291 #ifdef BLIS_ENABLE_MULTITHREAD_TEST
292 perf->gflops = ( 2.0 * m * n * k ) / time_min * test_way / FLOPS_PER_UNIT_PERF;
293 if ( bli_obj_is_complex( c[0] ) ) perf->gflops *= 4.0;
294 #else
295 perf->gflops = ( 2.0 * m * n * k ) / time_min / FLOPS_PER_UNIT_PERF;
296 if ( bli_obj_is_complex( c ) ) perf->gflops *= 4.0;
297 #endif
298 perf->time = time_min;
300 // Perform checks.
301 #ifdef BLIS_ENABLE_MULTITHREAD_TEST
302 // Check output of each thread, and send max residue to main
303 for(i = 0; i < test_way; i++)
304 {
305 libblis_test_gemm_check( &alpha, &a, &b, &beta, &c[i], &c_save[i], &resid_local[i] );
306 libblis_test_check_empty_problem( &c[i], perf, &resid_local[i] );
308 if(i == 0)
309 {
310 *resid = resid_local[i];
311 }
312 else if (resid_local[i] > *resid)
313 {
314 *resid = resid_local[i];
315 }
316 }
318 #else
319 libblis_test_gemm_check( &alpha, &a, &b, &beta, &c, &c_save, resid );
321 // Zero out performance and residual if output matrix is empty.
322 libblis_test_check_empty_problem( &c, perf, resid );
323 #endif
325 // Free the test objects.
326 bli_obj_free( &a );
327 bli_obj_free( &b );
328 #ifdef BLIS_ENABLE_MULTITHREAD_TEST
329 for(i = 0; i < test_way; i++)
330 {
331 bli_obj_free( &c[i] );
332 bli_obj_free( &c_save[i] );
333 }
334 #else
335 bli_obj_free( &c );
336 bli_obj_free( &c_save );
337 #endif
338 }
342 void libblis_test_gemm_impl( iface_t iface,
343 obj_t* alpha,
344 obj_t* a,
345 obj_t* b,
346 obj_t* beta,
347 obj_t* c )
348 {
349 switch ( iface )
350 {
351 case BLIS_TEST_SEQ_FRONT_END:
352 ;
353 #ifdef CBLAS
354 enum CBLAS_ORDER order;
355 enum CBLAS_TRANSPOSE transA, transB;
356 int m, n, k;
357 int lda, ldb, ldc;
359 m = bli_obj_length( *c );
360 k = bli_obj_width_after_trans( *a );
361 n = bli_obj_width( *c );
363 // printf("m %d k %d n %d \n", m, k, n);
364 // printf("A %d %d\n", bli_obj_length( *a ), bli_obj_width( *a) );
365 // printf("B %d %d\n", bli_obj_length( *b ), bli_obj_width( *b) );
366 // printf("C %d %d\n", bli_obj_length( *c ), bli_obj_width( *c) );
367 // printf("A %d %d\n", bli_obj_row_stride( *a ), bli_obj_col_stride( *a) );
368 // printf("B %d %d\n", bli_obj_row_stride( *b ), bli_obj_col_stride( *b) );
369 // printf("C %d %d\n", bli_obj_row_stride( *c ), bli_obj_col_stride( *c) );
371 if(bli_obj_is_row_stored( *c ))
372 {
373 order = CblasRowMajor;
374 lda = bli_obj_row_stride(*a);
375 ldb = bli_obj_row_stride(*b);
376 ldc = bli_obj_row_stride(*c);
377 }
378 else if(bli_obj_is_col_stored( *c ))
379 {
380 order = CblasColMajor;
381 lda = bli_obj_col_stride(*a);
382 ldb = bli_obj_col_stride(*b);
383 ldc = bli_obj_col_stride(*c);
384 }
385 else
386 {
387 bli_check_error_code( BLIS_INVALID_DIM_STRIDE_COMBINATION );
388 return;
389 }
391 if(bli_obj_has_notrans( *a ) && bli_obj_has_noconj( *a ))
392 {
393 transA = CblasNoTrans;
394 }
395 else if(bli_obj_has_trans( *a ) && bli_obj_has_noconj( *a ))
396 {
397 transA = CblasTrans;
398 }
399 else if(bli_obj_has_trans( *a ) && bli_obj_has_conj( *a ))
400 {
401 transA = CblasConjTrans;
402 }
403 else
404 {
405 bli_check_error_code( BLIS_INVALID_TRANS );
406 return;
407 }
409 if(bli_obj_has_notrans( *b ) && bli_obj_has_noconj( *b ))
410 {
411 transB = CblasNoTrans;
412 }
413 else if(bli_obj_has_trans( *b ) && bli_obj_has_noconj( *b ))
414 {
415 transB = CblasTrans;
416 }
417 else if(bli_obj_has_trans( *b ) && bli_obj_has_conj( *b ))
418 {
419 transB = CblasConjTrans;
420 }
421 else
422 {
423 bli_check_error_code( BLIS_INVALID_TRANS );
424 return;
425 }
427 if (bli_obj_is_float( *a ))
428 {
429 float *cblas_alpha, *cblas_beta;
430 float *cblas_a, *cblas_b, *cblas_c;
432 //obj_t *c_save;
435 cblas_alpha = (float *) bli_obj_buffer( *alpha );
436 cblas_beta = (float *) bli_obj_buffer( *beta );
437 cblas_a = (float *) bli_obj_buffer( *a );
438 cblas_b = (float *) bli_obj_buffer( *b );
439 cblas_c = (float *) bli_obj_buffer( *c );
441 //printf("test_gemm %d %d %d %d %d, 0x%x, 0x%x, 0x%x\n", order, transA, transB, lda, ldb, (unsigned int)cblas_a,(unsigned int)cblas_b,(unsigned int)cblas_c);
442 //printf("Start sgemm for (m,k,n) = (%d,%d,%d) \n", m, k, n);
443 cblas_sgemm(order, transA, transB, m, n, k, *cblas_alpha, cblas_a, lda, cblas_b, ldb, *cblas_beta, cblas_c, ldc);
444 //printf("sgemm for (m,k,n) = (%d,%d,%d) finished.\n", m, k, n);
446 }
447 else if (bli_obj_is_double( *a ))
448 {
449 double *cblas_alpha, *cblas_beta;
450 double *cblas_a, *cblas_b, *cblas_c;
452 cblas_alpha = (double *) bli_obj_buffer( *alpha );
453 cblas_beta = (double *) bli_obj_buffer( *beta );
454 cblas_a = (double *) bli_obj_buffer( *a );
455 cblas_b = (double *) bli_obj_buffer( *b );
456 cblas_c = (double *) bli_obj_buffer( *c );
458 //printf("test_gemm %d %d %d %d %d, 0x%x, 0x%x, 0x%x\n", order, transA, transB, lda, ldb, (unsigned int)cblas_a,(unsigned int)cblas_b,(unsigned int)cblas_c);
459 cblas_dgemm(order, transA, transB, m, n, k, *cblas_alpha, cblas_a, lda, cblas_b, ldb, *cblas_beta, cblas_c, ldc);
461 }
462 else if (bli_obj_is_scomplex( *a ))
463 {
464 void *cblas_alpha, *cblas_beta;
465 void *cblas_a, *cblas_b, *cblas_c;
467 cblas_alpha = bli_obj_buffer( *alpha );
468 cblas_beta = bli_obj_buffer( *beta );
469 cblas_a = bli_obj_buffer( *a );
470 cblas_b = bli_obj_buffer( *b );
471 cblas_c = bli_obj_buffer( *c );
473 cblas_cgemm(order, transA, transB, m, n, k, cblas_alpha, cblas_a, lda, cblas_b, ldb, cblas_beta, cblas_c, ldc);
474 }
475 else if (bli_obj_is_dcomplex( *a ))
476 {
477 void *cblas_alpha, *cblas_beta;
478 void *cblas_a, *cblas_b, *cblas_c;
480 cblas_alpha = bli_obj_buffer( *alpha );
481 cblas_beta = bli_obj_buffer( *beta );
482 cblas_a = bli_obj_buffer( *a );
483 cblas_b = bli_obj_buffer( *b );
484 cblas_c = bli_obj_buffer( *c );
486 //printf("Start zgemm for (m,k,n) = (%d,%d,%d) \n", m, k, n);
487 cblas_zgemm(order, transA, transB, m, n, k, cblas_alpha, cblas_a, lda, cblas_b, ldb, cblas_beta, cblas_c, ldc);
488 //printf("zgemm for (m,k,n) = (%d,%d,%d) finished.\n", m, k, n);
489 }
490 #else
491 bli_gemm( alpha, a, b, beta, c );
492 //bli_gemm4m( alpha, a, b, beta, c );
493 //bli_gemm3m( alpha, a, b, beta, c );
494 #endif
495 break;
497 default:
498 libblis_test_printf_error( "Invalid interface type.\n" );
499 }
500 }
504 void libblis_test_gemm_check( obj_t* alpha,
505 obj_t* a,
506 obj_t* b,
507 obj_t* beta,
508 obj_t* c,
509 obj_t* c_orig,
510 double* resid )
511 {
512 num_t dt = bli_obj_datatype( *c );
513 num_t dt_real = bli_obj_datatype_proj_to_real( *c );
515 dim_t m = bli_obj_length( *c );
516 dim_t n = bli_obj_width( *c );
517 dim_t k = bli_obj_width_after_trans( *a );
519 obj_t kappa, norm;
520 obj_t t, v, w, z;
522 double junk;
524 //
525 // Pre-conditions:
526 // - a is randomized.
527 // - b is randomized.
528 // - c_orig is randomized.
529 // Note:
530 // - alpha and beta should have non-zero imaginary components in the
531 // complex cases in order to more fully exercise the implementation.
532 //
533 // Under these conditions, we assume that the implementation for
534 //
535 // C := beta * C_orig + alpha * transa(A) * transb(B)
536 //
537 // is functioning correctly if
538 //
539 // normf( v - z )
540 //
541 // is negligible, where
542 //
543 // v = C * t
544 // z = ( beta * C_orig + alpha * transa(A) * transb(B) ) * t
545 // = beta * C_orig * t + alpha * transa(A) * transb(B) * t
546 // = beta * C_orig * t + alpha * transa(A) * w
547 // = beta * C_orig * t + z
548 //
550 bli_obj_scalar_init_detached( dt, &kappa );
551 bli_obj_scalar_init_detached( dt_real, &norm );
553 bli_obj_create( dt, n, 1, 0, 0, &t );
554 bli_obj_create( dt, m, 1, 0, 0, &v );
555 bli_obj_create( dt, k, 1, 0, 0, &w );
556 bli_obj_create( dt, m, 1, 0, 0, &z );
558 bli_randv( &t );
559 bli_setsc( 1.0/( double )n, 0.0, &kappa );
560 bli_scalv( &kappa, &t );
562 bli_gemv( &BLIS_ONE, c, &t, &BLIS_ZERO, &v );
564 bli_gemv( &BLIS_ONE, b, &t, &BLIS_ZERO, &w );
565 bli_gemv( alpha, a, &w, &BLIS_ZERO, &z );
566 bli_gemv( beta, c_orig, &t, &BLIS_ONE, &z );
568 bli_subv( &z, &v );
569 bli_normfv( &v, &norm );
570 bli_getsc( &norm, resid, &junk );
572 bli_obj_free( &t );
573 bli_obj_free( &v );
574 bli_obj_free( &w );
575 bli_obj_free( &z );
576 }