]> Gitweb @ Texas Instruments - Open Source Git Repositories - git.TI.com/gitweb - ti-machine-learning/ti-machine-learning.git/blob - src/common/cnn/timlCNNConvBackPropagation.c
1. Enable network state write/read
[ti-machine-learning/ti-machine-learning.git] / src / common / cnn / timlCNNConvBackPropagation.c
1 /******************************************************************************/\r
2 /*!\r
3  * \file timlCNNConvBackPropagation.c\r
4  */\r
5 /* Copyright (C) 2015 Texas Instruments Incorporated - http://www.ti.com/\r
6  *\r
7  * Redistribution and use in source and binary forms, with or without\r
8  * modification, are permitted provided that the following conditions\r
9  * are met:\r
10  *\r
11  *    Redistributions of source code must retain the above copyright\r
12  *    notice, this list of conditions and the following disclaimer.\r
13  *\r
14  *    Redistributions in binary form must reproduce the above copyright\r
15  *    notice, this list of conditions and the following disclaimer in the\r
16  *    documentation and/or other materials provided with the\r
17  *    distribution.\r
18  *\r
19  *    Neither the name of Texas Instruments Incorporated nor the names of\r
20  *    its contributors may be used to endorse or promote products derived\r
21  *    from this software without specific prior written permission.\r
22  *\r
23  * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS\r
24  * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT\r
25  * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR\r
26  * A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT\r
27  * OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,\r
28  * SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT\r
29  * LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,\r
30  * DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY\r
31  * THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT\r
32  * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE\r
33  * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.\r
34  *\r
35  ******************************************************************************/\r
36 \r
37 \r
38 /*******************************************************************************\r
39  *\r
40  * INCLUDES\r
41  *\r
42  ******************************************************************************/\r
43 \r
44 #include "../api/timl.h"\r
45 \r
46 \r
47 /******************************************************************************/\r
48 /*!\r
49  * \ingroup       cnn\r
50  * \brief         Back propagate the gradient from the conv layer to the previous layer\r
51  * \details       layer->prev->delta[i] = sum_{j}(layer->delta[j] conv2full layer->kernel[i, j])\r
52  * \param[in]     layer Layer ptr\r
53  * \return        Error code\r
54  */\r
55 /******************************************************************************/\r
56 \r
57 int timlCNNConvBackPropagation(timlCNNLayer *layer)\r
58 {\r
59    int M;\r
60    int K;\r
61    int N;\r
62    int b;\r
63    int err;\r
64    int prevRow;\r
65    int prevCol;\r
66    int prevChannel;\r
67    int row;\r
68    int col;\r
69    int channel;\r
70    int kernelRow;\r
71    int kernelCol;\r
72    int deviceId;\r
73    int threadId;\r
74 \r
75    // init\r
76    prevRow     = layer->prev->row;\r
77    prevCol     = layer->prev->col;\r
78    prevChannel = layer->prev->channel;\r
79    row         = layer->row;\r
80    col         = layer->col;\r
81    channel     = layer->channel;\r
82    kernelRow   = layer->convParams.kernelRow;\r
83    kernelCol   = layer->convParams.kernelCol;\r
84    deviceId    = layer->cnn->deviceId;\r
85    threadId    = layer->cnn->threadId;\r
86    M           = channel;\r
87    K           = kernelRow*kernelCol*prevChannel;\r
88    N           = row*col;\r
89    err         = 0;\r
90 \r
91    // kernelGrad = delta * prevFeatureMapReshape' -- (M*N)*(N*K)\r
92 \r
93    for (b = 0; b < layer->batchSize; b++) {\r
94       timlUtilConv2ImageReshape(layer->convParams.prevFeatureMapReshape, layer->prev->featureMap + b*prevRow*prevCol*prevChannel, layer->convParams.prevFeatureMapReshapeIndex, prevChannel, prevRow*prevCol, kernelRow*kernelCol*row*col, deviceId, threadId);\r
95       #pragma omp critical\r
96       {\r
97          timlUtilBLASsgemm(CblasNoTrans, CblasTrans, M, K, N, 1.0, layer->delta + b*M*N, layer->convParams.prevFeatureMapReshape, 1.0, layer->convParams.kernelGradAccum, deviceId, threadId);\r
98          timlUtilBLASsgemv(CblasNoTrans, M, N, 1.0, layer->delta + b*M*N, layer->convParams.biasMultiplier, 1.0, layer->convParams.biasGradAccum, deviceId, threadId);\r
99       }\r
100    }\r
101 \r
102    // back propagate delta\r
103    if (layer->prev->delta != NULL) {\r
104       for (b = 0; b < layer->batchSize; b++) {\r
105          // reset prevDelta to 0\r
106          timlUtilVectorResetFloat(layer->prev->delta + b*prevRow*prevCol*prevChannel, prevRow*prevCol*prevChannel, 0.0, deviceId, threadId);\r
107          // prevDeltaTemp = kernel' * delta -- (K*M)(M*N)\r
108          timlUtilBLASsgemm(CblasTrans, CblasNoTrans, K, N, M, 1.0, layer->convParams.kernel, layer->delta + b*M*N, 0.0, layer->convParams.prevFeatureMapReshape, deviceId, threadId);\r
109          // reshape prevDeltaTemp to prevDelta\r
110          timlUtilConv2ImageReshapeBack(layer->prev->delta + b*prevRow*prevCol*prevChannel, layer->convParams.prevFeatureMapReshape, layer->convParams.prevFeatureMapReshapeIndex, prevChannel, prevRow*prevCol, kernelRow*kernelCol*row*col, deviceId, threadId);\r
111       }\r
112    }\r
113 \r
114    return err;\r
115 }\r