]> Gitweb @ Texas Instruments - Open Source Git Repositories - git.TI.com/gitweb - jacinto-ai/pytorch-jacinto-ai-devkit.git/blob - modules/pytorch_jacinto_ai/xvision/models/mobilenetv3.py
docs - added deprecation notice
[jacinto-ai/pytorch-jacinto-ai-devkit.git] / modules / pytorch_jacinto_ai / xvision / models / mobilenetv3.py
1 #################################################################################
2 # Copyright (c) 2018-2021, Texas Instruments Incorporated - http://www.ti.com
3 # All Rights Reserved.
4 #
5 # Redistribution and use in source and binary forms, with or without
6 # modification, are permitted provided that the following conditions are met:
7 #
8 # * Redistributions of source code must retain the above copyright notice, this
9 #   list of conditions and the following disclaimer.
10 #
11 # * Redistributions in binary form must reproduce the above copyright notice,
12 #   this list of conditions and the following disclaimer in the documentation
13 #   and/or other materials provided with the distribution.
14 #
15 # * Neither the name of the copyright holder nor the names of its
16 #   contributors may be used to endorse or promote products derived from
17 #   this software without specific prior written permission.
18 #
19 # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
20 # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
21 # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
22 # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
23 # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
24 # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
25 # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
26 # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
27 # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
28 # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29 #
30 #################################################################################
31 # Some parts of the code are borrowed from: https://github.com/pytorch/vision
32 # with the following license:
33 #
34 # BSD 3-Clause License
35 #
36 # Copyright (c) Soumith Chintala 2016,
37 # All rights reserved.
38 #
39 # Redistribution and use in source and binary forms, with or without
40 # modification, are permitted provided that the following conditions are met:
41 #
42 # * Redistributions of source code must retain the above copyright notice, this
43 #   list of conditions and the following disclaimer.
44 #
45 # * Redistributions in binary form must reproduce the above copyright notice,
46 #   this list of conditions and the following disclaimer in the documentation
47 #   and/or other materials provided with the distribution.
48 #
49 # * Neither the name of the copyright holder nor the names of its
50 #   contributors may be used to endorse or promote products derived from
51 #   this software without specific prior written permission.
52 #
53 # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
54 # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
55 # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
56 # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
57 # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
58 # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
59 # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
60 # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
61 # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
62 # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
63 #
64 #################################################################################
66 import torch
68 from functools import partial
69 from torch import nn, Tensor
70 from torch.nn import functional as F
71 from typing import Any, Callable, Dict, List, Optional, Sequence
73 from .utils import load_state_dict_from_url
74 from .mobilenetv2 import _make_divisible, ConvBNActivation
76 from ... import xnn
78 __all__ = ["MobileNetV3", "mobilenet_v3_large", "mobilenet_v3_small"]
81 model_urls = {
82     "mobilenet_v3_large": "https://download.pytorch.org/models/mobilenet_v3_large-8738ca79.pth",
83     "mobilenet_v3_small": "https://download.pytorch.org/models/mobilenet_v3_small-047dcff4.pth",
84 }
87 ###################################################
88 def get_config():
89     model_config = xnn.utils.ConfigNode()
90     model_config.input_channels = 3
91     model_config.num_classes = 1000
92     model_config.width_mult = 1.
93     model_config.strides = None #(2,2,2,2,2)
94     model_config.enable_fp16 = False
95     return model_config
98 class SqueezeExcitation(nn.Module):
100     def __init__(self, input_channels: int, squeeze_factor: int = 4):
101         super().__init__()
102         squeeze_channels = _make_divisible(input_channels // squeeze_factor, 8)
103         self.fc1 = nn.Conv2d(input_channels, squeeze_channels, 1)
104         self.relu = nn.ReLU(inplace=True)
105         self.fc2 = nn.Conv2d(squeeze_channels, input_channels, 1)
107     def _scale(self, input: Tensor, inplace: bool) -> Tensor:
108         scale = F.adaptive_avg_pool2d(input, 1)
109         scale = self.fc1(scale)
110         scale = self.relu(scale)
111         scale = self.fc2(scale)
112         return F.hardsigmoid(scale, inplace=inplace)
114     def forward(self, input: Tensor) -> Tensor:
115         scale = self._scale(input, True)
116         return scale * input
119 class InvertedResidualConfig:
121     def __init__(self, input_channels: int, kernel: int, expanded_channels: int, out_channels: int, use_se: bool,
122                  activation: str, stride: int, dilation: int, width_mult: float):
123         self.input_channels = self.adjust_channels(input_channels, width_mult)
124         self.kernel = kernel
125         self.expanded_channels = self.adjust_channels(expanded_channels, width_mult)
126         self.out_channels = self.adjust_channels(out_channels, width_mult)
127         self.use_se = use_se
128         self.use_hs = activation == "HS"
129         self.stride = stride
130         self.dilation = dilation
132     @staticmethod
133     def adjust_channels(channels: int, width_mult: float):
134         return _make_divisible(channels * width_mult, 8)
137 class InvertedResidual(nn.Module):
139     def __init__(self, cnf: InvertedResidualConfig, norm_layer: Callable[..., nn.Module],
140                  se_layer: Callable[..., nn.Module] = SqueezeExcitation):
141         super().__init__()
142         if not (1 <= cnf.stride <= 2):
143             raise ValueError('illegal stride value')
145         self.use_res_connect = cnf.stride == 1 and cnf.input_channels == cnf.out_channels
147         layers: List[nn.Module] = []
148         activation_layer = nn.Hardswish if cnf.use_hs else nn.ReLU
150         # expand
151         if cnf.expanded_channels != cnf.input_channels:
152             layers.append(ConvBNActivation(cnf.input_channels, cnf.expanded_channels, kernel_size=1,
153                                            norm_layer=norm_layer, activation_layer=activation_layer))
155         # depthwise
156         stride = 1 if cnf.dilation > 1 else cnf.stride
157         layers.append(ConvBNActivation(cnf.expanded_channels, cnf.expanded_channels, kernel_size=cnf.kernel,
158                                        stride=stride, dilation=cnf.dilation, groups=cnf.expanded_channels,
159                                        norm_layer=norm_layer, activation_layer=activation_layer))
160         if cnf.use_se:
161             layers.append(se_layer(cnf.expanded_channels))
163         # project
164         layers.append(ConvBNActivation(cnf.expanded_channels, cnf.out_channels, kernel_size=1, norm_layer=norm_layer,
165                                        activation_layer=nn.Identity))
167         self.block = nn.Sequential(*layers)
168         self.out_channels = cnf.out_channels
169         self._is_cn = cnf.stride > 1
171     def forward(self, input: Tensor) -> Tensor:
172         result = self.block(input)
173         if self.use_res_connect:
174             result += input
175         return result
178 class MobileNetV3(nn.Module):
180     def __init__(
181             self,
182             inverted_residual_setting: List[InvertedResidualConfig],
183             last_channel: int,
184             #num_classes: int = 1000,
185             block: Optional[Callable[..., nn.Module]] = None,
186             norm_layer: Optional[Callable[..., nn.Module]] = None,
187             activation_layer: Optional[Callable[..., nn.Module]] = nn.Hardswish,
188             **kwargs: Dict
189     ) -> None:
190         """
191         MobileNet V3 main class
193         Args:
194             inverted_residual_setting (List[InvertedResidualConfig]): Network structure
195             last_channel (int): The number of channels on the penultimate layer
196             num_classes (int): Number of classes
197             block (Optional[Callable[..., nn.Module]]): Module specifying inverted residual building block for mobilenet
198             norm_layer (Optional[Callable[..., nn.Module]]): Module specifying the normalization layer to use
199         """
200         model_config = get_config()
201         if 'model_config' in list(kwargs.keys()):
202             model_config = model_config.merge_from(kwargs['model_config'])
203         #
204         strides = model_config.strides if (model_config.strides is not None) else (2,2,2,2,2)
205         super().__init__()
206         self.num_classes = model_config.num_classes
207         self.enable_fp16 = model_config.enable_fp16
209         if not inverted_residual_setting:
210             raise ValueError("The inverted_residual_setting should not be empty")
211         elif not (isinstance(inverted_residual_setting, Sequence) and
212                   all([isinstance(s, InvertedResidualConfig) for s in inverted_residual_setting])):
213             raise TypeError("The inverted_residual_setting should be List[InvertedResidualConfig]")
215         if block is None:
216             block = InvertedResidual
218         if norm_layer is None:
219             norm_layer = partial(nn.BatchNorm2d, eps=0.001, momentum=0.01)
221         layers: List[nn.Module] = []
223         # building first layer
224         firstconv_output_channels = inverted_residual_setting[0].input_channels
225         layers.append(ConvBNActivation(model_config.input_channels, firstconv_output_channels, kernel_size=3, stride=strides[0], norm_layer=norm_layer,
226                                        activation_layer=activation_layer))
228         # building inverted residual blocks
229         for cnf in inverted_residual_setting:
230             layers.append(block(cnf, norm_layer))
232         # building last several layers
233         lastconv_input_channels = inverted_residual_setting[-1].out_channels
234         lastconv_output_channels = 6 * lastconv_input_channels
235         layers.append(ConvBNActivation(lastconv_input_channels, lastconv_output_channels, kernel_size=1,
236                                        norm_layer=norm_layer, activation_layer=activation_layer))
238         self.features = nn.Sequential(*layers)
239         self.avgpool = nn.AdaptiveAvgPool2d(1)
240         self.classifier = nn.Sequential(
241             nn.Linear(lastconv_output_channels, last_channel),
242             activation_layer(inplace=False),
243             nn.Dropout(p=0.2, inplace=True),
244             nn.Linear(last_channel, model_config.num_classes),
245         )
247         for m in self.modules():
248             if isinstance(m, nn.Conv2d):
249                 nn.init.kaiming_normal_(m.weight, mode='fan_out')
250                 if m.bias is not None:
251                     nn.init.zeros_(m.bias)
252             elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
253                 nn.init.ones_(m.weight)
254                 nn.init.zeros_(m.bias)
255             elif isinstance(m, nn.Linear):
256                 nn.init.normal_(m.weight, 0, 0.01)
257                 nn.init.zeros_(m.bias)
259     def _forward_impl(self, x: Tensor) -> Tensor:
260         x = self.features(x)
262         x = self.avgpool(x)
263         x = torch.flatten(x, 1)
265         x = self.classifier(x)
267         return x
269     @xnn.utils.auto_fp16
270     def forward(self, x: Tensor) -> Tensor:
271         return self._forward_impl(x)
274 def _mobilenet_v3_conf(arch: str, use_se: bool = True, hs_type: str='HS', **params: Dict[str, Any]):
275     # non-public config parameters
276     reduce_divider = 2 if params.pop('_reduced_tail', False) else 1
277     dilation = 2 if params.pop('_dilated', False) else 1
278     width_mult = params.pop('_width_mult', 1.0)
280     model_config = get_config()
281     if 'model_config' in list(params.keys()):
282         model_config = model_config.merge_from(params['model_config'])
283         width_mult = max(width_mult, model_config.width_mult)
284     #
285     strides = model_config.strides if (model_config.strides is not None) else (2,2,2,2,2)
287     bneck_conf = partial(InvertedResidualConfig, width_mult=width_mult)
288     adjust_channels = partial(InvertedResidualConfig.adjust_channels, width_mult=width_mult)
290     if arch in ("mobilenet_v3_large", "mobilenet_v3_lite_large"):
291         inverted_residual_setting = [
292             bneck_conf(16, 3, 16, 16, False, "RE", 1, 1),
293             bneck_conf(16, 3, 64, 24, False, "RE", strides[1], 1),  # C1
294             bneck_conf(24, 3, 72, 24, False, "RE", 1, 1),
295             bneck_conf(24, 5, 72, 40, use_se, "RE", strides[2], 1),  # C2
296             bneck_conf(40, 5, 120, 40, use_se, "RE", 1, 1),
297             bneck_conf(40, 5, 120, 40, use_se, "RE", 1, 1),
298             bneck_conf(40, 3, 240, 80, False, hs_type, strides[3], 1),  # C3
299             bneck_conf(80, 3, 200, 80, False, hs_type, 1, 1),
300             bneck_conf(80, 3, 184, 80, False, hs_type, 1, 1),
301             bneck_conf(80, 3, 184, 80, False, hs_type, 1, 1),
302             bneck_conf(80, 3, 480, 112, use_se, hs_type, 1, 1),
303             bneck_conf(112, 3, 672, 112, use_se, hs_type, 1, 1),
304             bneck_conf(112, 5, 672, 160 // reduce_divider, use_se, hs_type, strides[4], dilation),  # C4
305             bneck_conf(160 // reduce_divider, 5, 960 // reduce_divider, 160 // reduce_divider, use_se, hs_type, 1, dilation),
306             bneck_conf(160 // reduce_divider, 5, 960 // reduce_divider, 160 // reduce_divider, use_se, hs_type, 1, dilation),
307         ]
308         last_channel = adjust_channels(1280 // reduce_divider)  # C5
309     elif arch in ("mobilenet_v3_small", "mobilenet_v3_lite_small"):
310         inverted_residual_setting = [
311             bneck_conf(16, 3, 16, 16, use_se, "RE", strides[1], 1),  # C1
312             bneck_conf(16, 3, 72, 24, False, "RE", strides[2], 1),  # C2
313             bneck_conf(24, 3, 88, 24, False, "RE", 1, 1),
314             bneck_conf(24, 5, 96, 40, use_se, hs_type, strides[3], 1),  # C3
315             bneck_conf(40, 5, 240, 40, use_se, hs_type, 1, 1),
316             bneck_conf(40, 5, 240, 40, use_se, hs_type, 1, 1),
317             bneck_conf(40, 5, 120, 48, use_se, hs_type, 1, 1),
318             bneck_conf(48, 5, 144, 48, use_se, hs_type, 1, 1),
319             bneck_conf(48, 5, 288, 96 // reduce_divider, use_se, hs_type, strides[4], dilation),  # C4
320             bneck_conf(96 // reduce_divider, 5, 576 // reduce_divider, 96 // reduce_divider, use_se, hs_type, 1, dilation),
321             bneck_conf(96 // reduce_divider, 5, 576 // reduce_divider, 96 // reduce_divider, use_se, hs_type, 1, dilation),
322         ]
323         last_channel = adjust_channels(1024 // reduce_divider)  # C5
324     else:
325         raise ValueError("Unsupported model type {}".format(arch))
327     return inverted_residual_setting, last_channel
330 def _mobilenet_v3_model(
331     arch: str,
332     inverted_residual_setting: List[InvertedResidualConfig],
333     last_channel: int,
334     pretrained: bool,
335     progress: bool,
336     **kwargs: Any
337 ):
338     model = MobileNetV3(inverted_residual_setting, last_channel, **kwargs)
339     if pretrained:
340         if model_urls.get(arch, None) is None:
341             raise ValueError("No checkpoint is available for model type {}".format(arch))
342         state_dict = load_state_dict_from_url(model_urls[arch], progress=progress)
343         model.load_state_dict(state_dict)
344     return model
347 def mobilenet_v3_large(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> MobileNetV3:
348     """
349     Constructs a large MobileNetV3 architecture from
350     `"Searching for MobileNetV3" <https://arxiv.org/abs/1905.02244>`_.
352     Args:
353         pretrained (bool): If True, returns a model pre-trained on ImageNet
354         progress (bool): If True, displays a progress bar of the download to stderr
355     """
356     arch = "mobilenet_v3_large"
357     inverted_residual_setting, last_channel = _mobilenet_v3_conf(arch, **kwargs)
358     return _mobilenet_v3_model(arch, inverted_residual_setting, last_channel, pretrained, progress, **kwargs)
361 def mobilenet_v3_small(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> MobileNetV3:
362     """
363     Constructs a small MobileNetV3 architecture from
364     `"Searching for MobileNetV3" <https://arxiv.org/abs/1905.02244>`_.
366     Args:
367         pretrained (bool): If True, returns a model pre-trained on ImageNet
368         progress (bool): If True, displays a progress bar of the download to stderr
369     """
370     arch = "mobilenet_v3_small"
371     inverted_residual_setting, last_channel = _mobilenet_v3_conf(arch, **kwargs)
372     return _mobilenet_v3_model(arch, inverted_residual_setting, last_channel, pretrained, progress, **kwargs)