]> Gitweb @ Texas Instruments - Open Source Git Repositories - git.TI.com/gitweb - dense-linear-algebra-libraries/linalg.git/blob - blis/frame/2/trmv/bli_trmv_unf_var1.c
Consolidate all git repos of linalg into one.
[dense-linear-algebra-libraries/linalg.git] / blis / frame / 2 / trmv / bli_trmv_unf_var1.c
1 /*
3    BLIS    
4    An object-based framework for developing high-performance BLAS-like
5    libraries.
7    Copyright (C) 2014, The University of Texas at Austin
9    Redistribution and use in source and binary forms, with or without
10    modification, are permitted provided that the following conditions are
11    met:
12     - Redistributions of source code must retain the above copyright
13       notice, this list of conditions and the following disclaimer.
14     - Redistributions in binary form must reproduce the above copyright
15       notice, this list of conditions and the following disclaimer in the
16       documentation and/or other materials provided with the distribution.
17     - Neither the name of The University of Texas at Austin nor the names
18       of its contributors may be used to endorse or promote products
19       derived from this software without specific prior written permission.
21    THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
22    "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
23    LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
24    A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
25    HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
26    SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
27    LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
28    DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
29    THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
30    (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
31    OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
33 */
35 #include "blis.h"
37 #define FUNCPTR_T trmv_fp
39 typedef void (*FUNCPTR_T)(
40                            uplo_t  uplo,
41                            trans_t trans,
42                            diag_t  diag,
43                            dim_t   m,
44                            void*   alpha,
45                            void*   a, inc_t rs_a, inc_t cs_a,
46                            void*   x, inc_t incx
47                          );
49 // If some mixed datatype functions will not be compiled, we initialize
50 // the corresponding elements of the function array to NULL.
51 #ifdef BLIS_ENABLE_MIXED_PRECISION_SUPPORT
52 static FUNCPTR_T GENARRAY2_ALL(ftypes,trmv_unf_var1);
53 #else
54 #ifdef BLIS_ENABLE_MIXED_DOMAIN_SUPPORT
55 static FUNCPTR_T GENARRAY2_EXT(ftypes,trmv_unf_var1);
56 #else
57 static FUNCPTR_T GENARRAY2_MIN(ftypes,trmv_unf_var1);
58 #endif
59 #endif
62 void bli_trmv_unf_var1( obj_t*  alpha,
63                         obj_t*  a,
64                         obj_t*  x,
65                         trmv_t* cntl )
66 {
67         num_t     dt_a      = bli_obj_datatype( *a );
68         num_t     dt_x      = bli_obj_datatype( *x );
70         uplo_t    uplo      = bli_obj_uplo( *a );
71         trans_t   trans     = bli_obj_conjtrans_status( *a );
72         diag_t    diag      = bli_obj_diag( *a );
74         dim_t     m         = bli_obj_length( *a );
76         void*     buf_a     = bli_obj_buffer_at_off( *a );
77         inc_t     rs_a      = bli_obj_row_stride( *a );
78         inc_t     cs_a      = bli_obj_col_stride( *a );
80         void*     buf_x     = bli_obj_buffer_at_off( *x );
81         inc_t     incx      = bli_obj_vector_inc( *x );
83         num_t     dt_alpha;
84         void*     buf_alpha;
86         FUNCPTR_T f;
88         // The datatype of alpha MUST be the type union of a and x. This is to
89         // prevent any unnecessary loss of information during computation.
90         dt_alpha  = bli_datatype_union( dt_a, dt_x );
91         buf_alpha = bli_obj_buffer_for_1x1( dt_alpha, *alpha );
93         // Index into the type combination array to extract the correct
94         // function pointer.
95         f = ftypes[dt_a][dt_x];
97         // Invoke the function.
98         f( uplo,
99            trans,
100            diag,
101            m,
102            buf_alpha,
103            buf_a, rs_a, cs_a,
104            buf_x, incx );
108 #undef  GENTFUNC2U
109 #define GENTFUNC2U( ctype_a, ctype_x, ctype_ax, cha, chx, chax, varname, kername ) \
111 void PASTEMAC2(cha,chx,varname)( \
112                                  uplo_t  uplo, \
113                                  trans_t trans, \
114                                  diag_t  diag, \
115                                  dim_t   m, \
116                                  void*   alpha, \
117                                  void*   a, inc_t rs_a, inc_t cs_a, \
118                                  void*   x, inc_t incx  \
119                                ) \
120 { \
121         ctype_ax* alpha_cast = alpha; \
122         ctype_a*  a_cast     = a; \
123         ctype_x*  x_cast     = x; \
124         ctype_x*  one        = PASTEMAC(chx,1); \
125         ctype_a*  A10; \
126         ctype_a*  A11; \
127         ctype_a*  A12; \
128         ctype_a*  a10t; \
129         ctype_a*  alpha11; \
130         ctype_a*  a12t; \
131         ctype_x*  x0; \
132         ctype_x*  x1; \
133         ctype_x*  x2; \
134         ctype_x*  x01; \
135         ctype_x*  chi11; \
136         ctype_x*  x21; \
137         ctype_ax  alpha_alpha11_conj; \
138         ctype_ax  rho1; \
139         dim_t     iter, i, k, j, l; \
140         dim_t     b_fuse, f; \
141         dim_t     n_ahead, f_ahead; \
142         inc_t     rs_at, cs_at; \
143         uplo_t    uplo_trans; \
144         conj_t    conja; \
146         if ( bli_zero_dim1( m ) ) return; \
148         if      ( bli_does_notrans( trans ) ) \
149         { \
150                 rs_at = rs_a; \
151                 cs_at = cs_a; \
152                 uplo_trans = uplo; \
153         } \
154         else /* if ( bli_does_trans( trans ) ) */ \
155         { \
156                 rs_at = cs_a; \
157                 cs_at = rs_a; \
158                 uplo_trans = bli_uplo_toggled( uplo ); \
159         } \
161         conja = bli_extract_conj( trans ); \
163         /* Query the fusing factor for the dotxf implementation. */ \
164         b_fuse = PASTEMAC(chax,dotxf_fusefac); \
166         /* We reduce all of the possible cases down to just lower/upper. */ \
167         if      ( bli_is_upper( uplo_trans ) ) \
168         { \
169                 for ( iter = 0; iter < m; iter += f ) \
170                 { \
171                         f        = bli_determine_blocksize_dim_f( iter, m, b_fuse ); \
172                         i        = iter; \
173                         n_ahead  = m - iter - f; \
174                         A11      = a_cast + (i  )*rs_at + (i  )*cs_at; \
175                         A12      = a_cast + (i  )*rs_at + (i+f)*cs_at; \
176                         x1       = x_cast + (i  )*incx; \
177                         x2       = x_cast + (i+f)*incx; \
179                         /* x1 = alpha * A11 * x1; */ \
180                         for ( k = 0; k < f; ++k ) \
181                         { \
182                                 l        = k; \
183                                 f_ahead  = f - l - 1; \
184                                 alpha11  = A11 + (l  )*rs_at + (l  )*cs_at; \
185                                 a12t     = A11 + (l  )*rs_at + (l+1)*cs_at; \
186                                 chi11    = x1  + (l  )*incx; \
187                                 x21      = x1  + (l+1)*incx; \
189                                 /* chi11 = alpha * alpha11 * chi11; */ \
190                                 PASTEMAC2(chax,chax,copys)( *alpha_cast, alpha_alpha11_conj ); \
191                                 if ( bli_is_nonunit_diag( diag ) ) \
192                                         PASTEMAC2(cha,chax,scalcjs)( conja, *alpha11, alpha_alpha11_conj ); \
193                                 PASTEMAC2(chax,chx,scals)( alpha_alpha11_conj, *chi11 ); \
195                                 /* chi11 = chi11 + alpha * a12t * x21; */ \
196                                 PASTEMAC(chax,set0s)( rho1 ); \
197                                 if ( bli_is_conj( conja ) ) \
198                                 { \
199                                         for ( j = 0; j < f_ahead; ++j ) \
200                                                 PASTEMAC3(cha,chx,chax,dotjs)( *(a12t + j*cs_at), *(x21 + j*incx), rho1 ); \
201                                 } \
202                                 else \
203                                 { \
204                                         for ( j = 0; j < f_ahead; ++j ) \
205                                                 PASTEMAC3(cha,chx,chax,dots)( *(a12t + j*cs_at), *(x21 + j*incx), rho1 ); \
206                                 } \
207                                 PASTEMAC3(chax,chax,chx,axpys)( *alpha_cast, rho1, *chi11 ); \
208                         } \
210                         /* x1 = x1 + alpha * A12 * x2; */ \
211                         PASTEMAC3(cha,chx,chx,kername)( conja, \
212                                                         BLIS_NO_CONJUGATE, \
213                                                         n_ahead, \
214                                                         f, \
215                                                         alpha_cast, \
216                                                         A12, cs_at, rs_at, \
217                                                         x2,  incx, \
218                                                         one, \
219                                                         x1,  incx ); \
220                 } \
221         } \
222         else /* if ( bli_is_lower( uplo_trans ) ) */ \
223         { \
224                 for ( iter = 0; iter < m; iter += f ) \
225                 { \
226                         f        = bli_determine_blocksize_dim_b( iter, m, b_fuse ); \
227                         i        = m - iter - f; \
228                         n_ahead  = i; \
229                         A11      = a_cast + (i  )*rs_at + (i  )*cs_at; \
230                         A10      = a_cast + (i  )*rs_at + (0  )*cs_at; \
231                         x1       = x_cast + (i  )*incx; \
232                         x0       = x_cast + (0  )*incx; \
234                         /* x1 = alpha * A11 * x1; */ \
235                         for ( k = 0; k < f; ++k ) \
236                         { \
237                                 l        = f - k - 1; \
238                                 f_ahead  = l; \
239                                 alpha11  = A11 + (l  )*rs_at + (l  )*cs_at; \
240                                 a10t     = A11 + (l  )*rs_at + (0  )*cs_at; \
241                                 chi11    = x1  + (l  )*incx; \
242                                 x01      = x1  + (0  )*incx; \
244                                 /* chi11 = alpha * alpha11 * chi11; */ \
245                                 PASTEMAC2(chax,chax,copys)( *alpha_cast, alpha_alpha11_conj ); \
246                                 if ( bli_is_nonunit_diag( diag ) ) \
247                                         PASTEMAC2(cha,chax,scalcjs)( conja, *alpha11, alpha_alpha11_conj ); \
248                                 PASTEMAC2(chax,chx,scals)( alpha_alpha11_conj, *chi11 ); \
250                                 /* chi11 = chi11 + alpha * a10t * x01; */ \
251                                 PASTEMAC(chax,set0s)( rho1 ); \
252                                 if ( bli_is_conj( conja ) ) \
253                                 { \
254                                         for ( j = 0; j < f_ahead; ++j ) \
255                                                 PASTEMAC3(cha,chx,chax,dotjs)( *(a10t + j*cs_at), *(x01 + j*incx), rho1 ); \
256                                 } \
257                                 else \
258                                 { \
259                                         for ( j = 0; j < f_ahead; ++j ) \
260                                                 PASTEMAC3(cha,chx,chax,dots)( *(a10t + j*cs_at), *(x01 + j*incx), rho1 ); \
261                                 } \
262                                 PASTEMAC3(chax,chax,chx,axpys)( *alpha_cast, rho1, *chi11 ); \
263                         } \
265                         /* x1 = x1 + alpha * A10 * x0; */ \
266                         PASTEMAC3(cha,chx,chx,kername)( conja, \
267                                                         BLIS_NO_CONJUGATE, \
268                                                         n_ahead, \
269                                                         f, \
270                                                         alpha_cast, \
271                                                         A10, cs_at, rs_at, \
272                                                         x0,  incx, \
273                                                         one, \
274                                                         x1,  incx ); \
275                 } \
276         } \
279 // Define the basic set of functions unconditionally, and then also some
280 // mixed datatype functions if requested.
281 INSERT_GENTFUNC2U_BASIC( trmv_unf_var1, DOTXF_KERNEL )
283 #ifdef BLIS_ENABLE_MIXED_DOMAIN_SUPPORT
284 INSERT_GENTFUNC2U_MIX_D( trmv_unf_var1, DOTXF_KERNEL )
285 #endif
287 #ifdef BLIS_ENABLE_MIXED_PRECISION_SUPPORT
288 INSERT_GENTFUNC2U_MIX_P( trmv_unf_var1, DOTXF_KERNEL )
289 #endif