Reduce complexity of ssd_multibox example
[tidl/tidl-api.git] / tinn_api / src / execution_object.cpp
index 46606833e140e9b39482ba5446f5838fe389eed9..dbdb90293206cf804b7fc9458d1cae596c4d444a 100644 (file)
@@ -45,7 +45,8 @@ class ExecutionObject::Impl
         Impl(Device* d, uint8_t device_index,
              const ArgInfo& create_arg,
              const ArgInfo& param_heap_arg,
-             size_t extmem_heap_size);
+             size_t extmem_heap_size,
+             bool   internal_input);
         ~Impl() {}
 
         bool RunAsync(CallType ct);
@@ -54,6 +55,7 @@ class ExecutionObject::Impl
         bool SetupProcessKernel(const ArgInfo& in, const ArgInfo& out);
         void HostWriteNetInput();
         void HostReadNetOutput();
+        void ComputeInputOutputSizes();
 
         Device*                         device_m;
         std::unique_ptr<Kernel>         k_initialize_m;
@@ -64,6 +66,8 @@ class ExecutionObject::Impl
         up_malloc_ddr<OCL_TIDL_InitializeParams> shared_initialize_params_m;
         up_malloc_ddr<OCL_TIDL_ProcessParams>    shared_process_params_m;
 
+        size_t                          in_size_m;
+        size_t                          out_size_m;
         ArgInfo                         in_m;
         ArgInfo                         out_m;
 
@@ -79,13 +83,15 @@ ExecutionObject::ExecutionObject(Device* d,
                                  uint8_t device_index,
                                  const ArgInfo& create_arg,
                                  const ArgInfo& param_heap_arg,
-                                 size_t extmem_heap_size)
+                                 size_t extmem_heap_size,
+                                 bool   internal_input)
 {
     pimpl_m = std::unique_ptr<ExecutionObject::Impl>
               { new ExecutionObject::Impl(d, device_index,
                                           create_arg,
                                           param_heap_arg,
-                                          extmem_heap_size) };
+                                          extmem_heap_size,
+                                          internal_input) };
 }
 
 
@@ -93,7 +99,8 @@ ExecutionObject::Impl::Impl(Device* d,
                                  uint8_t device_index,
                                  const ArgInfo& create_arg,
                                  const ArgInfo& param_heap_arg,
-                                 size_t extmem_heap_size):
+                                 size_t extmem_heap_size,
+                                 bool   internal_input):
     device_m(d),
     k_initialize_m(nullptr),
     k_process_m(nullptr),
@@ -101,6 +108,8 @@ ExecutionObject::Impl::Impl(Device* d,
     tidl_extmem_heap_m (nullptr, &__free_ddr),
     shared_initialize_params_m(nullptr, &__free_ddr),
     shared_process_params_m(nullptr, &__free_ddr),
+    in_size_m(0),
+    out_size_m(0),
     in_m(nullptr, 0),
     out_m(nullptr, 0),
     device_index_m(device_index),
@@ -124,6 +133,7 @@ ExecutionObject::Impl::Impl(Device* d,
     shared_initialize_params_m->l2HeapSize   = tinn::internal::DMEM1_SIZE;
     shared_initialize_params_m->l1HeapSize   = tinn::internal::DMEM0_SIZE;
     shared_initialize_params_m->enableTrace  = OCL_TIDL_TRACE_OFF;
+    shared_initialize_params_m->enableInternalInput = internal_input ? 1 : 0;
 
     // Setup kernel arguments for initialize
     KernelArgs args = { create_arg,
@@ -152,7 +162,8 @@ char* ExecutionObject::GetInputBufferPtr() const
 
 size_t ExecutionObject::GetInputBufferSizeInBytes() const
 {
-    return pimpl_m->in_m.size();
+    if (pimpl_m->in_m.ptr() == nullptr)  return pimpl_m->in_size_m;
+    else                                 return pimpl_m->in_m.size();
 }
 
 char* ExecutionObject::GetOutputBufferPtr() const
@@ -162,7 +173,8 @@ char* ExecutionObject::GetOutputBufferPtr() const
 
 size_t ExecutionObject::GetOutputBufferSizeInBytes() const
 {
-    return pimpl_m->shared_process_params_m.get()->bytesWritten;
+    if (pimpl_m->out_m.ptr() == nullptr)  return pimpl_m->out_size_m;
+    else           return pimpl_m->shared_process_params_m.get()->bytesWritten;
 }
 
 void  ExecutionObject::SetFrameIndex(int idx)
@@ -177,9 +189,6 @@ int ExecutionObject::GetFrameIndex() const
 
 void ExecutionObject::SetInputOutputBuffer(const ArgInfo& in, const ArgInfo& out)
 {
-    assert (in.ptr() != nullptr && in.size() > 0);
-    assert (out.ptr() != nullptr && out.size() > 0);
-
     pimpl_m->SetupProcessKernel(in, out);
 }
 
@@ -231,8 +240,13 @@ ExecutionObject::Impl::SetupProcessKernel(const ArgInfo& in, const ArgInfo& out)
 
     shared_process_params_m.reset(malloc_ddr<OCL_TIDL_ProcessParams>());
     shared_process_params_m->enableTrace = OCL_TIDL_TRACE_OFF;
+    shared_process_params_m->enableInternalInput = 
+                               shared_initialize_params_m->enableInternalInput;
     shared_process_params_m->cycles = 0;
 
+    if (shared_process_params_m->enableInternalInput == 0)
+        assert(in.ptr() != nullptr && in.size() > 0);
+
     KernelArgs args = { ArgInfo(shared_process_params_m.get(),
                                 sizeof(OCL_TIDL_ProcessParams)),
                         in,
@@ -280,46 +294,97 @@ static size_t writeDataS8(char *writePtr, const char *ptr, int n, int width,
 
 void ExecutionObject::Impl::HostWriteNetInput()
 {
-    char* readPtr = (char *) in_m.ptr();
+    char* readPtr  = (char *) in_m.ptr();
+    PipeInfo *pipe = in_m.GetPipe();
+
     for (unsigned int i = 0; i < shared_initialize_params_m->numInBufs; i++)
     {
         OCL_TIDL_BufParams *inBuf = &shared_initialize_params_m->inBufs[i];
-        readPtr += readDataS8(
-            readPtr,
-            (char *) tidl_extmem_heap_m.get() + inBuf->bufPlaneBufOffset
-                + inBuf->bufPlaneWidth * OCL_TIDL_MAX_PAD_SIZE
-                + OCL_TIDL_MAX_PAD_SIZE,
-            inBuf->numROIs,
-            inBuf->numChannels,
-            inBuf->ROIWidth,
-            inBuf->ROIHeight,
-            inBuf->bufPlaneWidth,
-            inBuf->bufPlaneWidth
-                * (inBuf->ROIHeight + 2 * OCL_TIDL_MAX_PAD_SIZE) );
+
+        if (shared_process_params_m->enableInternalInput == 0)
+        {
+            readPtr += readDataS8(
+                readPtr,
+                (char *) tidl_extmem_heap_m.get() + inBuf->bufPlaneBufOffset
+                    + inBuf->bufPlaneWidth * OCL_TIDL_MAX_PAD_SIZE
+                    + OCL_TIDL_MAX_PAD_SIZE,
+                inBuf->numROIs,
+                inBuf->numChannels,
+                inBuf->ROIWidth,
+                inBuf->ROIHeight,
+                inBuf->bufPlaneWidth,
+                ((inBuf->bufPlaneWidth * inBuf->bufPlaneHeight) /
+                 inBuf->numChannels));
+        }
+        else
+        {
+            shared_process_params_m->inBufAddr[i] = pipe->bufAddr_m[i];
+        }
+
+        shared_process_params_m->inDataQ[i]   = pipe->dataQ_m[i];
     }
 }
 
 void ExecutionObject::Impl::HostReadNetOutput()
 {
     char* writePtr = (char *) out_m.ptr();
+    PipeInfo *pipe = out_m.GetPipe();
+
     for (unsigned int i = 0; i < shared_initialize_params_m->numOutBufs; i++)
     {
         OCL_TIDL_BufParams *outBuf = &shared_initialize_params_m->outBufs[i];
-        writePtr += writeDataS8(
-            writePtr,
-            (char *) tidl_extmem_heap_m.get() + outBuf->bufPlaneBufOffset
-                + outBuf->bufPlaneWidth * OCL_TIDL_MAX_PAD_SIZE
-                + OCL_TIDL_MAX_PAD_SIZE,
-            outBuf->numChannels,
-            outBuf->ROIWidth,
-            outBuf->ROIHeight,
-            outBuf->bufPlaneWidth,
-            ((outBuf->bufPlaneWidth * outBuf->bufPlaneHeight)/
-             outBuf->numChannels));
+        if (writePtr != nullptr)
+        {
+            writePtr += writeDataS8(
+                writePtr,
+                (char *) tidl_extmem_heap_m.get() + outBuf->bufPlaneBufOffset
+                    + outBuf->bufPlaneWidth * OCL_TIDL_MAX_PAD_SIZE
+                    + OCL_TIDL_MAX_PAD_SIZE,
+                outBuf->numChannels,
+                outBuf->ROIWidth,
+                outBuf->ROIHeight,
+                outBuf->bufPlaneWidth,
+                ((outBuf->bufPlaneWidth * outBuf->bufPlaneHeight)/
+                 outBuf->numChannels));
+        }
+
+        pipe->dataQ_m[i]   = shared_process_params_m->outDataQ[i];
+        pipe->bufAddr_m[i] = shared_initialize_params_m->bufAddrBase
+                           + outBuf->bufPlaneBufOffset;
     }
     shared_process_params_m->bytesWritten = writePtr - (char *) out_m.ptr();
 }
 
+void ExecutionObject::Impl::ComputeInputOutputSizes()
+{
+    if (shared_initialize_params_m->errorCode != OCL_TIDL_SUCCESS)  return;
+
+    if (shared_initialize_params_m->numInBufs > OCL_TIDL_MAX_IN_BUFS ||
+        shared_initialize_params_m->numOutBufs > OCL_TIDL_MAX_OUT_BUFS)
+    {
+        std::cout << "Num input/output bufs ("
+                  << shared_initialize_params_m->numInBufs << ", "
+                  << shared_initialize_params_m->numOutBufs
+                  << ") exceeded limit!" << std::endl;
+        shared_initialize_params_m->errorCode = OCL_TIDL_INIT_FAIL;
+        return;
+    }
+
+    in_size_m  = 0;
+    out_size_m = 0;
+    for (unsigned int i = 0; i < shared_initialize_params_m->numInBufs; i++)
+    {
+        OCL_TIDL_BufParams *inBuf = &shared_initialize_params_m->inBufs[i];
+        in_size_m += inBuf->numROIs * inBuf->numChannels * inBuf->ROIWidth *
+                     inBuf->ROIHeight;
+    }
+    for (unsigned int i = 0; i < shared_initialize_params_m->numOutBufs; i++)
+    {
+        OCL_TIDL_BufParams *outBuf = &shared_initialize_params_m->outBufs[i];
+        out_size_m += outBuf->numChannels * outBuf->ROIWidth *outBuf->ROIHeight;
+    }
+}
+
 
 bool ExecutionObject::Impl::RunAsync(CallType ct)
 {
@@ -360,6 +425,7 @@ bool ExecutionObject::Impl::Wait(CallType ct)
 
             if (has_work)
             {
+                ComputeInputOutputSizes();
                 if (shared_initialize_params_m->errorCode != OCL_TIDL_SUCCESS)
                     throw Exception(shared_initialize_params_m->errorCode,
                                     __FILE__, __FUNCTION__, __LINE__);
@@ -371,11 +437,10 @@ bool ExecutionObject::Impl::Wait(CallType ct)
             bool has_work = k_process_m->Wait();
             if (has_work)
             {
-                HostReadNetOutput();
-
                 if (shared_process_params_m->errorCode != OCL_TIDL_SUCCESS)
                     throw Exception(shared_process_params_m->errorCode,
                                     __FILE__, __FUNCTION__, __LINE__);
+                HostReadNetOutput();
             }
 
             return has_work;