Modified IODeviceArgInfo to enable pipelining EOs
[tidl/tidl-api.git] / tidl_api / src / execution_object.cpp
index 6ebf653d46710e9e9828986da257ee90f3053ebf..d722ebb196669019fbc3071338397d62a9a79ba9 100644 (file)
@@ -28,6 +28,9 @@
 
 /*! \file execution_object.cpp */
 
+#include <string.h>
+#include <fstream>
+#include <climits>
 #include "executor.h"
 #include "execution_object.h"
 #include "trace.h"
 #include "parameters.h"
 #include "configuration.h"
 #include "common_defines.h"
-#include <string.h>
 #include "tidl_create_params.h"
-#include <fstream>
-#include <climits>
+#include "device_arginfo.h"
 
 using namespace tidl;
 
@@ -46,8 +47,8 @@ class ExecutionObject::Impl
 {
     public:
         Impl(Device* d, uint8_t device_index,
-             const ArgInfo& create_arg,
-             const ArgInfo& param_heap_arg,
+             const DeviceArgInfo& create_arg,
+             const DeviceArgInfo& param_heap_arg,
              size_t extmem_heap_size,
              bool   internal_input);
         ~Impl() {}
@@ -63,8 +64,8 @@ class ExecutionObject::Impl
 
         size_t                          in_size_m;
         size_t                          out_size_m;
-        ArgInfo                         in_m;
-        ArgInfo                         out_m;
+        IODeviceArgInfo                 in_m;
+        IODeviceArgInfo                 out_m;
 
         // Frame being processed by the EO
         int                             current_frame_idx_m;
@@ -81,8 +82,8 @@ class ExecutionObject::Impl
         size_t                            trace_buf_params_sz_m;
 
     private:
-        void SetupInitializeKernel(const ArgInfo& create_arg,
-                                   const ArgInfo& param_heap_arg,
+        void SetupInitializeKernel(const DeviceArgInfo& create_arg,
+                                   const DeviceArgInfo& param_heap_arg,
                                    size_t extmem_heap_size,
                                    bool   internal_input);
         void SetupProcessKernel();
@@ -109,10 +110,13 @@ ExecutionObject::ExecutionObject(Device* d,
                                  size_t extmem_heap_size,
                                  bool   internal_input)
 {
+    DeviceArgInfo create_arg_d(create_arg, DeviceArgInfo::Kind::BUFFER);
+    DeviceArgInfo param_heap_arg_d(param_heap_arg, DeviceArgInfo::Kind::BUFFER);
+
     pimpl_m = std::unique_ptr<ExecutionObject::Impl>
               { new ExecutionObject::Impl(d, device_index,
-                                          create_arg,
-                                          param_heap_arg,
+                                          create_arg_d,
+                                          param_heap_arg_d,
                                           extmem_heap_size,
                                           internal_input) };
 }
@@ -120,8 +124,8 @@ ExecutionObject::ExecutionObject(Device* d,
 
 ExecutionObject::Impl::Impl(Device* d,
                                  uint8_t device_index,
-                                 const ArgInfo& create_arg,
-                                 const ArgInfo& param_heap_arg,
+                                 const DeviceArgInfo& create_arg,
+                                 const DeviceArgInfo& param_heap_arg,
                                  size_t extmem_heap_size,
                                  bool   internal_input):
     device_m(d),
@@ -130,8 +134,8 @@ ExecutionObject::Impl::Impl(Device* d,
     shared_process_params_m(nullptr, &__free_ddr),
     in_size_m(0),
     out_size_m(0),
-    in_m(nullptr, 0),
-    out_m(nullptr, 0),
+    in_m(),
+    out_m(),
     current_frame_idx_m(0),
     num_network_layers_m(0),
     trace_buf_params_m(nullptr, &__free_ddr),
@@ -159,24 +163,28 @@ ExecutionObject::~ExecutionObject() = default;
 
 char* ExecutionObject::GetInputBufferPtr() const
 {
-    return static_cast<char *>(pimpl_m->in_m.ptr());
+    return static_cast<char *>(pimpl_m->in_m.GetArg().ptr());
 }
 
 size_t ExecutionObject::GetInputBufferSizeInBytes() const
 {
-    if (pimpl_m->in_m.ptr() == nullptr)  return pimpl_m->in_size_m;
-    else                                 return pimpl_m->in_m.size();
+    const DeviceArgInfo& arg = pimpl_m->in_m.GetArg();
+    if    (arg.ptr() == nullptr)  return pimpl_m->in_size_m;
+    else                          return arg.size();
 }
 
 char* ExecutionObject::GetOutputBufferPtr() const
 {
-    return static_cast<char *>(pimpl_m->out_m.ptr());
+    return static_cast<char *>(pimpl_m->out_m.GetArg().ptr());
 }
 
 size_t ExecutionObject::GetOutputBufferSizeInBytes() const
 {
-    if (pimpl_m->out_m.ptr() == nullptr)  return pimpl_m->out_size_m;
-    else           return pimpl_m->shared_process_params_m.get()->bytesWritten;
+    const DeviceArgInfo& arg = pimpl_m->out_m.GetArg();
+    if   (arg.ptr() == nullptr)
+        return pimpl_m->out_size_m;
+    else
+        return pimpl_m->shared_process_params_m.get()->bytesWritten;
 }
 
 void  ExecutionObject::SetFrameIndex(int idx)
@@ -194,8 +202,15 @@ void ExecutionObject::SetInputOutputBuffer(const ArgInfo& in, const ArgInfo& out
     assert(in.ptr() != nullptr && in.size() > 0);
     assert(out.ptr() != nullptr && out.size() > 0);
 
-    pimpl_m->in_m  = in;
-    pimpl_m->out_m = out;
+    pimpl_m->in_m  = IODeviceArgInfo(in);
+    pimpl_m->out_m = IODeviceArgInfo(out);
+}
+
+void ExecutionObject::SetInputOutputBuffer(const IODeviceArgInfo* in,
+                                           const IODeviceArgInfo* out)
+{
+    pimpl_m->in_m  = *in;
+    pimpl_m->out_m = *out;
 }
 
 bool ExecutionObject::ProcessFrameStartAsync()
@@ -282,8 +297,8 @@ ExecutionObject::WriteLayerOutputsToFile(const std::string& filename_prefix) con
 // Create a kernel to call the "initialize" function
 //
 void
-ExecutionObject::Impl::SetupInitializeKernel(const ArgInfo& create_arg,
-                                             const ArgInfo& param_heap_arg,
+ExecutionObject::Impl::SetupInitializeKernel(const DeviceArgInfo& create_arg,
+                                             const DeviceArgInfo& param_heap_arg,
                                              size_t extmem_heap_size,
                                              bool   internal_input)
 {
@@ -310,13 +325,17 @@ ExecutionObject::Impl::SetupInitializeKernel(const ArgInfo& create_arg,
     // Setup kernel arguments for initialize
     KernelArgs args = { create_arg,
                         param_heap_arg,
-                        ArgInfo(tidl_extmem_heap_m.get(),
-                                extmem_heap_size),
-                        ArgInfo(shared_initialize_params_m.get(),
-                                sizeof(OCL_TIDL_InitializeParams)),
+                        DeviceArgInfo(tidl_extmem_heap_m.get(),
+                                      extmem_heap_size,
+                                      DeviceArgInfo::Kind::BUFFER),
+                        DeviceArgInfo(shared_initialize_params_m.get(),
+                                      sizeof(OCL_TIDL_InitializeParams),
+                                      DeviceArgInfo::Kind::BUFFER),
                         device_m->type() == CL_DEVICE_TYPE_ACCELERATOR ?
-                            ArgInfo(nullptr, tidl::internal::DMEM1_SIZE):
-                            ArgInfo(nullptr, 4)                       };
+                            DeviceArgInfo(nullptr, tidl::internal::DMEM1_SIZE,
+                                          DeviceArgInfo::Kind::LOCAL):
+                            DeviceArgInfo(nullptr, 4,
+                                          DeviceArgInfo::Kind::LOCAL) };
 
     k_initialize_m.reset(new Kernel(device_m,
                                     STRING(INIT_KERNEL), args,
@@ -335,12 +354,15 @@ ExecutionObject::Impl::SetupProcessKernel()
                                shared_initialize_params_m->enableInternalInput;
     shared_process_params_m->cycles = 0;
 
-    KernelArgs args = { ArgInfo(shared_process_params_m.get(),
-                                sizeof(OCL_TIDL_ProcessParams)),
-                        ArgInfo(tidl_extmem_heap_m.get(),
-                                shared_initialize_params_m->tidlHeapSize),
-                        ArgInfo(trace_buf_params_m.get(),
-                                trace_buf_params_sz_m)
+    KernelArgs args = { DeviceArgInfo(shared_process_params_m.get(),
+                                      sizeof(OCL_TIDL_ProcessParams),
+                                      DeviceArgInfo::Kind::BUFFER),
+                        DeviceArgInfo(tidl_extmem_heap_m.get(),
+                                      shared_initialize_params_m->tidlHeapSize,
+                                      DeviceArgInfo::Kind::BUFFER),
+                        DeviceArgInfo(trace_buf_params_m.get(),
+                                      trace_buf_params_sz_m,
+                                      DeviceArgInfo::Kind::BUFFER)
 
                       };
 
@@ -385,8 +407,8 @@ static size_t writeDataS8(char *writePtr, const char *ptr, int n, int width,
 //
 void ExecutionObject::Impl::HostWriteNetInput()
 {
-    const char*     readPtr  = (const char *) in_m.ptr();
-    const PipeInfo* pipe     = in_m.GetPipe();
+    const char*     readPtr  = (const char *) in_m.GetArg().ptr();
+    const PipeInfo& pipe     = in_m.GetPipe();
 
     for (unsigned int i = 0; i < shared_initialize_params_m->numInBufs; i++)
     {
@@ -409,10 +431,10 @@ void ExecutionObject::Impl::HostWriteNetInput()
         }
         else
         {
-            shared_process_params_m->inBufAddr[i] = pipe->bufAddr_m[i];
+            shared_process_params_m->inBufAddr[i] = pipe.bufAddr_m[i];
         }
 
-        shared_process_params_m->inDataQ[i]   = pipe->dataQ_m[i];
+        shared_process_params_m->inDataQ[i]   = pipe.dataQ_m[i];
     }
 }
 
@@ -421,8 +443,8 @@ void ExecutionObject::Impl::HostWriteNetInput()
 //
 void ExecutionObject::Impl::HostReadNetOutput()
 {
-    char* writePtr = (char *) out_m.ptr();
-    PipeInfo* pipe = out_m.GetPipe();
+    char* writePtr = (char *) out_m.GetArg().ptr();
+    PipeInfo& pipe = out_m.GetPipe();
 
     for (unsigned int i = 0; i < shared_initialize_params_m->numOutBufs; i++)
     {
@@ -442,11 +464,12 @@ void ExecutionObject::Impl::HostReadNetOutput()
                  outBuf->numChannels));
         }
 
-        pipe->dataQ_m[i]   = shared_process_params_m->outDataQ[i];
-        pipe->bufAddr_m[i] = shared_initialize_params_m->bufAddrBase
+        pipe.dataQ_m[i]   = shared_process_params_m->outDataQ[i];
+        pipe.bufAddr_m[i] = shared_initialize_params_m->bufAddrBase
                            + outBuf->bufPlaneBufOffset;
     }
-    shared_process_params_m->bytesWritten = writePtr - (char *) out_m.ptr();
+    shared_process_params_m->bytesWritten = writePtr -
+                                            (char *) out_m.GetArg().ptr();
 }
 
 void ExecutionObject::Impl::ComputeInputOutputSizes()