]> Gitweb @ Texas Instruments - Open Source Git Repositories - git.TI.com/gitweb - tidl/tidl-api.git/blobdiff - tinn_api/src/execution_object.cpp
Compute input/output size based on network
[tidl/tidl-api.git] / tinn_api / src / execution_object.cpp
index 46606833e140e9b39482ba5446f5838fe389eed9..6a71d87b47550c9708709974f161bec95593c163 100644 (file)
@@ -54,6 +54,7 @@ class ExecutionObject::Impl
         bool SetupProcessKernel(const ArgInfo& in, const ArgInfo& out);
         void HostWriteNetInput();
         void HostReadNetOutput();
+        void ComputeInputOutputSizes();
 
         Device*                         device_m;
         std::unique_ptr<Kernel>         k_initialize_m;
@@ -64,6 +65,8 @@ class ExecutionObject::Impl
         up_malloc_ddr<OCL_TIDL_InitializeParams> shared_initialize_params_m;
         up_malloc_ddr<OCL_TIDL_ProcessParams>    shared_process_params_m;
 
+        size_t                          in_size;
+        size_t                          out_size;
         ArgInfo                         in_m;
         ArgInfo                         out_m;
 
@@ -101,6 +104,8 @@ ExecutionObject::Impl::Impl(Device* d,
     tidl_extmem_heap_m (nullptr, &__free_ddr),
     shared_initialize_params_m(nullptr, &__free_ddr),
     shared_process_params_m(nullptr, &__free_ddr),
+    in_size(0),
+    out_size(0),
     in_m(nullptr, 0),
     out_m(nullptr, 0),
     device_index_m(device_index),
@@ -152,7 +157,8 @@ char* ExecutionObject::GetInputBufferPtr() const
 
 size_t ExecutionObject::GetInputBufferSizeInBytes() const
 {
-    return pimpl_m->in_m.size();
+    if (pimpl_m->in_m.ptr() == nullptr)  return pimpl_m->in_size;
+    else                                 return pimpl_m->in_m.size();
 }
 
 char* ExecutionObject::GetOutputBufferPtr() const
@@ -162,7 +168,8 @@ char* ExecutionObject::GetOutputBufferPtr() const
 
 size_t ExecutionObject::GetOutputBufferSizeInBytes() const
 {
-    return pimpl_m->shared_process_params_m.get()->bytesWritten;
+    if (pimpl_m->out_m.ptr() == nullptr)  return pimpl_m->out_size;
+    else           return pimpl_m->shared_process_params_m.get()->bytesWritten;
 }
 
 void  ExecutionObject::SetFrameIndex(int idx)
@@ -320,6 +327,23 @@ void ExecutionObject::Impl::HostReadNetOutput()
     shared_process_params_m->bytesWritten = writePtr - (char *) out_m.ptr();
 }
 
+void ExecutionObject::Impl::ComputeInputOutputSizes()
+{
+    in_size  = 0;
+    out_size = 0;
+    for (unsigned int i = 0; i < shared_initialize_params_m->numInBufs; i++)
+    {
+        OCL_TIDL_BufParams *inBuf = &shared_initialize_params_m->inBufs[i];
+        in_size += inBuf->numROIs * inBuf->numChannels * inBuf->ROIWidth *
+                   inBuf->ROIHeight;
+    }
+    for (unsigned int i = 0; i < shared_initialize_params_m->numOutBufs; i++)
+    {
+        OCL_TIDL_BufParams *outBuf = &shared_initialize_params_m->outBufs[i];
+        out_size += outBuf->numChannels * outBuf->ROIWidth * outBuf->ROIHeight;
+    }
+}
+
 
 bool ExecutionObject::Impl::RunAsync(CallType ct)
 {
@@ -363,6 +387,7 @@ bool ExecutionObject::Impl::Wait(CallType ct)
                 if (shared_initialize_params_m->errorCode != OCL_TIDL_SUCCESS)
                     throw Exception(shared_initialize_params_m->errorCode,
                                     __FILE__, __FUNCTION__, __LINE__);
+                ComputeInputOutputSizes();
             }
             return has_work;
         }
@@ -371,11 +396,10 @@ bool ExecutionObject::Impl::Wait(CallType ct)
             bool has_work = k_process_m->Wait();
             if (has_work)
             {
-                HostReadNetOutput();
-
                 if (shared_process_params_m->errorCode != OCL_TIDL_SUCCESS)
                     throw Exception(shared_process_params_m->errorCode,
                                     __FILE__, __FUNCTION__, __LINE__);
+                HostReadNetOutput();
             }
 
             return has_work;