]> Gitweb @ Texas Instruments - Open Source Git Repositories - git.TI.com/gitweb - ti-machine-learning/ti-machine-learning.git/blob - src/app/cnn/scene/appCNNSceneSupervisedTraining.c
5c4a79fddd1d32b9f1e589be1297844421981628
[ti-machine-learning/ti-machine-learning.git] / src / app / cnn / scene / appCNNSceneSupervisedTraining.c
1 /******************************************************************************/
2 /*!
3  * \file appCNNSceneSupervisedTraining.c
4  */
5 /* Copyright (C) 2015 Texas Instruments Incorporated - http://www.ti.com/
6  *
7  * Redistribution and use in source and binary forms, with or without
8  * modification, are permitted provided that the following conditions
9  * are met:
10  *
11  *    Redistributions of source code must retain the above copyright
12  *    notice, this list of conditions and the following disclaimer.
13  *
14  *    Redistributions in binary form must reproduce the above copyright
15  *    notice, this list of conditions and the following disclaimer in the
16  *    documentation and/or other materials provided with the
17  *    distribution.
18  *
19  *    Neither the name of Texas Instruments Incorporated nor the names of
20  *    its contributors may be used to endorse or promote products derived
21  *    from this software without specific prior written permission.
22  *
23  * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
24  * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
25  * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
26  * A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
27  * OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
28  * SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
29  * LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
30  * DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
31  * THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
32  * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
33  * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
34  *
35  ******************************************************************************/
38 /*******************************************************************************
39  *
40  * INCLUDES
41  *
42  ******************************************************************************/
44 #include "appCNNScene.h"
47 /******************************************************************************/
48 /*!
49  * \ingroup       appCNNScene
50  * \brief         Supervised training on the dataset
51  * \param[in,out] cnn     CNN
52  * \param[in]     dataSet Data set
53  * \return        Error code
54  */
55 /******************************************************************************/
57 int appCNNSceneSupervisedTraining(timlConvNeuralNetwork *cnn, appCNNSceneDataSet *dataSet)
58 {
60    int          i;
61    int          j;
62    int          err;
63    timlCNNLayer *bpStartLayer;
64    int          label;
65    int          iter;
66    int          patchDim;
67    int          *imageIdx;
68    int          *rowIdx;
69    int          *colIdx;
70    float        *patch;
71    int          epoch;
72    int          batchSize;
73    int          batchNum;
74    float        *cost;
75    float        *batchCost;
76    int          batchIndex;
78    // init
79    err        = 0;
80    iter       = dataSet->row*dataSet->col*dataSet->num;
81    patchDim   = dataSet->channel*dataSet->patchSize*dataSet->patchSize;
82    imageIdx   = malloc(sizeof(int)*iter);
83    rowIdx     = malloc(sizeof(int)*iter);
84    colIdx     = malloc(sizeof(int)*iter);
85    patch      = malloc(sizeof(float)*patchDim);
86    epoch      = cnn->params.epoch;
87    batchSize  = cnn->params.batchSize;
88    batchNum   = iter/batchSize;
89    cost       = malloc(sizeof(float)*batchSize);
90    batchCost  = malloc(sizeof(float)*batchNum*epoch);
91    batchIndex = 0;
93    // shuffle the training pixels
94    appCNNSceneShuffleIdx(imageIdx, rowIdx, colIdx, dataSet);
96    // training loop
97    cnn->params.count = 0;
98    for (i = 0; i < epoch; i++) {
99       cnn->params.count = 0;
100       for (j = 0; j < iter; j++) {
101          label = appCNNSceneGetLabel(imageIdx[j], rowIdx[j], colIdx[j], dataSet);
102          if (label != -1) {
103             cnn->params.count += 1;
104             appCNNSceneGetPatch(imageIdx[j], rowIdx[j], colIdx[j], dataSet, patch);
105             err = timlCNNForwardPropagation(cnn, patch, patchDim);
106             timlCNNCostWithLabel(cnn, label, cost + j%batchSize, &bpStartLayer);
107             err = timlCNNBackPropagation(cnn, bpStartLayer);
108          }
109          else {
110             cost[j%batchSize] = 0.0;
111          }
112          if ((j + 1)%batchSize == 0) { // update parameters once each batch
113             batchCost[batchIndex + i*batchNum] = timlUtilVectorSumFloat(cost, batchSize)/(double)cnn->params.count;
114             timlCNNUpdateParams(cnn);
115             printf("epoch = %d, batch = %d, cost = %f\n", i, batchIndex, batchCost[batchIndex + i*batchNum]);
116             batchIndex += 1;
117          }
118       }
119    }
121    free(imageIdx);
122    free(rowIdx);
123    free(colIdx);
124    free(cost);
125    free(batchCost);
127    return err;