[ti-machine-learning/ti-machine-learning.git] / src / common / cnn / timlCNNSupervisedTrainingWithLabelTeamModeOpenMP.c
1 /******************************************************************************/
2 /*!
3 * \file timlCNNSupervisedTrainingWithLabelTeamModeOpenMP.c
4 */
5 /* Copyright (C) 2015 Texas Instruments Incorporated - http://www.ti.com/
6 *
7 * Redistribution and use in source and binary forms, with or without
8 * modification, are permitted provided that the following conditions
9 * are met:
10 *
11 * Redistributions of source code must retain the above copyright
12 * notice, this list of conditions and the following disclaimer.
13 *
14 * Redistributions in binary form must reproduce the above copyright
15 * notice, this list of conditions and the following disclaimer in the
16 * documentation and/or other materials provided with the
17 * distribution.
18 *
19 * Neither the name of Texas Instruments Incorporated nor the names of
20 * its contributors may be used to endorse or promote products derived
21 * from this software without specific prior written permission.
22 *
23 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
24 * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
25 * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
26 * A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
27 * OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
28 * SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
29 * LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
30 * DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
31 * THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
32 * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
33 * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
34 *
35 ******************************************************************************/
38 /*******************************************************************************
39 *
40 * INCLUDES
41 *
42 ******************************************************************************/
44 #include "../api/timl.h"
47 /******************************************************************************/
48 /*!
49 * \ingroup cnn
50 * \brief supervised training with label using openmp
51 * \param[in] cnnTeam CNN team
52 * \param[in] teamNum CNN team number
53 * \param[in] data data batch
54 * \param[in] label label batch
55 * \param[in] dataDim Data dimension
56 * \param[in] labelDim Label dimension
57 * \param[in] batchUpdate data batch size
58 * \return error code
59 */
60 /******************************************************************************/
62 int timlCNNSupervisedTrainingWithLabelTeamModeOpenMP(timlConvNeuralNetwork **cnnTeam, int teamNum, float *image, int row, int col, int channel, int *label, int labelRow, int labelCol, int batchUpdate)
63 {
64 int i;
65 int t;
66 int thread;
67 int err;
68 int iter;
69 float cost;
70 int cnnBatchSize;
71 float batchCost;
72 int dataDim;
73 int labelDim;
74 timlConvNeuralNetwork *cnn;
76 err = 0;
77 cnn = cnnTeam[0];
78 cnnBatchSize = cnn->params.batchSize;
79 thread = omp_get_max_threads();
80 iter = batchUpdate/cnnBatchSize;
81 batchCost = 0;
82 dataDim = row*col*channel;
83 labelDim = labelRow*labelCol;
85 if (thread > teamNum) { // more thread than cnn copies
86 thread = teamNum;
87 }
89 #pragma omp parallel num_threads(thread) private(t, i, err, cost)
90 {
91 #pragma omp for ordered reduction(+:batchCost)
92 for (i = 0; i < iter; i++) {
93 t = omp_get_thread_num(); // get thread id
94 err = timlCNNLoadImage(cnnTeam[t], image + i*cnnBatchSize*dataDim, row, col, channel, cnnBatchSize);
95 err = timlCNNLoadLabel(cnnTeam[t], label + i*cnnBatchSize*labelDim, labelRow, labelCol, cnnBatchSize);
96 err = timlCNNForwardPropagation(cnnTeam[t]);
97 err = timlCNNBackPropagation(cnnTeam[t]);
98 timlUtilBLASsasum(cnnBatchSize*labelDim, cnnTeam[t]->tail->softmaxCostParams.cost, &cost, cnnTeam[t]->deviceId, cnnTeam[t]->threadId);
99 batchCost += cost;
100 }
101 }
103 // update params
104 timlCNNUpdateParams(cnn);
105 batchCost = batchCost/(float)batchUpdate;
106 printf("batch cost = %f\n", batchCost);
107 return err;
108 }