642e4ad73d1c8fd4fa54c2647d6cd75628f5d7ab
[ti-machine-learning/ti-machine-learning.git] / src / app / cnn / class / cifar10 / appCNNClassCIFAR10Testing.c
1 /******************************************************************************/
2 /*!
3  * \file appCNNClassCIFAR10Testing.c
4  */
5 /* Copyright (C) 2015 Texas Instruments Incorporated - http://www.ti.com/
6  *
7  * Redistribution and use in source and binary forms, with or without
8  * modification, are permitted provided that the following conditions
9  * are met:
10  *
11  *    Redistributions of source code must retain the above copyright
12  *    notice, this list of conditions and the following disclaimer.
13  *
14  *    Redistributions in binary form must reproduce the above copyright
15  *    notice, this list of conditions and the following disclaimer in the
16  *    documentation and/or other materials provided with the
17  *    distribution.
18  *
19  *    Neither the name of Texas Instruments Incorporated nor the names of
20  *    its contributors may be used to endorse or promote products derived
21  *    from this software without specific prior written permission.
22  *
23  * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
24  * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
25  * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
26  * A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
27  * OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
28  * SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
29  * LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
30  * DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
31  * THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
32  * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
33  * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
34  *
35  ******************************************************************************/
38 /*******************************************************************************
39  *
40  * INCLUDES
41  *
42  ******************************************************************************/
44 #include "../appCNNClass.h"
47 /*******************************************************************************
48  *
49  * DEFINES
50  *
51  ******************************************************************************/
53 #define MODEL_PATH       "../../../../database/model/cifar10/databaseModelCIFAR10.m"
54 #define DATABASE_PATH    "../../../../database/cifar10"
55 #define TOP_N            1
56 #define IMAGE_ROW        32
57 #define IMAGE_COL        32
58 #define IMAGE_CHANNEL    3
61 /*******************************************************************************
62  *
63  * main()
64  *
65  ******************************************************************************/
67 int main()
68 {
69    return appCNNClassCIFAR10Testing();
70 }
73 /******************************************************************************/
74 /*!
75  * \ingroup appCNNClass
76  * \brief   CIFAR10 testing example
77  */
78 /******************************************************************************/
80 int appCNNClassCIFAR10Testing()
81 {
82    int              err;
83    int              classifyNum;
84    float            classifyPercent;
85    int              dim;
86    long             mem;
87    struct timespec  startTime;
88    struct timespec  endTime;
89    long             testingTime;
90    int              topN;
91    timlUtilImageSet training;
92    timlUtilImageSet testing;
93    size_t           mem1;
94    size_t           mem2;
95    size_t           mem3;
97    // init
98    err         = 0;
99    dim         = IMAGE_ROW*IMAGE_COL*IMAGE_CHANNEL;
100    classifyNum = 0;
101    topN        = TOP_N;
103    setbuf(stdout, NULL); // do not buffer the console output
105    // read the CNN config file
106    printf("1. Read the CNN config\n");
107    timlConvNeuralNetwork *cnn = timlCNNReadFromFile(MODEL_PATH);
108    timlCNNAddAccuracyLayer(cnn, TOP_N);
109    timlCNNInitialize(cnn);
110    timlCNNLoadParamsFromFile(cnn, cnn->paramsFileName);
111    timlCNNSetMode(cnn, Util_Test);
112    timlCNNPrint(cnn);
114    mem1 = cnn->forwardMemory + cnn->backwardMemory + cnn->fixedMemory + cnn->paramsMemory;
115    mem2 = cnn->forwardMemory + cnn->fixedMemory + cnn->paramsMemory;
116    mem3 = cnn->memPoolSize + cnn->fixedMemory + cnn->paramsMemory;
117    printf("CNN level 1 memory size = %10.4f MB.\n", (float)mem1/1024.0/1024.0);
118    printf("CNN level 2 memory size = %10.4f MB.\n", (float)mem2/1024.0/1024.0);
119    printf("CNN level 3 memory size = %10.4f MB.\n", (float)mem3/1024.0/1024.0);
120    printf("CNN forward memory size = %10.4f MB.\n", (float)cnn->forwardMemory/1024.0/1024.0);
121    printf("CNN memory pool size    = %10.4f MB.\n", (float)cnn->memPoolSize/1024.0/1024.0);
122    printf("CNN params memory size  = %10.4f MB.\n", (float)cnn->paramsMemory/1024.0/1024.0);
124    // read CIFAR10 database
125    printf("2. Read CIFAR10 database\n");
126    timlUtilReadCIFAR10(DATABASE_PATH, &training, &testing);
128    // testing
129    printf("3. Start testing\n");
130    clock_gettime(CLOCK_REALTIME, &startTime);
131    timlCNNClassifyAccuracy(cnn, testing.data, IMAGE_ROW, IMAGE_COL, IMAGE_CHANNEL, testing.label, 1, 1, testing.num, &classifyNum);
132    clock_gettime(CLOCK_REALTIME, &endTime);
133    testingTime = timlUtilDiffTime(startTime, endTime);
134    classifyPercent = (float)classifyNum/(float)testing.num;
135    printf("Testing time      = %.3f s\n", testingTime/1000000.0);
136    printf("Classify accuracy = %.3f %%\n", classifyPercent*100.00);
138    // cleaning
139    printf("4. Clean up\n");
140    free(training.data);
141    free(training.label);
142    free(testing.data);
143    free(testing.label);
144    timlCNNDelete(cnn);
146    return err;