release commit
[jacinto-ai/pytorch-jacinto-ai-devkit.git] / modules / pytorch_jacinto_ai / vision / models / shufflenetv2.py
1 from collections import OrderedDict
2 import torch
3 import torch.nn as nn
4 from .utils import load_state_dict_from_url
5 from ... import xnn
7 __all__ = [
8     'ShuffleNetV2', 'shufflenet_v2_x0_5', 'shufflenet_v2_x1_0',
9     'shufflenet_v2_x1_5', 'shufflenet_v2_x2_0'
10 ]
12 model_urls = {
13     'shufflenetv2_x0.5': 'https://download.pytorch.org/models/shufflenetv2_x0.5-f707e7126e.pth',
14     'shufflenetv2_x1.0': 'https://download.pytorch.org/models/shufflenetv2_x1-5666bf0f80.pth',
15     'shufflenetv2_x1.5': None,
16     'shufflenetv2_x2.0': None,
17 }
20 def channel_shuffle(x, groups):
21     batchsize, num_channels, height, width = x.data.size()
22     channels_per_group = num_channels // groups
24     # reshape
25     x = x.view(batchsize, groups,
26                channels_per_group, height, width)
28     x = torch.transpose(x, 1, 2).contiguous()
30     # flatten
31     x = x.view(batchsize, -1, height, width)
33     return x
36 class InvertedResidual(nn.Module):
37     def __init__(self, inp, oup, stride):
38         super(InvertedResidual, self).__init__()
40         if not (1 <= stride <= 3):
41             raise ValueError('illegal stride value')
42         self.stride = stride
44         branch_features = oup // 2
45         assert (self.stride != 1) or (inp == branch_features << 1)
47         if self.stride > 1:
48             self.branch1 = nn.Sequential(
49                 self.depthwise_conv(inp, inp, kernel_size=3, stride=self.stride, padding=1),
50                 nn.BatchNorm2d(inp),
51                 nn.Conv2d(inp, branch_features, kernel_size=1, stride=1, padding=0, bias=False),
52                 nn.BatchNorm2d(branch_features),
53                 nn.ReLU(inplace=True),
54             )
56         self.branch2 = nn.Sequential(
57             nn.Conv2d(inp if (self.stride > 1) else branch_features,
58                       branch_features, kernel_size=1, stride=1, padding=0, bias=False),
59             nn.BatchNorm2d(branch_features),
60             nn.ReLU(inplace=True),
61             self.depthwise_conv(branch_features, branch_features, kernel_size=3, stride=self.stride, padding=1),
62             nn.BatchNorm2d(branch_features),
63             nn.Conv2d(branch_features, branch_features, kernel_size=1, stride=1, padding=0, bias=False),
64             nn.BatchNorm2d(branch_features),
65             nn.ReLU(inplace=True),
66         )
68     @staticmethod
69     def depthwise_conv(i, o, kernel_size, stride=1, padding=0, bias=False):
70         return nn.Conv2d(i, o, kernel_size, stride, padding, bias=bias, groups=i)
72     def forward(self, x):
73         if self.stride == 1:
74             x1, x2 = x.chunk(2, dim=1)
75             out = torch.cat((x1, self.branch2(x2)), dim=1)
76         else:
77             out = torch.cat((self.branch1(x), self.branch2(x)), dim=1)
79         out = channel_shuffle(out, 2)
81         return out
84 class ShuffleNetV2(nn.Module):
85     def __init__(self, stages_repeats, stages_out_channels, num_classes=1000):
86         super(ShuffleNetV2, self).__init__()
87         self.num_classes = num_classes
89         if len(stages_repeats) != 3:
90             raise ValueError('expected stages_repeats as list of 3 positive ints')
91         if len(stages_out_channels) != 5:
92             raise ValueError('expected stages_out_channels as list of 5 positive ints')
93         self._stage_out_channels = stages_out_channels
95         input_channels = 3
96         output_channels = self._stage_out_channels[0]
97         conv1 = nn.Sequential(
98             nn.Conv2d(input_channels, output_channels, 3, 2, 1, bias=False),
99             nn.BatchNorm2d(output_channels),
100             nn.ReLU(inplace=True),
101         )
102         input_channels = output_channels
104         maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
105         layers = [('conv1',conv1), ('maxpool',maxpool)]
107         stage_names = ['stage{}'.format(i) for i in [2, 3, 4]]
108         for name, repeats, output_channels in zip(
109                 stage_names, stages_repeats, self._stage_out_channels[1:]):
110             seq = [InvertedResidual(input_channels, output_channels, 2)]
111             for i in range(repeats - 1):
112                 seq.append(InvertedResidual(output_channels, output_channels, 1))
113             layers += [(name,nn.Sequential(*seq))]
114             input_channels = output_channels
116         output_channels = self._stage_out_channels[-1]
117         conv5 = nn.Sequential(
118             nn.Conv2d(input_channels, output_channels, 1, 1, 0, bias=False),
119             nn.BatchNorm2d(output_channels),
120             nn.ReLU(inplace=True),
121         )
122         layers += [('conv5',conv5)]
123         self.features = torch.nn.Sequential(OrderedDict(layers))
125         if self.num_classes is not None:
126             self.classifier = nn.Linear(output_channels, num_classes)
128         # weights init
129         xnn.utils.module_weights_init(self)
132     def forward(self, x):
133         x = self.features(x)
134         if self.num_classes is not None:
135             x = torch.nn.functional.adaptive_avg_pool2d(x,(1,1))
136             x = torch.flatten(x, 1)
137             x = self.classifier(x)
138         return x
141 def _shufflenetv2(arch, pretrained, progress, *args, **kwargs):
142     model = ShuffleNetV2(*args, **kwargs)
144     if pretrained:
145         model_url = model_urls[arch]
146         if model_url is None:
147             raise NotImplementedError('pretrained {} is not supported as of now'.format(arch))
148         else:
149             state_dict = load_state_dict_from_url(model_url, progress=progress)
150             # the pretrained model provided by torchvision and what is defined here differs slightly
151             # note: that this change_names_dict  will take effect only if the direct load fails
152             change_names_dict = {'^conv': 'features.conv', '^maxpool.': 'features.maxpool.',
153                                  '^stage': 'features.stage', '^fc.': 'classifier.'}
154             model = xnn.utils.load_weights(model, state_dict, change_names_dict=change_names_dict)
156     return model
159 def shufflenet_v2_x0_5(pretrained=False, progress=True, **kwargs):
160     """
161     Constructs a ShuffleNetV2 with 0.5x output channels, as described in
162     `"ShuffleNet V2: Practical Guidelines for Efficient CNN Architecture Design"
163     <https://arxiv.org/abs/1807.11164>`_.
165     Args:
166         pretrained (bool): If True, returns a model pre-trained on ImageNet
167         progress (bool): If True, displays a progress bar of the download to stderr
168     """
169     return _shufflenetv2('shufflenetv2_x0.5', pretrained, progress,
170                          [4, 8, 4], [24, 48, 96, 192, 1024], **kwargs)
173 def shufflenet_v2_x1_0(pretrained=False, progress=True, **kwargs):
174     """
175     Constructs a ShuffleNetV2 with 1.0x output channels, as described in
176     `"ShuffleNet V2: Practical Guidelines for Efficient CNN Architecture Design"
177     <https://arxiv.org/abs/1807.11164>`_.
179     Args:
180         pretrained (bool): If True, returns a model pre-trained on ImageNet
181         progress (bool): If True, displays a progress bar of the download to stderr
182     """
183     return _shufflenetv2('shufflenetv2_x1.0', pretrained, progress,
184                          [4, 8, 4], [24, 116, 232, 464, 1024], **kwargs)
187 def shufflenet_v2_x1_5(pretrained=False, progress=True, **kwargs):
188     """
189     Constructs a ShuffleNetV2 with 1.5x output channels, as described in
190     `"ShuffleNet V2: Practical Guidelines for Efficient CNN Architecture Design"
191     <https://arxiv.org/abs/1807.11164>`_.
193     Args:
194         pretrained (bool): If True, returns a model pre-trained on ImageNet
195         progress (bool): If True, displays a progress bar of the download to stderr
196     """
197     return _shufflenetv2('shufflenetv2_x1.5', pretrained, progress,
198                          [4, 8, 4], [24, 176, 352, 704, 1024], **kwargs)
201 def shufflenet_v2_x2_0(pretrained=False, progress=True, **kwargs):
202     """
203     Constructs a ShuffleNetV2 with 2.0x output channels, as described in
204     `"ShuffleNet V2: Practical Guidelines for Efficient CNN Architecture Design"
205     <https://arxiv.org/abs/1807.11164>`_.
207     Args:
208         pretrained (bool): If True, returns a model pre-trained on ImageNet
209         progress (bool): If True, displays a progress bar of the download to stderr
210     """
211     return _shufflenetv2('shufflenetv2_x2.0', pretrained, progress,
212                          [4, 8, 4], [24, 244, 488, 976, 2048], **kwargs)