[ti-machine-learning/ti-machine-learning.git] / src / app / cnn / class / cifar10 / appCNNClassCIFAR10Testing.c
1 /******************************************************************************/
2 /*!
3 * \file appCNNClassCIFAR10Testing.c
4 */
5 /* Copyright (C) 2015 Texas Instruments Incorporated - http://www.ti.com/
6 *
7 * Redistribution and use in source and binary forms, with or without
8 * modification, are permitted provided that the following conditions
9 * are met:
10 *
11 * Redistributions of source code must retain the above copyright
12 * notice, this list of conditions and the following disclaimer.
13 *
14 * Redistributions in binary form must reproduce the above copyright
15 * notice, this list of conditions and the following disclaimer in the
16 * documentation and/or other materials provided with the
17 * distribution.
18 *
19 * Neither the name of Texas Instruments Incorporated nor the names of
20 * its contributors may be used to endorse or promote products derived
21 * from this software without specific prior written permission.
22 *
23 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
24 * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
25 * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
26 * A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
27 * OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
28 * SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
29 * LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
30 * DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
31 * THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
32 * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
33 * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
34 *
35 ******************************************************************************/
38 /*******************************************************************************
39 *
40 * INCLUDES
41 *
42 ******************************************************************************/
44 #include "../appCNNClass.h"
47 /*******************************************************************************
48 *
49 * DEFINES
50 *
51 ******************************************************************************/
53 #define MODEL_PATH "../../../../database/model/cifar10/databaseModelCIFAR10.m"
54 #define DATABASE_PATH "../../../../database/cifar10"
55 #define IMAGE_PATH "../../../../database/cifar10/%1d.jpg"
56 #define LABEL_PATH "../../../../database/cifar10/label.txt"
57 #define IMAGE_NUM 3
58 #define TOP_N 1
59 #define IMAGE_ROW 32
60 #define IMAGE_COL 32
61 #define IMAGE_CHANNEL 3
64 /*******************************************************************************
65 *
66 * main()
67 *
68 ******************************************************************************/
70 int main()
71 {
72 return appCNNClassCIFAR10Testing();
73 }
76 /******************************************************************************/
77 /*!
78 * \ingroup appCNNClass
79 * \brief CIFAR10 testing example
80 */
81 /******************************************************************************/
83 int appCNNClassCIFAR10Testing()
84 {
85 int err;
86 int classifyNum;
87 float classifyPercent;
88 int dim;
89 long mem;
90 struct timespec startTime;
91 struct timespec endTime;
92 long testingTime;
93 int topN;
94 timlUtilImageSet training;
95 timlUtilImageSet testing;
96 size_t mem1;
97 size_t mem2;
98 size_t mem3;
100 // init
101 err = 0;
102 dim = IMAGE_ROW*IMAGE_COL*IMAGE_CHANNEL;
103 classifyNum = 0;
104 topN = TOP_N;
106 setbuf(stdout, NULL); // do not buffer the console output
108 // read the CNN config file
109 printf("1. Read the CNN config\n");
110 timlConvNeuralNetwork *cnn = timlCNNReadFromFile(MODEL_PATH);
111 timlCNNAddAccuracyLayer(cnn, TOP_N);
112 timlCNNInitialize(cnn);
113 timlCNNLoadParamsFromFile(cnn, cnn->paramsFileName);
114 timlCNNSetMode(cnn, Util_Test);
115 timlCNNPrint(cnn);
117 mem1 = cnn->forwardMemory + cnn->backwardMemory + cnn->fixedMemory + cnn->paramsMemory;
118 mem2 = cnn->forwardMemory + cnn->fixedMemory + cnn->paramsMemory;
119 mem3 = cnn->memPoolSize + cnn->fixedMemory + cnn->paramsMemory;
120 printf("CNN level 1 memory size = %10.4f MB.\n", (float)mem1/1024.0/1024.0);
121 printf("CNN level 2 memory size = %10.4f MB.\n", (float)mem2/1024.0/1024.0);
122 printf("CNN level 3 memory size = %10.4f MB.\n", (float)mem3/1024.0/1024.0);
123 printf("CNN forward memory size = %10.4f MB.\n", (float)cnn->forwardMemory/1024.0/1024.0);
124 printf("CNN memory pool size = %10.4f MB.\n", (float)cnn->memPoolSize/1024.0/1024.0);
125 printf("CNN params memory size = %10.4f MB.\n", (float)cnn->paramsMemory/1024.0/1024.0);
127 // read CIFAR10 database
128 // printf("2. Read CIFAR10 database\n");
129 // timlUtilReadCIFAR10(DATABASE_PATH, &training, &testing);
131 testing.data = malloc(sizeof(float)*IMAGE_ROW*IMAGE_COL*IMAGE_CHANNEL*IMAGE_NUM);
132 testing.label = malloc(sizeof(int)*IMAGE_NUM);
133 testing.num = IMAGE_NUM;
135 // read labels
136 fp = fopen(LABEL_PATH, "rt");
137 for (i = 0; i < IMAGE_NUM; i++) {
138 read = fscanf(fp, "%d", testing.label + i);
139 }
140 fclose(fp);
142 // read images
143 for (i = 0; i < IMAGE_NUM; i++) {
144 sprintf(str, IMAGE_PATH, i);
145 image = timlUtilReadJPEG(str);
146 cblas_scopy(dim, image.data, 1, testing.data + i*dim, 1);
147 free(image.data);
148 }
150 // testing
151 printf("3. Start testing\n");
152 clock_gettime(CLOCK_REALTIME, &startTime);
153 timlCNNClassifyAccuracy(cnn, testing.data, IMAGE_ROW, IMAGE_COL, IMAGE_CHANNEL, testing.label, 1, 1, testing.num, &classifyNum);
154 clock_gettime(CLOCK_REALTIME, &endTime);
155 testingTime = timlUtilDiffTime(startTime, endTime);
156 classifyPercent = (float)classifyNum/(float)testing.num;
157 printf("Testing time = %.3f s\n", testingTime/1000000.0);
158 printf("Classify accuracy = %.3f %%\n", classifyPercent*100.00);
160 // cleaning
161 printf("4. Clean up\n");
162 free(training.data);
163 free(training.label);
164 free(testing.data);
165 free(testing.label);
166 timlCNNDelete(cnn);
168 return err;
169 }