]> Gitweb @ Texas Instruments - Open Source Git Repositories - git.TI.com/gitweb - tidl/tidl-api.git/blob - examples/pybind/imagenet.py
Merge tag 'v01.03.03' into develop
[tidl/tidl-api.git] / examples / pybind / imagenet.py
1 #!/usr/bin/python3
3 # Copyright (c) 2019 Texas Instruments Incorporated - http://www.ti.com/
4 # All rights reserved.
5 #
6 # Redistribution and use in source and binary forms, with or without
7 # modification, are permitted provided that the following conditions are met:
8 # * Redistributions of source code must retain the above copyright
9 # notice, this list of conditions and the following disclaimer.
10 # * Redistributions in binary form must reproduce the above copyright
11 # notice, this list of conditions and the following disclaimer in the
12 # documentation and/or other materials provided with the distribution.
13 # * Neither the name of Texas Instruments Incorporated nor the
14 # names of its contributors may be used to endorse or promote products
15 # derived from this software without specific prior written permission.
16 #
17 # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
18 # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
19 # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
20 # ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE
21 # LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
22 # CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
23 # SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
24 # INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
25 # CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
26 # ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF
27 # THE POSSIBILITY OF SUCH DAMAGE.
29 """ Process each frame using a single ExecutionObject.
30     Increase throughput by using multiple ExecutionObjects.
31 """
33 import os
34 import argparse
35 import json
36 import heapq
37 import logging
38 import numpy as np
39 import cv2
41 from tidl import DeviceId, DeviceType, Configuration, TidlError
42 from tidl import Executor, ExecutionObjectPipeline
43 from tidl import allocate_memory, free_memory
46 def main():
47     """Read the configuration and run the network"""
48     #logging.basicConfig(level=logging.INFO)
50     args = parse_args()
52     config_file = '../test/testvecs/config/infer/tidl_config_j11_v2.txt'
53     labels_file = '../imagenet/imagenet_objects.json'
55     configuration = Configuration()
56     configuration.read_from_file(config_file)
58     if os.path.isfile(args.input_file):
59         configuration.in_data = args.input_file
60     else:
61         print('Input image {} does not exist'.format(args.input_file))
62         return
63     print('Input: {}'.format(args.input_file))
65     num_eve = Executor.get_num_devices(DeviceType.EVE)
66     num_dsp = Executor.get_num_devices(DeviceType.DSP)
68     if num_eve == 0 and num_dsp == 0:
69         print('No TIDL API capable devices available')
70         return
72     # use 1 EVE or DSP since input is a single image
73     # If input is a stream of images, feel free to use all EVEs and/or DSPs
74     if num_eve > 0:
75         num_eve = 1
76         num_dsp = 0
77     else:
78         num_dsp = 1
80     run(num_eve, num_dsp, configuration, labels_file)
82     return
85 DESCRIPTION = 'Run the imagenet network on input image.'
86 DEFAULT_INFILE = '../test/testvecs/input/objects/cat-pet-animal-domestic-104827.jpeg'
88 def parse_args():
89     """Parse input arguments"""
91     parser = argparse.ArgumentParser(description=DESCRIPTION)
92     parser.add_argument('-i', '--input_file',
93                         default=DEFAULT_INFILE,
94                         help='input image file (that OpenCV can read)')
95     args = parser.parse_args()
97     return args
99 PIPELINE_DEPTH = 2
101 def run(num_eve, num_dsp, configuration, labels_file):
102     """ Run the network on the specified device type and number of devices"""
104     logging.info('Running network across {} EVEs, {} DSPs'.format(num_eve,
105                                                                   num_dsp))
107     dsp_device_ids = set([DeviceId.ID0, DeviceId.ID1,
108                           DeviceId.ID2, DeviceId.ID3][0:num_dsp])
109     eve_device_ids = set([DeviceId.ID0, DeviceId.ID1,
110                           DeviceId.ID2, DeviceId.ID3][0:num_eve])
112     # Heap sizes for this network determined using Configuration.showHeapStats
113     configuration.param_heap_size = (3 << 20)
114     configuration.network_heap_size = (20 << 20)
117     try:
118         logging.info('TIDL API: performing one time initialization ...')
120         # Collect all EOs from EVE and DSP executors
121         eos = []
123         if eve_device_ids:
124             eve = Executor(DeviceType.EVE, eve_device_ids, configuration, 1)
125             for i in range(eve.get_num_execution_objects()):
126                 eos.append(eve.at(i))
128         if dsp_device_ids:
129             dsp = Executor(DeviceType.DSP, dsp_device_ids, configuration, 1)
130             for i in range(dsp.get_num_execution_objects()):
131                 eos.append(dsp.at(i))
133         eops = []
134         num_eos = len(eos)
135         for j in range(PIPELINE_DEPTH):
136             for i in range(num_eos):
137                 eops.append(ExecutionObjectPipeline([eos[i]]))
139         allocate_memory(eops)
141         # open labels file
142         with open(labels_file) as json_file:
143             labels_data = json.load(json_file)
145         configuration.num_frames = 1
146         logging.info('TIDL API: processing {} input frames ...'.format(
147                                                      configuration.num_frames))
149         num_eops = len(eops)
150         for frame_index in range(configuration.num_frames+num_eops):
151             eop = eops[frame_index % num_eops]
153             if eop.process_frame_wait():
154                 process_output(eop, labels_data)
156             if read_frame(eop, frame_index, configuration):
157                 eop.process_frame_start_async()
159         free_memory(eops)
161     except TidlError as err:
162         print(err)
164 def read_frame(eo, frame_index, configuration):
165     """Read a frame into the ExecutionObject input buffer"""
167     if frame_index >= configuration.num_frames:
168         return False
170     # Read into the EO's input buffer
171     arg_info = eo.get_input_buffer()
172     np_arg = np.asarray(arg_info)
174     img = cv2.imread(configuration.in_data)
175     resized = cv2.resize(img, (224, 224), interpolation=cv2.INTER_AREA)
176     b_frame, g_frame, r_frame = cv2.split(resized)
177     np_arg[0*224*224:1*224*224] = np.reshape(b_frame, 224*224)
178     np_arg[1*224*224:2*224*224] = np.reshape(g_frame, 224*224)
179     np_arg[2*224*224:3*224*224] = np.reshape(r_frame, 224*224)
181     eo.set_frame_index(frame_index)
183     return True
185 def process_output(eo, labels_data):
186     """Display the inference result using labels."""
188     # keep top k predictions in heap
189     k = 5
190     # output predictions with probability of 10/255 or higher
191     threshold = 10
193     out_buffer = eo.get_output_buffer()
194     output_array = np.asarray(out_buffer)
196     k_heap = []
197     for i in range(k):
198         heapq.heappush(k_heap, (output_array[i], i))
200     for i in range(k, out_buffer.size()):
201         if output_array[i] > k_heap[0][0]:
202             heapq.heappushpop(k_heap, (output_array[i], i))
204     k_sorted = []
205     for i in range(k):
206         k_sorted.insert(0, heapq.heappop(k_heap))
208     for i in range(k):
209         if k_sorted[i][0] > threshold:
210             print('{}: {},   prob = {:5.2f}%'.format(i+1, \
211                              labels_data['objects'][k_sorted[i][1]]['label'], \
212                              k_sorted[i][0]/255.0*100))
214     return 0
216 if __name__ == '__main__':
217     main()