1 /*
3 BLIS
4 An object-based framework for developing high-performance BLAS-like
5 libraries.
7 Copyright (C) 2014, The University of Texas at Austin
9 Redistribution and use in source and binary forms, with or without
10 modification, are permitted provided that the following conditions are
11 met:
12 - Redistributions of source code must retain the above copyright
13 notice, this list of conditions and the following disclaimer.
14 - Redistributions in binary form must reproduce the above copyright
15 notice, this list of conditions and the following disclaimer in the
16 documentation and/or other materials provided with the distribution.
17 - Neither the name of The University of Texas at Austin nor the names
18 of its contributors may be used to endorse or promote products
19 derived from this software without specific prior written permission.
21 THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
22 "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
23 LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
24 A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
25 HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
26 SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
27 LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
28 DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
29 THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
30 (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
31 OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
33 */
35 #include "blis.h"
36 #include "assert.h"
38 void bli_setup_trsm_thrinfo_node( trsm_thrinfo_t* thread,
39 thread_comm_t* ocomm, dim_t ocomm_id,
40 thread_comm_t* icomm, dim_t icomm_id,
41 dim_t n_way, dim_t work_id,
42 packm_thrinfo_t* opackm,
43 packm_thrinfo_t* ipackm,
44 trsm_thrinfo_t* sub_trsm )
45 {
46 thread->ocomm = ocomm;
47 thread->ocomm_id = ocomm_id;
48 thread->icomm = icomm;
49 thread->icomm_id = icomm_id;
50 thread->n_way = n_way;
51 thread->work_id = work_id;
52 thread->opackm = opackm;
53 thread->ipackm = ipackm;
54 thread->sub_trsm = sub_trsm;
55 }
57 void bli_setup_trsm_single_threaded_info( trsm_thrinfo_t* thread )
58 {
59 thread->ocomm = &BLIS_SINGLE_COMM;
60 thread->ocomm_id = 0;
61 thread->icomm = &BLIS_SINGLE_COMM;
62 thread->icomm_id = 0;
63 thread->n_way = 1;
64 thread->work_id = 0;
65 thread->opackm = &BLIS_PACKM_SINGLE_THREADED;
66 thread->ipackm = &BLIS_PACKM_SINGLE_THREADED;
67 thread->sub_trsm = thread;
68 }
70 trsm_thrinfo_t* bli_create_trsm_thrinfo_node( thread_comm_t* ocomm, dim_t ocomm_id,
71 thread_comm_t* icomm, dim_t icomm_id,
72 dim_t n_way, dim_t work_id,
73 packm_thrinfo_t* opackm,
74 packm_thrinfo_t* ipackm,
75 trsm_thrinfo_t* sub_trsm )
76 {
77 trsm_thrinfo_t* thread = ( trsm_thrinfo_t* ) bli_malloc( sizeof( trsm_thrinfo_t ) );
78 bli_setup_trsm_thrinfo_node( thread, ocomm, ocomm_id,
79 icomm, icomm_id,
80 n_way, work_id,
81 opackm,
82 ipackm,
83 sub_trsm );
84 return thread;
85 }
87 void bli_trsm_thrinfo_free( trsm_thrinfo_t* thread)
88 {
89 if( thread == NULL ) return;
91 // Free Communicators
92 if( thread_am_ochief( thread ) )
93 bli_free_communicator( thread->ocomm );
95 if (thread->sub_trsm == NULL)
96 bli_free_communicator( thread->icomm );
98 // Free Sub Thrinfos
99 bli_packm_thrinfo_free( thread->opackm );
100 bli_packm_thrinfo_free( thread->ipackm );
101 bli_trsm_thrinfo_free( thread->sub_trsm );
102 bli_free( thread );
104 return;
105 }
106 void bli_trsm_thrinfo_free_paths( trsm_thrinfo_t** threads, dim_t num )
107 {
108 for( int i = 0; i < num; i++)
109 bli_trsm_thrinfo_free( threads[i] );
110 bli_free( threads );
111 }
113 trsm_thrinfo_t** bli_create_trsm_thrinfo_paths( bool_t right_sided )
114 {
115 dim_t jc_way = 1;
116 dim_t kc_way = 1;
117 dim_t ic_way = 1;
118 dim_t jr_way = 1;
119 dim_t ir_way = 1;
121 #ifdef BLIS_ENABLE_MULTITHREADING
122 dim_t jc_in = bli_read_nway_from_env( "BLIS_JC_NT" );
123 /*dim_t kc_in = bli_read_nway_from_env( "BLIS_KC_NT" );*/
124 dim_t ic_in = bli_read_nway_from_env( "BLIS_IC_NT" );
125 dim_t jr_in = bli_read_nway_from_env( "BLIS_JR_NT" );
126 dim_t ir_in = bli_read_nway_from_env( "BLIS_IR_NT" );
128 if(right_sided) {
129 ic_way = jc_in * ic_in * jr_in;
130 ir_way = ir_in;
131 }
132 else {
133 jc_way = jc_in;
134 jr_way = jr_in * ic_in * ir_in;
135 }
136 #endif
138 dim_t global_num_threads = jc_way * kc_way * ic_way * jr_way * ir_way;
139 assert( global_num_threads != 0 );
141 dim_t jc_nt = kc_way * ic_way * jr_way * ir_way;
142 dim_t kc_nt = ic_way * jr_way * ir_way;
143 dim_t ic_nt = jr_way * ir_way;
144 dim_t jr_nt = ir_way;
145 dim_t ir_nt = 1;
148 trsm_thrinfo_t** paths = (trsm_thrinfo_t**) bli_malloc( global_num_threads * sizeof( trsm_thrinfo_t* ) );
150 thread_comm_t* global_comm = bli_create_communicator( global_num_threads );
151 for( int a = 0; a < jc_way; a++ )
152 {
153 thread_comm_t* jc_comm = bli_create_communicator( jc_nt );
154 for( int b = 0; b < kc_way; b++ )
155 {
156 thread_comm_t* kc_comm = bli_create_communicator( kc_nt );
157 for( int c = 0; c < ic_way; c++ )
158 {
159 thread_comm_t* ic_comm = bli_create_communicator( ic_nt );
160 for( int d = 0; d < jr_way; d++ )
161 {
162 thread_comm_t* jr_comm = bli_create_communicator( jr_nt );
163 for( int e = 0; e < ir_way; e++)
164 {
165 thread_comm_t* ir_comm = bli_create_communicator( ir_nt );
166 dim_t ir_comm_id = 0;
167 dim_t jr_comm_id = e*ir_nt + ir_comm_id;
168 dim_t ic_comm_id = d*jr_nt + jr_comm_id;
169 dim_t kc_comm_id = c*ic_nt + ic_comm_id;
170 dim_t jc_comm_id = b*kc_nt + kc_comm_id;
171 dim_t global_comm_id = a*jc_nt + jc_comm_id;
173 trsm_thrinfo_t* ir_info = bli_create_trsm_thrinfo_node( jr_comm, jr_comm_id,
174 ir_comm, ir_comm_id,
175 ir_way, e,
176 NULL, NULL, NULL);
178 trsm_thrinfo_t* jr_info = bli_create_trsm_thrinfo_node( ic_comm, ic_comm_id,
179 jr_comm, jr_comm_id,
180 jr_way, d,
181 NULL, NULL, ir_info);
183 /*
184 packm_thrinfo_t* packb = bli_create_packm_thread_info( kc_comm, kc_comm_id,
185 ic_comm, ic_comm_id,
186 kc_nt, kc_comm_id );
188 packm_thrinfo_t* packa = bli_create_packm_thread_info( ic_comm, ic_comm_id,
189 jr_comm, jr_comm_id,
190 ic_nt, ic_comm_id );
192 trsm_thrinfo_t* ic_info = bli_create_trsm_thrinfo_node( kc_comm, kc_comm_id,
193 ic_comm, ic_comm_id,
194 ic_way, c,
195 packb, packa, jr_info);
196 */
197 //blk_var_1
198 packm_thrinfo_t* pack_ic_in = bli_create_packm_thread_info( ic_comm, ic_comm_id,
199 jr_comm, jr_comm_id,
200 ic_nt, ic_comm_id );
202 packm_thrinfo_t* pack_ic_out = bli_create_packm_thread_info( kc_comm, kc_comm_id,
203 ic_comm, ic_comm_id,
204 kc_nt, kc_comm_id );
206 trsm_thrinfo_t* ic_info = bli_create_trsm_thrinfo_node( kc_comm, kc_comm_id,
207 ic_comm, ic_comm_id,
208 ic_way, c,
209 pack_ic_out, pack_ic_in, jr_info);
211 //blk_var_3
212 packm_thrinfo_t* pack_kc_in = bli_create_packm_thread_info( kc_comm, kc_comm_id,
213 ic_comm, ic_comm_id,
214 kc_nt, kc_comm_id );
216 packm_thrinfo_t* pack_kc_out = bli_create_packm_thread_info( jc_comm, jc_comm_id,
217 jc_comm, jc_comm_id,
218 jc_nt, jc_comm_id );
220 trsm_thrinfo_t* kc_info = bli_create_trsm_thrinfo_node( jc_comm, jc_comm_id,
221 kc_comm, kc_comm_id,
222 kc_way, b,
223 pack_kc_out, pack_kc_in, ic_info);
225 //blk_var_2
226 packm_thrinfo_t* pack_jc_in = bli_create_packm_thread_info( jc_comm, jc_comm_id,
227 kc_comm, kc_comm_id,
228 jc_nt, jc_comm_id );
230 packm_thrinfo_t* pack_jc_out = bli_create_packm_thread_info( global_comm, global_comm_id,
231 jc_comm, jc_comm_id,
232 global_num_threads, global_comm_id );
235 trsm_thrinfo_t* jc_info = bli_create_trsm_thrinfo_node( global_comm, global_comm_id,
236 jc_comm, jc_comm_id,
237 jc_way, a,
238 pack_jc_out, pack_jc_in, kc_info);
239 paths[global_comm_id] = jc_info;
240 }
241 }
242 }
243 }
244 }
245 return paths;
246 }