[jacinto-ai/pytorch-jacinto-ai-devkit.git] / modules / pytorch_jacinto_ai / xvision / models / mobilenetv2_ericsun_internal.py
1 #################################################################################
2 # Copyright (c) 2018-2021, Texas Instruments Incorporated - http://www.ti.com
3 # All Rights Reserved.
4 #
5 # Redistribution and use in source and binary forms, with or without
6 # modification, are permitted provided that the following conditions are met:
7 #
8 # * Redistributions of source code must retain the above copyright notice, this
9 # list of conditions and the following disclaimer.
10 #
11 # * Redistributions in binary form must reproduce the above copyright notice,
12 # this list of conditions and the following disclaimer in the documentation
13 # and/or other materials provided with the distribution.
14 #
15 # * Neither the name of the copyright holder nor the names of its
16 # contributors may be used to endorse or promote products derived from
17 # this software without specific prior written permission.
18 #
19 # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
20 # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
21 # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
22 # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
23 # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
24 # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
25 # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
26 # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
27 # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
28 # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29 #
30 #################################################################################
31 # Some parts of the code are borrowed from: https://github.com/pytorch/vision
32 # with the following license:
33 #
34 # BSD 3-Clause License
35 #
36 # Copyright (c) Soumith Chintala 2016,
37 # All rights reserved.
38 #
39 # Redistribution and use in source and binary forms, with or without
40 # modification, are permitted provided that the following conditions are met:
41 #
42 # * Redistributions of source code must retain the above copyright notice, this
43 # list of conditions and the following disclaimer.
44 #
45 # * Redistributions in binary form must reproduce the above copyright notice,
46 # this list of conditions and the following disclaimer in the documentation
47 # and/or other materials provided with the distribution.
48 #
49 # * Neither the name of the copyright holder nor the names of its
50 # contributors may be used to endorse or promote products derived from
51 # this software without specific prior written permission.
52 #
53 # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
54 # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
55 # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
56 # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
57 # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
58 # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
59 # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
60 # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
61 # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
62 # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
63 #
64 #################################################################################
65 #For MobileNetV2 model from https://github.com/ericsun99/MobileNet-V2-Pytorch
66 #
67 # BSD 2-Clause License
68 #
69 # Copyright (c) 2018, ericsun99
70 # All rights reserved.
71 #
72 # Redistribution and use in source and binary forms, with or without
73 # modification, are permitted provided that the following conditions are met:
74 #
75 # * Redistributions of source code must retain the above copyright notice, this
76 # list of conditions and the following disclaimer.
77 #
78 # * Redistributions in binary form must reproduce the above copyright notice,
79 # this list of conditions and the following disclaimer in the documentation
80 # and/or other materials provided with the distribution.
81 #
82 # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
83 # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
84 # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
85 # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
86 # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
87 # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
88 # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
89 # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
90 # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
91 # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
92 #################################################################################
95 import torch
96 import math
97 from ... import xnn
99 ###################################################
100 __all__ = ['get_config', 'MobileNetV2EricsunBase', 'MobileNetV2Ericsun', 'mobilenet_v2_ericsun']
103 ###################################################
104 def get_config():
105 model_config = xnn.utils.ConfigNode()
106 model_config.input_channels = 3
107 model_config.num_classes = 1000
108 model_config.width_mult = 1.
109 model_config.expand_ratio = 6
110 model_config.strides = (2,2,2,2,2)
111 model_config.activation = xnn.layers.DefaultAct2d
112 model_config.kernel_size = 3
113 model_config.dropout = False
114 model_config.linear_dw = False
115 model_config.layer_setting = None
116 return model_config
119 ###################################################
120 def conv_bn(inp, oup, stride, activation, kernel_size=3):
121 return torch.nn.Sequential(
122 torch.nn.Conv2d(inp, oup, kernel_size, stride, kernel_size//2, bias=False),
123 xnn.layers.DefaultNorm2d(oup),
124 activation(inplace=True)
125 )
127 ###################################################
128 def conv(inp, oup, stride, activation, kernel_size=3):
129 return torch.nn.Sequential(
130 torch.nn.Conv2d(inp, oup, kernel_size, stride, kernel_size//2, bias=False),
131 activation(inplace=True)
132 )
134 def conv_1x1_bn(inp, oup, activation, groups=1):
135 return torch.nn.Sequential(
136 torch.nn.Conv2d(inp, oup, 1, 1, 0, groups=groups, bias=False),
137 xnn.layers.DefaultNorm2d(oup),
138 activation(inplace=True)
139 )
142 def width_multiplier(value, base=8, min_val=8):
143 min_val = base if (min_val is None) else min_val
144 value = int(math.floor(float(value) / base + 0.5) * base)
145 value = max(value, min_val) if min_val else value
146 value = int(value)
147 return value
150 ###################################################
151 class InvertedResidual(torch.nn.Module):
152 def __init__(self, input_channels, output_channels, stride, expand_ratio, activation=None, kernel_size=3, linear_dw=False):
153 super(InvertedResidual, self).__init__()
154 self.stride = stride
156 self.use_res_connect = (self.stride == 1 and input_channels == output_channels)
157 intermediate_channels = input_channels * expand_ratio
158 conv = [
159 # pw
160 torch.nn.Conv2d(input_channels, intermediate_channels, 1, 1, 0, bias=False),
161 xnn.layers.DefaultNorm2d(intermediate_channels),
162 activation(inplace=True),
163 # dw
164 torch.nn.Conv2d(intermediate_channels, intermediate_channels, kernel_size, stride, kernel_size//2, groups=intermediate_channels, bias=False),
165 xnn.layers.BypassBlock() if linear_dw else xnn.layers.DefaultNorm2d(input_channels * expand_ratio),
166 xnn.layers.BypassBlock() if linear_dw else activation(inplace=True),
167 # pw-linear
168 torch.nn.Conv2d(intermediate_channels, output_channels, 1, 1, 0, bias=False),
169 xnn.layers.DefaultNorm2d(output_channels)
170 ]
171 if linear_dw:
172 conv.append(activation(inplace=True))
174 self.conv = torch.nn.Sequential(*conv)
175 if self.use_res_connect:
176 self.add = xnn.layers.AddBlock(signed=True)
178 def forward(self, x):
179 x1 = self.conv(x)
180 if self.use_res_connect:
181 x1 = self.add((x, x1))
183 return x1
186 ###################################################
187 class MobileNetV2EricsunBase(torch.nn.Module):
188 def __init__(self, ResidualBlock, model_config):
189 super().__init__()
190 self.model_config = model_config
191 self.num_classes = self.model_config.num_classes
193 # strides of various layers
194 s0 = model_config.strides[0]
195 s1 = model_config.strides[1]
196 s2 = model_config.strides[2]
197 s3 = model_config.strides[3]
198 s4 = model_config.strides[4]
200 # setting of inverted residual blocks
201 if self.model_config.layer_setting is None:
202 expand_ratio = self.model_config.expand_ratio
203 self.model_config.layer_setting = [
204 # t, c, n, s
205 [1, 32, 1, s0],
206 [1, 16, 1, 1],
207 [expand_ratio, 24, 2, s1],
208 [expand_ratio, 32, 3, s2],
209 [expand_ratio, 64, 4, s3],
210 [expand_ratio, 96, 3, 1],
211 [expand_ratio, 160, 3, s4],
212 [expand_ratio, 320, 1, 1],
213 [1, 1280, 1, 1],
214 ]
216 # building first layer
217 stride = self.model_config.layer_setting[0][3]
218 output_channels = width_multiplier(self.model_config.layer_setting[0][1]*self.model_config.width_mult)
219 features = [conv_bn(self.model_config.input_channels, output_channels, stride, self.model_config.activation, kernel_size=3)]
220 channels = output_channels
222 # building inverted residual blocks
223 for t, c, n, s in self.model_config.layer_setting[1:-1]:
224 output_channels = width_multiplier(c*self.model_config.width_mult)
225 for i in range(n):
226 stride = (s if i == 0 else 1)
227 block = ResidualBlock(channels, output_channels, stride, t, self.model_config.activation, \
228 kernel_size=self.model_config.kernel_size, linear_dw=self.model_config.linear_dw)
229 features.append(block)
230 channels = output_channels
231 #
232 #
234 # building classifier
235 if self.model_config.num_classes != None:
236 output_channels = width_multiplier(self.model_config.layer_setting[-1][1]*self.model_config.width_mult)
237 features.append(conv_1x1_bn(channels, output_channels, self.model_config.activation))
238 features.append(torch.nn.AdaptiveAvgPool2d(1))
239 channels = output_channels
241 # building classifier
242 self.classifier = torch.nn.Sequential(
243 torch.nn.Dropout(p=0.2, inplace=True) if self.model_config.dropout else xnn.layers.BypassBlock(),
244 torch.nn.Linear(channels, self.model_config.num_classes),
245 )
247 # make it torch.nn.Sequential
248 self.features = torch.nn.Sequential(*features)
250 self._initialize_weights()
253 def forward(self, x):
254 for block_id, block in enumerate(self.features):
255 # TODO: Cleanup. It should not be done in this complicated way.
256 # To print the correct size of features.
257 if isinstance(block, torch.nn.AdaptiveAvgPool2d):
258 xnn.utils.print_once('=> feature size is: ', x.size())
259 #
260 x = block(x)
262 if self.num_classes is not None:
263 x = torch.flatten(x, 1)
264 x = self.classifier(x)
266 return x
269 def _initialize_weights(self):
270 for m in self.modules():
271 if isinstance(m, torch.nn.Conv2d):
272 torch.nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
273 if m.bias is not None:
274 m.bias.data.zero_()
275 elif isinstance(m, torch.nn.BatchNorm2d):
276 if m.weight is not None:
277 torch.nn.init.constant_(m.weight, 1)
278 if m.bias is not None:
279 torch.nn.init.constant_(m.bias, 0)
280 elif isinstance(m, torch.nn.Linear):
281 m.weight.data.normal_(0, 0.01)
282 m.bias.data.zero_()
285 class MobileNetV2Ericsun(MobileNetV2EricsunBase):
286 def __init__(self, model_config):
287 model_config = get_config().merge_from(model_config)
288 super().__init__(InvertedResidual, model_config)
291 ###################################################
292 class mobilenet_v2_ericsun(MobileNetV2Ericsun):
293 def __init__(self, model_config, pretrained):
294 super().__init__(model_config)
295 if pretrained:
296 _ = xnn.utils.load_weights(self, pretrained)