aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorYuan Zhao2019-11-20 23:38:12 -0600
committerYuan Zhao2019-11-20 23:38:12 -0600
commitf6e0c49d2ec377adee1e94363a626a83497729c3 (patch)
tree8034a3eb18a14e5e73dcda6ad7c1ae7e22390a62
parent94d4720f20949699ba551ddf585d62831d8e50ed (diff)
downloadtidl-api-f6e0c49d2ec377adee1e94363a626a83497729c3.tar.gz
tidl-api-f6e0c49d2ec377adee1e94363a626a83497729c3.tar.xz
tidl-api-f6e0c49d2ec377adee1e94363a626a83497729c3.zip
Subgraph example: multi-threaded batch processing
- Compared different batch size in subgraph execution example - Compared async/future implementation vs thread pool implementation, async/future has slightly worse (~1%) performance, but it is much easier to program - Recommended inference is multi-threaded batch processing, where batch_size can be obtained from TidlGetPreferredBatchSize(), number of threads can be set to 2. - MCT-1223
-rw-r--r--examples/mobilenet_subgraph/Makefile2
-rw-r--r--examples/mobilenet_subgraph/main.cpp278
-rw-r--r--examples/mobilenet_subgraph/thread_pool.cpp144
-rw-r--r--examples/mobilenet_subgraph/thread_pool.h77
-rw-r--r--tidl_api/inc/subgraph_runtime.h7
-rw-r--r--tidl_api/src/subgraph_runtime.cpp5
-rw-r--r--tidl_api/src/subgraph_runtime_impl.h1
7 files changed, 472 insertions, 42 deletions
diff --git a/examples/mobilenet_subgraph/Makefile b/examples/mobilenet_subgraph/Makefile
index 68f5d9d..e4a5173 100644
--- a/examples/mobilenet_subgraph/Makefile
+++ b/examples/mobilenet_subgraph/Makefile
@@ -36,7 +36,7 @@ LIBS += -ljson-c
36LIBS += -L$(TIDL_API_DIR) -ltidl_api -ltidl_imgutil 36LIBS += -L$(TIDL_API_DIR) -ltidl_api -ltidl_imgutil
37 37
38SOURCES = main.cpp ../common/object_classes.cpp ../common/utils.cpp \ 38SOURCES = main.cpp ../common/object_classes.cpp ../common/utils.cpp \
39 ../common/video_utils.cpp 39 ../common/video_utils.cpp thread_pool.cpp
40 40
41$(EXE): $(HEADERS) $(SOURCES) 41$(EXE): $(HEADERS) $(SOURCES)
42 $(CXX) $(CXXFLAGS) $(SOURCES) \ 42 $(CXX) $(CXXFLAGS) $(SOURCES) \
diff --git a/examples/mobilenet_subgraph/main.cpp b/examples/mobilenet_subgraph/main.cpp
index e4e499a..8a77f65 100644
--- a/examples/mobilenet_subgraph/main.cpp
+++ b/examples/mobilenet_subgraph/main.cpp
@@ -1,5 +1,5 @@
1/****************************************************************************** 1/******************************************************************************
2 * Copyright (c) 2018, Texas Instruments Incorporated - http://www.ti.com/ 2 * Copyright (c) 2019, Texas Instruments Incorporated - http://www.ti.com/
3 * All rights reserved. 3 * All rights reserved.
4 * 4 *
5 * Redistribution and use in source and binary forms, with or without 5 * Redistribution and use in source and binary forms, with or without
@@ -25,6 +25,7 @@
25 * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF 25 * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF
26 * THE POSSIBILITY OF SUCH DAMAGE. 26 * THE POSSIBILITY OF SUCH DAMAGE.
27 *****************************************************************************/ 27 *****************************************************************************/
28
28#include <signal.h> 29#include <signal.h>
29#include <iostream> 30#include <iostream>
30#include <iomanip> 31#include <iomanip>
@@ -50,6 +51,7 @@
50#include "../common/object_classes.h" 51#include "../common/object_classes.h"
51#include "imgutil.h" 52#include "imgutil.h"
52#include "../common/video_utils.h" 53#include "../common/video_utils.h"
54#include "thread_pool.h"
53 55
54#include "opencv2/core.hpp" 56#include "opencv2/core.hpp"
55#include "opencv2/imgproc.hpp" 57#include "opencv2/imgproc.hpp"
@@ -70,13 +72,32 @@ const char *default_inputs[NUM_DEFAULT_INPUTS] =
70 "../test/testvecs/input/objects/cat-pet-animal-domestic-104827.jpeg" 72 "../test/testvecs/input/objects/cat-pet-animal-domestic-104827.jpeg"
71}; 73};
72std::unique_ptr<ObjectClasses> object_classes; 74std::unique_ptr<ObjectClasses> object_classes;
75typedef struct {
76 float **inputs;
77 float **outputs;
78} UserData;
73 79
74bool RunConfiguration(cmdline_opts_t& opts); 80bool RunConfiguration(cmdline_opts_t& opts);
75bool ReadFrame(const cmdline_opts_t& opts, VideoCapture &cap, float** inputs, 81bool ReadFrame(const cmdline_opts_t& opts, VideoCapture &cap, float** inputs,
76 int batch_size); 82 int batch_size);
77bool WriteFrameOutput(float *out, const cmdline_opts_t& opts); 83bool WriteFrameOutput(float *out, const cmdline_opts_t& opts);
78void DisplayHelp(); 84void DisplayHelp();
85void SubgraphUserFunc(void *user_data);
79 86
87const int num_printed_outputs = 4;
88bool SkipOutputs(int i, int offset, bool &skip_outputs)
89{
90 if (skip_outputs) return true;
91 if (i >= num_printed_outputs + offset)
92 {
93 if (! skip_outputs)
94 {
95 cout << " ... skippping outputs ..." << endl;
96 skip_outputs = true;
97 }
98 }
99 return skip_outputs;
100}
80 101
81int main(int argc, char *argv[]) 102int main(int argc, char *argv[])
82{ 103{
@@ -180,37 +201,123 @@ bool RunConfiguration(cmdline_opts_t& opts)
180 status = false; 201 status = false;
181 } 202 }
182 203
183 int batch_size = 8; 204 // If not doing multi-threaded processing, multiply by 2 or more
184 cout << "\n##### Batch size " << batch_size << " testing ######\n" << endl; 205 // for a larger batch to amortize batch initilization/tear down cost
206 int preferred_batch_size = TidlGetPreferredBatchSize(1);
207 for (int multiple = 1; multiple <= 16; multiple *= 2)
208 {
209 int batch_size = preferred_batch_size * multiple;
210 cout << "\n##### Batch size " << batch_size << " testing ######\n"
211 << endl;
212 bool skip_outputs = false;
213 try
214 {
215 float **inputs = new float *[batch_size];
216 float **outputs = new float *[batch_size];
217 for (int i = 0; i < batch_size; i++)
218 {
219 inputs[i] = new float[1*3*224*224];
220 outputs[i] = new float[1001];
221 }
222
223 chrono::time_point<chrono::steady_clock> tloop0, tloop1;
224 tloop0 = chrono::steady_clock::now();
225
226 ReadFrame(opts, cap, inputs, batch_size);
227 TidlRunSubgraph(1, 0, batch_size, 1, 1, inputs, outputs);
228 for (int i = 0; i < batch_size; i++)
229 {
230 if (! SkipOutputs(i, 0, skip_outputs))
231 {
232 cout << "Frame " << i << " of " << batch_size
233 << " output:" << endl;
234 WriteFrameOutput(outputs[i], opts);
235 }
236 }
237
238 tloop1 = chrono::steady_clock::now();
239 chrono::duration<float> elapsed = tloop1 - tloop0;
240 cout << "Batch size " << batch_size
241 << " time: "
242 << setw(6) << setprecision(4)
243 << (elapsed.count() * 1000) << "ms, fps = "
244 << setw(6) << setprecision(4)
245 << (batch_size / elapsed.count())
246 << endl;
247
248 for (int i = 0; i < batch_size; i++)
249 {
250 delete [] inputs[i];
251 delete [] outputs[i];
252 }
253 delete [] inputs;
254 delete [] outputs;
255 }
256 catch (tidl::Exception &e)
257 {
258 cerr << e.what() << endl;
259 status = false;
260 }
261 }
262
263 // This is to test the multithreaded inference with async/future
264 // async/future has slightly worse threading performance than
265 // thread pool, however, it is much easier to program
266 cout << "\n##### Multithreaded inference testing (async/future) #####\n"
267 << endl;
268 int num_threads = TidlGetPreferredBatchSize(1) * 2;
269 int num_iters = 100;
185 try 270 try
186 { 271 {
187 float **inputs = new float *[batch_size]; 272 float **inputs = new float *[num_threads];
188 float **outputs = new float *[batch_size]; 273 float **outputs = new float *[num_threads];
189 for (int i = 0; i < batch_size; i++) 274 for (int i = 0; i < num_threads; i++)
190 { 275 {
191 inputs[i] = new float[1*3*224*224]; 276 inputs[i] = new float[1*3*224*224];
192 outputs[i] = new float[1001]; 277 outputs[i] = new float[1001];
193 } 278 }
279 vector<future<bool>> futures(num_threads);
280 bool skip_outputs = false;
194 281
195 chrono::time_point<chrono::steady_clock> tloop0, tloop1; 282 chrono::time_point<chrono::steady_clock> tloop0, tloop1;
196 tloop0 = chrono::steady_clock::now(); 283 tloop0 = chrono::steady_clock::now();
197 284
198 ReadFrame(opts, cap, inputs, batch_size); 285 for (int i = 0; i < num_iters + num_threads; i++)
199 TidlRunSubgraph(1, 0, batch_size, 1, 1, inputs, outputs);
200 for (int i = 0; i < batch_size; i++)
201 { 286 {
202 cout << "Frame " << i << " of " << batch_size << " output:" << endl; 287 int index = i % num_threads;
203 WriteFrameOutput(outputs[i], opts); 288 if (i >= num_threads)
289 {
290 if (futures[index].get())
291 {
292 if (! SkipOutputs(i, num_threads, skip_outputs))
293 WriteFrameOutput(outputs[index], opts);
294 }
295 }
296
297 if (i < num_iters)
298 {
299 ReadFrame(opts, cap, &inputs[index], 1);
300 futures[index] = std::async(std::launch::async,
301 [inputs, outputs](int index) {
302 TidlRunSubgraph(1, 0, 1, 1, 1,
303 &inputs[index], &outputs[index]);
304 return true;
305 },
306 index);
307 }
204 } 308 }
205 309
206 tloop1 = chrono::steady_clock::now(); 310 tloop1 = chrono::steady_clock::now();
207 chrono::duration<float> elapsed = tloop1 - tloop0; 311 chrono::duration<float> elapsed = tloop1 - tloop0;
208 cout << "Batch size " << batch_size 312 cout << "Multithreaded (num_threads=" << num_threads
209 << " time (including read/write/opencv/print/etc): " 313 << ", batch_size=1) loop time (" << num_iters << " frames): "
210 << setw(6) << setprecision(4) 314 << setw(6) << setprecision(4)
211 << (elapsed.count() * 1000) << "ms" << endl; 315 << (elapsed.count() * 1000) << "ms, fps = "
316 << setw(6) << setprecision(4)
317 << (num_iters / elapsed.count())
318 << endl;
212 319
213 for (int i = 0; i < batch_size; i++) 320 for (int i = 0; i < num_threads; i++)
214 { 321 {
215 delete [] inputs[i]; 322 delete [] inputs[i];
216 delete [] outputs[i]; 323 delete [] outputs[i];
@@ -224,53 +331,62 @@ bool RunConfiguration(cmdline_opts_t& opts)
224 status = false; 331 status = false;
225 } 332 }
226 333
227 // This is only to test the multithreaded inference 334 // This is to test the multithreaded inference with a thread pool
228 // async/future may not be the most efficient multithreading method 335 cout << "\n##### Multithreaded inference testing (thread pool) #####\n"
229 // threading pool might have better performance 336 << endl;
230 cout << "\n##### Multithreaded inference testing #####\n" << endl;
231 int num_threads = 8;
232 int num_iters = 8;
233 try 337 try
234 { 338 {
235 float **inputs = new float *[num_threads]; 339 float **inputs = new float *[num_threads];
236 float **outputs = new float *[num_threads]; 340 float **outputs = new float *[num_threads];
341 vector<UserData> v_data(num_threads);
237 for (int i = 0; i < num_threads; i++) 342 for (int i = 0; i < num_threads; i++)
238 { 343 {
239 inputs[i] = new float[1*3*224*224]; 344 inputs[i] = new float[1*3*224*224];
240 outputs[i] = new float[1001]; 345 outputs[i] = new float[1001];
346 v_data[i].inputs = &inputs[i];
347 v_data[i].outputs = &outputs[i];
241 } 348 }
242 vector<future<bool>> futures(num_threads); 349 ThPool pool(num_threads, SubgraphUserFunc);
350 vector<int> th_ids(num_threads);
351 bool skip_outputs = false;
243 352
244 chrono::time_point<chrono::steady_clock> tloop0, tloop1; 353 chrono::time_point<chrono::steady_clock> tloop0, tloop1;
245 tloop0 = chrono::steady_clock::now(); 354 tloop0 = chrono::steady_clock::now();
246 355
247 for (int i = 0; i < num_iters + num_threads; i++) 356 for (int i = 0; i < num_iters + num_threads; i++)
248 { 357 {
249 int index = i % num_threads; 358 int index = i % num_threads;
250 if (i >= num_threads) 359 if (i >= num_threads)
251 { 360 {
252 if (futures[index].get()) 361 UserData *data = (UserData *) pool.Wait(th_ids[index]);
253 WriteFrameOutput(outputs[index], opts); 362 if (! SkipOutputs(i, num_threads, skip_outputs))
254 } 363 WriteFrameOutput(data->outputs[0], opts);
255 364 }
256 if (i < num_iters) 365
257 { 366 if (i < num_iters)
258 ReadFrame(opts, cap, &inputs[index], 1); 367 {
259 futures[index] = std::async(std::launch::async, 368 ReadFrame(opts, cap, &inputs[index], 1);
260 [inputs, outputs](int index) { 369 th_ids[index] = pool.RunAsync(&v_data[index]);
261 TidlRunSubgraph(1, 0, 1, 1, 1, &inputs[index], &outputs[index]); 370 }
262 return true;
263 },
264 index);
265 }
266 } 371 }
267 372
268 tloop1 = chrono::steady_clock::now(); 373 tloop1 = chrono::steady_clock::now();
269 chrono::duration<float> elapsed = tloop1 - tloop0; 374 chrono::duration<float> elapsed = tloop1 - tloop0;
270 cout << "Multithreaded (num_threads=" << num_threads 375 cout << "Multithreaded (num_threads=" << num_threads
271 << ") loop time (including read/write/opencv/print/etc): " 376 << ", batch_size=1) loop time (" << num_iters << " frames): "
377 << setw(6) << setprecision(4)
378 << (elapsed.count() * 1000) << "ms, fps = "
272 << setw(6) << setprecision(4) 379 << setw(6) << setprecision(4)
273 << (elapsed.count() * 1000) << "ms" << endl; 380 << (num_iters / elapsed.count())
381 << endl;
382
383 for (int i = 0; i < num_threads; i++)
384 {
385 delete [] inputs[i];
386 delete [] outputs[i];
387 }
388 delete [] inputs;
389 delete [] outputs;
274 } 390 }
275 catch (tidl::Exception &e) 391 catch (tidl::Exception &e)
276 { 392 {
@@ -278,9 +394,89 @@ bool RunConfiguration(cmdline_opts_t& opts)
278 status = false; 394 status = false;
279 } 395 }
280 396
397 num_threads = 2;
398 int batch_size = preferred_batch_size;
399 // This is to test the multithreaded batch inference with async/future
400 // Ideally, batch_size * num_threads <= number of threads
401 cout << "\n##### Multithreaded batch inference testing (async/future)"
402 << " #####\n" << endl;
403 try
404 {
405 float **inputs = new float *[num_threads * batch_size];
406 float **outputs = new float *[num_threads * batch_size];
407 for (int i = 0; i < num_threads * batch_size; i++)
408 {
409 inputs[i] = new float[1*3*224*224];
410 outputs[i] = new float[1001];
411 }
412 vector<future<bool>> futures(num_threads);
413 bool skip_outputs = false;
414
415 chrono::time_point<chrono::steady_clock> tloop0, tloop1;
416 tloop0 = chrono::steady_clock::now();
417
418 for (int i = 0; i < num_iters/batch_size + num_threads; i++)
419 {
420 int index = i % num_threads;
421 if (i >= num_threads)
422 {
423 if (futures[index].get())
424 if (! SkipOutputs(i*batch_size, num_threads*batch_size,
425 skip_outputs))
426 for (int b = 0; b < batch_size; b++)
427 WriteFrameOutput(outputs[index*batch_size+b], opts);
428 }
429
430 if (i < num_iters/batch_size)
431 {
432 ReadFrame(opts, cap, &inputs[index*batch_size], batch_size);
433 futures[index] = std::async(std::launch::async,
434 [inputs, outputs, batch_size](int index) {
435 TidlRunSubgraph(1, 0, batch_size, 1, 1,
436 &inputs[index*batch_size],
437 &outputs[index*batch_size]);
438 return true;
439 },
440 index);
441 }
442 }
443
444 tloop1 = chrono::steady_clock::now();
445 chrono::duration<float> elapsed = tloop1 - tloop0;
446 cout << "Multithreaded batch (num_threads=" << num_threads
447 << ", batch_size=" << batch_size
448 << ") loop time (" << num_iters << " frames): "
449 << setw(6) << setprecision(4)
450 << (elapsed.count() * 1000) << "ms, fps = "
451 << setw(6) << setprecision(4)
452 << (num_iters / elapsed.count())
453 << endl;
454
455 for (int i = 0; i < num_threads * batch_size; i++)
456 {
457 delete [] inputs[i];
458 delete [] outputs[i];
459 }
460 delete [] inputs;
461 delete [] outputs;
462 }
463 catch (tidl::Exception &e)
464 {
465 cerr << e.what() << endl;
466 status = false;
467 }
468
469
281 return status; 470 return status;
282} 471}
283 472
473void SubgraphUserFunc(void *user_data)
474{
475 UserData *data = (UserData *) user_data;
476 //printf("data inputs = %p, outputs = %p\n", data->inputs, data->outputs);
477 TidlRunSubgraph(1, 0, 1, 1, 1, data->inputs, data->outputs);
478 //printf("TidlRunSubgraph finished\n");
479}
284 480
285bool ReadFrame(const cmdline_opts_t& opts, VideoCapture &cap, float** inputs, 481bool ReadFrame(const cmdline_opts_t& opts, VideoCapture &cap, float** inputs,
286 int batch_size) 482 int batch_size)
diff --git a/examples/mobilenet_subgraph/thread_pool.cpp b/examples/mobilenet_subgraph/thread_pool.cpp
new file mode 100644
index 0000000..ee25aea
--- /dev/null
+++ b/examples/mobilenet_subgraph/thread_pool.cpp
@@ -0,0 +1,144 @@
1/******************************************************************************
2 * Copyright (c) 2019 Texas Instruments Incorporated - http://www.ti.com/
3 * All rights reserved.
4 *
5 * Redistribution and use in source and binary forms, with or without
6 * modification, are permitted provided that the following conditions are met:
7 * * Redistributions of source code must retain the above copyright
8 * notice, this list of conditions and the following disclaimer.
9 * * Redistributions in binary form must reproduce the above copyright
10 * notice, this list of conditions and the following disclaimer in the
11 * documentation and/or other materials provided with the distribution.
12 * * Neither the name of Texas Instruments Incorporated nor the
13 * names of its contributors may be used to endorse or promote products
14 * derived from this software without specific prior written permission.
15 *
16 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
17 * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
18 * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
19 * ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE
20 * LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
21 * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
22 * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
23 * INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
24 * CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
25 * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF
26 * THE POSSIBILITY OF SUCH DAMAGE.
27 *****************************************************************************/
28
29#include "thread_pool.h"
30
31using namespace std;
32using namespace tidl;
33
34void ThFunc(int th_id, ThPool* pool)
35{
36 while (true)
37 {
38 // wait on th_id
39 pool->WaitForWork(th_id);
40
41 // check stop condition
42 if (pool->Stop()) return;
43
44 // Run user func
45 pool->RunUserFunc(th_id);
46
47 // notify completition
48 pool->NotifyCompletion(th_id);
49 }
50}
51
52ThPool::ThPool(int num_threads, UserFunc user_func) :
53 num_threads_m(num_threads),
54 user_func_m(user_func),
55 stop_m(false),
56 pool_m(num_threads),
57 pool_state_m((1ULL << num_threads) - 1),
58 v_mutex_th_m(num_threads),
59 v_cv_th_work_m(num_threads),
60 v_cv_th_completion_m(num_threads),
61 v_user_data_m(num_threads, nullptr),
62 v_completion_data_m(num_threads, nullptr)
63{
64 for (int i = 0; i < num_threads_m; i++)
65 {
66 pool_m[i] = thread(ThFunc, i, this);
67 }
68}
69
70ThPool::~ThPool()
71{
72 stop_m = true;
73 for (auto& data : v_user_data_m) data = &stop_m;
74 for (auto& cv : v_cv_th_work_m) cv.notify_all();
75 for (auto& th : pool_m) th.join();
76}
77
78int ThPool::RunAsync(void *user_data)
79{
80 int th_id = -1;
81 {
82 std::unique_lock<std::mutex> lock(mutex_pool_m);
83 cv_pool_m.wait(lock, [this]{ return this->pool_state_m != 0; });
84 // find first 1 bit
85 for (int i = 0; i < num_threads_m; i++)
86 if (pool_state_m & (1 << i))
87 {
88 th_id = i;
89 break;
90 }
91 pool_state_m &= (~ (1 << th_id));
92 }
93
94 {
95 std::unique_lock<std::mutex> lock(v_mutex_th_m[th_id]);
96 v_user_data_m[th_id] = user_data;
97 }
98 v_cv_th_work_m[th_id].notify_all();
99 return th_id;
100}
101
102void* ThPool::Wait(int th_id)
103{
104 void *user_data = nullptr;
105
106 {
107 std::unique_lock<std::mutex> lock(v_mutex_th_m[th_id]);
108 v_cv_th_completion_m[th_id].wait(lock, [this, th_id]{
109 return this->v_completion_data_m[th_id] != nullptr; });
110 user_data = v_completion_data_m[th_id];
111 v_completion_data_m[th_id] = nullptr;
112 }
113
114 {
115 std::unique_lock<std::mutex> lock(mutex_pool_m);
116 pool_state_m |= (1 << th_id);
117 }
118 cv_pool_m.notify_all();
119
120 return user_data;
121}
122
123
124void ThPool::RunUserFunc(int th_id)
125{
126 user_func_m(v_user_data_m[th_id]);
127}
128
129void ThPool::WaitForWork(int th_id)
130{
131 std::unique_lock<std::mutex> lock(v_mutex_th_m[th_id]);
132 v_cv_th_work_m[th_id].wait(lock, [this, th_id]{
133 return this->v_user_data_m[th_id] != nullptr; });
134}
135
136void ThPool::NotifyCompletion(int th_id)
137{
138 {
139 std::unique_lock<std::mutex> lock(v_mutex_th_m[th_id]);
140 v_completion_data_m[th_id] = v_user_data_m[th_id];
141 v_user_data_m[th_id] = nullptr;
142 }
143 v_cv_th_completion_m[th_id].notify_all();
144}
diff --git a/examples/mobilenet_subgraph/thread_pool.h b/examples/mobilenet_subgraph/thread_pool.h
new file mode 100644
index 0000000..0a3f60d
--- /dev/null
+++ b/examples/mobilenet_subgraph/thread_pool.h
@@ -0,0 +1,77 @@
1/******************************************************************************
2 * Copyright (c) 2019 Texas Instruments Incorporated - http://www.ti.com/
3 * All rights reserved.
4 *
5 * Redistribution and use in source and binary forms, with or without
6 * modification, are permitted provided that the following conditions are met:
7 * * Redistributions of source code must retain the above copyright
8 * notice, this list of conditions and the following disclaimer.
9 * * Redistributions in binary form must reproduce the above copyright
10 * notice, this list of conditions and the following disclaimer in the
11 * documentation and/or other materials provided with the distribution.
12 * * Neither the name of Texas Instruments Incorporated nor the
13 * names of its contributors may be used to endorse or promote products
14 * derived from this software without specific prior written permission.
15 *
16 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
17 * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
18 * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
19 * ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE
20 * LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
21 * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
22 * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
23 * INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
24 * CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
25 * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF
26 * THE POSSIBILITY OF SUCH DAMAGE.
27 *****************************************************************************/
28
29#pragma once
30
31#include <vector>
32#include <mutex>
33#include <condition_variable>
34#include <thread>
35
36using namespace std;
37
38namespace tidl {
39
40#define TIDL_MAX_NUM_THREADS 32
41
42typedef void(*UserFunc)(void *user_data);
43
44class ThPool {
45 public:
46 ThPool(int num_threads, UserFunc user_func);
47 ~ThPool();
48 // returns th_id that can be used for Wait()
49 int RunAsync(void* user_data);
50 void* Wait(int th_id);
51
52 // Run by threaded function
53 bool Stop() { return stop_m; }
54 void RunUserFunc(int th_id);
55 void WaitForWork(int th_id);
56 void NotifyCompletion(int th_id);
57
58 private:
59
60 int num_threads_m;
61 UserFunc user_func_m;
62 bool stop_m;
63 vector<thread> pool_m;
64 mutex mutex_pool_m;
65 condition_variable cv_pool_m;
66 // bit vector for availability, up to 32 threads, 1: avail, 0: not avail
67 int32_t pool_state_m;
68
69 vector<mutex> v_mutex_th_m;
70 vector<condition_variable> v_cv_th_work_m;
71 vector<condition_variable> v_cv_th_completion_m;
72
73 vector<void *> v_user_data_m;
74 vector<void *> v_completion_data_m;
75};
76
77} // namespace tidl
diff --git a/tidl_api/inc/subgraph_runtime.h b/tidl_api/inc/subgraph_runtime.h
index b4fc2b7..65db5b5 100644
--- a/tidl_api/inc/subgraph_runtime.h
+++ b/tidl_api/inc/subgraph_runtime.h
@@ -32,6 +32,13 @@
32 32
33extern "C" { 33extern "C" {
34 34
35//! @brief Top level API to get preferred batch_size for a subgraph
36//! Best performance comes with preferred batch_size processing
37//! plus multi-threaded (num_threads = 2) processing
38//! @param total_subgraphs total number of TIDL subgraphs in whole inference
39//! @return preferred batch size
40extern int TidlGetPreferredBatchSize(int total_subgraphs);
41
35//! @brief Top level API to initialize a TIDL subgraph on device 42//! @brief Top level API to initialize a TIDL subgraph on device
36//! If not invoked ahead of time, TidlRunSubgraph() will call this 43//! If not invoked ahead of time, TidlRunSubgraph() will call this
37//! function before any inference 44//! function before any inference
diff --git a/tidl_api/src/subgraph_runtime.cpp b/tidl_api/src/subgraph_runtime.cpp
index 342acd8..24b378e 100644
--- a/tidl_api/src/subgraph_runtime.cpp
+++ b/tidl_api/src/subgraph_runtime.cpp
@@ -73,6 +73,11 @@ void TVM_TidlFunction(int total_subgraphs, int subgraph_id,
73// Singleton ResM .cpp 73// Singleton ResM .cpp
74using namespace tidl; 74using namespace tidl;
75 75
76int TidlGetPreferredBatchSize(int total_subgraphs)
77{
78 ResM& res = ResM::Instance(total_subgraphs);
79 return res.GetNumEs();
80}
76 81
77void TidlInitSubgraph(int total_subgraphs, int subgraph_id) 82void TidlInitSubgraph(int total_subgraphs, int subgraph_id)
78{ 83{
diff --git a/tidl_api/src/subgraph_runtime_impl.h b/tidl_api/src/subgraph_runtime_impl.h
index a792757..9738dbb 100644
--- a/tidl_api/src/subgraph_runtime_impl.h
+++ b/tidl_api/src/subgraph_runtime_impl.h
@@ -60,6 +60,7 @@ class ResM {
60 Configuration& GetConfiguration(uint32_t subgraph_id); 60 Configuration& GetConfiguration(uint32_t subgraph_id);
61 const SubgraphDataConv& GetInConv(uint32_t subgraph_id); 61 const SubgraphDataConv& GetInConv(uint32_t subgraph_id);
62 const SubgraphDataConv& GetOutConv(uint32_t subgraph_id); 62 const SubgraphDataConv& GetOutConv(uint32_t subgraph_id);
63 uint32_t GetNumEs() { return num_es_per_subgraph_m; }
63 64
64 65
65 private: 66 private: