Added Python variant of mnist example
authorAjay Jayaraj <ajayj@ti.com>
Tue, 27 Nov 2018 20:56:19 +0000 (14:56 -0600)
committerAjay Jayaraj <ajayj@ti.com>
Tue, 27 Nov 2018 22:15:50 +0000 (16:15 -0600)
Also fix one_eo_per_frame.py to avoid creating an EVE executor if there are
no EVEs available.
(MCT-1088)

docs/source/changelog.rst
examples/pybind/mnist.py [new file with mode: 0755]
examples/pybind/one_eo_per_frame.py
examples/pybind/tidl_app_utils.py
examples/pybind/two_eo_per_frame.py
examples/pybind/two_eo_per_frame_opt.py

index 1be54ac4444f7ac96928f89b268f6dc769c7a2fb..7057b5980c4a2fa26911e134de81907564c3e135 100644 (file)
@@ -7,7 +7,7 @@ Changelog
 **Added**
 
 #. Updated API implementation to minimize TIDL API/OpenCL dispatch overhead using multiple execution contexts in the :term:`ExecutionObject`.
-   See :ref:`mnist-example` example for demonstration.
+   Refer to :ref:`mnist-example` example for details.
 
 #. Execution Graph generation
 
diff --git a/examples/pybind/mnist.py b/examples/pybind/mnist.py
new file mode 100755 (executable)
index 0000000..4029982
--- /dev/null
@@ -0,0 +1,200 @@
+#!/usr/bin/python3
+
+# Copyright (c) 2018 Texas Instruments Incorporated - http://www.ti.com/
+# All rights reserved.
+#
+# Redistribution and use in source and binary forms, with or without
+# modification, are permitted provided that the following conditions are met:
+# * Redistributions of source code must retain the above copyright
+# notice, this list of conditions and the following disclaimer.
+# * Redistributions in binary form must reproduce the above copyright
+# notice, this list of conditions and the following disclaimer in the
+# documentation and/or other materials provided with the distribution.
+# * Neither the name of Texas Instruments Incorporated nor the
+# names of its contributors may be used to endorse or promote products
+# derived from this software without specific prior written permission.
+#
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
+# ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE
+# LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
+# CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
+# SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
+# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
+# CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
+# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF
+# THE POSSIBILITY OF SUCH DAMAGE.
+
+""" Process each frame using a single ExecutionObject.
+    Increase throughput by using multiple ExecutionObjects.
+"""
+
+import os
+import argparse
+import numpy as np
+
+from tidl import DeviceId, DeviceType, Configuration, TidlError
+from tidl import Executor, ExecutionObjectPipeline
+from tidl import allocate_memory, free_memory
+
+
+def main():
+    """Read the configuration and run the network"""
+
+    args = parse_args()
+
+    config_file = '../test/testvecs/config/infer/tidl_config_mnist_lenet.txt'
+    labels_file = '../test/testvecs/input/digits10_labels_10x1.y'
+
+    configuration = Configuration()
+    configuration.read_from_file(config_file)
+
+    num_eve = Executor.get_num_devices(DeviceType.EVE)
+    num_dsp = 0
+
+    if num_eve == 0:
+        print('MNIST network currently supported only on EVE')
+        return
+
+    run(num_eve, num_dsp, configuration, labels_file)
+
+    return
+
+
+DESCRIPTION = 'Run the mnist network on preprocessed input.'
+
+def parse_args():
+    """Parse input arguments"""
+
+    parser = argparse.ArgumentParser(description=DESCRIPTION)
+    args = parser.parse_args()
+
+    return args
+
+PIPELINE_DEPTH = 2
+
+def run(num_eve, num_dsp, configuration, labels_file):
+    """ Run the network on the specified device type and number of devices"""
+
+    print('Running network across {} EVEs, {} DSPs'.format(num_eve, num_dsp))
+
+    dsp_device_ids = set([DeviceId.ID0, DeviceId.ID1,
+                          DeviceId.ID2, DeviceId.ID3][0:num_dsp])
+    eve_device_ids = set([DeviceId.ID0, DeviceId.ID1,
+                          DeviceId.ID2, DeviceId.ID3][0:num_eve])
+
+    # Heap sizes for this network determined using Configuration.showHeapStats
+    configuration.param_heap_size = (3 << 20)
+    configuration.network_heap_size = (20 << 20)
+
+
+    try:
+        print('TIDL API: performing one time initialization ...')
+
+        # Collect all EOs from EVE and DSP executors
+        eos = []
+
+        if eve_device_ids:
+            eve = Executor(DeviceType.EVE, eve_device_ids, configuration, 1)
+            for i in range(eve.get_num_execution_objects()):
+                eos.append(eve.at(i))
+
+        if dsp_device_ids:
+            dsp = Executor(DeviceType.DSP, dsp_device_ids, configuration, 1)
+            for i in range(dsp.get_num_execution_objects()):
+                eos.append(dsp.at(i))
+
+
+        eops = []
+        num_eos = len(eos)
+        for j in range(num_eos):
+            for i in range(PIPELINE_DEPTH):
+                eops.append(ExecutionObjectPipeline([eos[i]]))
+
+        allocate_memory(eops)
+
+        # Open input, output files
+        f_in = open(configuration.in_data, 'rb')
+        f_labels = open(labels_file, 'rb')
+
+        input_size = os.path.getsize(configuration.in_data)
+        configuration.num_frames = int(input_size/(configuration.height *
+                                                   configuration.width))
+
+        print('TIDL API: processing {} input frames ...'.format(configuration.num_frames))
+
+        num_eops = len(eops)
+        num_errors = 0
+        for frame_index in range(configuration.num_frames+num_eops):
+            eop = eops[frame_index % num_eops]
+
+            if eop.process_frame_wait():
+                num_errors += process_output(eop, f_labels)
+
+            if read_frame(eop, frame_index, configuration, f_in):
+                eop.process_frame_start_async()
+
+
+        f_in.close()
+        f_labels.close()
+
+        free_memory(eops)
+
+        if num_errors == 0:
+            print("mnist PASSED")
+        else:
+            print("mnist FAILED")
+
+    except TidlError as err:
+        print(err)
+
+def read_frame(eo, frame_index, configuration, f_input):
+    """Read a frame into the ExecutionObject input buffer"""
+
+    if frame_index >= configuration.num_frames:
+        return False
+
+    # Read into the EO's input buffer
+    arg_info = eo.get_input_buffer()
+    bytes_read = f_input.readinto(arg_info)
+
+    if bytes_read == 0:
+        return False
+
+    # TIDL library requires a minimum of 2 channels. Read image data into
+    # channel 0. f_input.readinto will read twice as many bytes i.e. 2 input
+    # digits. Seek back to avoid skipping inputs.
+    f_input.seek((frame_index+1)*configuration.height*configuration.width)
+
+    eo.set_frame_index(frame_index)
+
+    return True
+
+def process_output(eo, f_labels):
+    """Display and check the inference result against labels."""
+
+    maxval = 0
+    maxloc = -1
+
+    out_buffer = eo.get_output_buffer()
+    output_array = np.asarray(out_buffer)
+    for i in range(out_buffer.size()):
+        if output_array[i] > maxval:
+            maxval = output_array[i]
+            maxloc = i
+
+    print(maxloc)
+
+    # Check inference result against label
+    frame_index = eo.get_frame_index()
+    f_labels.seek(frame_index)
+    label = ord(f_labels.read(1))
+    if maxloc != label:
+        print('Error Expected {}, got {}'.format(label, maxloc))
+        return 1
+
+    return 0
+
+if __name__ == '__main__':
+    main()
index 0b44c7b0cd90707696b20c0b981ea414c395926c..81c9e220f1d71f41c0f1236b75dcc901269a6de5 100755 (executable)
@@ -95,20 +95,21 @@ def run(num_eve, num_dsp, configuration):
     configuration.param_heap_size = (3 << 20)
     configuration.network_heap_size = (20 << 20)
 
-
     try:
         print('TIDL API: performing one time initialization ...')
 
-        eve = Executor(DeviceType.EVE, eve_device_ids, configuration, 1)
-        dsp = Executor(DeviceType.DSP, dsp_device_ids, configuration, 1)
-
         # Collect all EOs from EVE and DSP executors
         eos = []
-        for i in range(eve.get_num_execution_objects()):
-            eos.append(eve.at(i))
 
-        for i in range(dsp.get_num_execution_objects()):
-            eos.append(dsp.at(i))
+        if len(eve_device_ids) != 0:
+            eve = Executor(DeviceType.EVE, eve_device_ids, configuration, 1)
+            for i in range(eve.get_num_execution_objects()):
+                eos.append(eve.at(i))
+
+        if len(dsp_device_ids) != 0:
+            dsp = Executor(DeviceType.DSP, dsp_device_ids, configuration, 1)
+            for i in range(dsp.get_num_execution_objects()):
+                eos.append(dsp.at(i))
 
         allocate_memory(eos)
 
index 2cd2d606b4ee98e25f5447ffea5ab63c8d932194..85d1dd4a3c80350b2125cc6d06aa3d35fd87b4da 100644 (file)
@@ -33,23 +33,21 @@ from tidl import Configuration
 from tidl import Executor
 from tidl import TidlError
 
-import tidl
-
 def read_frame(eo, frame_index, c, f):
     """Read a frame into the ExecutionObject input buffer"""
 
-    if (frame_index >= c.num_frames):
+    if frame_index >= c.num_frames:
         return False
 
     # Read into the EO's input buffer
     arg_info = eo.get_input_buffer()
     bytes_read = f.readinto(arg_info)
 
-    if (bytes_read != arg_info.size()):
-        print("Expected {} bytes, read {}".format(size, bytes_read))
+    if bytes_read != arg_info.size():
+        print("Expected {} bytes, read {}".format(args_info.size(), bytes_read))
         return False
 
-    if (len(f.peek(1)) == 0):
+    if len(f.peek(1)) == 0:
         f.seek(0)
 
     eo.set_frame_index(frame_index)
index a1cd58c0dc9d547b94a6980f31c8b14ae0162499..953c691efb62921905f744be3e6d51091d49d199 100755 (executable)
@@ -61,7 +61,7 @@ def main():
     num_eve = Executor.get_num_devices(DeviceType.EVE)
 
     if num_dsp == 0 or num_eve == 0:
-        print('This example required EVEs and DSPs.')
+        print('This example requires EVEs and DSPs.')
         return
 
     enable_time_stamps("2eo_timestamp.log", 16)
index 4fe235c008924dc6470f3e20a7f927ae73086f9b..f46ca08e8cddfe123f7161fc4f79f3eae80cd710 100755 (executable)
@@ -57,7 +57,7 @@ def main():
     num_eve = Executor.get_num_devices(DeviceType.EVE)
 
     if num_dsp == 0 or num_eve == 0:
-        print('This example required EVEs and DSPs.')
+        print('This example requires EVEs and DSPs.')
         return
 
     enable_time_stamps("2eo_opt_timestamp.log", 16)