Enqueue multiple frames at device side
[tidl/tidl-api.git] / tidl_api / src / execution_object.cpp
index 00d6804d03ad2f481e034e78d1ba589c09eaf1e1..9bc9f05d61f1cd06574f986fc175ef95c1347b7c 100644 (file)
@@ -55,14 +55,14 @@ class ExecutionObject::Impl
              int    layers_group_id);
         ~Impl() {}
 
-        bool RunAsync(CallType ct);
-        bool Wait    (CallType ct);
-        bool AddCallback(CallType ct, void *user_data);
+        bool RunAsync(CallType ct, uint32_t context_idx);
+        bool Wait    (CallType ct, uint32_t context_idx);
+        bool AddCallback(CallType ct, void *user_data, uint32_t context_idx);
 
-        uint64_t GetProcessCycles() const;
+        uint64_t GetProcessCycles(uint32_t context_idx) const;
         int  GetLayersGroupId() const;
-        void AcquireLock();
-        void ReleaseLock();
+        void AcquireContext(uint32_t& context_idx);
+        void ReleaseContext(uint32_t  context_idx);
 
         Device*                         device_m;
         // Index of the OpenCL device/queue used by this EO
@@ -75,8 +75,8 @@ class ExecutionObject::Impl
 
         size_t                          in_size_m;
         size_t                          out_size_m;
-        IODeviceArgInfo                 in_m;
-        IODeviceArgInfo                 out_m;
+        IODeviceArgInfo                 in_m[tidl::internal::NUM_CONTEXTS];
+        IODeviceArgInfo                 out_m[tidl::internal::NUM_CONTEXTS];
 
         // Frame being processed by the EO
         int                             current_frame_idx_m;
@@ -96,7 +96,7 @@ class ExecutionObject::Impl
         size_t                            trace_buf_params_sz_m;
 
         // host time tracking: eo start to finish
-        float host_time_m;
+        float host_time_m[tidl::internal::NUM_CONTEXTS];
 
     private:
         void SetupInitializeKernel(const DeviceArgInfo& create_arg,
@@ -104,8 +104,8 @@ class ExecutionObject::Impl
         void EnableOutputBufferTrace();
         void SetupProcessKernel();
 
-        void HostWriteNetInput();
-        void HostReadNetOutput();
+        void HostWriteNetInput(uint32_t context_idx);
+        void HostReadNetOutput(uint32_t context_idx);
         void ComputeInputOutputSizes();
 
         std::unique_ptr<Kernel>         k_initialize_m;
@@ -113,7 +113,8 @@ class ExecutionObject::Impl
         std::unique_ptr<Kernel>         k_cleanup_m;
 
         // Guarding sole access to input/output for one frame during execution
-        bool                            is_idle_m;
+        // Encoding: context at bit index, bit value: 0 for idle, 1 for busy
+        uint32_t                        idle_encoding_m;
         std::mutex                      mutex_access_m;
         std::condition_variable         cv_access_m;
 
@@ -155,8 +156,6 @@ ExecutionObject::Impl::Impl(Device* d, uint8_t device_index,
     shared_process_params_m(nullptr, &__free_ddr),
     in_size_m(0),
     out_size_m(0),
-    in_m(),
-    out_m(),
     current_frame_idx_m(0),
     layers_group_id_m(layers_group_id),
     num_network_layers_m(0),
@@ -165,7 +164,7 @@ ExecutionObject::Impl::Impl(Device* d, uint8_t device_index,
     k_initialize_m(nullptr),
     k_process_m(nullptr),
     k_cleanup_m(nullptr),
-    is_idle_m(true),
+    idle_encoding_m(0),  // all contexts are idle
     configuration_m(configuration)
 {
     device_name_m = device_m->GetDeviceName() + std::to_string(device_index_m);
@@ -189,7 +188,7 @@ ExecutionObject::~ExecutionObject() = default;
 
 char* ExecutionObject::GetInputBufferPtr() const
 {
-    return static_cast<char *>(pimpl_m->in_m.GetArg().ptr());
+    return static_cast<char *>(pimpl_m->in_m[0].GetArg().ptr());
 }
 
 size_t ExecutionObject::GetInputBufferSizeInBytes() const
@@ -199,7 +198,7 @@ size_t ExecutionObject::GetInputBufferSizeInBytes() const
 
 char* ExecutionObject::GetOutputBufferPtr() const
 {
-    return static_cast<char *>(pimpl_m->out_m.GetArg().ptr());
+    return static_cast<char *>(pimpl_m->out_m[0].GetArg().ptr());
 }
 
 size_t ExecutionObject::GetOutputBufferSizeInBytes() const
@@ -217,59 +216,89 @@ int ExecutionObject::GetFrameIndex() const
     return pimpl_m->current_frame_idx_m;
 }
 
-void ExecutionObject::SetInputOutputBuffer(const ArgInfo& in, const ArgInfo& out)
+void ExecutionObject::SetInputOutputBuffer(const ArgInfo& in,
+                                           const ArgInfo& out)
+{
+    SetInputOutputBuffer(in, out, 0);
+}
+
+void ExecutionObject::SetInputOutputBuffer(const ArgInfo& in,
+                                      const ArgInfo& out, uint32_t context_idx)
 {
     assert(in.ptr()  != nullptr && in.size()  >= pimpl_m->in_size_m);
     assert(out.ptr() != nullptr && out.size() >= pimpl_m->out_size_m);
 
-    pimpl_m->in_m  = IODeviceArgInfo(in);
-    pimpl_m->out_m = IODeviceArgInfo(out);
+    pimpl_m->in_m[context_idx]  = IODeviceArgInfo(in);
+    pimpl_m->out_m[context_idx] = IODeviceArgInfo(out);
 }
 
 void ExecutionObject::SetInputOutputBuffer(const IODeviceArgInfo* in,
-                                           const IODeviceArgInfo* out)
+                                           const IODeviceArgInfo* out,
+                                           uint32_t context_idx)
 {
-    pimpl_m->in_m  = *in;
-    pimpl_m->out_m = *out;
+    pimpl_m->in_m[context_idx]  = *in;
+    pimpl_m->out_m[context_idx] = *out;
 }
 
 bool ExecutionObject::ProcessFrameStartAsync()
 {
-    TRACE::print("-> ExecutionObject::ProcessFrameStartAsync()\n");
+    return ProcessFrameStartAsync(0);
+}
+
+bool ExecutionObject::ProcessFrameStartAsync(uint32_t context_idx)
+{
+    TRACE::print("-> ExecutionObject::ProcessFrameStartAsync(%d)\n",
+                 context_idx);
     assert(GetInputBufferPtr() != nullptr && GetOutputBufferPtr() != nullptr);
-    return pimpl_m->RunAsync(ExecutionObject::CallType::PROCESS);
+    return pimpl_m->RunAsync(ExecutionObject::CallType::PROCESS, context_idx);
 }
 
 bool ExecutionObject::ProcessFrameWait()
 {
-    TRACE::print("-> ExecutionObject::ProcessFrameWait()\n");
-    return pimpl_m->Wait(ExecutionObject::CallType::PROCESS);
+    return ProcessFrameWait(0);
+}
+
+bool ExecutionObject::ProcessFrameWait(uint32_t context_idx)
+{
+    TRACE::print("-> ExecutionObject::ProcessFrameWait(%d)\n", context_idx);
+    return pimpl_m->Wait(ExecutionObject::CallType::PROCESS, context_idx);
 }
 
 bool ExecutionObject::RunAsync (CallType ct)
 {
-    return pimpl_m->RunAsync(ct);
+    return pimpl_m->RunAsync(ct, 0);
 }
 
 bool ExecutionObject::Wait (CallType ct)
 {
-    return pimpl_m->Wait(ct);
+    return pimpl_m->Wait(ct, 0);
 }
 
-bool ExecutionObject::AddCallback(CallType ct, void *user_data)
+bool ExecutionObject::AddCallback(CallType ct, void *user_data,
+                                  uint32_t context_idx)
 {
-    return pimpl_m->AddCallback(ct, user_data);
+    return pimpl_m->AddCallback(ct, user_data, context_idx);
 }
 
 float ExecutionObject::GetProcessTimeInMilliSeconds() const
+{
+    return GetProcessTimeInMilliSeconds(0);
+}
+
+float ExecutionObject::GetProcessTimeInMilliSeconds(uint32_t context_idx) const
 {
     float frequency = pimpl_m->device_m->GetFrequencyInMhz() * 1000000;
-    return ((float)pimpl_m->GetProcessCycles()) / frequency * 1000;
+    return ((float)pimpl_m->GetProcessCycles(context_idx)) / frequency * 1000;
 }
 
 float ExecutionObject::GetHostProcessTimeInMilliSeconds() const
 {
-    return pimpl_m->host_time_m;
+    return GetHostProcessTimeInMilliSeconds(0);
+}
+
+float ExecutionObject::GetHostProcessTimeInMilliSeconds(uint32_t context_idx) const
+{
+    return pimpl_m->host_time_m[context_idx];
 }
 
 void
@@ -299,14 +328,14 @@ const std::string& ExecutionObject::GetDeviceName() const
     return pimpl_m->device_name_m;
 }
 
-void ExecutionObject::AcquireLock()
+void ExecutionObject::AcquireContext(uint32_t& context_idx)
 {
-    pimpl_m->AcquireLock();
+    pimpl_m->AcquireContext(context_idx);
 }
 
-void ExecutionObject::ReleaseLock()
+void ExecutionObject::ReleaseContext(uint32_t context_idx)
 {
-    pimpl_m->ReleaseLock();
+    pimpl_m->ReleaseContext(context_idx);
 }
 
 //
@@ -334,7 +363,7 @@ ExecutionObject::Impl::SetupInitializeKernel(const DeviceArgInfo& create_arg,
     shared_initialize_params_m->tidlHeapSize =configuration_m.NETWORK_HEAP_SIZE;
     shared_initialize_params_m->l2HeapSize   = tidl::internal::DMEM1_SIZE;
     shared_initialize_params_m->l1HeapSize   = tidl::internal::DMEM0_SIZE;
-    shared_initialize_params_m->enableInternalInput = 0;
+    shared_initialize_params_m->numContexts  = tidl::internal::NUM_CONTEXTS;
 
     // Set up execution trace specified in the configuration
     EnableExecutionTrace(configuration_m,
@@ -392,16 +421,19 @@ void ExecutionObject::Impl::EnableOutputBufferTrace()
 void
 ExecutionObject::Impl::SetupProcessKernel()
 {
-    shared_process_params_m.reset(malloc_ddr<OCL_TIDL_ProcessParams>());
-    shared_process_params_m->enableInternalInput =
-                               shared_initialize_params_m->enableInternalInput;
-    shared_process_params_m->cycles = 0;
+    shared_process_params_m.reset(malloc_ddr<OCL_TIDL_ProcessParams>(
+               tidl::internal::NUM_CONTEXTS * sizeof(OCL_TIDL_ProcessParams)));
 
     // Set up execution trace specified in the configuration
-    EnableExecutionTrace(configuration_m,
-                         &shared_process_params_m->enableTrace);
+    for (int i = 0; i < tidl::internal::NUM_CONTEXTS; i++)
+    {
+        OCL_TIDL_ProcessParams *p_params = shared_process_params_m.get() + i;
+        EnableExecutionTrace(configuration_m, &p_params->enableTrace);
+    }
 
+    uint32_t context_idx = 0;
     KernelArgs args = { DeviceArgInfo(shared_process_params_m.get(),
+                                      tidl::internal::NUM_CONTEXTS *
                                       sizeof(OCL_TIDL_ProcessParams),
                                       DeviceArgInfo::Kind::BUFFER),
                         DeviceArgInfo(tidl_extmem_heap_m.get(),
@@ -409,8 +441,10 @@ ExecutionObject::Impl::SetupProcessKernel()
                                       DeviceArgInfo::Kind::BUFFER),
                         DeviceArgInfo(trace_buf_params_m.get(),
                                       trace_buf_params_sz_m,
-                                      DeviceArgInfo::Kind::BUFFER)
-
+                                      DeviceArgInfo::Kind::BUFFER),
+                        DeviceArgInfo(&context_idx,
+                                      sizeof(uint32_t),
+                                      DeviceArgInfo::Kind::SCALAR)
                       };
 
     k_process_m.reset(new Kernel(device_m,
@@ -452,20 +486,25 @@ static size_t writeDataS8(char *writePtr, const char *ptr, int n, int width,
 //
 // Copy from host buffer to TIDL device buffer
 //
-void ExecutionObject::Impl::HostWriteNetInput()
+void ExecutionObject::Impl::HostWriteNetInput(uint32_t context_idx)
 {
-    const char*     readPtr  = (const char *) in_m.GetArg().ptr();
-    const PipeInfo& pipe     = in_m.GetPipe();
+    const char*     readPtr  = (const char *) in_m[context_idx].GetArg().ptr();
+    const PipeInfo& pipe     = in_m[context_idx].GetPipe();
+    OCL_TIDL_ProcessParams *p_params = shared_process_params_m.get()
+                                       + context_idx;
 
     for (unsigned int i = 0; i < shared_initialize_params_m->numInBufs; i++)
     {
         OCL_TIDL_BufParams *inBuf = &shared_initialize_params_m->inBufs[i];
+        uint32_t context_size = inBuf->bufPlaneWidth * inBuf->bufPlaneHeight;
+                 context_size = (context_size + OCL_TIDL_CACHE_ALIGN - 1) &
+                                (~(OCL_TIDL_CACHE_ALIGN - 1));
+        char *inBufAddr = tidl_extmem_heap_m.get() + inBuf->bufPlaneBufOffset
+                          + context_idx * context_size;
 
-        if (shared_process_params_m->enableInternalInput == 0)
-        {
             readPtr += readDataS8(
                 readPtr,
-                (char *) tidl_extmem_heap_m.get() + inBuf->bufPlaneBufOffset
+                (char *) inBufAddr
                     + inBuf->bufPlaneWidth * OCL_TIDL_MAX_PAD_SIZE
                     + OCL_TIDL_MAX_PAD_SIZE,
                 inBuf->numROIs,
@@ -475,32 +514,34 @@ void ExecutionObject::Impl::HostWriteNetInput()
                 inBuf->bufPlaneWidth,
                 ((inBuf->bufPlaneWidth * inBuf->bufPlaneHeight) /
                  inBuf->numChannels));
-        }
-        else
-        {
-            shared_process_params_m->inBufAddr[i] = pipe.bufAddr_m[i];
-        }
 
-        shared_process_params_m->inDataQ[i]   = pipe.dataQ_m[i];
+        p_params->dataQ[i] = pipe.dataQ_m[i];
     }
 }
 
 //
 // Copy from TIDL device buffer into host buffer
 //
-void ExecutionObject::Impl::HostReadNetOutput()
+void ExecutionObject::Impl::HostReadNetOutput(uint32_t context_idx)
 {
-    char* writePtr = (char *) out_m.GetArg().ptr();
-    PipeInfo& pipe = out_m.GetPipe();
+    char* writePtr = (char *) out_m[context_idx].GetArg().ptr();
+    PipeInfo& pipe = out_m[context_idx].GetPipe();
+    OCL_TIDL_ProcessParams *p_params = shared_process_params_m.get()
+                                       + context_idx;
 
     for (unsigned int i = 0; i < shared_initialize_params_m->numOutBufs; i++)
     {
         OCL_TIDL_BufParams *outBuf = &shared_initialize_params_m->outBufs[i];
+        uint32_t context_size = outBuf->bufPlaneWidth * outBuf->bufPlaneHeight;
+                 context_size = (context_size + OCL_TIDL_CACHE_ALIGN - 1) &
+                                (~(OCL_TIDL_CACHE_ALIGN - 1));
+        char *outBufAddr = tidl_extmem_heap_m.get() + outBuf->bufPlaneBufOffset
+                           + context_idx * context_size;
         if (writePtr != nullptr)
         {
             writePtr += writeDataS8(
                 writePtr,
-                (char *) tidl_extmem_heap_m.get() + outBuf->bufPlaneBufOffset
+                (char *) outBufAddr
                     + outBuf->bufPlaneWidth * OCL_TIDL_MAX_PAD_SIZE
                     + OCL_TIDL_MAX_PAD_SIZE,
                 outBuf->numChannels,
@@ -511,12 +552,8 @@ void ExecutionObject::Impl::HostReadNetOutput()
                  outBuf->numChannels));
         }
 
-        pipe.dataQ_m[i]   = shared_process_params_m->outDataQ[i];
-        pipe.bufAddr_m[i] = shared_initialize_params_m->bufAddrBase
-                           + outBuf->bufPlaneBufOffset;
+        pipe.dataQ_m[i]   = p_params->dataQ[i];
     }
-    shared_process_params_m->bytesWritten = writePtr -
-                                            (char *) out_m.GetArg().ptr();
 }
 
 void ExecutionObject::Impl::ComputeInputOutputSizes()
@@ -550,7 +587,7 @@ void ExecutionObject::Impl::ComputeInputOutputSizes()
 }
 
 
-bool ExecutionObject::Impl::RunAsync(CallType ct)
+bool ExecutionObject::Impl::RunAsync(CallType ct, uint32_t context_idx)
 {
     switch (ct)
     {
@@ -564,14 +601,19 @@ bool ExecutionObject::Impl::RunAsync(CallType ct)
             std::chrono::time_point<std::chrono::steady_clock> t1, t2;
             t1 = std::chrono::steady_clock::now();
 
-            shared_process_params_m->frameIdx = current_frame_idx_m;
-            shared_process_params_m->bytesWritten = 0;
-            HostWriteNetInput();
-            k_process_m->RunAsync();
+            OCL_TIDL_ProcessParams *p_params = shared_process_params_m.get()
+                                               + context_idx;
+            p_params->frameIdx = current_frame_idx_m;
+            HostWriteNetInput(context_idx);
+            {
+                std::unique_lock<std::mutex> lock(mutex_access_m);
+                k_process_m->UpdateScalarArg(3, sizeof(uint32_t), &context_idx);
+                k_process_m->RunAsync(context_idx);
+            }
 
             t2 = std::chrono::steady_clock::now();
             std::chrono::duration<float> elapsed = t2 - t1;
-            host_time_m = elapsed.count() * 1000;
+            host_time_m[context_idx] = elapsed.count() * 1000;
             break;
         }
         case CallType::CLEANUP:
@@ -586,7 +628,7 @@ bool ExecutionObject::Impl::RunAsync(CallType ct)
     return true;
 }
 
-bool ExecutionObject::Impl::Wait(CallType ct)
+bool ExecutionObject::Impl::Wait(CallType ct, uint32_t context_idx)
 {
     switch (ct)
     {
@@ -609,13 +651,15 @@ bool ExecutionObject::Impl::Wait(CallType ct)
             bool has_work = k_process_m->Wait(&host_elapsed_ms);
             if (has_work)
             {
-                if (shared_process_params_m->errorCode != OCL_TIDL_SUCCESS)
-                    throw Exception(shared_process_params_m->errorCode,
+                OCL_TIDL_ProcessParams *p_params = shared_process_params_m.get()
+                                                   + context_idx;
+                if (p_params->errorCode != OCL_TIDL_SUCCESS)
+                    throw Exception(p_params->errorCode,
                                     __FILE__, __FUNCTION__, __LINE__);
 
                 std::chrono::time_point<std::chrono::steady_clock> t1, t2;
                 t1 = std::chrono::steady_clock::now();
-                HostReadNetOutput();
+                HostReadNetOutput(context_idx);
                 t2 = std::chrono::steady_clock::now();
                 std::chrono::duration<float> elapsed = t2 - t1;
                 host_time_m += elapsed.count() * 1000 + host_elapsed_ms;
@@ -635,13 +679,14 @@ bool ExecutionObject::Impl::Wait(CallType ct)
     return false;
 }
 
-bool ExecutionObject::Impl::AddCallback(CallType ct, void *user_data)
+bool ExecutionObject::Impl::AddCallback(CallType ct, void *user_data,
+                                        uint32_t context_idx)
 {
     switch (ct)
     {
         case CallType::PROCESS:
         {
-            return k_process_m->AddCallback(user_data);
+            return k_process_m->AddCallback(user_data, context_idx);
             break;
         }
         default:
@@ -651,7 +696,7 @@ bool ExecutionObject::Impl::AddCallback(CallType ct, void *user_data)
     return false;
 }
 
-uint64_t ExecutionObject::Impl::GetProcessCycles() const
+uint64_t ExecutionObject::Impl::GetProcessCycles(uint32_t context_idx) const
 {
     uint8_t factor = 1;
 
@@ -659,7 +704,9 @@ uint64_t ExecutionObject::Impl::GetProcessCycles() const
     if (device_m->type() == CL_DEVICE_TYPE_CUSTOM)
         factor = 2;
 
-    return shared_process_params_m.get()->cycles * factor;
+    OCL_TIDL_ProcessParams *p_params = shared_process_params_m.get() +
+                                       context_idx;
+    return p_params->cycles * factor;
 }
 
 //
@@ -786,15 +833,26 @@ LayerOutput::~LayerOutput()
     delete[] data_m;
 }
 
-void ExecutionObject::Impl::AcquireLock()
+void ExecutionObject::Impl::AcquireContext(uint32_t& context_idx)
 {
     std::unique_lock<std::mutex> lock(mutex_access_m);
-    cv_access_m.wait(lock, [this]{ return this->is_idle_m; });
-    is_idle_m = false;
+    cv_access_m.wait(lock, [this]{ return this->idle_encoding_m <
+                                   (1 << tidl::internal::NUM_CONTEXTS) - 1; });
+
+    for (uint32_t i = 0; i < tidl::internal::NUM_CONTEXTS; i++)
+        if (((1 << i) & idle_encoding_m) == 0)
+        {
+            context_idx = i;
+            break;
+        }
+    idle_encoding_m |= (1 << context_idx);  // mark the bit as busy
 }
 
-void ExecutionObject::Impl::ReleaseLock()
+void ExecutionObject::Impl::ReleaseContext(uint32_t context_idx)
 {
-    is_idle_m = true;
+    {
+        std::unique_lock<std::mutex> lock(mutex_access_m);
+        idle_encoding_m &= (~(1 << context_idx));  // mark the bit as free
+    }
     cv_access_m.notify_all();
 }