]> Gitweb @ Texas Instruments - Open Source Git Repositories - git.TI.com/gitweb - ti-machine-learning/ti-machine-learning.git/blob - src/test/cnn/testCNNSimpleTraining.c
1. Enable network state write/read
[ti-machine-learning/ti-machine-learning.git] / src / test / cnn / testCNNSimpleTraining.c
1 /******************************************************************************/\r
2 /*!\r
3  * \file testCNNSimpleTraining.c\r
4  */\r
5 /* Copyright (C) 2015 Texas Instruments Incorporated - http://www.ti.com/\r
6  *\r
7  * Redistribution and use in source and binary forms, with or without\r
8  * modification, are permitted provided that the following conditions\r
9  * are met:\r
10  *\r
11  *    Redistributions of source code must retain the above copyright\r
12  *    notice, this list of conditions and the following disclaimer.\r
13  *\r
14  *    Redistributions in binary form must reproduce the above copyright\r
15  *    notice, this list of conditions and the following disclaimer in the\r
16  *    documentation and/or other materials provided with the\r
17  *    distribution.\r
18  *\r
19  *    Neither the name of Texas Instruments Incorporated nor the names of\r
20  *    its contributors may be used to endorse or promote products derived\r
21  *    from this software without specific prior written permission.\r
22  *\r
23  * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS\r
24  * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT\r
25  * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR\r
26  * A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT\r
27  * OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,\r
28  * SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT\r
29  * LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,\r
30  * DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY\r
31  * THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT\r
32  * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE\r
33  * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.\r
34  *\r
35  ******************************************************************************/\r
36 \r
37 \r
38 /*******************************************************************************\r
39  *\r
40  * INCLUDES\r
41  *\r
42  ******************************************************************************/\r
43 \r
44 #include "testCNN.h"\r
45 \r
46 \r
47 /*******************************************************************************\r
48  *\r
49  * DEFINES\r
50  *\r
51  ******************************************************************************/\r
52 \r
53 #define DATABASE_PATH    "./src/database/mnist"\r
54 #define BATCH_SIZE       100    // cnn batch size\r
55 #define MAX_BATCH_SIZE   100    // cnn max batch size\r
56 #define BATCH_UPDATE     100    // cnn batch update size (must be a multiple of batch size)\r
57 #define TEST_NUM         10000\r
58 #define TRAIN_NUM        60000\r
59 #define IMAGE_ROW        28\r
60 #define IMAGE_COL        28\r
61 #define IMAGE_CHANNEL    1\r
62 #define TOP_N            1\r
63 \r
64 \r
65 /******************************************************************************/\r
66 /*!\r
67  * \ingroup testCNN\r
68  * \brief   Simple training function test\r
69  * \return  Error code\r
70  */\r
71 /******************************************************************************/\r
72 \r
73 int testCNNSimpleTraining()\r
74 {\r
75    int                   i;\r
76    int                   dim;\r
77    int                   classifyNum;\r
78    float                 classifyPercent;\r
79    int                   label;\r
80    struct timespec       startTime;\r
81    struct timespec       endTime;\r
82    long                  trainingTime;\r
83    long                  testingTime;\r
84    int                   err;\r
85    long                  mem;\r
86    int                   batchNum;\r
87    int                   batchSize;\r
88    timlUtilImageSet      training;\r
89    timlUtilImageSet      testing;\r
90    timlCNNInputParams    inputParams;\r
91    timlCNNTrainingParams trainingParams;\r
92    timlConvNeuralNetwork *cnn;\r
93    timlCNNLayer          *layer;\r
94 \r
95    dim            = IMAGE_ROW*IMAGE_COL*IMAGE_CHANNEL;\r
96    classifyNum    = 0;\r
97    err            = 0;\r
98 \r
99    setbuf(stdout, NULL); // do not buffer the console output\r
100 \r
101    printf("[Test] CNN simple training\n");\r
102    printf("1. Build up the CNN\n");\r
103    trainingParams              = timlCNNTrainingParamsDefault();\r
104    trainingParams.batchSize    = BATCH_SIZE;\r
105    trainingParams.maxBatchSize = MAX_BATCH_SIZE;\r
106    trainingParams.batchUpdate  = BATCH_UPDATE;\r
107    trainingParams.learningRate = 0.1;\r
108    cnn = timlCNNCreateConvNeuralNetwork(trainingParams);\r
109    inputParams       = timlCNNInputParamsDefault();\r
110    inputParams.scale = 1.0/256.0;\r
111    timlCNNAddInputLayer(cnn, IMAGE_ROW, IMAGE_COL, IMAGE_CHANNEL, inputParams);            // input layer\r
112    timlCNNAddConvLayer(cnn, 5, 5, 1, 1, 6, timlCNNConvParamsDefault());                    // conv layer\r
113    timlCNNAddNonlinearLayer(cnn, Util_Relu);                                               // relu layer\r
114    timlCNNAddPoolingLayer(cnn, 4, 4, 4, 4, CNN_MaxPooling, timlCNNPoolingParamsDefault()); // max pooling layer\r
115    timlCNNAddNormLayer(cnn, timlCNNNormParamsDefault());                                   // norm layer\r
116    timlCNNAddDropoutLayer(cnn, 0.5);                                                       // dropout layer\r
117    timlCNNAddLinearLayer(cnn, 10, timlCNNLinearParamsDefault());                           // linear layer\r
118    timlCNNAddSoftmaxCostLayer(cnn);                                                        // softmax cost layer\r
119    timlCNNInitialize(cnn);\r
120    timlCNNReset(cnn);\r
121    timlCNNPrint(cnn);\r
122 \r
123    printf("2. Load the MNIST database\n");\r
124    err = timlUtilReadMNIST(DATABASE_PATH, &training, &testing);\r
125    if (err) {\r
126       printf("MNIST database reading error\n");\r
127    }\r
128 \r
129    // training\r
130 //   timlCNNSetBatchSize(cnn, 1);\r
131    printf("3. Start training\n");\r
132    timlCNNSetMode(cnn, Util_Train);\r
133    batchNum = TRAIN_NUM/cnn->params.batchUpdate;\r
134    clock_gettime(CLOCK_REALTIME, &startTime);\r
135    for (i = 0; i < batchNum; i++) {\r
136       printf("Batch id = %d, ", i);\r
137       timlCNNSupervisedTrainingWithLabel(cnn, training.data + i*cnn->params.batchUpdate*dim, IMAGE_ROW, IMAGE_COL, IMAGE_CHANNEL, training.label + i*cnn->params.batchUpdate, 1, 1, cnn->params.batchUpdate);\r
138    }\r
139    clock_gettime(CLOCK_REALTIME, &endTime);\r
140    trainingTime = timlUtilDiffTime(startTime, endTime);\r
141    printf("Training time = %.3f s.\n", trainingTime/1000000.0);\r
142 \r
143    // testing\r
144    printf("4. Start testing\n");\r
145    // remove softmaxCost layer and add accuracy layer\r
146    timlCNNDeleteLayer(cnn->tail);\r
147    timlCNNAddAccuracyLayer(cnn, TOP_N);\r
148    timlCNNAccuracyInitialize(cnn->tail);\r
149    timlCNNSetMode(cnn, Util_Test);\r
150 \r
151    timlCNNSetBatchSize(cnn, 100);\r
152    clock_gettime(CLOCK_REALTIME, &startTime);\r
153    timlCNNClassifyAccuracy(cnn, testing.data, IMAGE_ROW, IMAGE_COL, IMAGE_CHANNEL, testing.label, 1, 1, testing.num, &classifyNum);\r
154    clock_gettime(CLOCK_REALTIME, &endTime);\r
155    testingTime = timlUtilDiffTime(startTime, endTime);\r
156    classifyPercent = (float)classifyNum/(float)testing.num;\r
157    printf("Testing time = %.3f s.\nSuccess percent = %.3f %%.\n", testingTime/1000000.0, classifyPercent*100.0);\r
158 \r
159    printf("5. Clean up\n");\r
160    free(training.data);\r
161    free(training.label);\r
162    free(training.mean);\r
163    free(testing.data);\r
164    free(testing.label);\r
165    timlCNNDelete(cnn);\r
166 \r
167    return err;\r
168 }\r