1 /*
3 BLIS
4 An object-based framework for developing high-performance BLAS-like
5 libraries.
7 Copyright (C) 2014, The University of Texas at Austin
9 Redistribution and use in source and binary forms, with or without
10 modification, are permitted provided that the following conditions are
11 met:
12 - Redistributions of source code must retain the above copyright
13 notice, this list of conditions and the following disclaimer.
14 - Redistributions in binary form must reproduce the above copyright
15 notice, this list of conditions and the following disclaimer in the
16 documentation and/or other materials provided with the distribution.
17 - Neither the name of The University of Texas at Austin nor the names
18 of its contributors may be used to endorse or promote products
19 derived from this software without specific prior written permission.
21 THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
22 "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
23 LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
24 A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
25 HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
26 SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
27 LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
28 DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
29 THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
30 (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
31 OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
33 */
35 #include "blis.h"
37 #define FUNCPTR_T gemm_fp
39 typedef void (*FUNCPTR_T)(
40 doff_t diagoffa,
41 pack_t schema_a,
42 pack_t schema_b,
43 dim_t m,
44 dim_t n,
45 dim_t k,
46 void* alpha1,
47 void* a, inc_t cs_a, inc_t pd_a, inc_t ps_a,
48 void* b, inc_t rs_b, inc_t pd_b, inc_t ps_b,
49 void* alpha2,
50 void* c, inc_t rs_c, inc_t cs_c,
51 void* gemmtrsm_ukr,
52 void* gemm_ukr,
53 trsm_thrinfo_t* thread
54 );
56 static FUNCPTR_T GENARRAY(ftypes,trsm_lu_ker_var2);
59 void bli_trsm_lu_ker_var2( obj_t* a,
60 obj_t* b,
61 obj_t* c,
62 trsm_t* cntl,
63 trsm_thrinfo_t* thread )
64 {
65 num_t dt_exec = bli_obj_execution_datatype( *c );
67 doff_t diagoffa = bli_obj_diag_offset( *a );
69 pack_t schema_a = bli_obj_pack_schema( *a );
70 pack_t schema_b = bli_obj_pack_schema( *b );
72 dim_t m = bli_obj_length( *c );
73 dim_t n = bli_obj_width( *c );
74 dim_t k = bli_obj_width( *a );
76 void* buf_a = bli_obj_buffer_at_off( *a );
77 inc_t cs_a = bli_obj_col_stride( *a );
78 inc_t pd_a = bli_obj_panel_dim( *a );
79 inc_t ps_a = bli_obj_panel_stride( *a );
81 void* buf_b = bli_obj_buffer_at_off( *b );
82 inc_t rs_b = bli_obj_row_stride( *b );
83 inc_t pd_b = bli_obj_panel_dim( *b );
84 inc_t ps_b = bli_obj_panel_stride( *b );
86 void* buf_c = bli_obj_buffer_at_off( *c );
87 inc_t rs_c = bli_obj_row_stride( *c );
88 inc_t cs_c = bli_obj_col_stride( *c );
90 void* buf_alpha1;
91 void* buf_alpha2;
93 FUNCPTR_T f;
95 func_t* gemmtrsm_ukrs;
96 func_t* gemm_ukrs;
97 void* gemmtrsm_ukr;
98 void* gemm_ukr;
101 // Grab the address of the internal scalar buffer for the scalar
102 // attached to B. This will be the alpha scalar used in the gemmtrsm
103 // subproblems (ie: the scalar that would be applied to the packed
104 // copy of B prior to it being updated by the trsm subproblem). This
105 // scalar may be unit, if for example it was applied during packing.
106 buf_alpha1 = bli_obj_internal_scalar_buffer( *b );
108 // Grab the address of the internal scalar buffer for the scalar
109 // attached to C. This will be the "beta" scalar used in the gemm-only
110 // subproblems that correspond to micro-panels that do not intersect
111 // the diagonal. We need this separate scalar because it's possible
112 // that the alpha attached to B was reset, if it was applied during
113 // packing.
114 buf_alpha2 = bli_obj_internal_scalar_buffer( *c );
116 // Index into the type combination array to extract the correct
117 // function pointer.
118 f = ftypes[dt_exec];
120 // Extract from the control tree node the func_t objects containing
121 // the gemmtrsm and gemm micro-kernel function addresses, and then
122 // query the function addresses corresponding to the current datatype.
123 gemmtrsm_ukrs = cntl_gemmtrsm_u_ukrs( cntl );
124 gemm_ukrs = cntl_gemm_ukrs( cntl );
125 gemmtrsm_ukr = bli_func_obj_query( dt_exec, gemmtrsm_ukrs );
126 gemm_ukr = bli_func_obj_query( dt_exec, gemm_ukrs );
128 // Invoke the function.
129 f( diagoffa,
130 schema_a,
131 schema_b,
132 m,
133 n,
134 k,
135 buf_alpha1,
136 buf_a, cs_a, pd_a, ps_a,
137 buf_b, rs_b, pd_b, ps_b,
138 buf_alpha2,
139 buf_c, rs_c, cs_c,
140 gemmtrsm_ukr,
141 gemm_ukr,
142 thread );
143 }
146 #undef GENTFUNC
147 #define GENTFUNC( ctype, ch, varname, gemmtrsmtype, gemmtype ) \
148 \
149 void PASTEMAC(ch,varname)( \
150 doff_t diagoffa, \
151 pack_t schema_a, \
152 pack_t schema_b, \
153 dim_t m, \
154 dim_t n, \
155 dim_t k, \
156 void* alpha1, \
157 void* a, inc_t cs_a, inc_t pd_a, inc_t ps_a, \
158 void* b, inc_t rs_b, inc_t pd_b, inc_t ps_b, \
159 void* alpha2, \
160 void* c, inc_t rs_c, inc_t cs_c, \
161 void* gemmtrsm_ukr, \
162 void* gemm_ukr, \
163 trsm_thrinfo_t* thread \
164 ) \
165 { \
166 /* Cast the micro-kernels' addresses to their function pointer types. */ \
167 PASTECH(ch,gemmtrsmtype) gemmtrsm_ukr_cast = (PASTECH(ch,gemmtrsmtype)) gemmtrsm_ukr; \
168 PASTECH(ch,gemmtype) gemm_ukr_cast = (PASTECH(ch,gemmtype)) gemm_ukr; \
169 \
170 /* Temporary C buffer for edge cases. */ \
171 ctype ct[ PASTEMAC(ch,maxmr) * \
172 PASTEMAC(ch,maxnr) ] \
173 __attribute__((aligned(BLIS_STACK_BUF_ALIGN_SIZE))); \
174 const inc_t rs_ct = 1; \
175 const inc_t cs_ct = PASTEMAC(ch,maxmr); \
176 \
177 /* Alias some constants to simpler names. */ \
178 const dim_t MR = pd_a; \
179 const dim_t NR = pd_b; \
180 const dim_t PACKMR = cs_a; \
181 const dim_t PACKNR = rs_b; \
182 \
183 ctype* restrict zero = PASTEMAC(ch,0); \
184 ctype* restrict minus_one = PASTEMAC(ch,m1); \
185 ctype* restrict a_cast = a; \
186 ctype* restrict b_cast = b; \
187 ctype* restrict c_cast = c; \
188 ctype* restrict alpha1_cast = alpha1; \
189 ctype* restrict alpha2_cast = alpha2; \
190 ctype* restrict b1; \
191 ctype* restrict c1; \
192 \
193 doff_t diagoffa_i; \
194 dim_t k_full; \
195 dim_t m_iter, m_left; \
196 dim_t n_iter, n_left; \
197 dim_t m_cur; \
198 dim_t n_cur; \
199 dim_t k_a1112; \
200 dim_t k_a11; \
201 dim_t k_a12; \
202 dim_t off_a11; \
203 dim_t off_a12; \
204 dim_t i, j, ib; \
205 inc_t rstep_a; \
206 inc_t cstep_b; \
207 inc_t rstep_c, cstep_c; \
208 inc_t istep_a; \
209 inc_t istep_b; \
210 inc_t off_scl; \
211 inc_t ss_a_num; \
212 inc_t ss_a_den; \
213 inc_t ps_a_cur; \
214 auxinfo_t aux; \
215 \
216 /*
217 Assumptions/assertions:
218 rs_a == 1
219 cs_a == PACKMR
220 pd_a == MR
221 ps_a == stride to next micro-panel of A
222 rs_b == PACKNR
223 cs_b == 1
224 pd_b == NR
225 ps_b == stride to next micro-panel of B
226 rs_c == (no assumptions)
227 cs_c == (no assumptions)
228 */ \
229 \
230 /* If any dimension is zero, return immediately. */ \
231 if ( bli_zero_dim3( m, n, k ) ) return; \
232 \
233 /* Safeguard: If matrix A is below the diagonal, it is implicitly zero.
234 So we do nothing. */ \
235 if ( bli_is_strictly_below_diag_n( diagoffa, m, k ) ) return; \
236 \
237 /* Compute k_full as k inflated up to a multiple of MR. This is
238 needed because some parameter combinations of trsm reduce k
239 to advance past zero regions in the triangular matrix, and
240 when computing the imaginary stride of B (the non-triangular
241 matrix), which is used by 3m and 4m implementations, we need
242 this unreduced value of k. */ \
243 k_full = ( k % MR != 0 ? k + MR - ( k % MR ) : k ); \
244 \
245 /* Compute indexing scaling factor for for 4m or 3m. This is
246 needed because one of the packing register blocksizes (PACKMR
247 or PACKNR) is used to index into the micro-panels of the non-
248 triangular matrix when computing with a diagonal-intersecting
249 micro-panel of the triangular matrix. In the case of 4m or 3m,
250 real values are stored in both sub-panels, and so the indexing
251 needs to occur in units of real values. The value computed
252 here is divided into the complex pointer offset to cause the
253 pointer to be advanced by the correct value. */ \
254 if ( bli_is_4m_packed( schema_a ) || \
255 bli_is_3m_packed( schema_a ) || \
256 bli_is_rih_packed( schema_a ) ) off_scl = 2; \
257 else off_scl = 1; \
258 \
259 /* Compute the storage stride. Usually this is just PACKMR (for A
260 or PACKNR (for B). However, in the case of 3m, we need to scale
261 the offset by 3/2. Since it's possible we may need to scale
262 the packing dimension by a non-integer value, we break up the
263 scaling factor into numerator and denominator. */ \
264 if ( bli_is_3m_packed( schema_a ) ) { ss_a_num = 3*PACKMR; \
265 ss_a_den = 2; } \
266 else { ss_a_num = 1*PACKMR; \
267 ss_a_den = 1; } \
268 \
269 /* If there is a zero region to the left of where the diagonal of A
270 intersects the top edge of the block, adjust the pointer to B and
271 treat this case as if the diagonal offset were zero. Note that we
272 don't need to adjust the pointer to A since packm would have simply
273 skipped over the region that was not stored. */ \
274 if ( diagoffa > 0 ) \
275 { \
276 i = diagoffa; \
277 k = k - i; \
278 diagoffa = 0; \
279 b_cast = b_cast + ( i * PACKNR ) / off_scl; \
280 } \
281 \
282 /* If there is a zero region below where the diagonal of A intersects the
283 right side of the block, shrink it to prevent "no-op" iterations from
284 executing. */ \
285 if ( -diagoffa + k < m ) \
286 { \
287 m = -diagoffa + k; \
288 } \
289 \
290 /* Check the k dimension, which needs to be a multiple of MR. If k
291 isn't a multiple of MR, we adjust it higher to satisfy the micro-
292 kernel, which is expecting to perform an MR x MR triangular solve.
293 This adjustment of k is consistent with what happened when A was
294 packed: all of its bottom/right edges were zero-padded, and
295 furthermore, the panel that stores the bottom-right corner of the
296 matrix has its diagonal extended into the zero-padded region (as
297 identity). This allows the trsm of that bottom-right panel to
298 proceed without producing any infs or NaNs that would infect the
299 "good" values of the corresponding block of B. */ \
300 if ( k % MR != 0 ) k += MR - ( k % MR ); \
301 \
302 /* NOTE: We don't need to check that m is a multiple of PACKMR since we
303 know that the underlying buffer was already allocated to have an m
304 dimension that is a multiple of PACKMR, with the region between the
305 last row and the next multiple of MR zero-padded accordingly. */ \
306 \
307 /* Clear the temporary C buffer in case it has any infs or NaNs. */ \
308 PASTEMAC(ch,set0s_mxn)( MR, NR, \
309 ct, rs_ct, cs_ct ); \
310 \
311 /* Compute number of primary and leftover components of the m and n
312 dimensions. */ \
313 n_iter = n / NR; \
314 n_left = n % NR; \
315 \
316 m_iter = m / MR; \
317 m_left = m % MR; \
318 \
319 if ( n_left ) ++n_iter; \
320 if ( m_left ) ++m_iter; \
321 \
322 /* Determine some increments used to step through A, B, and C. */ \
323 rstep_a = ps_a; \
324 \
325 cstep_b = ps_b; \
326 \
327 rstep_c = rs_c * MR; \
328 cstep_c = cs_c * NR; \
329 \
330 istep_a = PACKMR * k; \
331 istep_b = PACKNR * k_full; \
332 \
333 /* Save the pack schemas of A and B to the auxinfo_t object. */ \
334 bli_auxinfo_set_schema_a( schema_a, aux ); \
335 bli_auxinfo_set_schema_b( schema_b, aux ); \
336 \
337 /* Save the imaginary stride of B to the auxinfo_t object. */ \
338 bli_auxinfo_set_is_b( istep_b, aux ); \
339 \
340 b1 = b_cast; \
341 c1 = c_cast; \
342 \
343 /* Loop over the n dimension (NR columns at a time). */ \
344 for ( j = 0; j < n_iter; ++j ) \
345 { \
346 if( trsm_my_iter( j, thread ) ) { \
347 \
348 ctype* restrict a1; \
349 ctype* restrict c11; \
350 ctype* restrict b2; \
351 \
352 a1 = a_cast; \
353 c11 = c1 + (m_iter-1)*rstep_c; \
354 \
355 n_cur = ( bli_is_not_edge_f( j, n_iter, n_left ) ? NR : n_left ); \
356 \
357 /* Initialize our next panel of B to be the current panel of B. */ \
358 b2 = b1; \
359 \
360 /* Loop over the m dimension (MR rows at a time). */ \
361 for ( ib = 0; ib < m_iter; ++ib ) \
362 { \
363 i = m_iter - 1 - ib; \
364 diagoffa_i = diagoffa + ( doff_t )i*MR; \
365 \
366 m_cur = ( bli_is_not_edge_b( ib, m_iter, m_left ) ? MR : m_left ); \
367 \
368 /* If the current panel of A intersects the diagonal, use a
369 special micro-kernel that performs a fused gemm and trsm.
370 If the current panel of A resides above the diagonal, use a
371 a regular gemm micro-kernel. Otherwise, if it is below the
372 diagonal, it was not packed (because it is implicitly zero)
373 and so we do nothing. */ \
374 if ( bli_intersects_diag_n( diagoffa_i, MR, k ) ) \
375 { \
376 ctype* restrict a11; \
377 ctype* restrict a12; \
378 ctype* restrict b11; \
379 ctype* restrict b21; \
380 ctype* restrict a2; \
381 \
382 /* Compute various offsets into and lengths of parts of A. */ \
383 off_a11 = diagoffa_i; \
384 k_a1112 = k - off_a11;; \
385 k_a11 = MR; \
386 k_a12 = k_a1112 - MR; \
387 off_a12 = off_a11 + k_a11; \
388 \
389 /* Compute the panel stride for the current diagonal-
390 intersecting micro-panel. */ \
391 ps_a_cur = ( k_a1112 * ss_a_num ) / ss_a_den; \
392 \
393 /* Compute the addresses of the triangular block A11 and the
394 panel A12. */ \
395 a11 = a1; \
396 a12 = a1 + ( k_a11 * PACKMR ) / off_scl; \
397 \
398 /* Compute the addresses of the panel B01 and the block
399 B11. */ \
400 b11 = b1 + ( off_a11 * PACKNR ) / off_scl; \
401 b21 = b1 + ( off_a12 * PACKNR ) / off_scl; \
402 \
403 /* Compute the addresses of the next panels of A and B. */ \
404 a2 = a1 + ps_a_cur; \
405 if ( bli_is_last_iter( ib, m_iter, 0, 1 ) ) \
406 { \
407 a2 = a_cast; \
408 b2 = b1; \
409 /*if ( bli_is_last_iter( j, n_iter, 0, 1 ) ) */\
410 if ( j + thread_num_threads(thread) >= n_iter ) \
411 b2 = b_cast; \
412 } \
413 \
414 /* Save addresses of next panels of A and B to the auxinfo_t
415 object. */ \
416 bli_auxinfo_set_next_a( a2, aux ); \
417 bli_auxinfo_set_next_b( b2, aux ); \
418 \
419 /* Save the 4m/3m imaginary stride of A to the auxinfo_t
420 object. */ \
421 bli_auxinfo_set_is_a( PACKMR * k_a1112, aux ); \
422 \
423 /* Handle interior and edge cases separately. */ \
424 if ( m_cur == MR && n_cur == NR ) \
425 { \
426 /* Invoke the fused gemm/trsm micro-kernel. */ \
427 gemmtrsm_ukr_cast( k_a12, \
428 alpha1_cast, \
429 a12, \
430 a11, \
431 b21, \
432 b11, \
433 c11, rs_c, cs_c, \
434 &aux ); \
435 } \
436 else \
437 { \
438 /* Invoke the fused gemm/trsm micro-kernel. */ \
439 gemmtrsm_ukr_cast( k_a12, \
440 alpha1_cast, \
441 a12, \
442 a11, \
443 b21, \
444 b11, \
445 ct, rs_ct, cs_ct, \
446 &aux ); \
447 \
448 /* Copy the result to the bottom edge of C. */ \
449 PASTEMAC(ch,copys_mxn)( m_cur, n_cur, \
450 ct, rs_ct, cs_ct, \
451 c11, rs_c, cs_c ); \
452 } \
453 \
454 a1 += ps_a_cur; \
455 } \
456 else if ( bli_is_strictly_above_diag_n( diagoffa_i, MR, k ) ) \
457 { \
458 ctype* restrict a2; \
459 \
460 /* Compute the addresses of the next panels of A and B. */ \
461 a2 = a1 + rstep_a; \
462 if ( bli_is_last_iter( ib, m_iter, 0, 1 ) ) \
463 { \
464 a2 = a_cast; \
465 b2 = b1; \
466 /*if ( bli_is_last_iter( j, n_iter, 0, 1 ) ) */\
467 if ( j + thread_num_threads(thread) >= n_iter ) \
468 b2 = b_cast; \
469 } \
470 \
471 /* Save addresses of next panels of A and B to the auxinfo_t
472 object. */ \
473 bli_auxinfo_set_next_a( a2, aux ); \
474 bli_auxinfo_set_next_b( b2, aux ); \
475 \
476 /* Save the 4m/3m imaginary stride of A to the auxinfo_t
477 object. */ \
478 bli_auxinfo_set_is_a( istep_a, aux ); \
479 \
480 /* Handle interior and edge cases separately. */ \
481 if ( m_cur == MR && n_cur == NR ) \
482 { \
483 /* Invoke the gemm micro-kernel. */ \
484 gemm_ukr_cast( k, \
485 minus_one, \
486 a1, \
487 b1, \
488 alpha2_cast, \
489 c11, rs_c, cs_c, \
490 &aux ); \
491 } \
492 else \
493 { \
494 /* Invoke the gemm micro-kernel. */ \
495 gemm_ukr_cast( k, \
496 minus_one, \
497 a1, \
498 b1, \
499 zero, \
500 ct, rs_ct, cs_ct, \
501 &aux ); \
502 \
503 /* Add the result to the edge of C. */ \
504 PASTEMAC(ch,xpbys_mxn)( m_cur, n_cur, \
505 ct, rs_ct, cs_ct, \
506 alpha2_cast, \
507 c11, rs_c, cs_c ); \
508 } \
509 \
510 a1 += rstep_a; \
511 } \
512 \
513 c11 -= rstep_c; \
514 } \
515 } \
516 \
517 b1 += cstep_b; \
518 c1 += cstep_c; \
519 } \
520 \
521 /*
522 PASTEMAC(ch,fprintm)( stdout, "trsm_lu_ker_var2: a1 (diag)", MR, k_a1112, a1, 1, MR, "%5.2f", "" ); \
523 PASTEMAC(ch,fprintm)( stdout, "trsm_lu_ker_var2: b11 (diag)", MR, NR, b11, NR, 1, "%6.3f", "" ); \
524 printf( "m_iter = %lu\n", m_iter ); \
525 printf( "m_cur = %lu\n", m_cur ); \
526 printf( "k = %lu\n", k ); \
527 printf( "diagoffa_i = %lu\n", diagoffa_i ); \
528 printf( "off_a1112 = %lu\n", off_a1112 ); \
529 printf( "k_a1112 = %lu\n", k_a1112 ); \
530 printf( "k_a12 = %lu\n", k_a12 ); \
531 printf( "k_a11 = %lu\n", k_a11 ); \
532 printf( "rs_c,cs_c = %lu %lu\n", rs_c, cs_c ); \
533 printf( "rs_ct,cs_ct= %lu %lu\n", rs_ct, cs_ct ); \
534 */ \
535 \
536 /*
537 PASTEMAC(ch,fprintm)( stdout, "trsm_lu_ker_var2: b11 after (diag)", MR, NR, b11, NR, 1, "%5.2f", "" ); \
538 PASTEMAC(ch,fprintm)( stdout, "trsm_lu_ker_var2: b11 after (diag)", MR, NR, b11, NR, 1, "%5.2f", "" ); \
539 PASTEMAC(ch,fprintm)( stdout, "trsm_lu_ker_var2: ct after (diag)", m_cur, n_cur, ct, rs_ct, cs_ct, "%5.2f", "" ); \
540 */ \
541 }
543 INSERT_GENTFUNC_BASIC2( trsm_lu_ker_var2, gemmtrsm_ukr_t, gemm_ukr_t )