]> Gitweb @ Texas Instruments - Open Source Git Repositories - git.TI.com/gitweb - ti-machine-learning/ti-machine-learning.git/blob - debian/ti-timl/usr/src/timl/src/common/cnn/timlCNNClassifyTopNBatchMode.c
modified
[ti-machine-learning/ti-machine-learning.git] / debian / ti-timl / usr / src / timl / src / common / cnn / timlCNNClassifyTopNBatchMode.c
1 /*****************************************************************************/
2 /*!
3  * \file timlCNNClassifyTopNBatchMode.c
4  */
5 /* Copyright (C) 2015 Texas Instruments Incorporated - http://www.ti.com/
6  *
7  * Redistribution and use in source and binary forms, with or without
8  * modification, are permitted provided that the following conditions
9  * are met:
10  *
11  *    Redistributions of source code must retain the above copyright
12  *    notice, this list of conditions and the following disclaimer.
13  *
14  *    Redistributions in binary form must reproduce the above copyright
15  *    notice, this list of conditions and the following disclaimer in the
16  *    documentation and/or other materials provided with the
17  *    distribution.
18  *
19  *    Neither the name of Texas Instruments Incorporated nor the names of
20  *    its contributors may be used to endorse or promote products derived
21  *    from this software without specific prior written permission.
22  *
23  * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
24  * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
25  * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
26  * A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
27  * OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
28  * SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
29  * LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
30  * DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
31  * THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
32  * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
33  * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
34  *
35  ******************************************************************************/
38 /*******************************************************************************
39  *
40  * INCLUDES
41  *
42  ******************************************************************************/
44 #include "../api/timl.h"
47 /******************************************************************************/
48 /*!
49  * \ingroup       cnn
50  * \brief         Batch classification
51  * \param[in,out] cnn     CNN
52  * \param[in]     data    Data batch
53  * \param[in]     dim     Data dimension
54  * \param[in]     num     Data number
55  * \param[out]    label  Label array ptr, size = num*topN
56  * \param[out]    percent Percent array ptr, size = num*topN
57  * \param[out]    topN    Output the top N labels and the corresponding percentage
58  * \return        Error code
59  */
60 /******************************************************************************/
62 int timlCNNClassifyTopNBatchMode(timlConvNeuralNetwork *cnn, float *data, int dim, int num, int *label, float *percent, int topN)
63 {
64    int i;
65    int j;
66    int t;
67    int k;
68    int outputDim;
69    int *index;
70    int err;
72    err       = 0;
73    outputDim = cnn->tail->row*cnn->tail->col*cnn->tail->channel;
74    index     = malloc(sizeof(int)*outputDim); // multiple copies for each thread
76    for (j = 0; j < num; j++) {
77       err = timlCNNForwardPropagation(cnn, data + j*dim, dim); // fp
78       if (err) return err;
79       err = timlUtilVectorSortIndexFloat(cnn->tail->featureMap, index, outputDim); // sort
80       if (err) return err;
81       for (i = 0; i < topN; i++) {
82          label[j*topN + i] = index[outputDim - i - 1]; // record label
83          if (percent != NULL) {
84             timlUtilBLASscopy(1, cnn->tail->featureMap + index[outputDim - i - 1], percent + j*topN + i, cnn->deviceId, cnn->threadId);
85          }
86       }
87    }
88    free(index);
90    return err;
91 }