1 /*
3 BLIS
4 An object-based framework for developing high-performance BLAS-like
5 libraries.
7 Copyright (C) 2014, The University of Texas at Austin
9 Redistribution and use in source and binary forms, with or without
10 modification, are permitted provided that the following conditions are
11 met:
12 - Redistributions of source code must retain the above copyright
13 notice, this list of conditions and the following disclaimer.
14 - Redistributions in binary form must reproduce the above copyright
15 notice, this list of conditions and the following disclaimer in the
16 documentation and/or other materials provided with the distribution.
17 - Neither the name of The University of Texas at Austin nor the names
18 of its contributors may be used to endorse or promote products
19 derived from this software without specific prior written permission.
21 THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
22 "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
23 LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
24 A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
25 HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
26 SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
27 LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
28 DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
29 THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
30 (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
31 OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
33 */
35 #include "blis.h"
37 #define FUNCPTR_T gemm_fp
39 typedef void (*FUNCPTR_T)(
40 dim_t m,
41 dim_t n,
42 dim_t k,
43 void* alpha,
44 void* a, inc_t cs_a, inc_t pd_a, inc_t ps_a,
45 void* b, inc_t rs_b, inc_t pd_b, inc_t ps_b,
46 void* beta,
47 void* c, inc_t rs_c, inc_t cs_c,
48 void* gemm_ukr
49 );
51 static FUNCPTR_T GENARRAY(ftypes,gemm_ker_var5);
54 void bli_gemm_ker_var5( obj_t* a,
55 obj_t* b,
56 obj_t* c,
57 gemm_t* cntl,
58 gemm_thrinfo_t* thread )
59 {
60 num_t dt_exec = bli_obj_execution_datatype( *c );
62 dim_t m = bli_obj_length( *c );
63 dim_t n = bli_obj_width( *c );
64 dim_t k = bli_obj_width( *a );
66 void* buf_a = bli_obj_buffer_at_off( *a );
67 inc_t cs_a = bli_obj_col_stride( *a );
68 inc_t pd_a = bli_obj_panel_dim( *a );
69 inc_t ps_a = bli_obj_panel_stride( *a );
71 void* buf_b = bli_obj_buffer_at_off( *b );
72 inc_t rs_b = bli_obj_row_stride( *b );
73 inc_t pd_b = bli_obj_panel_dim( *b );
74 inc_t ps_b = bli_obj_panel_stride( *b );
76 void* buf_c = bli_obj_buffer_at_off( *c );
77 inc_t rs_c = bli_obj_row_stride( *c );
78 inc_t cs_c = bli_obj_col_stride( *c );
80 obj_t scalar_a;
81 obj_t scalar_b;
83 void* buf_alpha;
84 void* buf_beta;
86 FUNCPTR_T f;
88 func_t* gemm_ukrs;
89 void* gemm_ukr;
92 // Detach and multiply the scalars attached to A and B.
93 bli_obj_scalar_detach( a, &scalar_a );
94 bli_obj_scalar_detach( b, &scalar_b );
95 bli_mulsc( &scalar_a, &scalar_b );
97 // Grab the addresses of the internal scalar buffers for the scalar
98 // merged above and the scalar attached to C.
99 buf_alpha = bli_obj_internal_scalar_buffer( scalar_b );
100 buf_beta = bli_obj_internal_scalar_buffer( *c );
102 // Index into the type combination array to extract the correct
103 // function pointer.
104 f = ftypes[dt_exec];
106 // Extract from the control tree node the func_t object containing
107 // the gemm micro-kernel function addresses, and then query the
108 // function address corresponding to the current datatype.
109 gemm_ukrs = cntl_gemm_ukrs( cntl );
110 gemm_ukr = bli_func_obj_query( dt_exec, gemm_ukrs );
112 // Invoke the function.
113 f( m,
114 n,
115 k,
116 buf_alpha,
117 buf_a, cs_a, pd_a, ps_a,
118 buf_b, rs_b, pd_b, ps_b,
119 buf_beta,
120 buf_c, rs_c, cs_c,
121 gemm_ukr );
122 }
125 #undef GENTFUNC
126 #define GENTFUNC( ctype, ch, varname, ukrtype ) \
127 \
128 void PASTEMAC(ch,varname)( \
129 dim_t m, \
130 dim_t n, \
131 dim_t k, \
132 void* alpha, \
133 void* a, inc_t cs_a, inc_t pd_a, inc_t ps_a, \
134 void* b, inc_t rs_b, inc_t pd_b, inc_t ps_b, \
135 void* beta, \
136 void* c, inc_t rs_c, inc_t cs_c, \
137 void* gemm_ukr \
138 ) \
139 { \
140 /* Cast the micro-kernel address to its function pointer type. */ \
141 PASTECH(ch,ukrtype) gemm_ukr_cast = gemm_ukr; \
142 \
143 /* Temporary buffer for incremental packing of B. */ \
144 ctype bp[ PASTEMAC(ch,maxkc) * \
145 /* !!!! NOTE: This packnr actually needs to be something like maxpacknr
146 if it is to be guaranteed to work in all situations !!!! The right
147 place to define maxpackmr/nr would be in bli_kernel_post_macro_defs.h */ \
148 PASTEMAC(ch,packnr) ] \
149 __attribute__((aligned(BLIS_STACK_BUF_ALIGN_SIZE))); \
150 \
151 /* Temporary C buffer for edge cases. */ \
152 ctype ct[ PASTEMAC(ch,maxmr) * \
153 PASTEMAC(ch,maxnr) ] \
154 __attribute__((aligned(BLIS_STACK_BUF_ALIGN_SIZE))); \
155 const inc_t rs_ct = 1; \
156 const inc_t cs_ct = PASTEMAC(ch,maxmr); \
157 \
158 /* Alias some constants to simpler names. */ \
159 const dim_t MR = pd_a; \
160 const dim_t NR = pd_b; \
161 const dim_t PACKNR = rs_b; \
162 \
163 ctype* restrict one = PASTEMAC(ch,1); \
164 ctype* restrict zero = PASTEMAC(ch,0); \
165 ctype* restrict a_cast = a; \
166 ctype* restrict b_cast = b; \
167 ctype* restrict c_cast = c; \
168 ctype* restrict alpha_cast = alpha; \
169 ctype* restrict beta_cast = beta; \
170 ctype* restrict b1; \
171 ctype* restrict c1; \
172 ctype* restrict b2; \
173 \
174 dim_t m_iter, m_left; \
175 dim_t n_iter, n_left; \
176 dim_t i, j; \
177 dim_t m_cur; \
178 dim_t n_cur; \
179 inc_t rstep_a; \
180 inc_t cstep_b; \
181 inc_t rstep_c, cstep_c; \
182 auxinfo_t aux; \
183 \
184 /*
185 Assumptions/assertions:
186 rs_a == 1
187 cs_a == PACKMR
188 pd_a == MR
189 ps_a == stride to next micro-panel of A
190 rs_b == PACKNR
191 cs_b == 1
192 pd_b == NR
193 ps_b == stride to next micro-panel of B
194 rs_c == (no assumptions)
195 cs_c == (no assumptions)
196 */ \
197 \
198 /* If any dimension is zero, return immediately. */ \
199 if ( bli_zero_dim3( m, n, k ) ) return; \
200 \
201 /* Clear the temporary C buffer in case it has any infs or NaNs. */ \
202 PASTEMAC(ch,set0s_mxn)( MR, NR, \
203 ct, rs_ct, cs_ct ); \
204 \
205 /* Compute number of primary and leftover components of the m and n
206 dimensions. */ \
207 n_iter = n / NR; \
208 n_left = n % NR; \
209 \
210 m_iter = m / MR; \
211 m_left = m % MR; \
212 \
213 if ( n_left ) ++n_iter; \
214 if ( m_left ) ++m_iter; \
215 \
216 /* Determine some increments used to step through A, B, and C. */ \
217 rstep_a = ps_a; \
218 \
219 cstep_b = ps_b; \
220 \
221 rstep_c = rs_c * MR; \
222 cstep_c = cs_c * NR; \
223 \
224 /* Save the panel strides of A and B to the auxinfo_t object. */ \
225 bli_auxinfo_set_ps_a( ps_a, aux ); \
226 bli_auxinfo_set_ps_b( ps_b, aux ); \
227 \
228 b1 = b_cast; \
229 c1 = c_cast; \
230 \
231 /* Since we pack micro-panels of B incrementaly, one at a time, the
232 address of the next micro-panel of B remains constant. */ \
233 b2 = bp; \
234 \
235 /* Save address of next panel of B to the auxinfo_t object. */ \
236 bli_auxinfo_set_next_b( b2, aux ); \
237 \
238 /* Loop over the n dimension (NR columns at a time). */ \
239 for ( j = 0; j < n_iter; ++j ) \
240 { \
241 ctype* restrict a1; \
242 ctype* restrict c11; \
243 \
244 a1 = a_cast; \
245 c11 = c1; \
246 \
247 n_cur = ( bli_is_not_edge_f( j, n_iter, n_left ) ? NR : n_left ); \
248 \
249 /* Incrementally pack a single micro-panel of B. */ \
250 PASTEMAC(ch,packm_cxk)( BLIS_NO_CONJUGATE, \
251 n_cur, \
252 k, \
253 one, \
254 b1, 1, rs_b, \
255 bp, PACKNR ); \
256 \
257 /* Loop over the m dimension (MR rows at a time). */ \
258 for ( i = 0; i < m_iter; ++i ) \
259 { \
260 ctype* restrict a2; \
261 \
262 m_cur = ( bli_is_not_edge_f( i, m_iter, m_left ) ? MR : m_left ); \
263 \
264 /* Compute the addresses of the next panels of A and B. */ \
265 a2 = a1 + rstep_a; \
266 if ( bli_is_last_iter( i, m_iter ) ) \
267 { \
268 a2 = a_cast; \
269 } \
270 \
271 /* Save address of next panel of A to the auxinfo_t object. */ \
272 bli_auxinfo_set_next_a( a2, aux ); \
273 \
274 /* Handle interior and edge cases separately. */ \
275 if ( m_cur == MR && n_cur == NR ) \
276 { \
277 /* Invoke the gemm micro-kernel. */ \
278 gemm_ukr_cast( k, \
279 alpha_cast, \
280 a1, \
281 bp, \
282 beta_cast, \
283 c11, rs_c, cs_c, \
284 &aux ); \
285 } \
286 else \
287 { \
288 /* Invoke the gemm micro-kernel. */ \
289 gemm_ukr_cast( k, \
290 alpha_cast, \
291 a1, \
292 bp, \
293 zero, \
294 ct, rs_ct, cs_ct, \
295 &aux ); \
296 \
297 /* Scale the bottom edge of C and add the result from above. */ \
298 PASTEMAC(ch,xpbys_mxn)( m_cur, n_cur, \
299 ct, rs_ct, cs_ct, \
300 beta_cast, \
301 c11, rs_c, cs_c ); \
302 } \
303 \
304 a1 += rstep_a; \
305 c11 += rstep_c; \
306 } \
307 \
308 b1 += cstep_b; \
309 c1 += cstep_c; \
310 } \
311 \
312 /*PASTEMAC(ch,fprintm)( stdout, "gemm_ker_var5: b1", k, NR, b1, NR, 1, "%4.1f", "" ); \
313 PASTEMAC(ch,fprintm)( stdout, "gemm_ker_var5: a1", MR, k, a1, 1, MR, "%4.1f", "" );*/ \
314 }
316 INSERT_GENTFUNC_BASIC( gemm_ker_var5, gemm_ukr_t )