[ti-machine-learning/ti-machine-learning.git] / src / common / cnn / timlCNNClassifyTopNTeamModeOpenMP.c
1 /******************************************************************************/
2 /*!
3 * \file timlCNNClassifyTopNTeamModeOpenMP.c
4 */
5 /* Copyright (C) 2015 Texas Instruments Incorporated - http://www.ti.com/
6 *
7 * Redistribution and use in source and binary forms, with or without
8 * modification, are permitted provided that the following conditions
9 * are met:
10 *
11 * Redistributions of source code must retain the above copyright
12 * notice, this list of conditions and the following disclaimer.
13 *
14 * Redistributions in binary form must reproduce the above copyright
15 * notice, this list of conditions and the following disclaimer in the
16 * documentation and/or other materials provided with the
17 * distribution.
18 *
19 * Neither the name of Texas Instruments Incorporated nor the names of
20 * its contributors may be used to endorse or promote products derived
21 * from this software without specific prior written permission.
22 *
23 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
24 * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
25 * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
26 * A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
27 * OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
28 * SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
29 * LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
30 * DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
31 * THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
32 * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
33 * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
34 *
35 ******************************************************************************/
38 /*******************************************************************************
39 *
40 * INCLUDES
41 *
42 ******************************************************************************/
44 #include "../api/timl.h"
47 /******************************************************************************/
48 /*!
49 * \ingroup cnn
50 * \brief Batch classification using openmp
51 * \details This is the same function as timlCNNBatchClassifyOpenMP but avoids creating and deleting the cnn team each time the function is called
52 * \param[in,out] cnnTeam An array of CNNs that shares the same parameters
53 * \param[in] num Size of the CNN array as well as the data
54 * \param[in] data Data batch
55 * \param[in] dim Data dimension
56 * \param[in,out] label Label array ptr, size = num*topN
57 * \param[in,out] percent Percent array ptr, size = num*topN
58 * \param[in,out] topN Output the top N labels and the corresponding percentage
59 * \return Error code
60 */
61 /******************************************************************************/
63 int timlCNNClassifyTopNTeamModeOpenMP(timlConvNeuralNetwork **cnnTeam, int num, float *data, int dim, int *label, float *percent, int topN)
64 {
65 int err;
66 int i;
67 int j;
68 int outputDim;
69 int *index;
70 int t;
72 outputDim = cnnTeam[0]->tail->row*cnnTeam[0]->tail->col*cnnTeam[0]->tail->channel;
73 index = malloc(sizeof(int)*outputDim*num); // multiple copies for each thread
75 #pragma omp parallel num_threads(num) private(i, j, t)
76 {
77 #pragma omp for
78 for (j = 0; j < num; j++) {
79 t = omp_get_thread_num();
80 timlCNNForwardPropagation(cnnTeam[t], data + j*dim, dim); // fp
81 timlUtilVectorSortIndexFloat(cnnTeam[t]->tail->featureMap, index + t*outputDim, outputDim); // sort
82 for (i = 0; i < topN; i++) {
83 label[j*topN + i] = index[t*outputDim + outputDim - i - 1]; // record label
84 if (percent != NULL) {
85 percent[j*topN + i] = cnnTeam[t]->tail->featureMap[index[t*outputDim + outputDim - i - 1]]; // record percent
86 }
87 }
88 }
89 }
90 free(index);
92 return err;
93 }