Add option to specify object classes list file
[tidl/tidl-api.git] / examples / common / utils.cpp
index a4c978b94f2258e2b6ca759772ea94a39ee6b05a..04b7b97d6e1f054a1021fa67db55f07f4d29edad 100644 (file)
@@ -35,27 +35,68 @@ using namespace tidl;
 
 using boost::format;
 using std::string;
+using std::istream;
+using std::ostream;
+using std::vector;
 
-bool ReadFrame(ExecutionObject* eo, int frame_idx,
+
+// Create an Executor with the specified type and number of EOs
+Executor* CreateExecutor(DeviceType dt, int num, const Configuration& c,
+                         int layer_group_id)
+{
+    if (num == 0) return nullptr;
+
+    DeviceIds ids;
+    for (int i = 0; i < num; i++)
+        ids.insert(static_cast<DeviceId>(i));
+
+    return new Executor(dt, ids, c, layer_group_id);
+}
+static bool read_frame_helper(char* ptr, size_t size, istream& input_file);
+
+bool ReadFrame(ExecutionObject*     eo,
+               int                  frame_idx,
                const Configuration& configuration,
-               std::istream& input_file)
+               std::istream&        input_file)
+{
+    if (frame_idx >= configuration.numFrames)
+        return false;
+
+    // Note: Frame index is used by the EO for debug messages only
+    eo->SetFrameIndex(frame_idx);
+
+    return read_frame_helper(eo->GetInputBufferPtr(),
+                             eo->GetInputBufferSizeInBytes(),
+                             input_file);
+}
+
+bool ReadFrame(ExecutionObjectPipeline* eop,
+               int                      frame_idx,
+               const Configuration&     configuration,
+               std::istream&            input_file)
 {
     if (frame_idx >= configuration.numFrames)
         return false;
 
-    assert (eo->GetInputBufferPtr() != nullptr);
+    // Note: Frame index is used by the EOP for debug messages only
+    eop->SetFrameIndex(frame_idx);
+
+    return read_frame_helper(eop->GetInputBufferPtr(),
+                             eop->GetInputBufferSizeInBytes(),
+                             input_file);
+}
+
+bool read_frame_helper(char* ptr, size_t size, istream& input_file)
+{
+    assert (ptr != nullptr);
     assert (input_file.good());
 
-    input_file.read(eo->GetInputBufferPtr(),
-                    eo->GetInputBufferSizeInBytes());
+    input_file.read(ptr, size);
     assert (input_file.good());
 
     if (input_file.eof())
         return false;
 
-    // Note: Frame index is used by the EO for debug messages only
-    eo->SetFrameIndex(frame_idx);
-
     // Wrap-around : if EOF is reached, start reading from the beginning.
     if (input_file.peek() == EOF)
         input_file.seekg(0, input_file.beg);
@@ -66,7 +107,8 @@ bool ReadFrame(ExecutionObject* eo, int frame_idx,
     return false;
 }
 
-bool WriteFrame(const ExecutionObject* eo, std::ostream& output_file)
+
+bool WriteFrame(const ExecutionObject* eo, ostream& output_file)
 {
     output_file.write(eo->GetOutputBufferPtr(),
                       eo->GetOutputBufferSizeInBytes());
@@ -80,14 +122,11 @@ bool WriteFrame(const ExecutionObject* eo, std::ostream& output_file)
 
 void ReportTime(const ExecutionObject* eo)
 {
-    double elapsed_host   = eo->GetHostProcessTimeInMilliSeconds();
     double elapsed_device = eo->GetProcessTimeInMilliSeconds();
-    double overhead = 100 - (elapsed_device/elapsed_host*100);
 
-    std::cout << format("frame[%3d]: Time on %s: %4.2f ms, host: %4.2f ms"
-                        " API overhead: %2.2f %%\n")
+    std::cout << format("frame[%3d]: Time on %s: %4.2f ms\n")
                         % eo->GetFrameIndex() % eo->GetDeviceName()
-                        % elapsed_device % elapsed_host % overhead;
+                        % elapsed_device;
 }
 
 // Compare output against reference output
@@ -101,10 +140,14 @@ bool CheckFrame(const ExecutionObject *eo, const char *ref_output)
     return false;
 }
 
+bool CheckFrame(const ExecutionObjectPipeline *eop, const char *ref_output)
+{
+    if (std::memcmp(static_cast<const void*>(ref_output),
+               static_cast<const void*>(eop->GetOutputBufferPtr()),
+               eop->GetOutputBufferSizeInBytes()) == 0)
+        return true;
 
-namespace tidl {
-std::size_t GetBinaryFileSize (const std::string &F);
-bool        ReadBinary        (const std::string &F, char* buffer, int size);
+    return false;
 }
 
 // Read file into a buffer.
@@ -128,3 +171,60 @@ const char* ReadReferenceOutput(const string& name)
 
     return buffer;
 }
+
+// Allocate input and output memory for each EO
+void AllocateMemory(const vector<ExecutionObject *>& eos)
+{
+    // Allocate input and output buffers for each execution object
+    for (auto eo : eos)
+    {
+        size_t in_size  = eo->GetInputBufferSizeInBytes();
+        size_t out_size = eo->GetOutputBufferSizeInBytes();
+        void*  in_ptr   = malloc(in_size);
+        void*  out_ptr  = malloc(out_size);
+        assert(in_ptr != nullptr && out_ptr != nullptr);
+
+        ArgInfo in  = { ArgInfo(in_ptr,  in_size)};
+        ArgInfo out = { ArgInfo(out_ptr, out_size)};
+        eo->SetInputOutputBuffer(in, out);
+    }
+}
+
+// Free the input and output memory associated with each EO
+void FreeMemory(const vector<ExecutionObject *>& eos)
+{
+    for (auto eo : eos)
+    {
+        free(eo->GetInputBufferPtr());
+        free(eo->GetOutputBufferPtr());
+    }
+}
+
+// Allocate input and output memory for each EOP
+void AllocateMemory(const vector<ExecutionObjectPipeline *>& eops)
+{
+    // Allocate input and output buffers for each execution object
+    for (auto eop : eops)
+    {
+        size_t in_size  = eop->GetInputBufferSizeInBytes();
+        size_t out_size = eop->GetOutputBufferSizeInBytes();
+        void*  in_ptr   = malloc(in_size);
+        void*  out_ptr  = malloc(out_size);
+        assert(in_ptr != nullptr && out_ptr != nullptr);
+
+        ArgInfo in  = { ArgInfo(in_ptr,  in_size)};
+        ArgInfo out = { ArgInfo(out_ptr, out_size)};
+        eop->SetInputOutputBuffer(in, out);
+    }
+}
+
+// Free the input and output memory associated with each EOP
+void FreeMemory(const vector<ExecutionObjectPipeline *>& eops)
+{
+    for (auto eop : eops)
+    {
+        free(eop->GetInputBufferPtr());
+        free(eop->GetOutputBufferPtr());
+    }
+}
+