]> Gitweb @ Texas Instruments - Open Source Git Repositories - git.TI.com/gitweb - ti-machine-learning/ti-machine-learning.git/blob - src/test/cnn/testCNNSimpleProfile.c
initial release
[ti-machine-learning/ti-machine-learning.git] / src / test / cnn / testCNNSimpleProfile.c
1 /******************************************************************************/
2 /*!
3  * \file testCNNSimpleProfile.c
4  */
5 /* Copyright (C) 2015 Texas Instruments Incorporated - http://www.ti.com/
6  *
7  * Redistribution and use in source and binary forms, with or without
8  * modification, are permitted provided that the following conditions
9  * are met:
10  *
11  *    Redistributions of source code must retain the above copyright
12  *    notice, this list of conditions and the following disclaimer.
13  *
14  *    Redistributions in binary form must reproduce the above copyright
15  *    notice, this list of conditions and the following disclaimer in the
16  *    documentation and/or other materials provided with the
17  *    distribution.
18  *
19  *    Neither the name of Texas Instruments Incorporated nor the names of
20  *    its contributors may be used to endorse or promote products derived
21  *    from this software without specific prior written permission.
22  *
23  * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
24  * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
25  * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
26  * A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
27  * OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
28  * SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
29  * LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
30  * DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
31  * THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
32  * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
33  * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
34  *
35  ******************************************************************************/
38 /*******************************************************************************
39  *
40  * INCLUDES
41  *
42  ******************************************************************************/
44 #include "testCNN.h"
47 /*******************************************************************************
48  *
49  * DEFINES
50  *
51  ******************************************************************************/
53 #define IMAGE_ROW            28
54 #define IMAGE_COL            28
55 #define IMAGE_CHANNEL        1
56 #define ITER                 2
57 #define BATCH_SIZE           100
58 #define DATABASE_PATH        "../../database/mnist"
61 /******************************************************************************/
62 /*!
63  * \ingroup test
64  * \brief simple profile function test
65  * \return error code
66  */
67 /******************************************************************************/
69 int testCNNSimpleProfile()
70 {
71    int  dim;
72    int  err;
73    long mem;
74    int  iter;
75    int  batchSize;
76    timlUtilImageSet      training;
77    timlUtilImageSet      testing;
78    timlConvNeuralNetwork *cnn;
80    iter      = ITER;
81    err       = 0;
82    dim       = IMAGE_ROW*IMAGE_COL*IMAGE_CHANNEL;
83    batchSize = BATCH_SIZE;
85    setbuf(stdout, NULL); // do not buffer the console output
87    printf("[Test] CNN simple profile\n");
88    // build up the CNN
89    printf("1. Build CNN\n");
90    cnn = timlCNNCreateConvNeuralNetwork(timlCNNTrainingParamsDefault(), 0);
91    timlCNNAddInputLayer(cnn, IMAGE_ROW, IMAGE_COL, IMAGE_CHANNEL, timlCNNInputParamsDefault()); // input layer
92    timlCNNAddConvLayer(cnn, 5, 5, 1, 1, 6, timlCNNConvParamsDefault());                         // conv layer
93    timlCNNAddNonlinearLayer(cnn, Util_Sigmoid);                                                 // sigmoid layer
94    timlCNNAddPoolingLayer(cnn, 4, 4, 4, 4, CNN_MaxPooling, timlCNNPoolingParamsDefault());      // max pooling layer
95    timlCNNAddNormLayer(cnn, timlCNNNormParamsDefault());                                        // norm layer
96    timlCNNAddDropoutLayer(cnn, 0.2);                                                            // dropout layer
97    timlCNNAddLinearLayer(cnn, 10, timlCNNLinearParamsDefault());                                // linear layer
98    timlCNNAddNonlinearLayer(cnn, Util_Softmax);                                                 // softmax layer
99    timlCNNInitialize(cnn);
100    timlCNNReset(cnn);
102    printf("2. Read MNIST database\n");
103    timlUtilReadMNIST(DATABASE_PATH, &training, &testing);
105    timlCNNSetMode(cnn, Util_Train);
106    printf("3. Start profiling\n");
107    timlCNNProfile(cnn, training.data, dim, batchSize, training.label, iter);
109    free(training.data);
110    free(training.label);
111    free(training.mean);
112    free(testing.data);
113    free(testing.label);
115    timlCNNDelete(cnn);
117    return err;