1 /******************************************************************************/\r
2 /*!\r
3 * \file timlCNNConvBackPropagation.c\r
4 */\r
5 /* Copyright (C) 2015 Texas Instruments Incorporated - http://www.ti.com/\r
6 *\r
7 * Redistribution and use in source and binary forms, with or without\r
8 * modification, are permitted provided that the following conditions\r
9 * are met:\r
10 *\r
11 * Redistributions of source code must retain the above copyright\r
12 * notice, this list of conditions and the following disclaimer.\r
13 *\r
14 * Redistributions in binary form must reproduce the above copyright\r
15 * notice, this list of conditions and the following disclaimer in the\r
16 * documentation and/or other materials provided with the\r
17 * distribution.\r
18 *\r
19 * Neither the name of Texas Instruments Incorporated nor the names of\r
20 * its contributors may be used to endorse or promote products derived\r
21 * from this software without specific prior written permission.\r
22 *\r
23 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS\r
24 * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT\r
25 * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR\r
26 * A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT\r
27 * OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,\r
28 * SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT\r
29 * LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,\r
30 * DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY\r
31 * THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT\r
32 * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE\r
33 * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.\r
34 *\r
35 ******************************************************************************/\r
36 \r
37 \r
38 /*******************************************************************************\r
39 *\r
40 * INCLUDES\r
41 *\r
42 ******************************************************************************/\r
43 \r
44 #include "../api/timl.h"\r
45 \r
46 \r
47 /******************************************************************************/\r
48 /*!\r
49 * \ingroup cnn\r
50 * \brief Back propagate the gradient from the conv layer to the previous layer\r
51 * \details layer->prev->delta[i] = sum_{j}(layer->delta[j] conv2full layer->kernel[i, j])\r
52 * \param[in] layer Layer ptr\r
53 * \return Error code\r
54 */\r
55 /******************************************************************************/\r
56 \r
57 int timlCNNConvBackPropagation(timlCNNLayer *layer)\r
58 {\r
59 int M;\r
60 int K;\r
61 int N;\r
62 int b;\r
63 int err;\r
64 int prevRow;\r
65 int prevCol;\r
66 int prevChannel;\r
67 int row;\r
68 int col;\r
69 int channel;\r
70 int kernelRow;\r
71 int kernelCol;\r
72 int deviceId;\r
73 int threadId;\r
74 \r
75 // init\r
76 prevRow = layer->prev->row;\r
77 prevCol = layer->prev->col;\r
78 prevChannel = layer->prev->channel;\r
79 row = layer->row;\r
80 col = layer->col;\r
81 channel = layer->channel;\r
82 kernelRow = layer->convParams.kernelRow;\r
83 kernelCol = layer->convParams.kernelCol;\r
84 deviceId = layer->cnn->deviceId;\r
85 threadId = layer->cnn->threadId;\r
86 M = channel;\r
87 K = kernelRow*kernelCol*prevChannel;\r
88 N = row*col;\r
89 err = 0;\r
90 \r
91 // kernelGrad = delta * prevFeatureMapReshape' -- (M*N)*(N*K)\r
92 \r
93 for (b = 0; b < layer->batchSize; b++) {\r
94 timlUtilConv2ImageReshape(layer->convParams.prevFeatureMapReshape, layer->prev->featureMap + b*prevRow*prevCol*prevChannel, layer->convParams.prevFeatureMapReshapeIndex, prevChannel, prevRow*prevCol, kernelRow*kernelCol*row*col, deviceId, threadId);\r
95 #pragma omp critical\r
96 {\r
97 timlUtilBLASsgemm(CblasNoTrans, CblasTrans, M, K, N, 1.0, layer->delta + b*M*N, layer->convParams.prevFeatureMapReshape, 1.0, layer->convParams.kernelGradAccum, deviceId, threadId);\r
98 timlUtilBLASsgemv(CblasNoTrans, M, N, 1.0, layer->delta + b*M*N, layer->convParams.biasMultiplier, 1.0, layer->convParams.biasGradAccum, deviceId, threadId);\r
99 }\r
100 }\r
101 \r
102 // back propagate delta\r
103 if (layer->prev->delta != NULL) {\r
104 for (b = 0; b < layer->batchSize; b++) {\r
105 // reset prevDelta to 0\r
106 timlUtilVectorResetFloat(layer->prev->delta + b*prevRow*prevCol*prevChannel, prevRow*prevCol*prevChannel, 0.0, deviceId, threadId);\r
107 // prevDeltaTemp = kernel' * delta -- (K*M)(M*N)\r
108 timlUtilBLASsgemm(CblasTrans, CblasNoTrans, K, N, M, 1.0, layer->convParams.kernel, layer->delta + b*M*N, 0.0, layer->convParams.prevFeatureMapReshape, deviceId, threadId);\r
109 // reshape prevDeltaTemp to prevDelta\r
110 timlUtilConv2ImageReshapeBack(layer->prev->delta + b*prevRow*prevCol*prevChannel, layer->convParams.prevFeatureMapReshape, layer->convParams.prevFeatureMapReshapeIndex, prevChannel, prevRow*prevCol, kernelRow*kernelCol*row*col, deviceId, threadId);\r
111 }\r
112 }\r
113 \r
114 return err;\r
115 }\r