Subgraph example: multi-threaded batch processing
authorYuan Zhao <yuanzhao@ti.com>
Thu, 21 Nov 2019 05:38:12 +0000 (23:38 -0600)
committerYuan Zhao <yuanzhao@ti.com>
Thu, 21 Nov 2019 05:38:12 +0000 (23:38 -0600)
- Compared different batch size in subgraph execution example
- Compared async/future implementation vs thread pool implementation,
  async/future has slightly worse (~1%) performance,
  but it is much easier to program
- Recommended inference is multi-threaded batch processing, where
  batch_size can be obtained from TidlGetPreferredBatchSize(),
  number of threads can be set to 2.
- MCT-1223

examples/mobilenet_subgraph/Makefile
examples/mobilenet_subgraph/main.cpp
examples/mobilenet_subgraph/thread_pool.cpp [new file with mode: 0644]
examples/mobilenet_subgraph/thread_pool.h [new file with mode: 0644]
tidl_api/inc/subgraph_runtime.h
tidl_api/src/subgraph_runtime.cpp
tidl_api/src/subgraph_runtime_impl.h

index 68f5d9df5a811ab4aaeba2b23190e98e24551379..e4a5173e7444f37aefd9e8699e3fe115583f5d14 100644 (file)
@@ -36,7 +36,7 @@ LIBS     += -ljson-c
 LIBS     += -L$(TIDL_API_DIR) -ltidl_api -ltidl_imgutil
 
 SOURCES = main.cpp ../common/object_classes.cpp ../common/utils.cpp \
-          ../common/video_utils.cpp
+          ../common/video_utils.cpp thread_pool.cpp
 
 $(EXE): $(HEADERS) $(SOURCES)
        $(CXX) $(CXXFLAGS) $(SOURCES) \
index e4e499af67eb7a377cae2278b93674231e5fe654..8a77f6576eed068ecfc3b20d506c476b26ac3aa7 100644 (file)
@@ -1,5 +1,5 @@
 /******************************************************************************
- * Copyright (c) 2018, Texas Instruments Incorporated - http://www.ti.com/
+ * Copyright (c) 2019, Texas Instruments Incorporated - http://www.ti.com/
  *   All rights reserved.
  *
  *   Redistribution and use in source and binary forms, with or without
@@ -25,6 +25,7 @@
  *   ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF
  *   THE POSSIBILITY OF SUCH DAMAGE.
  *****************************************************************************/
+
 #include <signal.h>
 #include <iostream>
 #include <iomanip>
@@ -50,6 +51,7 @@
 #include "../common/object_classes.h"
 #include "imgutil.h"
 #include "../common/video_utils.h"
+#include "thread_pool.h"
 
 #include "opencv2/core.hpp"
 #include "opencv2/imgproc.hpp"
@@ -70,13 +72,32 @@ const char *default_inputs[NUM_DEFAULT_INPUTS] =
     "../test/testvecs/input/objects/cat-pet-animal-domestic-104827.jpeg"
 };
 std::unique_ptr<ObjectClasses> object_classes;
+typedef struct {
+  float **inputs;
+  float **outputs;
+} UserData;
 
 bool RunConfiguration(cmdline_opts_t& opts);
 bool ReadFrame(const cmdline_opts_t& opts, VideoCapture &cap, float** inputs,
                int batch_size);
 bool WriteFrameOutput(float *out, const cmdline_opts_t& opts);
 void DisplayHelp();
+void SubgraphUserFunc(void *user_data);
 
+const int num_printed_outputs = 4;
+bool SkipOutputs(int i, int offset, bool &skip_outputs)
+{
+    if (skip_outputs)  return true;
+    if (i >= num_printed_outputs + offset)
+    {
+        if (! skip_outputs)
+        {
+            cout << "   ... skippping outputs ..." << endl;
+            skip_outputs = true;
+        }
+    }
+    return skip_outputs;
+}
 
 int main(int argc, char *argv[])
 {
@@ -180,37 +201,123 @@ bool RunConfiguration(cmdline_opts_t& opts)
         status = false;
     }
 
-    int batch_size = 8;
-    cout << "\n##### Batch size " << batch_size << " testing ######\n" << endl;
+    // If not doing multi-threaded processing, multiply by 2 or more
+    //     for a larger batch to amortize batch initilization/tear down cost
+    int preferred_batch_size = TidlGetPreferredBatchSize(1);
+    for (int multiple = 1; multiple <= 16; multiple *= 2)
+    {
+        int batch_size = preferred_batch_size * multiple;
+        cout << "\n##### Batch size " << batch_size << " testing ######\n"
+             << endl;
+        bool skip_outputs = false;
+        try
+        {
+            float **inputs  = new float *[batch_size];
+            float **outputs = new float *[batch_size];
+            for (int i = 0; i < batch_size; i++)
+            {
+                inputs[i]  = new float[1*3*224*224];
+                outputs[i] = new float[1001];
+            }
+
+            chrono::time_point<chrono::steady_clock> tloop0, tloop1;
+            tloop0 = chrono::steady_clock::now();
+
+            ReadFrame(opts, cap, inputs, batch_size);
+            TidlRunSubgraph(1, 0, batch_size, 1, 1, inputs, outputs);
+            for (int i = 0; i < batch_size; i++)
+            {
+                if (! SkipOutputs(i, 0, skip_outputs))
+                {
+                    cout << "Frame " << i << " of " << batch_size
+                         << " output:" << endl;
+                    WriteFrameOutput(outputs[i], opts);
+                }
+            }
+
+            tloop1 = chrono::steady_clock::now();
+            chrono::duration<float> elapsed = tloop1 - tloop0;
+            cout << "Batch size " << batch_size
+                 << " time: "
+                 << setw(6) << setprecision(4)
+                 << (elapsed.count() * 1000) << "ms, fps = "
+                 << setw(6) << setprecision(4)
+                 << (batch_size / elapsed.count())
+                 << endl;
+
+            for (int i = 0; i < batch_size; i++)
+            {
+                delete [] inputs[i];
+                delete [] outputs[i];
+            }
+            delete [] inputs;
+            delete [] outputs;
+        }
+        catch (tidl::Exception &e)
+        {
+            cerr << e.what() << endl;
+            status = false;
+        }
+    }
+
+    // This is to test the multithreaded inference with async/future
+    // async/future has slightly worse threading performance than
+    //     thread pool, however, it is much easier to program
+    cout << "\n##### Multithreaded inference testing (async/future) #####\n"
+         << endl;
+    int num_threads = TidlGetPreferredBatchSize(1) * 2;
+    int num_iters = 100;
     try
     {
-        float **inputs  = new float *[batch_size];
-        float **outputs = new float *[batch_size];
-        for (int i = 0; i < batch_size; i++)
+        float **inputs  = new float *[num_threads];
+        float **outputs = new float *[num_threads];
+        for (int i = 0; i < num_threads; i++)
         {
             inputs[i]  = new float[1*3*224*224];
             outputs[i] = new float[1001];
         }
+        vector<future<bool>> futures(num_threads);
+        bool skip_outputs = false;
 
         chrono::time_point<chrono::steady_clock> tloop0, tloop1;
         tloop0 = chrono::steady_clock::now();
 
-        ReadFrame(opts, cap, inputs, batch_size);
-        TidlRunSubgraph(1, 0, batch_size, 1, 1, inputs, outputs);
-        for (int i = 0; i < batch_size; i++)
+        for (int i = 0; i < num_iters + num_threads; i++)
         {
-            cout << "Frame " << i << " of " << batch_size << " output:" << endl;
-            WriteFrameOutput(outputs[i], opts);
+            int index = i % num_threads;
+            if (i >= num_threads)
+            {
+                if (futures[index].get())
+                {
+                    if (! SkipOutputs(i, num_threads, skip_outputs))
+                        WriteFrameOutput(outputs[index], opts);
+                }
+            }
+
+            if (i < num_iters)
+            {
+                ReadFrame(opts, cap, &inputs[index], 1);
+                futures[index] = std::async(std::launch::async,
+                              [inputs, outputs](int index) {
+                                  TidlRunSubgraph(1, 0, 1, 1, 1,
+                                              &inputs[index], &outputs[index]);
+                                   return true;
+                              },
+                                            index);
+            }
         }
 
         tloop1 = chrono::steady_clock::now();
         chrono::duration<float> elapsed = tloop1 - tloop0;
-        cout << "Batch size " << batch_size
-             << " time (including read/write/opencv/print/etc): "
+        cout << "Multithreaded (num_threads=" << num_threads
+             << ", batch_size=1) loop time (" << num_iters << " frames): "
              << setw(6) << setprecision(4)
-             << (elapsed.count() * 1000) << "ms" << endl;
+             << (elapsed.count() * 1000) << "ms, fps = "
+             << setw(6) << setprecision(4)
+             << (num_iters / elapsed.count())
+             << endl;
 
-        for (int i = 0; i < batch_size; i++)
+        for (int i = 0; i < num_threads; i++)
         {
             delete [] inputs[i];
             delete [] outputs[i];
@@ -224,53 +331,62 @@ bool RunConfiguration(cmdline_opts_t& opts)
         status = false;
     }
 
-    // This is only to test the multithreaded inference
-    // async/future may not be the most efficient multithreading method
-    // threading pool might have better performance
-    cout << "\n##### Multithreaded inference testing #####\n" << endl;
-    int num_threads = 8;
-    int num_iters = 8;
+    // This is to test the multithreaded inference with a thread pool
+    cout << "\n##### Multithreaded inference testing (thread pool) #####\n"
+         << endl;
     try
     {
         float **inputs  = new float *[num_threads];
         float **outputs = new float *[num_threads];
+        vector<UserData> v_data(num_threads);
         for (int i = 0; i < num_threads; i++)
         {
             inputs[i]  = new float[1*3*224*224];
             outputs[i] = new float[1001];
+            v_data[i].inputs  = &inputs[i];
+            v_data[i].outputs = &outputs[i];
         }
-        vector<future<bool>> futures(num_threads);
+        ThPool pool(num_threads, SubgraphUserFunc);
+        vector<int> th_ids(num_threads);
+        bool skip_outputs = false;
 
         chrono::time_point<chrono::steady_clock> tloop0, tloop1;
         tloop0 = chrono::steady_clock::now();
 
         for (int i = 0; i < num_iters + num_threads; i++)
         {
-          int index = i % num_threads;
-          if (i >= num_threads)
-          {
-            if (futures[index].get())
-              WriteFrameOutput(outputs[index], opts);
-          }
-
-          if (i < num_iters)
-          {
-            ReadFrame(opts, cap, &inputs[index], 1);
-            futures[index] = std::async(std::launch::async,
-                                        [inputs, outputs](int index) {
-               TidlRunSubgraph(1, 0, 1, 1, 1, &inputs[index], &outputs[index]);
-               return true;
-               },
-                                        index);
-          }
+            int index = i % num_threads;
+            if (i >= num_threads)
+            {
+                UserData *data = (UserData *) pool.Wait(th_ids[index]);
+                if (! SkipOutputs(i, num_threads, skip_outputs))
+                    WriteFrameOutput(data->outputs[0], opts);
+            }
+
+            if (i < num_iters)
+            {
+                ReadFrame(opts, cap, &inputs[index], 1);
+                th_ids[index] = pool.RunAsync(&v_data[index]);
+            }
         }
 
         tloop1 = chrono::steady_clock::now();
         chrono::duration<float> elapsed = tloop1 - tloop0;
         cout << "Multithreaded (num_threads=" << num_threads
-             << ") loop time (including read/write/opencv/print/etc): "
+             << ", batch_size=1) loop time (" << num_iters << " frames): "
+             << setw(6) << setprecision(4)
+             << (elapsed.count() * 1000) << "ms, fps = "
              << setw(6) << setprecision(4)
-             << (elapsed.count() * 1000) << "ms" << endl;
+             << (num_iters / elapsed.count())
+             << endl;
+
+        for (int i = 0; i < num_threads; i++)
+        {
+            delete [] inputs[i];
+            delete [] outputs[i];
+        }
+        delete [] inputs;
+        delete [] outputs;
     }
     catch (tidl::Exception &e)
     {
@@ -278,9 +394,89 @@ bool RunConfiguration(cmdline_opts_t& opts)
         status = false;
     }
 
+    num_threads = 2;
+    int batch_size  = preferred_batch_size;
+    // This is to test the multithreaded batch inference with async/future
+    // Ideally, batch_size * num_threads <= number of threads
+    cout << "\n##### Multithreaded batch inference testing (async/future)"
+         << " #####\n" << endl;
+    try
+    {
+        float **inputs  = new float *[num_threads * batch_size];
+        float **outputs = new float *[num_threads * batch_size];
+        for (int i = 0; i < num_threads * batch_size; i++)
+        {
+            inputs[i]  = new float[1*3*224*224];
+            outputs[i] = new float[1001];
+        }
+        vector<future<bool>> futures(num_threads);
+        bool skip_outputs = false;
+
+        chrono::time_point<chrono::steady_clock> tloop0, tloop1;
+        tloop0 = chrono::steady_clock::now();
+
+        for (int i = 0; i < num_iters/batch_size + num_threads; i++)
+        {
+            int index = i % num_threads;
+            if (i >= num_threads)
+            {
+                if (futures[index].get())
+                    if (! SkipOutputs(i*batch_size, num_threads*batch_size,
+                                      skip_outputs))
+                        for (int b = 0; b < batch_size; b++)
+                            WriteFrameOutput(outputs[index*batch_size+b], opts);
+            }
+
+            if (i < num_iters/batch_size)
+            {
+                ReadFrame(opts, cap, &inputs[index*batch_size], batch_size);
+                futures[index] = std::async(std::launch::async,
+                      [inputs, outputs, batch_size](int index) {
+                          TidlRunSubgraph(1, 0, batch_size, 1, 1,
+                                          &inputs[index*batch_size],
+                                          &outputs[index*batch_size]);
+                          return true;
+                      },
+                                            index);
+            }
+        }
+
+        tloop1 = chrono::steady_clock::now();
+        chrono::duration<float> elapsed = tloop1 - tloop0;
+        cout << "Multithreaded batch (num_threads=" << num_threads
+             << ", batch_size=" << batch_size
+             << ") loop time (" << num_iters << " frames): "
+             << setw(6) << setprecision(4)
+             << (elapsed.count() * 1000) << "ms, fps = "
+             << setw(6) << setprecision(4)
+             << (num_iters / elapsed.count())
+             << endl;
+
+        for (int i = 0; i < num_threads * batch_size; i++)
+        {
+            delete [] inputs[i];
+            delete [] outputs[i];
+        }
+        delete [] inputs;
+        delete [] outputs;
+    }
+    catch (tidl::Exception &e)
+    {
+        cerr << e.what() << endl;
+        status = false;
+    }
+
+
     return status;
 }
 
+void SubgraphUserFunc(void *user_data)
+{
+  UserData *data = (UserData *) user_data;
+  //printf("data inputs = %p, outputs = %p\n", data->inputs, data->outputs);
+  TidlRunSubgraph(1, 0, 1, 1, 1, data->inputs, data->outputs);
+  //printf("TidlRunSubgraph finished\n");
+}
 
 bool ReadFrame(const cmdline_opts_t& opts, VideoCapture &cap, float** inputs,
                int batch_size)
diff --git a/examples/mobilenet_subgraph/thread_pool.cpp b/examples/mobilenet_subgraph/thread_pool.cpp
new file mode 100644 (file)
index 0000000..ee25aea
--- /dev/null
@@ -0,0 +1,144 @@
+/******************************************************************************
+ * Copyright (c) 2019 Texas Instruments Incorporated - http://www.ti.com/
+ *  All rights reserved.
+ *
+ *  Redistribution and use in source and binary forms, with or without
+ *  modification, are permitted provided that the following conditions are met:
+ *      * Redistributions of source code must retain the above copyright
+ *        notice, this list of conditions and the following disclaimer.
+ *      * Redistributions in binary form must reproduce the above copyright
+ *        notice, this list of conditions and the following disclaimer in the
+ *        documentation and/or other materials provided with the distribution.
+ *      * Neither the name of Texas Instruments Incorporated nor the
+ *        names of its contributors may be used to endorse or promote products
+ *        derived from this software without specific prior written permission.
+ *
+ *  THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+ *  AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+ *  IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
+ *  ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE
+ *  LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
+ *  CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
+ *  SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
+ *  INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
+ *  CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
+ *  ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF
+ *  THE POSSIBILITY OF SUCH DAMAGE.
+ *****************************************************************************/
+
+#include "thread_pool.h"
+
+using namespace std;
+using namespace tidl;
+
+void ThFunc(int th_id, ThPool* pool)
+{
+  while (true)
+  {
+    // wait on th_id
+    pool->WaitForWork(th_id);
+
+    // check stop condition
+    if (pool->Stop())  return;
+
+    // Run user func
+    pool->RunUserFunc(th_id);
+
+    // notify completition
+    pool->NotifyCompletion(th_id);
+  }
+}
+
+ThPool::ThPool(int num_threads, UserFunc user_func) :
+        num_threads_m(num_threads),
+        user_func_m(user_func),
+        stop_m(false),
+        pool_m(num_threads),
+        pool_state_m((1ULL << num_threads) - 1),
+        v_mutex_th_m(num_threads),
+        v_cv_th_work_m(num_threads),
+        v_cv_th_completion_m(num_threads),
+        v_user_data_m(num_threads, nullptr),
+        v_completion_data_m(num_threads, nullptr)
+{
+  for (int i = 0; i < num_threads_m; i++)
+  {
+    pool_m[i] = thread(ThFunc, i, this);
+  }
+}
+
+ThPool::~ThPool()
+{
+  stop_m = true;
+  for (auto& data : v_user_data_m)  data = &stop_m;
+  for (auto& cv : v_cv_th_work_m)   cv.notify_all();
+  for (auto& th : pool_m)           th.join();
+}
+
+int ThPool::RunAsync(void *user_data)
+{
+  int th_id = -1;
+  {
+    std::unique_lock<std::mutex> lock(mutex_pool_m);
+    cv_pool_m.wait(lock, [this]{ return this->pool_state_m != 0; });
+    // find first 1 bit
+    for (int i = 0; i < num_threads_m; i++)
+      if (pool_state_m & (1 << i))
+      {
+        th_id = i;
+        break;
+      }
+    pool_state_m &= (~ (1 << th_id));
+  }
+
+  {
+    std::unique_lock<std::mutex> lock(v_mutex_th_m[th_id]);
+    v_user_data_m[th_id] = user_data;
+  }
+  v_cv_th_work_m[th_id].notify_all();
+  return th_id;
+}
+
+void* ThPool::Wait(int th_id)
+{
+  void *user_data = nullptr;
+
+  {
+    std::unique_lock<std::mutex> lock(v_mutex_th_m[th_id]);
+    v_cv_th_completion_m[th_id].wait(lock, [this, th_id]{
+                       return this->v_completion_data_m[th_id] != nullptr; });
+    user_data = v_completion_data_m[th_id];
+    v_completion_data_m[th_id] = nullptr;
+  }
+
+  {
+    std::unique_lock<std::mutex> lock(mutex_pool_m);
+    pool_state_m |= (1 << th_id);
+  }
+  cv_pool_m.notify_all();
+
+  return user_data;
+}
+
+
+void ThPool::RunUserFunc(int th_id)
+{
+  user_func_m(v_user_data_m[th_id]);
+}
+
+void ThPool::WaitForWork(int th_id)
+{
+    std::unique_lock<std::mutex> lock(v_mutex_th_m[th_id]);
+    v_cv_th_work_m[th_id].wait(lock, [this, th_id]{
+                              return this->v_user_data_m[th_id] != nullptr; });
+}
+
+void ThPool::NotifyCompletion(int th_id)
+{
+  {
+    std::unique_lock<std::mutex> lock(v_mutex_th_m[th_id]);
+    v_completion_data_m[th_id] = v_user_data_m[th_id];
+    v_user_data_m[th_id] = nullptr;
+  }
+  v_cv_th_completion_m[th_id].notify_all();
+}
diff --git a/examples/mobilenet_subgraph/thread_pool.h b/examples/mobilenet_subgraph/thread_pool.h
new file mode 100644 (file)
index 0000000..0a3f60d
--- /dev/null
@@ -0,0 +1,77 @@
+/******************************************************************************
+ * Copyright (c) 2019 Texas Instruments Incorporated - http://www.ti.com/
+ *  All rights reserved.
+ *
+ *  Redistribution and use in source and binary forms, with or without
+ *  modification, are permitted provided that the following conditions are met:
+ *      * Redistributions of source code must retain the above copyright
+ *        notice, this list of conditions and the following disclaimer.
+ *      * Redistributions in binary form must reproduce the above copyright
+ *        notice, this list of conditions and the following disclaimer in the
+ *        documentation and/or other materials provided with the distribution.
+ *      * Neither the name of Texas Instruments Incorporated nor the
+ *        names of its contributors may be used to endorse or promote products
+ *        derived from this software without specific prior written permission.
+ *
+ *  THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+ *  AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+ *  IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
+ *  ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE
+ *  LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
+ *  CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
+ *  SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
+ *  INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
+ *  CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
+ *  ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF
+ *  THE POSSIBILITY OF SUCH DAMAGE.
+ *****************************************************************************/
+
+#pragma once
+
+#include <vector>
+#include <mutex>
+#include <condition_variable>
+#include <thread>
+
+using namespace std;
+
+namespace tidl {
+
+#define TIDL_MAX_NUM_THREADS  32
+
+typedef void(*UserFunc)(void *user_data);
+
+class ThPool {
+  public:
+    ThPool(int num_threads, UserFunc user_func);
+    ~ThPool();
+    // returns th_id that can be used for Wait()
+    int RunAsync(void* user_data);
+    void* Wait(int th_id);
+
+    // Run by threaded function
+    bool Stop()  { return stop_m; }
+    void RunUserFunc(int th_id);
+    void WaitForWork(int th_id);
+    void NotifyCompletion(int th_id);
+
+  private:
+
+    int                num_threads_m;
+    UserFunc           user_func_m;
+    bool               stop_m;
+    vector<thread>     pool_m;
+    mutex              mutex_pool_m;
+    condition_variable cv_pool_m;
+    // bit vector for availability, up to 32 threads, 1: avail, 0: not avail
+    int32_t            pool_state_m;
+
+    vector<mutex>              v_mutex_th_m;
+    vector<condition_variable> v_cv_th_work_m;
+    vector<condition_variable> v_cv_th_completion_m;
+
+    vector<void *>             v_user_data_m;
+    vector<void *>             v_completion_data_m;
+};
+
+}  // namespace tidl
index b4fc2b70ec1b06278c6b9bb46d80f3c3bc1d5529..65db5b5dc2e1c14dc77a039080b172d0bb66a1d8 100644 (file)
 
 extern "C" {
 
+//! @brief Top level API to get preferred batch_size for a subgraph
+//!        Best performance comes with preferred batch_size processing
+//!        plus multi-threaded (num_threads = 2) processing
+//! @param total_subgraphs  total number of TIDL subgraphs in whole inference
+//! @return preferred batch size
+extern int TidlGetPreferredBatchSize(int total_subgraphs);
+
 //! @brief Top level API to initialize a TIDL subgraph on device
 //!        If not invoked ahead of time, TidlRunSubgraph() will call this
 //!        function before any inference
index 342acd8fa9f3fe6ff343abd2a9e54bf40f418cda..24b378e241b4d6b7ac5e881aa35969bbaf6c22d5 100644 (file)
@@ -73,6 +73,11 @@ void TVM_TidlFunction(int total_subgraphs, int subgraph_id,
 // Singleton ResM .cpp
 using namespace tidl;
 
+int TidlGetPreferredBatchSize(int total_subgraphs)
+{
+  ResM& res = ResM::Instance(total_subgraphs);
+  return res.GetNumEs();
+}
 
 void TidlInitSubgraph(int total_subgraphs, int subgraph_id)
 {
index a792757da062ba87233fc0d0fbf5771f88f3b370..9738dbbac7cbf65671d263880043050582c4926b 100644 (file)
@@ -60,6 +60,7 @@ class ResM {
     Configuration&           GetConfiguration(uint32_t subgraph_id);
     const SubgraphDataConv&  GetInConv(uint32_t subgraph_id);
     const SubgraphDataConv&  GetOutConv(uint32_t subgraph_id);
+    uint32_t                 GetNumEs() { return num_es_per_subgraph_m; }
 
 
   private: