5c4a79fddd1d32b9f1e589be1297844421981628
[ti-machine-learning/ti-machine-learning.git] / src / app / cnn / scene / appCNNSceneSupervisedTraining.c
1 /******************************************************************************/
2 /*!
3 * \file appCNNSceneSupervisedTraining.c
4 */
5 /* Copyright (C) 2015 Texas Instruments Incorporated - http://www.ti.com/
6 *
7 * Redistribution and use in source and binary forms, with or without
8 * modification, are permitted provided that the following conditions
9 * are met:
10 *
11 * Redistributions of source code must retain the above copyright
12 * notice, this list of conditions and the following disclaimer.
13 *
14 * Redistributions in binary form must reproduce the above copyright
15 * notice, this list of conditions and the following disclaimer in the
16 * documentation and/or other materials provided with the
17 * distribution.
18 *
19 * Neither the name of Texas Instruments Incorporated nor the names of
20 * its contributors may be used to endorse or promote products derived
21 * from this software without specific prior written permission.
22 *
23 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
24 * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
25 * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
26 * A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
27 * OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
28 * SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
29 * LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
30 * DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
31 * THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
32 * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
33 * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
34 *
35 ******************************************************************************/
38 /*******************************************************************************
39 *
40 * INCLUDES
41 *
42 ******************************************************************************/
44 #include "appCNNScene.h"
47 /******************************************************************************/
48 /*!
49 * \ingroup appCNNScene
50 * \brief Supervised training on the dataset
51 * \param[in,out] cnn CNN
52 * \param[in] dataSet Data set
53 * \return Error code
54 */
55 /******************************************************************************/
57 int appCNNSceneSupervisedTraining(timlConvNeuralNetwork *cnn, appCNNSceneDataSet *dataSet)
58 {
60 int i;
61 int j;
62 int err;
63 timlCNNLayer *bpStartLayer;
64 int label;
65 int iter;
66 int patchDim;
67 int *imageIdx;
68 int *rowIdx;
69 int *colIdx;
70 float *patch;
71 int epoch;
72 int batchSize;
73 int batchNum;
74 float *cost;
75 float *batchCost;
76 int batchIndex;
78 // init
79 err = 0;
80 iter = dataSet->row*dataSet->col*dataSet->num;
81 patchDim = dataSet->channel*dataSet->patchSize*dataSet->patchSize;
82 imageIdx = malloc(sizeof(int)*iter);
83 rowIdx = malloc(sizeof(int)*iter);
84 colIdx = malloc(sizeof(int)*iter);
85 patch = malloc(sizeof(float)*patchDim);
86 epoch = cnn->params.epoch;
87 batchSize = cnn->params.batchSize;
88 batchNum = iter/batchSize;
89 cost = malloc(sizeof(float)*batchSize);
90 batchCost = malloc(sizeof(float)*batchNum*epoch);
91 batchIndex = 0;
93 // shuffle the training pixels
94 appCNNSceneShuffleIdx(imageIdx, rowIdx, colIdx, dataSet);
96 // training loop
97 cnn->params.count = 0;
98 for (i = 0; i < epoch; i++) {
99 cnn->params.count = 0;
100 for (j = 0; j < iter; j++) {
101 label = appCNNSceneGetLabel(imageIdx[j], rowIdx[j], colIdx[j], dataSet);
102 if (label != -1) {
103 cnn->params.count += 1;
104 appCNNSceneGetPatch(imageIdx[j], rowIdx[j], colIdx[j], dataSet, patch);
105 err = timlCNNForwardPropagation(cnn, patch, patchDim);
106 timlCNNCostWithLabel(cnn, label, cost + j%batchSize, &bpStartLayer);
107 err = timlCNNBackPropagation(cnn, bpStartLayer);
108 }
109 else {
110 cost[j%batchSize] = 0.0;
111 }
112 if ((j + 1)%batchSize == 0) { // update parameters once each batch
113 batchCost[batchIndex + i*batchNum] = timlUtilVectorSumFloat(cost, batchSize)/(double)cnn->params.count;
114 timlCNNUpdateParams(cnn);
115 printf("epoch = %d, batch = %d, cost = %f\n", i, batchIndex, batchCost[batchIndex + i*batchNum]);
116 batchIndex += 1;
117 }
118 }
119 }
121 free(imageIdx);
122 free(rowIdx);
123 free(colIdx);
124 free(cost);
125 free(batchCost);
127 return err;
129 }