1 /*
3 BLIS
4 An object-based framework for developing high-performance BLAS-like
5 libraries.
7 Copyright (C) 2014, The University of Texas at Austin
9 Redistribution and use in source and binary forms, with or without
10 modification, are permitted provided that the following conditions are
11 met:
12 - Redistributions of source code must retain the above copyright
13 notice, this list of conditions and the following disclaimer.
14 - Redistributions in binary form must reproduce the above copyright
15 notice, this list of conditions and the following disclaimer in the
16 documentation and/or other materials provided with the distribution.
17 - Neither the name of The University of Texas at Austin nor the names
18 of its contributors may be used to endorse or promote products
19 derived from this software without specific prior written permission.
21 THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
22 "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
23 LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
24 A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
25 HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
26 SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
27 LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
28 DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
29 THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
30 (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
31 OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
33 */
35 #include "blis.h"
37 #define FUNCPTR_T herk_fp
39 typedef void (*FUNCPTR_T)(
40 doff_t diagoffc,
41 pack_t schema_a,
42 pack_t schema_b,
43 dim_t m,
44 dim_t n,
45 dim_t k,
46 void* alpha,
47 void* a, inc_t cs_a, inc_t pd_a, inc_t ps_a,
48 void* b, inc_t rs_b, inc_t pd_b, inc_t ps_b,
49 void* beta,
50 void* c, inc_t rs_c, inc_t cs_c,
51 void* gemm_ukr,
52 herk_thrinfo_t* thread
53 );
55 static FUNCPTR_T GENARRAY(ftypes,herk_u_ker_var2);
58 void bli_herk_u_ker_var2( obj_t* a,
59 obj_t* b,
60 obj_t* c,
61 gemm_t* cntl,
62 herk_thrinfo_t* thread )
63 {
64 num_t dt_exec = bli_obj_execution_datatype( *c );
66 doff_t diagoffc = bli_obj_diag_offset( *c );
68 pack_t schema_a = bli_obj_pack_schema( *a );
69 pack_t schema_b = bli_obj_pack_schema( *b );
71 dim_t m = bli_obj_length( *c );
72 dim_t n = bli_obj_width( *c );
73 dim_t k = bli_obj_width( *a );
75 void* buf_a = bli_obj_buffer_at_off( *a );
76 inc_t cs_a = bli_obj_col_stride( *a );
77 inc_t pd_a = bli_obj_panel_dim( *a );
78 inc_t ps_a = bli_obj_panel_stride( *a );
80 void* buf_b = bli_obj_buffer_at_off( *b );
81 inc_t rs_b = bli_obj_row_stride( *b );
82 inc_t pd_b = bli_obj_panel_dim( *b );
83 inc_t ps_b = bli_obj_panel_stride( *b );
85 void* buf_c = bli_obj_buffer_at_off( *c );
86 inc_t rs_c = bli_obj_row_stride( *c );
87 inc_t cs_c = bli_obj_col_stride( *c );
89 obj_t scalar_a;
90 obj_t scalar_b;
92 void* buf_alpha;
93 void* buf_beta;
95 FUNCPTR_T f;
97 func_t* gemm_ukrs;
98 void* gemm_ukr;
101 // Detach and multiply the scalars attached to A and B.
102 bli_obj_scalar_detach( a, &scalar_a );
103 bli_obj_scalar_detach( b, &scalar_b );
104 bli_mulsc( &scalar_a, &scalar_b );
106 // Grab the addresses of the internal scalar buffers for the scalar
107 // merged above and the scalar attached to C.
108 buf_alpha = bli_obj_internal_scalar_buffer( scalar_b );
109 buf_beta = bli_obj_internal_scalar_buffer( *c );
111 // Index into the type combination array to extract the correct
112 // function pointer.
113 f = ftypes[dt_exec];
115 // Extract from the control tree node the func_t object containing
116 // the gemm micro-kernel function addresses, and then query the
117 // function address corresponding to the current datatype.
118 gemm_ukrs = cntl_gemm_ukrs( cntl );
119 gemm_ukr = bli_func_obj_query( dt_exec, gemm_ukrs );
121 // Invoke the function.
122 f( diagoffc,
123 schema_a,
124 schema_b,
125 m,
126 n,
127 k,
128 buf_alpha,
129 buf_a, cs_a, pd_a, ps_a,
130 buf_b, rs_b, pd_b, ps_b,
131 buf_beta,
132 buf_c, rs_c, cs_c,
133 gemm_ukr,
134 thread );
135 }
136 #ifdef BLIS_ENABLE_C66X_MEM_POOLS
138 #if defined (BLIS_ENABLE_C66X_EDMA) && defined (BLIS_ENABLE_C66X_IDMA)
140 #undef GENTFUNC
141 #define GENTFUNC( ctype, ch, varname, ukrtype ) \
142 \
143 void PASTEMAC(ch,varname)( \
144 doff_t diagoffc, \
145 pack_t schema_a, \
146 pack_t schema_b, \
147 dim_t m, \
148 dim_t n, \
149 dim_t k, \
150 void* alpha, \
151 void* a, inc_t cs_a, inc_t pd_a, inc_t ps_a, \
152 void* b, inc_t rs_b, inc_t pd_b, inc_t ps_b, \
153 void* beta, \
154 void* c, inc_t rs_c, inc_t cs_c, \
155 void* gemm_ukr, \
156 herk_thrinfo_t* thread \
157 ) \
158 { \
159 /* Cast the micro-kernel address to its function pointer type. */ \
160 PASTECH(ch,ukrtype) gemm_ukr_cast = (PASTECH(ch,ukrtype)) gemm_ukr; \
161 \
162 /* Temporary C buffer for edge cases. */ \
163 ctype ct[ PASTEMAC(ch,maxmr) * \
164 PASTEMAC(ch,maxnr) ] \
165 __attribute__((aligned(BLIS_STACK_BUF_ALIGN_SIZE))); \
166 const inc_t rs_ct = 1; \
167 const inc_t cs_ct = PASTEMAC(ch,maxmr); \
168 \
169 /* Alias some constants to simpler names. */ \
170 const dim_t MR = pd_a; \
171 const dim_t NR = pd_b; \
172 const dim_t PACKMR = cs_a; \
173 const dim_t PACKNR = rs_b; \
174 \
175 ctype* restrict zero = PASTEMAC(ch,0); \
176 ctype* restrict a_cast = a; \
177 ctype* restrict b_cast = b; \
178 ctype* restrict c_cast = c; \
179 ctype* restrict alpha_cast = alpha; \
180 ctype* restrict beta_cast = beta; \
181 ctype* restrict b1; \
182 ctype* restrict c1; \
183 \
184 doff_t diagoffc_ij; \
185 dim_t m_iter, m_left; \
186 dim_t n_iter, n_left; \
187 dim_t m_cur; \
188 dim_t n_cur; \
189 dim_t n_next; \
190 dim_t i, j, jp; \
191 inc_t rstep_a; \
192 inc_t cstep_b; \
193 /*inc_t rstep_c; */\
194 inc_t cstep_c; \
195 inc_t rstep_c11, rs_c11, cs_c11; \
196 inc_t istep_a; \
197 inc_t istep_b; \
198 auxinfo_t aux; \
199 \
200 herk_thrinfo_t* caucus = herk_thread_sub_herk( thread ); \
201 dim_t jr_num_threads = thread_n_way( thread ); \
202 dim_t jr_thread_id = thread_work_id( thread ); \
203 dim_t ir_num_threads = thread_n_way( caucus ); \
204 dim_t ir_thread_id = thread_work_id( caucus ); \
205 \
206 mem_t b1_L1_mem; \
207 /*memcpy does not like b1_L1 if it is restrict. The resid of gemm is non zero if this is changed to ctype* restrict*/ \
208 ctype* b1_L1; \
209 \
210 mem_t a1_L1_mem, a2_L1_mem; \
211 ctype *a1_L1, *a2_L1, *temp; \
212 \
213 mem_t c0_L2_mem, c1_L2_mem, c2_L2_mem; \
214 ctype *cNew0, *cNew1, *cNew2, *cNewTemp; \
215 /*EDMA Declarations */ \
216 \
217 EdmaMgr_Handle edma_handle_b = NULL; \
218 EdmaMgr_Handle edma_handle_c0 = NULL; \
219 EdmaMgr_Handle edma_handle_c1 = NULL; \
220 \
221 /*
222 Assumptions/assertions:
223 rs_a == 1
224 cs_a == PACKMR
225 pd_a == MR
226 ps_a == stride to next micro-panel of A
227 rs_b == PACKNR
228 cs_b == 1
229 pd_b == NR
230 ps_b == stride to next micro-panel of B
231 rs_c == (no assumptions)
232 cs_c == (no assumptions)
233 */ \
234 \
235 /* If any dimension is zero, return immediately. */ \
236 if ( bli_zero_dim3( m, n, k ) ) return; \
237 \
238 /* Safeguard: If the current panel of C is entirely below the diagonal,
239 it is not stored. So we do nothing. */ \
240 if ( bli_is_strictly_below_diag_n( diagoffc, m, n ) ) return; \
241 \
242 /* If there is a zero region to the left of where the diagonal of C
243 intersects the top edge of the panel, adjust the pointer to C and B
244 and treat this case as if the diagonal offset were zero. */ \
245 if ( diagoffc > 0 ) \
246 { \
247 jp = diagoffc / NR; \
248 j = jp * NR; \
249 n = n - j; \
250 diagoffc = diagoffc % NR; \
251 c_cast = c_cast + (j )*cs_c; \
252 b_cast = b_cast + (jp )*ps_b; \
253 } \
254 \
255 /* If there is a zero region below where the diagonal of C intersects
256 the right edge of the panel, shrink it to prevent "no-op" iterations
257 from executing. */ \
258 if ( -diagoffc + n < m ) \
259 { \
260 m = -diagoffc + n; \
261 } \
262 \
263 /* Clear the temporary C buffer in case it has any infs or NaNs. */ \
264 PASTEMAC(ch,set0s_mxn)( MR, NR, \
265 ct, rs_ct, cs_ct ); \
266 \
267 /* Compute number of primary and leftover components of the m and n
268 dimensions. */ \
269 n_iter = n / NR; \
270 n_left = n % NR; \
271 \
272 m_iter = m / MR; \
273 m_left = m % MR; \
274 \
275 if ( n_left ) ++n_iter; \
276 if ( m_left ) ++m_iter; \
277 \
278 /* Determine some increments used to step through A, B, and C. */ \
279 rstep_a = ps_a; \
280 \
281 cstep_b = ps_b; \
282 \
283 /*rstep_c = rs_c * MR; */\
284 cstep_c = cs_c * NR; \
285 \
286 /* When C (MC*NR) is moved to L2 the stride to get to the next panel of MRxNR*/ \
287 rstep_c11 = MR; /*stride to get to next panel of MRxNR in a panel of MCxNR*/\
288 rs_c11 = 1;\
289 cs_c11 = (m%2 == 0) ? m : m+1 ; /*(m_iter-ir_thread_id)*MR;*/ /*stride to get to next column in a panel of MRxNR*/\
290 \
291 istep_a = PACKMR * k; \
292 istep_b = PACKNR * k; \
293 \
294 /* Save the pack schemas of A and B to the auxinfo_t object. */ \
295 bli_auxinfo_set_schema_a( schema_a, aux ); \
296 bli_auxinfo_set_schema_b( schema_b, aux ); \
297 \
298 /* Save the imaginary stride of A and B to the auxinfo_t object. */ \
299 bli_auxinfo_set_is_a( istep_a, aux ); \
300 bli_auxinfo_set_is_b( istep_b, aux ); \
301 \
302 b1 = b_cast; \
303 c1 = c_cast; \
304 \
305 /*Acquiring a buffer for B in L1*/ \
306 bli_mem_acquire_m( k*NR*sizeof(ctype), BLIS_BUFFER_FOR_B_PANEL_L1, &b1_L1_mem); \
307 b1_L1 = bli_mem_buffer( &b1_L1_mem ); \
308 b1_L1 = (ctype *) ((char *) b1_L1_mem.buf + PASTEMAC(ch,bank)); \
309 \
310 /*Acquiring a buffer for A in L1*/ \
311 bli_mem_acquire_m( k*MR*sizeof(ctype), BLIS_BUFFER_FOR_A_BLOCK_L1, &a1_L1_mem); \
312 a1_L1 = bli_mem_buffer( &a1_L1_mem ); \
313 a1_L1 = a1_L1; \
314 \
315 bli_mem_acquire_m( k*MR*sizeof(ctype), BLIS_BUFFER_FOR_A_BLOCK_L1, &a2_L1_mem); \
316 a2_L1 = bli_mem_buffer( &a2_L1_mem ); \
317 \
318 /*Acquiring buffers for C (MC_x_NR) in L2 */\
319 bli_mem_acquire_m( cs_c11*NR*sizeof(ctype), BLIS_BUFFER_FOR_C_PANEL_L2, &c0_L2_mem); \
320 cNew0 = bli_mem_buffer( &c0_L2_mem ); \
321 \
322 bli_mem_acquire_m( cs_c11*NR*sizeof(ctype), BLIS_BUFFER_FOR_C_PANEL_L2, &c1_L2_mem); \
323 cNew1 = bli_mem_buffer( &c1_L2_mem ); \
324 \
325 bli_mem_acquire_m( cs_c11*NR*sizeof(ctype), BLIS_BUFFER_FOR_C_PANEL_L2, &c2_L2_mem); \
326 cNew2 = bli_mem_buffer( &c2_L2_mem ); \
327 \
328 /*Acquiring an EDMA handle from the pool*/ \
329 bli_dma_channel_acquire(&(edma_handle_b), CSL_chipReadDNUM()); \
330 if(edma_handle_b == NULL) \
331 { \
332 printf("ker_var2 Failed to alloc edma handle CoreID %d \n", CSL_chipReadDNUM()); \
333 } \
334 bli_dma_channel_acquire(&(edma_handle_c0), CSL_chipReadDNUM()); \
335 if(edma_handle_c0 == NULL) \
336 { \
337 printf("ker_var2 Failed to alloc edma handle for C0 CoreID %d \n", CSL_chipReadDNUM()); \
338 } \
339 /*Acquiring an EDMA handle from the pool*/ \
340 bli_dma_channel_acquire(&(edma_handle_c1), CSL_chipReadDNUM()); \
341 if(edma_handle_c1 == NULL) \
342 { \
343 printf("ker_var2 Failed to alloc edma handle for C1 CoreID %d \n", CSL_chipReadDNUM()); \
344 } \
345 \
346 /* initiate first c transfer */ \
347 /* For C need to transfer mxn_cur. For smaller matrix sizes it can happen that
348 * (m_iter-ir_thread_id)*MR is not equal to m which would lead to incorrect
349 * values of C written back.*/ \
350 n_cur = ( bli_is_not_edge_f( jr_thread_id, n_iter, n_left ) ? NR : n_left ); \
351 if (cs_c*sizeof(ctype) < BLIS_C66X_MAXDMASTRIDE) \
352 { \
353 EdmaMgr_copy2D2DSep(edma_handle_c0, c_cast+jr_thread_id*cstep_c, \
354 cNew1, m*sizeof(ctype), \
355 n_cur, cs_c*sizeof(ctype), cs_c11*sizeof(ctype)); \
356 }\
357 else \
358 { \
359 dim_t ii; \
360 ctype *ptr_source; \
361 ctype *ptr_dest; \
362 ptr_source = c_cast+jr_thread_id*cstep_c; \
363 ptr_dest = cNew1; \
364 for(ii = 0; ii < n_cur; ii++) \
365 { \
366 memcpy(ptr_dest, ptr_source, m*sizeof(ctype)); \
367 ptr_source += cs_c; \
368 ptr_dest += cs_c11; \
369 } \
370 } \
371 /* Loop over the n dimension (NR columns at a time). */ \
372 for ( j = jr_thread_id; j < n_iter; j += jr_num_threads ) \
373 { \
374 ctype* restrict a1; \
375 ctype* restrict c11; \
376 ctype* restrict b2; \
377 \
378 b1 = b_cast + j * cstep_b; \
379 c1 = c_cast + j * cstep_c; \
380 \
381 n_cur = ( bli_is_not_edge_f( j, n_iter, n_left ) ? NR : n_left ); \
382 n_next = ( bli_is_not_edge_f( j+jr_num_threads, n_iter, n_left ) ? NR : n_left ); \
383 \
384 m_cur = ( bli_is_not_edge_f( ir_thread_id, m_iter, m_left ) ? MR : m_left ); \
385 /* Initialize our next panel of B to be the current panel of B. */ \
386 b2 = b1; \
387 \
388 EdmaMgr_copy1D1D(edma_handle_b, b1, b1_L1, k*NR*sizeof(ctype)); \
389 idma1_setup(a2_L1, a_cast + ir_thread_id * rstep_a, k*MR*sizeof(ctype), 0, 0, 7); \
390 /* wait for previous c transfer to complete and initiate next transfer */ \
391 EdmaMgr_wait(edma_handle_c0); \
392 if(j < (n_iter-jr_num_threads)) /* no transfer for last iteration */ \
393 { \
394 if (cs_c*sizeof(ctype) < BLIS_C66X_MAXDMASTRIDE) \
395 { \
396 EdmaMgr_copy2D2DSep(edma_handle_c0, c1+jr_num_threads*cstep_c, \
397 cNew0, m*sizeof(ctype), /*(m_iter-ir_thread_id)*sizeof(ctype)*MR,*/ \
398 n_next, cs_c*sizeof(ctype), \
399 cs_c11*sizeof(ctype)); \
400 }\
401 else \
402 { \
403 dim_t ii; \
404 ctype *ptr_source; \
405 ctype *ptr_dest; \
406 ptr_source = c1+jr_num_threads*cstep_c; \
407 ptr_dest = cNew0; \
408 for(ii = 0; ii < n_next; ii++) \
409 { \
410 memcpy(ptr_dest, ptr_source, m*sizeof(ctype)); \
411 ptr_source += cs_c; \
412 ptr_dest += cs_c11; \
413 } \
414 } \
415 }\
416 /* Interior loop over the m dimension (MR rows at a time). */ \
417 for ( i = ir_thread_id; i < m_iter; i += ir_num_threads ) \
418 { \
419 ctype* restrict a2; \
420 \
421 a1 = a_cast + i * rstep_a; \
422 c11 = cNew1 + i * rstep_c11; \
423 /*c11 = c1 + i * rstep_c;*/ \
424 \
425 /* Compute the diagonal offset for the submatrix at (i,j). */ \
426 diagoffc_ij = diagoffc - (doff_t)j*NR + (doff_t)i*MR; \
427 \
428 m_cur = ( bli_is_not_edge_f( i, m_iter, m_left ) ? MR : m_left ); \
429 \
430 /* Compute the addresses of the next panels of A and B. */ \
431 a2 = herk_get_next_a_micropanel( caucus, a1, rstep_a ); \
432 temp = a1_L1; \
433 a1_L1 = a2_L1; \
434 a2_L1 = temp; \
435 /*a1 = a2; Make the next panel the current panel for the next iteration*/ \
436 while(!idma1_done()){;} \
437 if ( bli_is_last_iter( i, m_iter, ir_thread_id, ir_num_threads ) ) \
438 { \
439 a2 = a_cast; \
440 b2 = herk_get_next_b_micropanel( thread, b1, cstep_b ); \
441 if ( bli_is_last_iter( j, n_iter, jr_thread_id, jr_num_threads ) ) \
442 b2 = b_cast; \
443 } \
444 else \
445 {\
446 /*Start next panel*/ \
447 idma1_setup(a2_L1, a2, k*MR*sizeof(ctype), 0, 0, 7); \
448 }\
449 if(i == ir_thread_id) \
450 { \
451 EdmaMgr_wait(edma_handle_b);\
452 } \
453 \
454 /* Save addresses of next panels of A and B to the auxinfo_t
455 object. */ \
456 bli_auxinfo_set_next_a( a2, aux ); \
457 bli_auxinfo_set_next_b( b2, aux ); \
458 \
459 /* If the diagonal intersects the current MR x NR submatrix, we
460 compute it the temporary buffer and then add in the elements
461 on or below the diagonal.
462 Otherwise, if the submatrix is strictly above the diagonal,
463 we compute and store as we normally would.
464 And if we're strictly below the diagonal, we do nothing and
465 continue. */ \
466 if ( bli_intersects_diag_n( diagoffc_ij, m_cur, n_cur ) ) \
467 { \
468 /* Invoke the gemm micro-kernel. */ \
469 gemm_ukr_cast( k, \
470 alpha_cast, \
471 a1_L1, /*a1_L1,*/ \
472 b1_L1, /*b1_L1,*/ \
473 zero, \
474 ct, rs_ct, cs_ct,\
475 &aux ); \
476 \
477 /* Scale C and add the result to only the stored part. */ \
478 PASTEMAC(ch,xpbys_mxn_u)( diagoffc_ij, \
479 m_cur, n_cur, \
480 ct, rs_ct, cs_ct, \
481 beta_cast, \
482 c11, rs_c11, cs_c11 /*rs_c, cs_c */); \
483 } \
484 else if ( bli_is_strictly_above_diag_n( diagoffc_ij, m_cur, n_cur ) ) \
485 { \
486 /* Handle interior and edge cases separately. */ \
487 if ( m_cur == MR && n_cur == NR ) \
488 { \
489 /* Invoke the gemm micro-kernel. */ \
490 gemm_ukr_cast( k, \
491 alpha_cast, \
492 a1_L1, /*a1_L1,*/ \
493 b1_L1, /*b1_L1,*/ \
494 beta_cast, \
495 c11, rs_c11, cs_c11 /* rs_c, cs_c */, \
496 &aux ); \
497 } \
498 else \
499 { \
500 /* Invoke the gemm micro-kernel. */ \
501 gemm_ukr_cast( k, \
502 alpha_cast, \
503 a1_L1, /*a1_L1,*/ \
504 b1_L1, /*b1_L1,*/ \
505 zero, \
506 ct, rs_ct, cs_ct, \
507 &aux ); \
508 \
509 /* Scale the edge of C and add the result. */ \
510 PASTEMAC(ch,xpbys_mxn)( m_cur, n_cur, \
511 ct, rs_ct, cs_ct, \
512 beta_cast, \
513 c11, rs_c11, cs_c11 /*rs_c, cs_c */); \
514 } \
515 } \
516 } \
517 /* circularly shift buffers */ \
518 cNewTemp = cNew0; \
519 cNew0 = cNew2; \
520 cNew2 = cNew1; \
521 cNew1 = cNewTemp; \
522 if(j != jr_thread_id) /* wait for save c to complete; skip first iteration */ \
523 { \
524 EdmaMgr_wait(edma_handle_c1); \
525 } \
526 /* save updated c */ \
527 if (cs_c*sizeof(ctype) < BLIS_C66X_MAXDMASTRIDE) \
528 { \
529 EdmaMgr_copy2D2DSep(edma_handle_c1, cNew2, c1, m*sizeof(ctype), \
530 n_cur, cs_c11*sizeof(ctype), cs_c*sizeof(ctype)); \
531 } \
532 else \
533 { \
534 dim_t ii; \
535 ctype *ptr_source; \
536 ctype *ptr_dest; \
537 ptr_source = cNew2; \
538 ptr_dest = c1; \
539 for(ii = 0; ii < n_cur; ii++) \
540 { \
541 memcpy(ptr_dest, ptr_source, m*sizeof(ctype)); \
542 ptr_source += cs_c11; \
543 ptr_dest += cs_c; \
544 } \
545 } \
546 } \
547 \
548 bli_mem_release( &c2_L2_mem ); \
549 bli_mem_release( &c1_L2_mem ); \
550 bli_mem_release( &c0_L2_mem ); \
551 bli_mem_release( &a2_L1_mem ); \
552 bli_mem_release( &a1_L1_mem ); \
553 bli_mem_release( &b1_L1_mem ); \
554 if ( edma_handle_b != NULL ) \
555 { \
556 bli_dma_channel_release(edma_handle_b, CSL_chipReadDNUM()); \
557 edma_handle_b = NULL; \
558 } \
559 if ( edma_handle_c0 != NULL ) \
560 { \
561 bli_dma_channel_release(edma_handle_c0, CSL_chipReadDNUM()); \
562 edma_handle_c0 = NULL; \
563 } \
564 if ( edma_handle_c1 != NULL ) \
565 { \
566 EdmaMgr_wait(edma_handle_c1); /* wait for save c to complete */ \
567 bli_dma_channel_release(edma_handle_c1, CSL_chipReadDNUM()); \
568 edma_handle_c1 = NULL; \
569 } \
570 }
572 INSERT_GENTFUNC_BASIC( herk_u_ker_var2, gemm_ukr_t )
575 #else
577 #endif
579 #else
581 #undef GENTFUNC
582 #define GENTFUNC( ctype, ch, varname, ukrtype ) \
583 \
584 void PASTEMAC(ch,varname)( \
585 doff_t diagoffc, \
586 pack_t schema_a, \
587 pack_t schema_b, \
588 dim_t m, \
589 dim_t n, \
590 dim_t k, \
591 void* alpha, \
592 void* a, inc_t cs_a, inc_t pd_a, inc_t ps_a, \
593 void* b, inc_t rs_b, inc_t pd_b, inc_t ps_b, \
594 void* beta, \
595 void* c, inc_t rs_c, inc_t cs_c, \
596 void* gemm_ukr, \
597 herk_thrinfo_t* thread \
598 ) \
599 { \
600 /* Cast the micro-kernel address to its function pointer type. */ \
601 PASTECH(ch,ukrtype) gemm_ukr_cast = gemm_ukr; \
602 \
603 /* Temporary C buffer for edge cases. */ \
604 ctype ct[ PASTEMAC(ch,maxmr) * \
605 PASTEMAC(ch,maxnr) ] \
606 __attribute__((aligned(BLIS_STACK_BUF_ALIGN_SIZE))); \
607 const inc_t rs_ct = 1; \
608 const inc_t cs_ct = PASTEMAC(ch,maxmr); \
609 \
610 /* Alias some constants to simpler names. */ \
611 const dim_t MR = pd_a; \
612 const dim_t NR = pd_b; \
613 const dim_t PACKMR = cs_a; \
614 const dim_t PACKNR = rs_b; \
615 \
616 ctype* restrict zero = PASTEMAC(ch,0); \
617 ctype* restrict a_cast = a; \
618 ctype* restrict b_cast = b; \
619 ctype* restrict c_cast = c; \
620 ctype* restrict alpha_cast = alpha; \
621 ctype* restrict beta_cast = beta; \
622 ctype* restrict b1; \
623 ctype* restrict c1; \
624 \
625 doff_t diagoffc_ij; \
626 dim_t m_iter, m_left; \
627 dim_t n_iter, n_left; \
628 dim_t m_cur; \
629 dim_t n_cur; \
630 dim_t i, j, jp; \
631 inc_t rstep_a; \
632 inc_t cstep_b; \
633 inc_t rstep_c, cstep_c; \
634 inc_t istep_a; \
635 inc_t istep_b; \
636 auxinfo_t aux; \
637 \
638 herk_thrinfo_t* caucus = herk_thread_sub_herk( thread ); \
639 dim_t jr_num_threads = thread_n_way( thread ); \
640 dim_t jr_thread_id = thread_work_id( thread ); \
641 dim_t ir_num_threads = thread_n_way( caucus ); \
642 dim_t ir_thread_id = thread_work_id( caucus ); \
643 \
644 /*
645 Assumptions/assertions:
646 rs_a == 1
647 cs_a == PACKMR
648 pd_a == MR
649 ps_a == stride to next micro-panel of A
650 rs_b == PACKNR
651 cs_b == 1
652 pd_b == NR
653 ps_b == stride to next micro-panel of B
654 rs_c == (no assumptions)
655 cs_c == (no assumptions)
656 */ \
657 \
658 /* If any dimension is zero, return immediately. */ \
659 if ( bli_zero_dim3( m, n, k ) ) return; \
660 \
661 /* Safeguard: If the current panel of C is entirely below the diagonal,
662 it is not stored. So we do nothing. */ \
663 if ( bli_is_strictly_below_diag_n( diagoffc, m, n ) ) return; \
664 \
665 /* If there is a zero region to the left of where the diagonal of C
666 intersects the top edge of the panel, adjust the pointer to C and B
667 and treat this case as if the diagonal offset were zero. */ \
668 if ( diagoffc > 0 ) \
669 { \
670 jp = diagoffc / NR; \
671 j = jp * NR; \
672 n = n - j; \
673 diagoffc = diagoffc % NR; \
674 c_cast = c_cast + (j )*cs_c; \
675 b_cast = b_cast + (jp )*ps_b; \
676 } \
677 \
678 /* If there is a zero region below where the diagonal of C intersects
679 the right edge of the panel, shrink it to prevent "no-op" iterations
680 from executing. */ \
681 if ( -diagoffc + n < m ) \
682 { \
683 m = -diagoffc + n; \
684 } \
685 \
686 /* Clear the temporary C buffer in case it has any infs or NaNs. */ \
687 PASTEMAC(ch,set0s_mxn)( MR, NR, \
688 ct, rs_ct, cs_ct ); \
689 \
690 /* Compute number of primary and leftover components of the m and n
691 dimensions. */ \
692 n_iter = n / NR; \
693 n_left = n % NR; \
694 \
695 m_iter = m / MR; \
696 m_left = m % MR; \
697 \
698 if ( n_left ) ++n_iter; \
699 if ( m_left ) ++m_iter; \
700 \
701 /* Determine some increments used to step through A, B, and C. */ \
702 rstep_a = ps_a; \
703 \
704 cstep_b = ps_b; \
705 \
706 rstep_c = rs_c * MR; \
707 cstep_c = cs_c * NR; \
708 \
709 istep_a = PACKMR * k; \
710 istep_b = PACKNR * k; \
711 \
712 /* Save the pack schemas of A and B to the auxinfo_t object. */ \
713 bli_auxinfo_set_schema_a( schema_a, aux ); \
714 bli_auxinfo_set_schema_b( schema_b, aux ); \
715 \
716 /* Save the imaginary stride of A and B to the auxinfo_t object. */ \
717 bli_auxinfo_set_is_a( istep_a, aux ); \
718 bli_auxinfo_set_is_b( istep_b, aux ); \
719 \
720 b1 = b_cast; \
721 c1 = c_cast; \
722 \
723 /* Loop over the n dimension (NR columns at a time). */ \
724 for ( j = jr_thread_id; j < n_iter; j += jr_num_threads ) \
725 { \
726 ctype* restrict a1; \
727 ctype* restrict c11; \
728 ctype* restrict b2; \
729 \
730 b1 = b_cast + j * cstep_b; \
731 c1 = c_cast + j * cstep_c; \
732 \
733 n_cur = ( bli_is_not_edge_f( j, n_iter, n_left ) ? NR : n_left ); \
734 \
735 /* Initialize our next panel of B to be the current panel of B. */ \
736 b2 = b1; \
737 \
738 /* Interior loop over the m dimension (MR rows at a time). */ \
739 for ( i = ir_thread_id; i < m_iter; i += ir_num_threads ) \
740 { \
741 ctype* restrict a2; \
742 \
743 a1 = a_cast + i * rstep_a; \
744 c11 = c1 + i * rstep_c; \
745 \
746 /* Compute the diagonal offset for the submatrix at (i,j). */ \
747 diagoffc_ij = diagoffc - (doff_t)j*NR + (doff_t)i*MR; \
748 \
749 m_cur = ( bli_is_not_edge_f( i, m_iter, m_left ) ? MR : m_left ); \
750 \
751 /* Compute the addresses of the next panels of A and B. */ \
752 a2 = herk_get_next_a_micropanel( caucus, a1, rstep_a ); \
753 if ( bli_is_last_iter( i, m_iter, ir_thread_id, ir_num_threads ) ) \
754 { \
755 a2 = a_cast; \
756 b2 = herk_get_next_b_micropanel( thread, b1, cstep_b ); \
757 if ( bli_is_last_iter( j, n_iter, jr_thread_id, jr_num_threads ) ) \
758 b2 = b_cast; \
759 } \
760 \
761 /* Save addresses of next panels of A and B to the auxinfo_t
762 object. */ \
763 bli_auxinfo_set_next_a( a2, aux ); \
764 bli_auxinfo_set_next_b( b2, aux ); \
765 \
766 /* If the diagonal intersects the current MR x NR submatrix, we
767 compute it the temporary buffer and then add in the elements
768 on or below the diagonal.
769 Otherwise, if the submatrix is strictly above the diagonal,
770 we compute and store as we normally would.
771 And if we're strictly below the diagonal, we do nothing and
772 continue. */ \
773 if ( bli_intersects_diag_n( diagoffc_ij, m_cur, n_cur ) ) \
774 { \
775 /* Invoke the gemm micro-kernel. */ \
776 gemm_ukr_cast( k, \
777 alpha_cast, \
778 a1, \
779 b1, \
780 zero, \
781 ct, rs_ct, cs_ct, \
782 &aux ); \
783 \
784 /* Scale C and add the result to only the stored part. */ \
785 PASTEMAC(ch,xpbys_mxn_u)( diagoffc_ij, \
786 m_cur, n_cur, \
787 ct, rs_ct, cs_ct, \
788 beta_cast, \
789 c11, rs_c, cs_c ); \
790 } \
791 else if ( bli_is_strictly_above_diag_n( diagoffc_ij, m_cur, n_cur ) ) \
792 { \
793 /* Handle interior and edge cases separately. */ \
794 if ( m_cur == MR && n_cur == NR ) \
795 { \
796 /* Invoke the gemm micro-kernel. */ \
797 gemm_ukr_cast( k, \
798 alpha_cast, \
799 a1, \
800 b1, \
801 beta_cast, \
802 c11, rs_c, cs_c, \
803 &aux ); \
804 } \
805 else \
806 { \
807 /* Invoke the gemm micro-kernel. */ \
808 gemm_ukr_cast( k, \
809 alpha_cast, \
810 a1, \
811 b1, \
812 zero, \
813 ct, rs_ct, cs_ct, \
814 &aux ); \
815 \
816 /* Scale the edge of C and add the result. */ \
817 PASTEMAC(ch,xpbys_mxn)( m_cur, n_cur, \
818 ct, rs_ct, cs_ct, \
819 beta_cast, \
820 c11, rs_c, cs_c ); \
821 } \
822 } \
823 } \
824 } \
825 }
827 INSERT_GENTFUNC_BASIC( herk_u_ker_var2, gemm_ukr_t )
829 #endif