]> Gitweb @ Texas Instruments - Open Source Git Repositories - git.TI.com/gitweb - ti-machine-learning/ti-machine-learning.git/blob - src/common/cnn/timlCNNSupervisedTrainingWithLabelTeamModeOpenMP.c
Remove deleted files
[ti-machine-learning/ti-machine-learning.git] / src / common / cnn / timlCNNSupervisedTrainingWithLabelTeamModeOpenMP.c
1 /******************************************************************************/
2 /*!
3  * \file timlCNNSupervisedTrainingWithLabelTeamModeOpenMP.c
4  */
5 /* Copyright (C) 2015 Texas Instruments Incorporated - http://www.ti.com/
6  *
7  * Redistribution and use in source and binary forms, with or without
8  * modification, are permitted provided that the following conditions
9  * are met:
10  *
11  *    Redistributions of source code must retain the above copyright
12  *    notice, this list of conditions and the following disclaimer.
13  *
14  *    Redistributions in binary form must reproduce the above copyright
15  *    notice, this list of conditions and the following disclaimer in the
16  *    documentation and/or other materials provided with the
17  *    distribution.
18  *
19  *    Neither the name of Texas Instruments Incorporated nor the names of
20  *    its contributors may be used to endorse or promote products derived
21  *    from this software without specific prior written permission.
22  *
23  * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
24  * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
25  * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
26  * A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
27  * OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
28  * SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
29  * LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
30  * DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
31  * THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
32  * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
33  * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
34  *
35  ******************************************************************************/
38 /*******************************************************************************
39  *
40  * INCLUDES
41  *
42  ******************************************************************************/
44 #include "../api/timl.h"
47 /******************************************************************************/
48 /*!
49  * \ingroup cnn
50  * \brief supervised training with label using openmp
51  * \param[in] cnnTeam     CNN team
52  * \param[in] teamNum     CNN team number
53  * \param[in] data        data batch
54  * \param[in] label       label batch
55  * \param[in] dataDim     Data dimension
56  * \param[in] labelDim    Label dimension
57  * \param[in] batchUpdate data batch size
58  * \return error code
59  */
60 /******************************************************************************/
62 int timlCNNSupervisedTrainingWithLabelTeamModeOpenMP(timlConvNeuralNetwork **cnnTeam, int teamNum, float *image, int row, int col, int channel, int *label, int labelRow, int labelCol, int batchUpdate)
63 {
64    int          i;
65    int          t;
66    int          thread;
67    int          err;
68    int          iter;
69    float        cost;
70    int          cnnBatchSize;
71    float        batchCost;
72    int          dataDim;
73    int          labelDim;
74    timlConvNeuralNetwork *cnn;
76    err    = 0;
77    cnn = cnnTeam[0];
78    cnnBatchSize = cnn->params.batchSize;
79    thread = omp_get_max_threads();
80    iter = batchUpdate/cnnBatchSize;
81    batchCost = 0;
82    dataDim      = row*col*channel;
83    labelDim     = labelRow*labelCol;
85    if (thread > teamNum) { // more thread than cnn copies
86       thread = teamNum;
87    }
89    #pragma omp parallel num_threads(thread) private(t, i, err, cost)
90    {
91       #pragma omp for ordered reduction(+:batchCost)
92       for (i = 0; i < iter; i++) {
93          t = omp_get_thread_num(); // get thread id
94          err = timlCNNLoadImage(cnnTeam[t], image + i*cnnBatchSize*dataDim, row, col, channel, cnnBatchSize);
95          err = timlCNNLoadLabel(cnnTeam[t], label + i*cnnBatchSize*labelDim, labelRow, labelCol, cnnBatchSize);
96          err = timlCNNForwardPropagation(cnnTeam[t]);
97          err = timlCNNBackPropagation(cnnTeam[t]);
98          timlUtilBLASsasum(cnnBatchSize*labelDim, cnnTeam[t]->tail->softmaxCostParams.cost, &cost, cnnTeam[t]->deviceId, cnnTeam[t]->threadId);
99          batchCost += cost;
100       }
101    }
103    // update params
104    timlCNNUpdateParams(cnn);
105    batchCost = batchCost/(float)batchUpdate;
106    printf("batch cost = %f\n", batchCost);
107    return err;