Add 3 preloaded images to CIFAR10 database for tesing purpose
[ti-machine-learning/ti-machine-learning.git] / src / app / cnn / class / cifar10 / appCNNClassCIFAR10Testing.c
1 /******************************************************************************/
2 /*!
3  * \file appCNNClassCIFAR10Testing.c
4  */
5 /* Copyright (C) 2015 Texas Instruments Incorporated - http://www.ti.com/
6  *
7  * Redistribution and use in source and binary forms, with or without
8  * modification, are permitted provided that the following conditions
9  * are met:
10  *
11  *    Redistributions of source code must retain the above copyright
12  *    notice, this list of conditions and the following disclaimer.
13  *
14  *    Redistributions in binary form must reproduce the above copyright
15  *    notice, this list of conditions and the following disclaimer in the
16  *    documentation and/or other materials provided with the
17  *    distribution.
18  *
19  *    Neither the name of Texas Instruments Incorporated nor the names of
20  *    its contributors may be used to endorse or promote products derived
21  *    from this software without specific prior written permission.
22  *
23  * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
24  * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
25  * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
26  * A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
27  * OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
28  * SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
29  * LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
30  * DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
31  * THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
32  * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
33  * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
34  *
35  ******************************************************************************/
38 /*******************************************************************************
39  *
40  * INCLUDES
41  *
42  ******************************************************************************/
44 #include "../appCNNClass.h"
47 /*******************************************************************************
48  *
49  * DEFINES
50  *
51  ******************************************************************************/
53 #define MODEL_PATH       "../../../../database/model/cifar10/databaseModelCIFAR10.m"
54 #define DATABASE_PATH    "../../../../database/cifar10"
55 #define IMAGE_PATH       "../../../../database/cifar10/%1d.jpg"
56 #define LABEL_PATH       "../../../../database/cifar10/label.txt"
57 #define IMAGE_NUM        3
58 #define TOP_N            1
59 #define IMAGE_ROW        32
60 #define IMAGE_COL        32
61 #define IMAGE_CHANNEL    3
64 /*******************************************************************************
65  *
66  * main()
67  *
68  ******************************************************************************/
70 int main()
71 {
72    return appCNNClassCIFAR10Testing();
73 }
76 /******************************************************************************/
77 /*!
78  * \ingroup appCNNClass
79  * \brief   CIFAR10 testing example
80  */
81 /******************************************************************************/
83 int appCNNClassCIFAR10Testing()
84 {
85    int              err;
86    int              classifyNum;
87    float            classifyPercent;
88    int              dim;
89    long             mem;
90    struct timespec  startTime;
91    struct timespec  endTime;
92    long             testingTime;
93    int              topN;
94    int              *label;
95    timlUtilImageSet training;
96    timlUtilImageSet testing;
97    timlUtilImage    image;
98    char             str[TIML_UTIL_MAX_STR];
99    int              i;
100    FILE             *fp;
101    int              read;
103    // init
104    err         = 0;
105    dim         = IMAGE_ROW*IMAGE_COL*IMAGE_CHANNEL;
106    classifyNum = 0;
107    topN        = TOP_N;
109    setbuf(stdout, NULL); // do not buffer the console output
111    // read the CNN config file
112    printf("1. Read the CNN config\n");
113    timlConvNeuralNetwork *cnn = timlCNNReadFromFile(MODEL_PATH, 0);
114    timlCNNSetMode(cnn, Util_Test);
115    mem = timlCNNMemory(cnn);
116    timlCNNPrint(cnn);
117    printf("CNN memory allocation = %.10f MB.\n", (float)mem/1024.0/1024.0);
118    printf("CNN parameter #       = %lu.\n", timlCNNGetParamsNum(cnn));
120 //   // read CIFAR10 database
121 //   printf("2. Read CIFAR10 database\n");
122 //   timlUtilReadCIFAR10(DATABASE_PATH, &training, &testing);
124    testing.data  = malloc(sizeof(float)*IMAGE_ROW*IMAGE_COL*IMAGE_CHANNEL*IMAGE_NUM);
125    testing.label = malloc(sizeof(int)*IMAGE_NUM);
126    testing.num   = IMAGE_NUM;
127    // read labels
128    fp = fopen(LABEL_PATH, "rt");
129    for (i = 0; i < IMAGE_NUM; i++) {
130       read = fscanf(fp, "%d", testing.label + i);
131    }
132    fclose(fp);
134    // read images
135    for (i = 0; i < IMAGE_NUM; i++) {
136       sprintf(str, IMAGE_PATH, i);
137       image = timlUtilReadJPEG(str);
138       cblas_scopy(dim, image.data, 1, testing.data + i*dim, 1);
139       free(image.data);
140    }
142    // testing
143    printf("3. Start testing\n");
144    label = malloc(sizeof(int)*topN*testing.num);
145    clock_gettime(CLOCK_REALTIME, &startTime);
146    timlCNNClassifyTopNBatchMode(cnn, testing.data, dim, testing.num, label, NULL, topN);
147    clock_gettime(CLOCK_REALTIME, &endTime);
148    testingTime = timlUtilDiffTime(startTime, endTime);
149    classifyNum = timlUtilClassifyAccuracy(label, topN, testing.num, testing.label);
150    classifyPercent = (float)classifyNum/(float)testing.num;
151    printf("Testing time      = %.3f s\n", testingTime/1000000.0);
152    printf("Classify accuracy = %.3f %%\n", classifyPercent*100.00);
154    // cleaning
155    printf("4. Clean up\n");
156 //   free(training.data);
157 //   free(training.label);
158    free(testing.data);
159    free(testing.label);
160    free(label);
161    timlCNNDelete(cnn);
163    return err;