]> Gitweb @ Texas Instruments - Open Source Git Repositories - git.TI.com/gitweb - ti-machine-learning/ti-machine-learning.git/blob - src/app/cnn/scene/appCNNSceneClassify.c
1. Enable network state write/read
[ti-machine-learning/ti-machine-learning.git] / src / app / cnn / scene / appCNNSceneClassify.c
1 /******************************************************************************/
2 /*!
3  * \file appCNNSceneClassify.c
4  */
5 /* Copyright (C) 2015 Texas Instruments Incorporated - http://www.ti.com/
6  *
7  * Redistribution and use in source and binary forms, with or without
8  * modification, are permitted provided that the following conditions
9  * are met:
10  *
11  *    Redistributions of source code must retain the above copyright
12  *    notice, this list of conditions and the following disclaimer.
13  *
14  *    Redistributions in binary form must reproduce the above copyright
15  *    notice, this list of conditions and the following disclaimer in the
16  *    documentation and/or other materials provided with the
17  *    distribution.
18  *
19  *    Neither the name of Texas Instruments Incorporated nor the names of
20  *    its contributors may be used to endorse or promote products derived
21  *    from this software without specific prior written permission.
22  *
23  * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
24  * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
25  * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
26  * A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
27  * OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
28  * SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
29  * LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
30  * DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
31  * THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
32  * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
33  * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
34  *
35  ******************************************************************************/
38 /*******************************************************************************
39  *
40  * INCLUDES
41  *
42  ******************************************************************************/
44 #include "appCNNScene.h"
47 /******************************************************************************/
48 /*!
49  * \ingroup       appCNNScene
50  * \brief         Pixel label classification
51  * \param[in,out] cnn         CNN
52  * \param[in]     image       Image
53  * \param[in,out] labelMatrix Generated label matrix
54  * \param[in]     scale       Down scaling factor of the label matrix
55  * \return        Error code
56  */
57 /******************************************************************************/
59 int appCNNSceneClassify(timlConvNeuralNetwork *cnn, float *image, int row, int col, int channel, int *labelMatrix, int scale)
60 {
61    int   i;
62    int   j;
63    int   p;
64    int   m;
65    int   k;
66    int   err;
67    float imageMean;
68    float imageDeviation;
69    int   paddedRow;
70    int   paddedCol;
71    int   paddedDim;
72    float *paddedImage;
73    int   resolutionLossRow;
74    int   resolutionLossCol;
75    int   rowStart;
76    int   rowEnd;
77    int   colStart;
78    int   colEnd;
79    int   rowDown;
80    int   colDown;
82    // init
83    err               = 0;
84    paddedRow         = cnn->head->row;
85    paddedCol         = cnn->head->col;
86    paddedDim         = paddedRow*paddedCol*channel;
87    paddedImage       = malloc(sizeof(float)*paddedDim);
88    resolutionLossRow = row/cnn->tail->row;
89    resolutionLossCol = col/cnn->tail->col;
91    // image normalization (per image)
92    for (k = 0; k < channel; k++) {
93       imageMean = 0.0;
94       imageDeviation = 0.0;
95       for (i = 0; i < row * col; i++) {
96          imageMean += image[i + k*row*col];
97       }
98       imageMean /= row*col;
99       for (i = 0; i < row*col; i++) {
100          image[i + k*row*col] -= imageMean;
101       }
102       for (i = 0; i < row * col; i++) {
103          imageDeviation += image[i + k*row*col] * image[i + k*row*col];
104       }
105       imageDeviation /= row*col;
106       imageDeviation = sqrtf(imageDeviation);
107       for (i = 0; i < row*col; i++) {
108          image[i + k*row*col] /= imageDeviation;
109       }
110    }
112    // main loop over each pixel on the image
113    for (m = -resolutionLossRow/2; m < resolutionLossRow/2; m += scale) {
114       for (k = -resolutionLossCol/2; k < resolutionLossCol/2; k += scale) {
115          rowStart = (paddedRow - row)/2 - m;
116          rowEnd = rowStart + row - 1;
117          colStart = (paddedCol - col)/2 - k;
118          colEnd = colStart + col - 1;
120          // zero padding
121          for (i = 0; i < paddedRow; i++) {
122             for (j = 0; j < paddedCol; j++) {
123                if (i < rowStart || i > rowEnd || j < colStart || j > colEnd) {
124                   for (p = 0; p < channel; p++)
125                      paddedImage[j + i*paddedCol + p*paddedRow*paddedCol] = 0.0;
126                }
127                else {
128                   for (p = 0; p < channel; p++)
129                      paddedImage[j + i*paddedCol + p*paddedRow*paddedCol] = image[j - colStart + (i - rowStart)*col + p*row*col];
130                }
131             }
132          }
134          // cnn Forward Propagation
135          err = timlCNNLoadImage(cnn, paddedImage, paddedRow, paddedCol, channel, 1);
136          err = timlCNNForwardPropagation(cnn);
138          // labeling
139          for (i = 0; i < cnn->tail->row; i++) {
140             for (j = 0; j < cnn->tail->col; j++) {
141                labelMatrix[k + resolutionLossCol/2 + j*resolutionLossCol + (m + resolutionLossRow/2 + i*resolutionLossRow)*col] = cnn->tail->accuracyParams.label[i*cnn->tail->col + j];
142             }
143          }
145       }
146    }
148    // up-sample the label matrix
149    for (i = 0; i < row; i++) {
150       for (j = 0; j < col; j++) {
151          rowDown = i/scale;
152          colDown = j/scale;
153          labelMatrix[j + i*col] = labelMatrix[colDown*scale + rowDown*scale*col];
154       }
155    }
157    free(paddedImage);
159    return err;