Add MNIST LeNet network model and test input
[tidl/tidl-api.git] / examples / mnist / main.cpp
1 /******************************************************************************
2  * Copyright (c) 2018, Texas Instruments Incorporated - http://www.ti.com/
3  *   All rights reserved.
4  *
5  *   Redistribution and use in source and binary forms, with or without
6  *   modification, are permitted provided that the following conditions are met:
7  *       * Redistributions of source code must retain the above copyright
8  *         notice, this list of conditions and the following disclaimer.
9  *       * Redistributions in binary form must reproduce the above copyright
10  *         notice, this list of conditions and the following disclaimer in the
11  *         documentation and/or other materials provided with the distribution.
12  *       * Neither the name of Texas Instruments Incorporated nor the
13  *         names of its contributors may be used to endorse or promote products
14  *         derived from this software without specific prior written permission.
15  *
16  *   THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
17  *   AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
18  *   IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
19  *   ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE
20  *   LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
21  *   CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
22  *   SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
23  *   INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
24  *   CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
25  *   ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF
26  *   THE POSSIBILITY OF SUCH DAMAGE.
27  *****************************************************************************/
28 #include <signal.h>
29 #include <iostream>
30 #include <iomanip>
31 #include <fstream>
32 #include <cassert>
33 #include <string>
34 #include <functional>
35 #include <algorithm>
36 #include <time.h>
37 #include <unistd.h>
39 #include <queue>
40 #include <vector>
41 #include <chrono>
43 #include "executor.h"
44 #include "execution_object.h"
45 #include "execution_object_pipeline.h"
46 #include "configuration.h"
47 #include "imgutil.h"
48 #include "../common/video_utils.h"
50 using namespace std;
51 using namespace tidl;
53 #define DEFAULT_CONFIG    "mnist_lenet"
54 #define DEFAULT_INPUT_IMAGES "../test/testvecs/input/digits10_images_28x28.y"
55 #define DEFAULT_INPUT_LABELS "../test/testvecs/input/digits10_labels_10x1.y"
57 uint32_t images_file_offset = 0;
58 uint32_t labels_file_offset = 0;
59 uint32_t num_frames_file    = 0;
60 uint32_t num_wrong_results  = 0;
63 Executor* CreateExecutor(DeviceType dt, uint32_t num, const Configuration& c);
64 bool RunConfiguration(cmdline_opts_t& opts);
65 bool ReadFrame(ExecutionObjectPipeline& eop,
66                uint32_t frame_idx, const Configuration& c,
67                const cmdline_opts_t& opts, ifstream &ifs);
68 bool WriteFrameOutput(const ExecutionObjectPipeline &eop, ifstream &ifs_labels);
69 void DisplayHelp();
72 int main(int argc, char *argv[])
73 {
74     // Catch ctrl-c to ensure a clean exit
75     signal(SIGABRT, exit);
76     signal(SIGTERM, exit);
78     // If there are no devices capable of offloading TIDL on the SoC, exit
79     uint32_t num_eves = Executor::GetNumDevices(DeviceType::EVE);
80     uint32_t num_dsps = Executor::GetNumDevices(DeviceType::DSP);
81     if (num_eves == 0 && num_dsps == 0)
82     {
83         cout << "TI DL not supported on this SoC." << endl;
84         return EXIT_SUCCESS;
85     }
87     // Process arguments
88     cmdline_opts_t opts;
89     opts.config = DEFAULT_CONFIG;
90     if (num_eves != 0) { opts.num_eves = 1;  opts.num_dsps = 0; }
91     else               { opts.num_eves = 0;  opts.num_dsps = 1; }
92     if (! ProcessArgs(argc, argv, opts))
93     {
94         DisplayHelp();
95         exit(EXIT_SUCCESS);
96     }
97     assert(opts.num_dsps != 0 || opts.num_eves != 0);
98     if (opts.num_dsps != 0)
99     {
100         cout << "MNIST network not supported on DSP yet." << endl;
101         exit(EXIT_SUCCESS);
102     }
104     if (opts.input_file.empty())
105     {
106         opts.input_file               = DEFAULT_INPUT_IMAGES;
107         opts.object_classes_list_file = DEFAULT_INPUT_LABELS;
108     }
110     // if inputs are MNIST data set: skip MNIST header
111     string& s_images = opts.input_file;
112     if (s_images.size() > 10 &&
113         s_images.compare(s_images.size() - 10, 10, "idx3-ubyte") == 0)
114         images_file_offset = 16;
115     string& s_labels = opts.object_classes_list_file;
116     if (s_labels.size() > 10 &&
117         s_labels.compare(s_labels.size() - 10, 10, "idx1-ubyte") == 0)
118         labels_file_offset = 8;
120     cout << "Input images: " << opts.input_file << endl;
121     if (! opts.object_classes_list_file.empty())
122         cout << "Input labels: " << opts.object_classes_list_file << endl;
124     // Run network
125     bool status = RunConfiguration(opts);
126     if (!status)
127     {
128         cout << "mnist FAILED" << endl;
129         return EXIT_FAILURE;
130     }
132     cout << "mnist PASSED" << endl;
133     return EXIT_SUCCESS;
136 bool RunConfiguration(cmdline_opts_t& opts)
138     // Read the TI DL configuration file
139     Configuration c;
140     string config_file = "../test/testvecs/config/infer/tidl_config_"
141                          + opts.config + ".txt";
142     bool status = c.ReadFromFile(config_file);
143     if (!status)
144     {
145         cerr << "Error in configuration file: " << config_file << endl;
146         return false;
147     }
148     c.enableApiTrace = opts.verbose;
150     // setup images/labels input/output
151     ifstream ifs, ifs_labels;
152     ifs.open(opts.input_file, ios::binary | ios::ate);
153     if (! ifs.good())
154     {
155         cerr << "Cannot open " << opts.input_file << endl;
156         return false;
157     }
158     num_frames_file = (((int) ifs.tellg()) - images_file_offset) /
159                       (c.inWidth * c.inHeight);
160     if (opts.num_frames == 0)
161         opts.num_frames = num_frames_file;
162     if (! opts.object_classes_list_file.empty())
163     {
164         ifs_labels.open(opts.object_classes_list_file, ios::binary);
165         if (! ifs_labels.good())
166         {
167             cerr << "Cannot open " << opts.object_classes_list_file << endl;
168             return false;
169         }
170     }
172     try
173     {
174         // Create Executors with the approriate core type, number of cores
175         // and configuration specified
176         Executor* e_eve = CreateExecutor(DeviceType::EVE, opts.num_eves, c);
177         Executor* e_dsp = CreateExecutor(DeviceType::DSP, opts.num_dsps, c);
179         // Get ExecutionObjects from Executors
180         vector<ExecutionObject*> eos;
181         for (uint32_t i = 0; i < opts.num_eves; i++) eos.push_back((*e_eve)[i]);
182         for (uint32_t i = 0; i < opts.num_dsps; i++) eos.push_back((*e_dsp)[i]);
183         uint32_t num_eos = eos.size();
185         // Use duplicate EOPs to do double buffering on frame input/output
186         //    because each EOP has its own set of input/output buffers,
187         //    so that host ReadFrame() can be overlapped with device processing
188         // Use one EO as an example, with different buffer_factor,
189         //    we have different execution behavior:
190         // If buffer_factor is set to 1 -> single buffering
191         //    we create one EOP: eop0 (eo0)
192         //    pipeline execution of multiple frames over time is as follows:
193         //    --------------------- time ------------------->
194         //    eop0: [RF][eo0.....][WF]
195         //    eop0:                   [RF][eo0.....][WF]
196         //    eop0:                                     [RF][eo0.....][WF]
197         // If buffer_factor is set to 2 -> double buffering
198         //    we create two EOPs: eop0 (eo0), eop1(eo0)
199         //    pipeline execution of multiple frames over time is as follows:
200         //    --------------------- time ------------------->
201         //    eop0: [RF][eo0.....][WF]
202         //    eop1:     [RF]      [eo0.....][WF]
203         //    eop0:                   [RF]  [eo0.....][WF]
204         //    eop1:                             [RF]  [eo0.....][WF]
205         vector<ExecutionObjectPipeline *> eops;
206         uint32_t buffer_factor = 2;  // set to 1 for single buffering
207         for (uint32_t j = 0; j < buffer_factor; j++)
208             for (uint32_t i = 0; i < num_eos; i++)
209                 eops.push_back(new ExecutionObjectPipeline({eos[i]}));
210         uint32_t num_eops = eops.size();
212         // Allocate input and output buffers for each EOP
213         AllocateMemory(eops);
215         float device_time = 0.0f;
216         chrono::time_point<chrono::steady_clock> tloop0, tloop1;
217         tloop0 = chrono::steady_clock::now();
219         // Process frames with available eops in a pipelined manner
220         // additional num_eos iterations to flush the pipeline (epilogue)
221         for (uint32_t frame_idx = 0;
222              frame_idx < opts.num_frames + num_eops; frame_idx++)
223         {
224             ExecutionObjectPipeline* eop = eops[frame_idx % num_eops];
226             // Wait for previous frame on the same eop to finish processing
227             if (eop->ProcessFrameWait())
228             {
229                 device_time +=
230                       eos[frame_idx % num_eos]->GetProcessTimeInMilliSeconds();
231                 WriteFrameOutput(*eop, ifs_labels);
232             }
234             // Read a frame and start processing it with current eop
235             if (ReadFrame(*eop, frame_idx, c, opts, ifs))
236                 eop->ProcessFrameStartAsync();
237         }
239         tloop1 = chrono::steady_clock::now();
240         chrono::duration<float> elapsed = tloop1 - tloop0;
241         cout << "Device total time: " << setw(6) << setprecision(4)
242              << device_time << "ms" << endl;
243         cout << "Loop total time (including read/write/print/etc): "
244              << setw(6) << setprecision(4)
245              << (elapsed.count() * 1000) << "ms" << endl;
246         if (opts.num_frames > 0 && ifs_labels.is_open())
247         {
248             cout << "Accuracy: " << setw(6) << setprecision(4)
249                  << (opts.num_frames-num_wrong_results)*100.f / opts.num_frames
250                  << "%" << endl;
251             if (opts.input_file == DEFAULT_INPUT_IMAGES && num_wrong_results >0)
252                 status = false;
253         }
255         FreeMemory(eops);
256         for (auto eop : eops)  delete eop;
257         delete e_eve;
258         delete e_dsp;
259     }
260     catch (tidl::Exception &e)
261     {
262         cerr << e.what() << endl;
263         status = false;
264     }
266     return status;
269 // Create an Executor with the specified type and number of EOs
270 Executor* CreateExecutor(DeviceType dt, uint32_t num, const Configuration& c)
272     if (num == 0) return nullptr;
274     DeviceIds ids;
275     for (uint32_t i = 0; i < num; i++)
276         ids.insert(static_cast<DeviceId>(i));
278     return new Executor(dt, ids, c);
281 bool ReadFrame(ExecutionObjectPipeline &eop,
282                uint32_t frame_idx, const Configuration& c,
283                const cmdline_opts_t& opts, ifstream &ifs)
285     if (frame_idx >= opts.num_frames)
286         return false;
288     eop.SetFrameIndex(frame_idx);
290     char*  frame_buffer = eop.GetInputBufferPtr();
291     assert (frame_buffer != nullptr);
292     int channel_size = c.inWidth * c.inHeight;
294     // already PreProc-ed white-on-black 28x28 frames
295     ifs.seekg(images_file_offset + (frame_idx%num_frames_file) * channel_size);
296     ifs.read(frame_buffer, channel_size);
297     return ifs.good();
300 // Display top 5 classified imagenet classes with probabilities
301 bool WriteFrameOutput(const ExecutionObjectPipeline &eop, ifstream &ifs_labels)
303     unsigned char *out = (unsigned char *) eop.GetOutputBufferPtr();
304     int out_size = eop.GetOutputBufferSizeInBytes();
306     unsigned char maxval = 0;
307     int           maxloc = -1;
308     for (int i = 0; i < out_size; i++)
309     {
310         // cout << (int) out[i] << " ";  // 10 probability outputs
311         if (out[i] > maxval)
312         {
313             maxval = out[i];
314             maxloc = i;
315         }
316     }
317     cout << maxloc << endl;
319     // check inference result against pre-determined label
320     if (ifs_labels.is_open())
321     {
322         int frame_index = eop.GetFrameIndex();
323         char label = -1;
324         ifs_labels.seekg(labels_file_offset + (frame_index % num_frames_file));
325         ifs_labels.read(&label, 1);
326         if (maxloc != (int) label)
327             num_wrong_results += 1;
328     }
330     return true;
333 void DisplayHelp()
335     cout <<
336     "Usage: mnist\n"
337     "  Will run MNIST LeNet to predict the digit in a 28x28"
338     " white-on-black image.\n  Use -c to run a"
339     "  different MNIST network. Default is mnist_lenet.\n"
340     "Optional arguments:\n"
341     " -c <config>          Valid configs: mnist_lenet\n"
342     " -e <number>          Number of eve cores to use\n"
343     " -i <images>          Path to the MNIST white-on-black images file\n"
344     " -l <labels>          Path to the MNIST labels file\n"
345     " -f <number>          Number of frames to process\n"
346     " -v                   Verbose output during execution\n"
347     " -h                   Help\n";