update the version number
[ti-machine-learning/ti-machine-learning.git] / src / common / cnn / timlCNNClassifyTopNBatchModeOpenMP.c
1 /******************************************************************************/
2 /*!
3  * \file timlCNNClassifyTopNBatchModeOpenMP.c
4  */
5 /* Copyright (C) 2015 Texas Instruments Incorporated - http://www.ti.com/
6  *
7  * Redistribution and use in source and binary forms, with or without
8  * modification, are permitted provided that the following conditions
9  * are met:
10  *
11  *    Redistributions of source code must retain the above copyright
12  *    notice, this list of conditions and the following disclaimer.
13  *
14  *    Redistributions in binary form must reproduce the above copyright
15  *    notice, this list of conditions and the following disclaimer in the
16  *    documentation and/or other materials provided with the
17  *    distribution.
18  *
19  *    Neither the name of Texas Instruments Incorporated nor the names of
20  *    its contributors may be used to endorse or promote products derived
21  *    from this software without specific prior written permission.
22  *
23  * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
24  * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
25  * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
26  * A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
27  * OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
28  * SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
29  * LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
30  * DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
31  * THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
32  * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
33  * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
34  *
35  ******************************************************************************/
38 /*******************************************************************************
39  *
40  * INCLUDES
41  *
42  ******************************************************************************/
44 #include "../api/timl.h"
47 /******************************************************************************/
48 /*!
49  * \ingroup       cnn
50  * \brief         Batch classification using openmp
51  * \param[in,out] cnn     CNN
52  * \param[in]     data    Data batch
53  * \param[in]     dim     Data dimension
54  * \param[in]     num     Data number
55  * \param[out]    label   Label array ptr, size = num*topN
56  * \param[out]    percent Percent array ptr, size = num*topN
57  * \param[out]    topN    Output the top N labels and the corresponding percentage
58  * \return        Error   code
59  */
60 /******************************************************************************/
62 int timlCNNClassifyTopNBatchModeOpenMP(timlConvNeuralNetwork *cnn, float *data, int dim, int num, int *label, float *percent, int topN)
63 {
64    int                   err;
65    int                   i;
66    int                   j;
67    int                   t;
68    int                   k;
69    int                   thread;
70    struct timespec       startTime;
71    struct timespec       endTime;
72    long                  testingTime;
73    int                   outputDim;
74    int                   *index;
75    timlConvNeuralNetwork **cnnTeam;
77    // init
78    err       = 0;
79    thread    = omp_get_max_threads();
80    outputDim = cnn->tail->row*cnn->tail->col*cnn->tail->channel;
81    index     = malloc(sizeof(int)*outputDim*thread); // multiple copies for each thread
82    cnnTeam   = malloc(sizeof(timlConvNeuralNetwork*)*thread);
84    // create cnnTeam
85    cnnTeam[0] = cnn;
86    for (i = 1; i < thread; i++) {
87       cnnTeam[i] = timlCNNShareParams(cnn, 0);
88       if (cnnTeam[i] == NULL) {
89          return ERROR_CNN_TEAM_ALLOCATION;
90       }
91    }
93    // batch iterations
94    clock_gettime(CLOCK_REALTIME, &startTime);
95    #pragma omp parallel num_threads(thread) private(i, j, t)
96    {
97       #pragma omp for
98       for (j = 0; j < num; j++) {
99          t = omp_get_thread_num();
100          timlCNNForwardPropagation(cnnTeam[t], data + j*dim, dim); // fp
101          timlUtilVectorSortIndexFloat(cnnTeam[t]->tail->featureMap, index + t*outputDim, outputDim); // sort
102          for (i = 0; i < topN; i++) {
103             label[j*topN + i] = index[t*outputDim + outputDim - i - 1]; // record label
104             if (percent != NULL) {
105                timlUtilBLASscopy(1, cnnTeam[t]->tail->featureMap + index[t*outputDim + outputDim - i - 1], percent + j*topN + i, cnnTeam[t]->deviceId, cnnTeam[t]->threadId);
106             }
107          }
108       }
109    }
110    clock_gettime(CLOCK_REALTIME, &endTime);
111    testingTime = timlUtilDiffTime(startTime, endTime);
112 //   printf("Loop testing time = %.2fs.\n", testingTime/1000000.0);
114    // clean up
115    for (i = 1; i < thread; i++) {
116       timlCNNDelete(cnnTeam[i]);
117    }
118    free(cnnTeam);
119    free(index);
121    return err;