[jacinto-ai/pytorch-jacinto-ai-devkit.git] / modules / pytorch_jacinto_ai / xvision / models / mobilenetv3.py
1 #################################################################################
2 # Copyright (c) 2018-2021, Texas Instruments Incorporated - http://www.ti.com
3 # All Rights Reserved.
4 #
5 # Redistribution and use in source and binary forms, with or without
6 # modification, are permitted provided that the following conditions are met:
7 #
8 # * Redistributions of source code must retain the above copyright notice, this
9 # list of conditions and the following disclaimer.
10 #
11 # * Redistributions in binary form must reproduce the above copyright notice,
12 # this list of conditions and the following disclaimer in the documentation
13 # and/or other materials provided with the distribution.
14 #
15 # * Neither the name of the copyright holder nor the names of its
16 # contributors may be used to endorse or promote products derived from
17 # this software without specific prior written permission.
18 #
19 # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
20 # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
21 # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
22 # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
23 # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
24 # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
25 # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
26 # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
27 # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
28 # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29 #
30 #################################################################################
31 # Some parts of the code are borrowed from: https://github.com/pytorch/vision
32 # with the following license:
33 #
34 # BSD 3-Clause License
35 #
36 # Copyright (c) Soumith Chintala 2016,
37 # All rights reserved.
38 #
39 # Redistribution and use in source and binary forms, with or without
40 # modification, are permitted provided that the following conditions are met:
41 #
42 # * Redistributions of source code must retain the above copyright notice, this
43 # list of conditions and the following disclaimer.
44 #
45 # * Redistributions in binary form must reproduce the above copyright notice,
46 # this list of conditions and the following disclaimer in the documentation
47 # and/or other materials provided with the distribution.
48 #
49 # * Neither the name of the copyright holder nor the names of its
50 # contributors may be used to endorse or promote products derived from
51 # this software without specific prior written permission.
52 #
53 # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
54 # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
55 # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
56 # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
57 # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
58 # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
59 # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
60 # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
61 # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
62 # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
63 #
64 #################################################################################
66 import torch
68 from functools import partial
69 from torch import nn, Tensor
70 from torch.nn import functional as F
71 from typing import Any, Callable, Dict, List, Optional, Sequence
73 from .utils import load_state_dict_from_url
74 from .mobilenetv2 import _make_divisible, ConvBNActivation
76 from ... import xnn
78 __all__ = ["MobileNetV3", "mobilenet_v3_large", "mobilenet_v3_small"]
81 model_urls = {
82 "mobilenet_v3_large": "https://download.pytorch.org/models/mobilenet_v3_large-8738ca79.pth",
83 "mobilenet_v3_small": "https://download.pytorch.org/models/mobilenet_v3_small-047dcff4.pth",
84 }
87 ###################################################
88 def get_config():
89 model_config = xnn.utils.ConfigNode()
90 model_config.input_channels = 3
91 model_config.num_classes = 1000
92 model_config.width_mult = 1.
93 model_config.strides = None #(2,2,2,2,2)
94 model_config.enable_fp16 = False
95 return model_config
98 class SqueezeExcitation(nn.Module):
100 def __init__(self, input_channels: int, squeeze_factor: int = 4):
101 super().__init__()
102 squeeze_channels = _make_divisible(input_channels // squeeze_factor, 8)
103 self.fc1 = nn.Conv2d(input_channels, squeeze_channels, 1)
104 self.relu = nn.ReLU(inplace=True)
105 self.fc2 = nn.Conv2d(squeeze_channels, input_channels, 1)
107 def _scale(self, input: Tensor, inplace: bool) -> Tensor:
108 scale = F.adaptive_avg_pool2d(input, 1)
109 scale = self.fc1(scale)
110 scale = self.relu(scale)
111 scale = self.fc2(scale)
112 return F.hardsigmoid(scale, inplace=inplace)
114 def forward(self, input: Tensor) -> Tensor:
115 scale = self._scale(input, True)
116 return scale * input
119 class InvertedResidualConfig:
121 def __init__(self, input_channels: int, kernel: int, expanded_channels: int, out_channels: int, use_se: bool,
122 activation: str, stride: int, dilation: int, width_mult: float):
123 self.input_channels = self.adjust_channels(input_channels, width_mult)
124 self.kernel = kernel
125 self.expanded_channels = self.adjust_channels(expanded_channels, width_mult)
126 self.out_channels = self.adjust_channels(out_channels, width_mult)
127 self.use_se = use_se
128 self.use_hs = activation == "HS"
129 self.stride = stride
130 self.dilation = dilation
132 @staticmethod
133 def adjust_channels(channels: int, width_mult: float):
134 return _make_divisible(channels * width_mult, 8)
137 class InvertedResidual(nn.Module):
139 def __init__(self, cnf: InvertedResidualConfig, norm_layer: Callable[..., nn.Module],
140 se_layer: Callable[..., nn.Module] = SqueezeExcitation):
141 super().__init__()
142 if not (1 <= cnf.stride <= 2):
143 raise ValueError('illegal stride value')
145 self.use_res_connect = cnf.stride == 1 and cnf.input_channels == cnf.out_channels
147 layers: List[nn.Module] = []
148 activation_layer = nn.Hardswish if cnf.use_hs else nn.ReLU
150 # expand
151 if cnf.expanded_channels != cnf.input_channels:
152 layers.append(ConvBNActivation(cnf.input_channels, cnf.expanded_channels, kernel_size=1,
153 norm_layer=norm_layer, activation_layer=activation_layer))
155 # depthwise
156 stride = 1 if cnf.dilation > 1 else cnf.stride
157 layers.append(ConvBNActivation(cnf.expanded_channels, cnf.expanded_channels, kernel_size=cnf.kernel,
158 stride=stride, dilation=cnf.dilation, groups=cnf.expanded_channels,
159 norm_layer=norm_layer, activation_layer=activation_layer))
160 if cnf.use_se:
161 layers.append(se_layer(cnf.expanded_channels))
163 # project
164 layers.append(ConvBNActivation(cnf.expanded_channels, cnf.out_channels, kernel_size=1, norm_layer=norm_layer,
165 activation_layer=nn.Identity))
167 self.block = nn.Sequential(*layers)
168 self.out_channels = cnf.out_channels
169 self._is_cn = cnf.stride > 1
171 def forward(self, input: Tensor) -> Tensor:
172 result = self.block(input)
173 if self.use_res_connect:
174 result += input
175 return result
178 class MobileNetV3(nn.Module):
180 def __init__(
181 self,
182 inverted_residual_setting: List[InvertedResidualConfig],
183 last_channel: int,
184 #num_classes: int = 1000,
185 block: Optional[Callable[..., nn.Module]] = None,
186 norm_layer: Optional[Callable[..., nn.Module]] = None,
187 activation_layer: Optional[Callable[..., nn.Module]] = nn.Hardswish,
188 **kwargs: Dict
189 ) -> None:
190 """
191 MobileNet V3 main class
193 Args:
194 inverted_residual_setting (List[InvertedResidualConfig]): Network structure
195 last_channel (int): The number of channels on the penultimate layer
196 num_classes (int): Number of classes
197 block (Optional[Callable[..., nn.Module]]): Module specifying inverted residual building block for mobilenet
198 norm_layer (Optional[Callable[..., nn.Module]]): Module specifying the normalization layer to use
199 """
200 model_config = get_config()
201 if 'model_config' in list(kwargs.keys()):
202 model_config = model_config.merge_from(kwargs['model_config'])
203 #
204 strides = model_config.strides if (model_config.strides is not None) else (2,2,2,2,2)
205 super().__init__()
206 self.num_classes = model_config.num_classes
207 self.enable_fp16 = model_config.enable_fp16
209 if not inverted_residual_setting:
210 raise ValueError("The inverted_residual_setting should not be empty")
211 elif not (isinstance(inverted_residual_setting, Sequence) and
212 all([isinstance(s, InvertedResidualConfig) for s in inverted_residual_setting])):
213 raise TypeError("The inverted_residual_setting should be List[InvertedResidualConfig]")
215 if block is None:
216 block = InvertedResidual
218 if norm_layer is None:
219 norm_layer = partial(nn.BatchNorm2d, eps=0.001, momentum=0.01)
221 layers: List[nn.Module] = []
223 # building first layer
224 firstconv_output_channels = inverted_residual_setting[0].input_channels
225 layers.append(ConvBNActivation(model_config.input_channels, firstconv_output_channels, kernel_size=3, stride=strides[0], norm_layer=norm_layer,
226 activation_layer=activation_layer))
228 # building inverted residual blocks
229 for cnf in inverted_residual_setting:
230 layers.append(block(cnf, norm_layer))
232 # building last several layers
233 lastconv_input_channels = inverted_residual_setting[-1].out_channels
234 lastconv_output_channels = 6 * lastconv_input_channels
235 layers.append(ConvBNActivation(lastconv_input_channels, lastconv_output_channels, kernel_size=1,
236 norm_layer=norm_layer, activation_layer=activation_layer))
238 self.features = nn.Sequential(*layers)
239 self.avgpool = nn.AdaptiveAvgPool2d(1)
240 self.classifier = nn.Sequential(
241 nn.Linear(lastconv_output_channels, last_channel),
242 activation_layer(inplace=False),
243 nn.Dropout(p=0.2, inplace=True),
244 nn.Linear(last_channel, model_config.num_classes),
245 )
247 for m in self.modules():
248 if isinstance(m, nn.Conv2d):
249 nn.init.kaiming_normal_(m.weight, mode='fan_out')
250 if m.bias is not None:
251 nn.init.zeros_(m.bias)
252 elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
253 nn.init.ones_(m.weight)
254 nn.init.zeros_(m.bias)
255 elif isinstance(m, nn.Linear):
256 nn.init.normal_(m.weight, 0, 0.01)
257 nn.init.zeros_(m.bias)
259 def _forward_impl(self, x: Tensor) -> Tensor:
260 x = self.features(x)
262 x = self.avgpool(x)
263 x = torch.flatten(x, 1)
265 x = self.classifier(x)
267 return x
269 @xnn.utils.auto_fp16
270 def forward(self, x: Tensor) -> Tensor:
271 return self._forward_impl(x)
274 def _mobilenet_v3_conf(arch: str, use_se: bool = True, hs_type: str='HS', **params: Dict[str, Any]):
275 # non-public config parameters
276 reduce_divider = 2 if params.pop('_reduced_tail', False) else 1
277 dilation = 2 if params.pop('_dilated', False) else 1
278 width_mult = params.pop('_width_mult', 1.0)
280 model_config = get_config()
281 if 'model_config' in list(params.keys()):
282 model_config = model_config.merge_from(params['model_config'])
283 width_mult = max(width_mult, model_config.width_mult)
284 #
285 strides = model_config.strides if (model_config.strides is not None) else (2,2,2,2,2)
287 bneck_conf = partial(InvertedResidualConfig, width_mult=width_mult)
288 adjust_channels = partial(InvertedResidualConfig.adjust_channels, width_mult=width_mult)
290 if arch in ("mobilenet_v3_large", "mobilenet_v3_lite_large"):
291 inverted_residual_setting = [
292 bneck_conf(16, 3, 16, 16, False, "RE", 1, 1),
293 bneck_conf(16, 3, 64, 24, False, "RE", strides[1], 1), # C1
294 bneck_conf(24, 3, 72, 24, False, "RE", 1, 1),
295 bneck_conf(24, 5, 72, 40, use_se, "RE", strides[2], 1), # C2
296 bneck_conf(40, 5, 120, 40, use_se, "RE", 1, 1),
297 bneck_conf(40, 5, 120, 40, use_se, "RE", 1, 1),
298 bneck_conf(40, 3, 240, 80, False, hs_type, strides[3], 1), # C3
299 bneck_conf(80, 3, 200, 80, False, hs_type, 1, 1),
300 bneck_conf(80, 3, 184, 80, False, hs_type, 1, 1),
301 bneck_conf(80, 3, 184, 80, False, hs_type, 1, 1),
302 bneck_conf(80, 3, 480, 112, use_se, hs_type, 1, 1),
303 bneck_conf(112, 3, 672, 112, use_se, hs_type, 1, 1),
304 bneck_conf(112, 5, 672, 160 // reduce_divider, use_se, hs_type, strides[4], dilation), # C4
305 bneck_conf(160 // reduce_divider, 5, 960 // reduce_divider, 160 // reduce_divider, use_se, hs_type, 1, dilation),
306 bneck_conf(160 // reduce_divider, 5, 960 // reduce_divider, 160 // reduce_divider, use_se, hs_type, 1, dilation),
307 ]
308 last_channel = adjust_channels(1280 // reduce_divider) # C5
309 elif arch in ("mobilenet_v3_small", "mobilenet_v3_lite_small"):
310 inverted_residual_setting = [
311 bneck_conf(16, 3, 16, 16, use_se, "RE", strides[1], 1), # C1
312 bneck_conf(16, 3, 72, 24, False, "RE", strides[2], 1), # C2
313 bneck_conf(24, 3, 88, 24, False, "RE", 1, 1),
314 bneck_conf(24, 5, 96, 40, use_se, hs_type, strides[3], 1), # C3
315 bneck_conf(40, 5, 240, 40, use_se, hs_type, 1, 1),
316 bneck_conf(40, 5, 240, 40, use_se, hs_type, 1, 1),
317 bneck_conf(40, 5, 120, 48, use_se, hs_type, 1, 1),
318 bneck_conf(48, 5, 144, 48, use_se, hs_type, 1, 1),
319 bneck_conf(48, 5, 288, 96 // reduce_divider, use_se, hs_type, strides[4], dilation), # C4
320 bneck_conf(96 // reduce_divider, 5, 576 // reduce_divider, 96 // reduce_divider, use_se, hs_type, 1, dilation),
321 bneck_conf(96 // reduce_divider, 5, 576 // reduce_divider, 96 // reduce_divider, use_se, hs_type, 1, dilation),
322 ]
323 last_channel = adjust_channels(1024 // reduce_divider) # C5
324 else:
325 raise ValueError("Unsupported model type {}".format(arch))
327 return inverted_residual_setting, last_channel
330 def _mobilenet_v3_model(
331 arch: str,
332 inverted_residual_setting: List[InvertedResidualConfig],
333 last_channel: int,
334 pretrained: bool,
335 progress: bool,
336 **kwargs: Any
337 ):
338 model = MobileNetV3(inverted_residual_setting, last_channel, **kwargs)
339 if pretrained:
340 if model_urls.get(arch, None) is None:
341 raise ValueError("No checkpoint is available for model type {}".format(arch))
342 state_dict = load_state_dict_from_url(model_urls[arch], progress=progress)
343 model.load_state_dict(state_dict)
344 return model
347 def mobilenet_v3_large(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> MobileNetV3:
348 """
349 Constructs a large MobileNetV3 architecture from
350 `"Searching for MobileNetV3" <https://arxiv.org/abs/1905.02244>`_.
352 Args:
353 pretrained (bool): If True, returns a model pre-trained on ImageNet
354 progress (bool): If True, displays a progress bar of the download to stderr
355 """
356 arch = "mobilenet_v3_large"
357 inverted_residual_setting, last_channel = _mobilenet_v3_conf(arch, **kwargs)
358 return _mobilenet_v3_model(arch, inverted_residual_setting, last_channel, pretrained, progress, **kwargs)
361 def mobilenet_v3_small(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> MobileNetV3:
362 """
363 Constructs a small MobileNetV3 architecture from
364 `"Searching for MobileNetV3" <https://arxiv.org/abs/1905.02244>`_.
366 Args:
367 pretrained (bool): If True, returns a model pre-trained on ImageNet
368 progress (bool): If True, displays a progress bar of the download to stderr
369 """
370 arch = "mobilenet_v3_small"
371 inverted_residual_setting, last_channel = _mobilenet_v3_conf(arch, **kwargs)
372 return _mobilenet_v3_model(arch, inverted_residual_setting, last_channel, pretrained, progress, **kwargs)