]> Gitweb @ Texas Instruments - Open Source Git Repositories - git.TI.com/gitweb - ti-machine-learning/ti-machine-learning.git/blob - src/common/cnn/timlCNNAddLinearLayer.c
Fix CIFAR 10 Testing example.
[ti-machine-learning/ti-machine-learning.git] / src / common / cnn / timlCNNAddLinearLayer.c
1 /******************************************************************************/
2 /*!
3  * \file timlCNNAddLinearLayer.c
4  */
5 /* Copyright (C) 2015 Texas Instruments Incorporated - http://www.ti.com/
6  *
7  * Redistribution and use in source and binary forms, with or without
8  * modification, are permitted provided that the following conditions
9  * are met:
10  *
11  *    Redistributions of source code must retain the above copyright
12  *    notice, this list of conditions and the following disclaimer.
13  *
14  *    Redistributions in binary form must reproduce the above copyright
15  *    notice, this list of conditions and the following disclaimer in the
16  *    documentation and/or other materials provided with the
17  *    distribution.
18  *
19  *    Neither the name of Texas Instruments Incorporated nor the names of
20  *    its contributors may be used to endorse or promote products derived
21  *    from this software without specific prior written permission.
22  *
23  * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
24  * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
25  * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
26  * A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
27  * OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
28  * SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
29  * LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
30  * DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
31  * THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
32  * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
33  * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
34  *
35  ******************************************************************************/
38 /*******************************************************************************
39  *
40  * INCLUDES
41  *
42  ******************************************************************************/
44 #include "../api/timl.h"
47 /*******************************************************************************/
48 /**
49  * \ingroup       cnn
50  * \brief         Add linear layer
51  * \param[in,out] cnn    CNN
52  * \param[in]     dim    Output 1D feature map dimension
53  * \param[in]     params Optional parameters
54  * \return        Error code
55  */
56 /*******************************************************************************/
58 int timlCNNAddLinearLayer(timlConvNeuralNetwork *cnn, int dim, timlCNNLinearParams params)
59 {
60    int          prevDim;
61    timlCNNLayer *prev;
62    timlCNNLayer *linearLayer;
64    // error checking
65    if (cnn == NULL) {
66       return ERROR_CNN_NULL_PTR;
67    }
68    if (cnn->tail == NULL) {
69       return ERROR_CNN_EMPTY;
70    }
71    if (dim <= 0) {
72       return ERROR_CNN_LINEAR_LAYER_DIM;
73    }
75    prev    = cnn->tail;
76    prevDim = prev->row * prev->col * prev->channel;
78    // allocate linear layer
79    if (timlUtilMallocHost((void**)&linearLayer, sizeof(timlCNNLayer))) {
80       return ERROR_CNN_LAYER_ALLOCATION;
81    }
83 //   linearLayer->inputParams     = timlCNNInputParamsDefault();
84 //   linearLayer->convParams      = timlCNNConvParamsDefault();
85 //   linearLayer->poolingParams   = timlCNNPoolingParamsDefault();
86 //   linearLayer->normParams      = timlCNNNormParamsDefault();
87 //   linearLayer->nonlinearParams = timlCNNNonlinearParamsDefault();
89    // load parameter
90    linearLayer->linearParams = params;
91    // override parameters
92    linearLayer->linearParams.weight          = NULL;
93    linearLayer->linearParams.weightInc       = NULL;
94    linearLayer->linearParams.weightGradAccum = NULL;
95    linearLayer->linearParams.bias            = NULL;
96    linearLayer->linearParams.biasInc         = NULL;
97    linearLayer->linearParams.biasGradAccum   = NULL;
98    linearLayer->linearParams.biasMultiplier  = NULL;
99    linearLayer->linearParams.prevDim         = prevDim;
100    linearLayer->linearParams.dim             = dim;
101    linearLayer->linearParams.shared          = false;
103    linearLayer->type           = CNN_Linear;
104    linearLayer->featureMap     = NULL;
105    linearLayer->delta          = NULL;
106    linearLayer->row            = 1;
107    linearLayer->col            = 1;
108    linearLayer->channel        = dim;
109    linearLayer->batchSize      = prev->batchSize;
110    linearLayer->maxBatchSize   = prev->maxBatchSize;
111    linearLayer->allocatorLevel = cnn->params.allocatorLevel;
112    linearLayer->phase          = cnn->params.phase;
114    // link the linearLayer
115    linearLayer->cnn        = cnn;
116    linearLayer->id         = prev->id + 1;
117    linearLayer->prev       = prev;
118    linearLayer->prev->next = linearLayer;
119    linearLayer->next       = NULL;
120    cnn->tail               = linearLayer;
122    return 0;