update the version number
[ti-machine-learning/ti-machine-learning.git] / src / common / cnn / timlCNNSupervisedTrainingWithLabelBatchModeOpenMP.c
1 /******************************************************************************/\r
2 /*!\r
3  * \file timlCNNSupervisedTrainingWithLabelBatchModeOpenMP.c\r
4  */\r
5 /* Copyright (C) 2015 Texas Instruments Incorporated - http://www.ti.com/\r
6  *\r
7  * Redistribution and use in source and binary forms, with or without\r
8  * modification, are permitted provided that the following conditions\r
9  * are met:\r
10  *\r
11  *    Redistributions of source code must retain the above copyright\r
12  *    notice, this list of conditions and the following disclaimer.\r
13  *\r
14  *    Redistributions in binary form must reproduce the above copyright\r
15  *    notice, this list of conditions and the following disclaimer in the\r
16  *    documentation and/or other materials provided with the\r
17  *    distribution.\r
18  *\r
19  *    Neither the name of Texas Instruments Incorporated nor the names of\r
20  *    its contributors may be used to endorse or promote products derived\r
21  *    from this software without specific prior written permission.\r
22  *\r
23  * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS\r
24  * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT\r
25  * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR\r
26  * A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT\r
27  * OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,\r
28  * SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT\r
29  * LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,\r
30  * DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY\r
31  * THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT\r
32  * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE\r
33  * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.\r
34  *\r
35  ******************************************************************************/\r
36 \r
37 \r
38 /*******************************************************************************\r
39  *\r
40  * INCLUDES\r
41  *\r
42  ******************************************************************************/\r
43 \r
44 #include "../api/timl.h"\r
45 \r
46 \r
47 /******************************************************************************/\r
48 /*!\r
49  * \ingroup cnn\r
50  * \brief supervised training with label using openmp\r
51  * \param[in,out] cnn\r
52  * \param[in] data data batch\r
53  * \param[in] label\r
54  * \param[in] dim data dimension\r
55  * \param[in] num data number\r
56  * \return error code\r
57  */\r
58 /******************************************************************************/\r
59 \r
60 int timlCNNSupervisedTrainingWithLabelBatchModeOpenMP(timlConvNeuralNetwork *cnn, float *data, int *label, int dim, int num)\r
61 {\r
62    int          i;\r
63    int          t;\r
64    int          thread;\r
65    int          err;\r
66    timlCNNLayer *bpStartLayer;\r
67    float        *cost;\r
68    float        batchCost;\r
69 \r
70    err    = 0;\r
71    cost   = malloc(sizeof(float)*num);\r
72    thread = omp_get_max_threads();\r
73 \r
74    // create cnnTeam\r
75    timlConvNeuralNetwork **cnnTeam = malloc(sizeof(timlConvNeuralNetwork*)*thread);\r
76    cnnTeam[0] = cnn;\r
77    for (i = 1; i < thread; i++) {\r
78       cnnTeam[i] = timlCNNShareParams(cnn, 0);\r
79    }\r
80 \r
81    // parallel for loop\r
82    #pragma omp parallel num_threads(thread) private(t, i, bpStartLayer, err)\r
83    {\r
84       #pragma omp for\r
85       for (i = 0; i < num; i++) {\r
86          t = omp_get_thread_num();\r
87          err = timlCNNForwardPropagation(cnnTeam[t], data + i*dim, dim);\r
88          timlCNNCostWithLabel(cnnTeam[t], label[i], cost + i, &bpStartLayer);\r
89          err = timlCNNBackPropagation(cnnTeam[t], bpStartLayer);\r
90       }\r
91    }\r
92 \r
93    // update params\r
94    cnn->params.count += num;\r
95    timlCNNUpdateParams(cnn);\r
96    batchCost = timlUtilVectorSumFloat(cost, num)/(float)num;\r
97    printf("batch = %d, cost = %f\n", cnn->params.batchCount, batchCost);\r
98    cnn->params.batchCount += 1;\r
99 \r
100    // free cnnTeam\r
101    for (i = 1; i < thread; i++) {\r
102       timlCNNDelete(cnnTeam[i]);\r
103    }\r
104    free(cnnTeam);\r
105    free(cost);\r
106 \r
107    return err;\r
108 }\r