]> Gitweb @ Texas Instruments - Open Source Git Repositories - git.TI.com/gitweb - dense-linear-algebra-libraries/linalg.git/blob - blis/frame/3/trsm/bli_trsm_lu_ker_var2.c
Consolidate all git repos of linalg into one.
[dense-linear-algebra-libraries/linalg.git] / blis / frame / 3 / trsm / bli_trsm_lu_ker_var2.c
1 /*
3    BLIS    
4    An object-based framework for developing high-performance BLAS-like
5    libraries.
7    Copyright (C) 2014, The University of Texas at Austin
9    Redistribution and use in source and binary forms, with or without
10    modification, are permitted provided that the following conditions are
11    met:
12     - Redistributions of source code must retain the above copyright
13       notice, this list of conditions and the following disclaimer.
14     - Redistributions in binary form must reproduce the above copyright
15       notice, this list of conditions and the following disclaimer in the
16       documentation and/or other materials provided with the distribution.
17     - Neither the name of The University of Texas at Austin nor the names
18       of its contributors may be used to endorse or promote products
19       derived from this software without specific prior written permission.
21    THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
22    "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
23    LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
24    A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
25    HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
26    SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
27    LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
28    DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
29    THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
30    (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
31    OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
33 */
35 #include "blis.h"
37 #define FUNCPTR_T gemm_fp
39 typedef void (*FUNCPTR_T)(
40                            doff_t  diagoffa,
41                            pack_t  schema_a,
42                            pack_t  schema_b,
43                            dim_t   m,
44                            dim_t   n,
45                            dim_t   k,
46                            void*   alpha1,
47                            void*   a, inc_t cs_a, inc_t pd_a, inc_t ps_a,
48                            void*   b, inc_t rs_b, inc_t pd_b, inc_t ps_b,
49                            void*   alpha2,
50                            void*   c, inc_t rs_c, inc_t cs_c,
51                            void*   gemmtrsm_ukr,
52                            void*   gemm_ukr,
53                            trsm_thrinfo_t* thread
54                          );
56 static FUNCPTR_T GENARRAY(ftypes,trsm_lu_ker_var2);
59 void bli_trsm_lu_ker_var2( obj_t*  a,
60                            obj_t*  b,
61                            obj_t*  c,
62                            trsm_t* cntl,
63                            trsm_thrinfo_t* thread )
64 {
65         num_t     dt_exec   = bli_obj_execution_datatype( *c );
67         doff_t    diagoffa  = bli_obj_diag_offset( *a );
69         pack_t    schema_a  = bli_obj_pack_schema( *a );
70         pack_t    schema_b  = bli_obj_pack_schema( *b );
72         dim_t     m         = bli_obj_length( *c );
73         dim_t     n         = bli_obj_width( *c );
74         dim_t     k         = bli_obj_width( *a );
76         void*     buf_a     = bli_obj_buffer_at_off( *a );
77         inc_t     cs_a      = bli_obj_col_stride( *a );
78         inc_t     pd_a      = bli_obj_panel_dim( *a );
79         inc_t     ps_a      = bli_obj_panel_stride( *a );
81         void*     buf_b     = bli_obj_buffer_at_off( *b );
82         inc_t     rs_b      = bli_obj_row_stride( *b );
83         inc_t     pd_b      = bli_obj_panel_dim( *b );
84         inc_t     ps_b      = bli_obj_panel_stride( *b );
86         void*     buf_c     = bli_obj_buffer_at_off( *c );
87         inc_t     rs_c      = bli_obj_row_stride( *c );
88         inc_t     cs_c      = bli_obj_col_stride( *c );
90         void*     buf_alpha1;
91         void*     buf_alpha2;
93         FUNCPTR_T f;
95         func_t*   gemmtrsm_ukrs;
96         func_t*   gemm_ukrs;
97         void*     gemmtrsm_ukr;
98         void*     gemm_ukr;
101         // Grab the address of the internal scalar buffer for the scalar
102         // attached to B. This will be the alpha scalar used in the gemmtrsm
103         // subproblems (ie: the scalar that would be applied to the packed
104         // copy of B prior to it being updated by the trsm subproblem). This
105         // scalar may be unit, if for example it was applied during packing.
106         buf_alpha1 = bli_obj_internal_scalar_buffer( *b );
108         // Grab the address of the internal scalar buffer for the scalar
109         // attached to C. This will be the "beta" scalar used in the gemm-only
110         // subproblems that correspond to micro-panels that do not intersect
111         // the diagonal. We need this separate scalar because it's possible
112         // that the alpha attached to B was reset, if it was applied during
113         // packing.
114         buf_alpha2 = bli_obj_internal_scalar_buffer( *c );
116         // Index into the type combination array to extract the correct
117         // function pointer.
118         f = ftypes[dt_exec];
120         // Extract from the control tree node the func_t objects containing
121         // the gemmtrsm and gemm micro-kernel function addresses, and then
122         // query the function addresses corresponding to the current datatype.
123         gemmtrsm_ukrs = cntl_gemmtrsm_u_ukrs( cntl );
124         gemm_ukrs     = cntl_gemm_ukrs( cntl );
125         gemmtrsm_ukr  = bli_func_obj_query( dt_exec, gemmtrsm_ukrs );
126         gemm_ukr      = bli_func_obj_query( dt_exec, gemm_ukrs );
128         // Invoke the function.
129         f( diagoffa,
130            schema_a,
131            schema_b,
132            m,
133            n,
134            k,
135            buf_alpha1,
136            buf_a, cs_a, pd_a, ps_a,
137            buf_b, rs_b, pd_b, ps_b,
138            buf_alpha2,
139            buf_c, rs_c, cs_c,
140            gemmtrsm_ukr,
141            gemm_ukr,
142            thread );
146 #undef  GENTFUNC
147 #define GENTFUNC( ctype, ch, varname, gemmtrsmtype, gemmtype ) \
149 void PASTEMAC(ch,varname)( \
150                            doff_t  diagoffa, \
151                            pack_t  schema_a, \
152                            pack_t  schema_b, \
153                            dim_t   m, \
154                            dim_t   n, \
155                            dim_t   k, \
156                            void*   alpha1, \
157                            void*   a, inc_t cs_a, inc_t pd_a, inc_t ps_a, \
158                            void*   b, inc_t rs_b, inc_t pd_b, inc_t ps_b, \
159                            void*   alpha2, \
160                            void*   c, inc_t rs_c, inc_t cs_c, \
161                            void*   gemmtrsm_ukr, \
162                            void*   gemm_ukr,  \
163                            trsm_thrinfo_t* thread \
164                          ) \
165 { \
166         /* Cast the micro-kernels' addresses to their function pointer types. */ \
167         PASTECH(ch,gemmtrsmtype) gemmtrsm_ukr_cast = (PASTECH(ch,gemmtrsmtype)) gemmtrsm_ukr; \
168         PASTECH(ch,gemmtype)     gemm_ukr_cast     = (PASTECH(ch,gemmtype)) gemm_ukr; \
170         /* Temporary C buffer for edge cases. */ \
171         ctype           ct[ PASTEMAC(ch,maxmr) * \
172                             PASTEMAC(ch,maxnr) ] \
173                             __attribute__((aligned(BLIS_STACK_BUF_ALIGN_SIZE))); \
174         const inc_t     rs_ct      = 1; \
175         const inc_t     cs_ct      = PASTEMAC(ch,maxmr); \
177         /* Alias some constants to simpler names. */ \
178         const dim_t     MR         = pd_a; \
179         const dim_t     NR         = pd_b; \
180         const dim_t     PACKMR     = cs_a; \
181         const dim_t     PACKNR     = rs_b; \
183         ctype* restrict zero        = PASTEMAC(ch,0); \
184         ctype* restrict minus_one   = PASTEMAC(ch,m1); \
185         ctype* restrict a_cast      = a; \
186         ctype* restrict b_cast      = b; \
187         ctype* restrict c_cast      = c; \
188         ctype* restrict alpha1_cast = alpha1; \
189         ctype* restrict alpha2_cast = alpha2; \
190         ctype* restrict b1; \
191         ctype* restrict c1; \
193         doff_t          diagoffa_i; \
194         dim_t           k_full; \
195         dim_t           m_iter, m_left; \
196         dim_t           n_iter, n_left; \
197         dim_t           m_cur; \
198         dim_t           n_cur; \
199         dim_t           k_a1112; \
200         dim_t           k_a11; \
201         dim_t           k_a12; \
202         dim_t           off_a11; \
203         dim_t           off_a12; \
204         dim_t           i, j, ib; \
205         inc_t           rstep_a; \
206         inc_t           cstep_b; \
207         inc_t           rstep_c, cstep_c; \
208         inc_t           istep_a; \
209         inc_t           istep_b; \
210         inc_t           off_scl; \
211         inc_t           ss_a_num; \
212         inc_t           ss_a_den; \
213         inc_t           ps_a_cur; \
214         auxinfo_t       aux; \
216         /*
217            Assumptions/assertions:
218              rs_a == 1
219              cs_a == PACKMR
220              pd_a == MR
221              ps_a == stride to next micro-panel of A
222              rs_b == PACKNR
223              cs_b == 1
224              pd_b == NR
225              ps_b == stride to next micro-panel of B
226              rs_c == (no assumptions)
227              cs_c == (no assumptions)
228         */ \
230         /* If any dimension is zero, return immediately. */ \
231         if ( bli_zero_dim3( m, n, k ) ) return; \
233         /* Safeguard: If matrix A is below the diagonal, it is implicitly zero.
234            So we do nothing. */ \
235         if ( bli_is_strictly_below_diag_n( diagoffa, m, k ) ) return; \
237         /* Compute k_full as k inflated up to a multiple of MR. This is
238            needed because some parameter combinations of trsm reduce k
239            to advance past zero regions in the triangular matrix, and
240            when computing the imaginary stride of B (the non-triangular
241            matrix), which is used by 3m and 4m implementations, we need
242            this unreduced value of k. */ \
243         k_full = ( k % MR != 0 ? k + MR - ( k % MR ) : k ); \
245         /* Compute indexing scaling factor for for 4m or 3m. This is
246            needed because one of the packing register blocksizes (PACKMR
247            or PACKNR) is used to index into the micro-panels of the non-
248            triangular matrix when computing with a diagonal-intersecting
249            micro-panel of the triangular matrix. In the case of 4m or 3m,
250            real values are stored in both sub-panels, and so the indexing
251            needs to occur in units of real values. The value computed
252            here is divided into the complex pointer offset to cause the
253            pointer to be advanced by the correct value. */ \
254         if ( bli_is_4m_packed( schema_a ) || \
255              bli_is_3m_packed( schema_a ) || \
256              bli_is_rih_packed( schema_a ) ) off_scl = 2; \
257         else                                 off_scl = 1; \
259         /* Compute the storage stride. Usually this is just PACKMR (for A
260            or PACKNR (for B). However, in the case of 3m, we need to scale
261            the offset by 3/2. Since it's possible we may need to scale
262            the packing dimension by a non-integer value, we break up the
263            scaling factor into numerator and denominator. */ \
264         if ( bli_is_3m_packed( schema_a ) ) { ss_a_num = 3*PACKMR; \
265                                               ss_a_den = 2; } \
266         else                                { ss_a_num = 1*PACKMR; \
267                                               ss_a_den = 1; } \
269         /* If there is a zero region to the left of where the diagonal of A
270            intersects the top edge of the block, adjust the pointer to B and
271            treat this case as if the diagonal offset were zero. Note that we
272            don't need to adjust the pointer to A since packm would have simply
273            skipped over the region that was not stored. */ \
274         if ( diagoffa > 0 ) \
275         { \
276                 i        = diagoffa; \
277                 k        = k - i; \
278                 diagoffa = 0; \
279                 b_cast   = b_cast + ( i * PACKNR ) / off_scl; \
280         } \
282         /* If there is a zero region below where the diagonal of A intersects the
283            right side of the block, shrink it to prevent "no-op" iterations from
284            executing. */ \
285         if ( -diagoffa + k < m ) \
286         { \
287                 m = -diagoffa + k; \
288         } \
290         /* Check the k dimension, which needs to be a multiple of MR. If k
291            isn't a multiple of MR, we adjust it higher to satisfy the micro-
292            kernel, which is expecting to perform an MR x MR triangular solve.
293            This adjustment of k is consistent with what happened when A was
294            packed: all of its bottom/right edges were zero-padded, and
295            furthermore, the panel that stores the bottom-right corner of the
296            matrix has its diagonal extended into the zero-padded region (as
297            identity). This allows the trsm of that bottom-right panel to
298            proceed without producing any infs or NaNs that would infect the
299            "good" values of the corresponding block of B. */ \
300         if ( k % MR != 0 ) k += MR - ( k % MR ); \
302         /* NOTE: We don't need to check that m is a multiple of PACKMR since we
303            know that the underlying buffer was already allocated to have an m
304            dimension that is a multiple of PACKMR, with the region between the
305            last row and the next multiple of MR zero-padded accordingly. */ \
307         /* Clear the temporary C buffer in case it has any infs or NaNs. */ \
308         PASTEMAC(ch,set0s_mxn)( MR, NR, \
309                                 ct, rs_ct, cs_ct ); \
311         /* Compute number of primary and leftover components of the m and n
312        dimensions. */ \
313         n_iter = n / NR; \
314         n_left = n % NR; \
316         m_iter = m / MR; \
317         m_left = m % MR; \
319         if ( n_left ) ++n_iter; \
320         if ( m_left ) ++m_iter; \
322         /* Determine some increments used to step through A, B, and C. */ \
323         rstep_a = ps_a; \
325         cstep_b = ps_b; \
327         rstep_c = rs_c * MR; \
328         cstep_c = cs_c * NR; \
330         istep_a = PACKMR * k; \
331         istep_b = PACKNR * k_full; \
333         /* Save the pack schemas of A and B to the auxinfo_t object. */ \
334         bli_auxinfo_set_schema_a( schema_a, aux ); \
335         bli_auxinfo_set_schema_b( schema_b, aux ); \
337         /* Save the imaginary stride of B to the auxinfo_t object. */ \
338         bli_auxinfo_set_is_b( istep_b, aux ); \
340         b1 = b_cast; \
341         c1 = c_cast; \
343         /* Loop over the n dimension (NR columns at a time). */ \
344         for ( j = 0; j < n_iter; ++j ) \
345         { \
346                 if( trsm_my_iter( j, thread ) ) { \
348                 ctype* restrict a1; \
349                 ctype* restrict c11; \
350                 ctype* restrict b2; \
352                 a1  = a_cast; \
353                 c11 = c1 + (m_iter-1)*rstep_c; \
355                 n_cur = ( bli_is_not_edge_f( j, n_iter, n_left ) ? NR : n_left ); \
357                 /* Initialize our next panel of B to be the current panel of B. */ \
358                 b2 = b1; \
360                 /* Loop over the m dimension (MR rows at a time). */ \
361                 for ( ib = 0; ib < m_iter; ++ib ) \
362                 { \
363                         i          = m_iter - 1 - ib; \
364                         diagoffa_i = diagoffa + ( doff_t )i*MR; \
366                         m_cur = ( bli_is_not_edge_b( ib, m_iter, m_left ) ? MR : m_left ); \
368                         /* If the current panel of A intersects the diagonal, use a
369                            special micro-kernel that performs a fused gemm and trsm.
370                            If the current panel of A resides above the diagonal, use a
371                            a regular gemm micro-kernel. Otherwise, if it is below the
372                            diagonal, it was not packed (because it is implicitly zero)
373                            and so we do nothing. */ \
374                         if ( bli_intersects_diag_n( diagoffa_i, MR, k ) ) \
375                         { \
376                                 ctype* restrict a11; \
377                                 ctype* restrict a12; \
378                                 ctype* restrict b11; \
379                                 ctype* restrict b21; \
380                                 ctype* restrict a2; \
382                                 /* Compute various offsets into and lengths of parts of A. */ \
383                                 off_a11 = diagoffa_i; \
384                                 k_a1112 = k - off_a11;; \
385                                 k_a11   = MR; \
386                                 k_a12   = k_a1112 - MR; \
387                                 off_a12 = off_a11 + k_a11; \
389                                 /* Compute the panel stride for the current diagonal-
390                                    intersecting micro-panel. */ \
391                                 ps_a_cur = ( k_a1112 * ss_a_num ) / ss_a_den; \
393                                 /* Compute the addresses of the triangular block A11 and the
394                                    panel A12. */ \
395                                 a11 = a1; \
396                                 a12 = a1 + ( k_a11 * PACKMR ) / off_scl; \
398                                 /* Compute the addresses of the panel B01 and the block
399                                    B11. */ \
400                                 b11 = b1 + ( off_a11 * PACKNR ) / off_scl; \
401                                 b21 = b1 + ( off_a12 * PACKNR ) / off_scl; \
403                                 /* Compute the addresses of the next panels of A and B. */ \
404                                 a2 = a1 + ps_a_cur; \
405                                 if ( bli_is_last_iter( ib, m_iter, 0, 1 ) ) \
406                                 { \
407                                         a2 = a_cast; \
408                                         b2 = b1; \
409                                         /*if ( bli_is_last_iter( j, n_iter, 0, 1 ) ) */\
410                                         if ( j + thread_num_threads(thread) >= n_iter ) \
411                                                 b2 = b_cast; \
412                                 } \
414                                 /* Save addresses of next panels of A and B to the auxinfo_t
415                                    object. */ \
416                                 bli_auxinfo_set_next_a( a2, aux ); \
417                                 bli_auxinfo_set_next_b( b2, aux ); \
419                                 /* Save the 4m/3m imaginary stride of A to the auxinfo_t
420                                    object. */ \
421                                 bli_auxinfo_set_is_a( PACKMR * k_a1112, aux ); \
423                                 /* Handle interior and edge cases separately. */ \
424                                 if ( m_cur == MR && n_cur == NR ) \
425                                 { \
426                                         /* Invoke the fused gemm/trsm micro-kernel. */ \
427                                         gemmtrsm_ukr_cast( k_a12, \
428                                                            alpha1_cast, \
429                                                            a12, \
430                                                            a11, \
431                                                            b21, \
432                                                            b11, \
433                                                            c11, rs_c, cs_c, \
434                                                            &aux ); \
435                                 } \
436                                 else \
437                                 { \
438                                         /* Invoke the fused gemm/trsm micro-kernel. */ \
439                                         gemmtrsm_ukr_cast( k_a12, \
440                                                            alpha1_cast, \
441                                                            a12, \
442                                                            a11, \
443                                                            b21, \
444                                                            b11, \
445                                                            ct, rs_ct, cs_ct, \
446                                                            &aux ); \
448                                         /* Copy the result to the bottom edge of C. */ \
449                                         PASTEMAC(ch,copys_mxn)( m_cur, n_cur, \
450                                                                 ct,  rs_ct, cs_ct, \
451                                                                 c11, rs_c,  cs_c ); \
452                                 } \
454                                 a1 += ps_a_cur; \
455                         } \
456                         else if ( bli_is_strictly_above_diag_n( diagoffa_i, MR, k ) ) \
457                         { \
458                                 ctype* restrict a2; \
460                                 /* Compute the addresses of the next panels of A and B. */ \
461                                 a2 = a1 + rstep_a; \
462                                 if ( bli_is_last_iter( ib, m_iter, 0, 1 ) ) \
463                                 { \
464                                         a2 = a_cast; \
465                                         b2 = b1; \
466                                         /*if ( bli_is_last_iter( j, n_iter, 0, 1 ) ) */\
467                                         if ( j + thread_num_threads(thread) >= n_iter ) \
468                                                 b2 = b_cast; \
469                                 } \
471                                 /* Save addresses of next panels of A and B to the auxinfo_t
472                                    object. */ \
473                                 bli_auxinfo_set_next_a( a2, aux ); \
474                                 bli_auxinfo_set_next_b( b2, aux ); \
476                                 /* Save the 4m/3m imaginary stride of A to the auxinfo_t
477                                    object. */ \
478                                 bli_auxinfo_set_is_a( istep_a, aux ); \
480                                 /* Handle interior and edge cases separately. */ \
481                                 if ( m_cur == MR && n_cur == NR ) \
482                                 { \
483                                         /* Invoke the gemm micro-kernel. */ \
484                                         gemm_ukr_cast( k, \
485                                                        minus_one, \
486                                                        a1, \
487                                                        b1, \
488                                                        alpha2_cast, \
489                                                        c11, rs_c, cs_c, \
490                                                        &aux ); \
491                                 } \
492                                 else \
493                                 { \
494                                         /* Invoke the gemm micro-kernel. */ \
495                                         gemm_ukr_cast( k, \
496                                                        minus_one, \
497                                                        a1, \
498                                                        b1, \
499                                                        zero, \
500                                                        ct, rs_ct, cs_ct, \
501                                                        &aux ); \
503                                         /* Add the result to the edge of C. */ \
504                                         PASTEMAC(ch,xpbys_mxn)( m_cur, n_cur, \
505                                                                 ct,  rs_ct, cs_ct, \
506                                                                 alpha2_cast, \
507                                                                 c11, rs_c,  cs_c ); \
508                                 } \
510                                 a1 += rstep_a; \
511                         } \
513                         c11 -= rstep_c; \
514                 } \
515                 } \
517                 b1 += cstep_b; \
518                 c1 += cstep_c; \
519         } \
521 /*
522 PASTEMAC(ch,fprintm)( stdout, "trsm_lu_ker_var2: a1 (diag)", MR, k_a1112, a1, 1, MR, "%5.2f", "" ); \
523 PASTEMAC(ch,fprintm)( stdout, "trsm_lu_ker_var2: b11 (diag)", MR, NR, b11, NR, 1, "%6.3f", "" ); \
524 printf( "m_iter     = %lu\n", m_iter ); \
525 printf( "m_cur      = %lu\n", m_cur ); \
526 printf( "k          = %lu\n", k ); \
527 printf( "diagoffa_i = %lu\n", diagoffa_i ); \
528 printf( "off_a1112  = %lu\n", off_a1112 ); \
529 printf( "k_a1112    = %lu\n", k_a1112 ); \
530 printf( "k_a12      = %lu\n", k_a12 ); \
531 printf( "k_a11      = %lu\n", k_a11 ); \
532 printf( "rs_c,cs_c  = %lu %lu\n", rs_c, cs_c ); \
533 printf( "rs_ct,cs_ct= %lu %lu\n", rs_ct, cs_ct ); \
534 */ \
536 /*
537 PASTEMAC(ch,fprintm)( stdout, "trsm_lu_ker_var2: b11 after (diag)", MR, NR, b11, NR, 1, "%5.2f", "" ); \
538 PASTEMAC(ch,fprintm)( stdout, "trsm_lu_ker_var2: b11 after (diag)", MR, NR, b11, NR, 1, "%5.2f", "" ); \
539 PASTEMAC(ch,fprintm)( stdout, "trsm_lu_ker_var2: ct after (diag)", m_cur, n_cur, ct, rs_ct, cs_ct, "%5.2f", "" ); \
540 */ \
543 INSERT_GENTFUNC_BASIC2( trsm_lu_ker_var2, gemmtrsm_ukr_t, gemm_ukr_t )