Enqueue multiple frames at device side
[tidl/tidl-api.git] / tidl_api / src / ocl_device.cpp
index 5768627521a8f33ed46e508b13eb4cba6d9363ab..508c5498a352d43ab99c24390ffbd9010d3d1d88 100644 (file)
@@ -262,7 +262,7 @@ bool EveDevice::BuildProgramFromBinary(const std::string& kernel_names,
 Kernel::Kernel(Device* device, const std::string& name,
                const KernelArgs& args, uint8_t device_index):
            name_m(name), device_m(device), device_index_m(device_index),
-           is_running_m(false)
+           num_running_contexts_m(0)
 {
     TRACE::print("Creating kernel %s\n", name.c_str());
     cl_int err;
@@ -304,45 +304,52 @@ Kernel::Kernel(Device* device, const std::string& name,
     }
 }
 
-Kernel& Kernel::RunAsync()
+bool Kernel::UpdateScalarArg(uint32_t index, size_t size, const void *value)
+{
+    cl_int ret = clSetKernelArg(kernel_m, index, size, value);
+    return ret == CL_SUCCESS;
+}
+
+Kernel& Kernel::RunAsync(uint32_t context_idx)
 {
     // Execute kernel
-    TRACE::print("\tKernel: device %d executing %s\n", device_index_m,
-                                                       name_m.c_str());
+    TRACE::print("\tKernel: device %d executing %s, context %d\n",
+                 device_index_m, name_m.c_str(), context_idx);
     cl_int ret = clEnqueueTask(device_m->queue_m[device_index_m],
-                               kernel_m, 0, 0, &event_m);
+                               kernel_m, 0, 0, &event_m[context_idx]);
     errorCheck(ret, __LINE__);
-    is_running_m = true;
+    __sync_fetch_and_add(&num_running_contexts_m, 1);
 
     return *this;
 }
 
 
-bool Kernel::Wait(float *host_elapsed_ms)
+bool Kernel::Wait(float *host_elapsed_ms, uint32_t context_idx)
 {
     // Wait called without a corresponding RunAsync
-    if (!is_running_m)
+    if (num_running_contexts_m == 0)
         return false;
 
-    TRACE::print("\tKernel: waiting...\n");
-    cl_int ret = clWaitForEvents(1, &event_m);
+    TRACE::print("\tKernel: waiting context %d...\n", context_idx);
+    cl_int ret = clWaitForEvents(1, &event_m[context_idx]);
     errorCheck(ret, __LINE__);
 
     if (host_elapsed_ms != nullptr)
     {
         cl_ulong t_que, t_end;
-        clGetEventProfilingInfo(event_m, CL_PROFILING_COMMAND_QUEUED,
+        clGetEventProfilingInfo(event_m[context_idx],
+                                CL_PROFILING_COMMAND_QUEUED,
                                 sizeof(cl_ulong), &t_que, nullptr);
-        clGetEventProfilingInfo(event_m, CL_PROFILING_COMMAND_END,
+        clGetEventProfilingInfo(event_m[context_idx], CL_PROFILING_COMMAND_END,
                                 sizeof(cl_ulong), &t_end, nullptr);
         *host_elapsed_ms = (t_end - t_que) / 1.0e6;  // nano to milli seconds
     }
 
-    ret = clReleaseEvent(event_m);
+    ret = clReleaseEvent(event_m[context_idx]);
     errorCheck(ret, __LINE__);
     TRACE::print("\tKernel: finished execution\n");
 
-    is_running_m = false;
+    __sync_fetch_and_sub(&num_running_contexts_m, 1);
     return true;
 }
 
@@ -355,11 +362,11 @@ void EventCallback(cl_event event, cl_int exec_status, void *user_data)
     if (CallbackWrapper)  CallbackWrapper(user_data);
 }
 
-bool Kernel::AddCallback(void *user_data)
+bool Kernel::AddCallback(void *user_data, uint32_t context_idx)
 {
-    if (! is_running_m)  return false;
-    return clSetEventCallback(event_m, CL_COMPLETE, EventCallback, user_data)
-           == CL_SUCCESS;
+    if (num_running_contexts_m == 0)  return false;
+    return clSetEventCallback(event_m[context_idx], CL_COMPLETE, EventCallback,
+                              user_data) == CL_SUCCESS;
 }
 
 Kernel::~Kernel()