f4f4026fffb3d47d93036696a633214cf730e17f
[jacinto-ai/pytorch-jacinto-ai-devkit.git] / modules / pytorch_jacinto_ai / vision / models / resnet.py
1 import torch
2 import torch.nn as nn
3 import collections
4 from .utils import load_state_dict_from_url
5 from ... import xnn
8 __all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101',
9            'resnet152', 'resnext50_32x4d', 'resnext101_32x8d',
10            'wide_resnet50_2', 'wide_resnet101_2',
11            'resnet50_with_model_config']
14 model_urls = {
15     'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth',
16     'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth',
17     'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth',
18     'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth',
19     'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth',
20     'resnext50_32x4d': 'https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth',
21     'resnext101_32x8d': 'https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth',
22     'wide_resnet50_2': 'https://download.pytorch.org/models/wide_resnet50_2-95faca4d.pth',
23     'wide_resnet101_2': 'https://download.pytorch.org/models/wide_resnet101_2-32ee1156.pth',
24 }
27 def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):
28     """3x3 convolution with padding"""
29     return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
30                      padding=dilation, groups=groups, bias=False, dilation=dilation)
33 def conv1x1(in_planes, out_planes, stride=1):
34     """1x1 convolution"""
35     return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
38 class BasicBlock(nn.Module):
39     expansion = 1
41     def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
42                  base_width=64, dilation=1, norm_layer=None):
43         super(BasicBlock, self).__init__()
44         if norm_layer is None:
45             norm_layer = nn.BatchNorm2d
46         if groups != 1 or base_width != 64:
47             raise ValueError('BasicBlock only supports groups=1 and base_width=64')
48         if dilation > 1:
49             raise NotImplementedError("Dilation > 1 not supported in BasicBlock")
50         # Both self.conv1 and self.downsample layers downsample the input when stride != 1
51         self.conv1 = conv3x3(inplanes, planes, stride)
52         self.bn1 = norm_layer(planes)
53         self.relu1 = nn.ReLU(inplace=True)
54         self.conv2 = conv3x3(planes, planes)
55         self.bn2 = norm_layer(planes)
56         self.downsample = downsample
57         self.add = xnn.layers.AddBlock()
58         self.relu2 = nn.ReLU(inplace=True)
59         self.stride = stride
61     def forward(self, x):
62         identity = x
64         out = self.conv1(x)
65         out = self.bn1(out)
66         out = self.relu1(out)
68         out = self.conv2(out)
69         out = self.bn2(out)
71         if self.downsample is not None:
72             identity = self.downsample(x)
74         out = self.add((out,identity))
75         out = self.relu2(out)
77         return out
80 class Bottleneck(nn.Module):
81     expansion = 4
83     def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
84                  base_width=64, dilation=1, norm_layer=None):
85         super(Bottleneck, self).__init__()
86         if norm_layer is None:
87             norm_layer = nn.BatchNorm2d
88         width = int(planes * (base_width / 64.)) * groups
89         # Both self.conv2 and self.downsample layers downsample the input when stride != 1
90         self.conv1 = conv1x1(inplanes, width)
91         self.relu1 = nn.ReLU(inplace=True)
92         self.bn1 = norm_layer(width)
93         self.conv2 = conv3x3(width, width, stride, groups, dilation)
94         self.bn2 = norm_layer(width)
95         self.relu2 = nn.ReLU(inplace=True)
96         self.conv3 = conv1x1(width, planes * self.expansion)
97         self.bn3 = norm_layer(planes * self.expansion)
98         self.downsample = downsample
99         self.add = xnn.layers.AddBlock()
100         self.relu3 = nn.ReLU(inplace=True)
101         self.stride = stride
103     def forward(self, x):
104         identity = x
106         out = self.conv1(x)
107         out = self.bn1(out)
108         out = self.relu1(out)
110         out = self.conv2(out)
111         out = self.bn2(out)
112         out = self.relu2(out)
114         out = self.conv3(out)
115         out = self.bn3(out)
117         if self.downsample is not None:
118             identity = self.downsample(x)
120         out = self.add((out,identity))
121         out = self.relu3(out)
123         return out
126 class ResNet(nn.Module):
128     def __init__(self, block, layers, num_classes=1000, zero_init_residual=False,
129                  groups=1, width_per_group=64, replace_stride_with_dilation=None,
130                  norm_layer=None, input_channels=3, strides=None, width_mult=1.0, fastdown=False):
131         super(ResNet, self).__init__()
132         if norm_layer is None:
133             norm_layer = nn.BatchNorm2d
134         self._norm_layer = norm_layer
135         self.num_classes = num_classes
137         self.inplanes = int(64 * width_mult)
138         self.dilation = 1
139         if replace_stride_with_dilation is None:
140             # each element in the tuple indicates if we should replace
141             # the 2x2 stride with a dilated convolution instead
142             replace_stride_with_dilation = [False, False, False]
143         if len(replace_stride_with_dilation) != 3:
144             raise ValueError(f"replace_stride_with_dilation should be None or a 3-element tuple, got {replace_stride_with_dilation}")
146         # strides of various layers
147         strides = strides if (strides is not None) else (2,2,2,2,2)
148         s0 = strides[0]
149         s1 = strides[1]
150         sf = 2 if fastdown else 1 # additional stride if fast down is true
151         s2 = strides[2]
152         s3 = strides[3]
153         s4 = strides[4]
155         self.groups = groups
156         self.base_width = width_per_group
157         conv1 = nn.Conv2d(input_channels, self.inplanes, kernel_size=7, stride=s0, padding=3, bias=False)
158         bn1 = norm_layer(self.inplanes)
159         relu = nn.ReLU(inplace=True)
160         maxpool = nn.MaxPool2d(kernel_size=3, stride=s1, padding=1)
161         features = [('conv1',conv1), ('bn1',bn1), ('relu',relu), ('maxpool',maxpool)]
163         layer1 = self._make_layer(block, int(64*width_mult), layers[0], stride=sf)
164         layer2 = self._make_layer(block, int(128*width_mult), layers[1], stride=s2,
165                                        dilate=replace_stride_with_dilation[0])
166         layer3 = self._make_layer(block, int(256*width_mult), layers[2], stride=s3,
167                                        dilate=replace_stride_with_dilation[1])
168         layer4 = self._make_layer(block, int(512*width_mult), layers[3], stride=s4,
169                                        dilate=replace_stride_with_dilation[2])
170         features += [('layer1',layer1), ('layer2',layer2), ('layer3',layer3), ('layer4',layer4)]
171         self.features = torch.nn.Sequential(collections.OrderedDict(features))
173         if self.num_classes:
174             self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
175             self.classifier = nn.Linear(int(512*width_mult) * block.expansion, num_classes)
177         xnn.utils.module_weights_init(self)
179         # Zero-initialize the last BN in each residual branch,
180         # so that the residual branch starts with zeros, and each residual block behaves like an identity.
181         # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677
182         if zero_init_residual:
183             for m in self.modules():
184                 if isinstance(m, Bottleneck):
185                     nn.init.constant_(m.bn3.weight, 0)
186                 elif isinstance(m, BasicBlock):
187                     nn.init.constant_(m.bn2.weight, 0)
189     def _make_layer(self, block, planes, blocks, stride=1, dilate=False):
190         norm_layer = self._norm_layer
191         downsample = None
192         previous_dilation = self.dilation
193         if dilate:
194             self.dilation *= stride
195             stride = 1
196         if stride != 1 or self.inplanes != planes * block.expansion:
197             downsample = nn.Sequential(
198                 conv1x1(self.inplanes, planes * block.expansion, stride),
199                 norm_layer(planes * block.expansion),
200             )
202         layers = []
203         layers.append(block(self.inplanes, planes, stride, downsample, self.groups,
204                             self.base_width, previous_dilation, norm_layer))
205         self.inplanes = planes * block.expansion
206         for _ in range(1, blocks):
207             layers.append(block(self.inplanes, planes, groups=self.groups,
208                                 base_width=self.base_width, dilation=self.dilation,
209                                 norm_layer=norm_layer))
211         return nn.Sequential(*layers)
213     def forward(self, x):
214         x = self.features(x)
215         if self.num_classes:
216             x = self.avgpool(x)
217             x = torch.flatten(x, 1)
218             x = self.classifier(x)
220         return x
223     # define a load weights fuinction in the module since the module is changed w.r.t. to torchvision
224     # since we want to be able to laod the existing torchvision pretrained weights
225     def load_weights(self, pretrained, change_names_dict=None, download_root=None):
226         if change_names_dict is None:
227             # the pretrained model provided by torchvision and what is defined here differs slightly
228             # note: that this change_names_dict  will take effect only if the direct load fails
229             change_names_dict = {'^conv1.':'features.conv1.', '^bn1.':'features.bn1.',
230                                  '^relu.':'features.relu.', '^maxpool.':'features.maxpool.',
231                                  '^layer':'features.layer' , '^fc.':'classifier.'}
232         #
233         if pretrained is not None:
234             xnn.utils.load_weights(self, pretrained, change_names_dict=change_names_dict, download_root=download_root)
235         return self, change_names_dict
238 def _resnet(arch, block, layers, pretrained, progress, **kwargs):
239     model = ResNet(block, layers, **kwargs)
240     if pretrained is True:
241         change_names_dict = kwargs.get('change_names_dict', None)
242         state_dict = load_state_dict_from_url(model_urls[arch],
243                                               progress=progress)
244         model.load_weights(state_dict, change_names_dict=change_names_dict)
245     elif pretrained:
246         change_names_dict = kwargs.get('change_names_dict', None)
247         download_root = kwargs.get('download_root', None)
248         model.load_weights(pretrained, change_names_dict=change_names_dict, download_root=download_root)
249     return model
252 def resnet18(pretrained=False, progress=True, **kwargs):
253     r"""ResNet-18 model from
254     `"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_
256     Args:
257         pretrained (bool): If True, returns a model pre-trained on ImageNet
258         progress (bool): If True, displays a progress bar of the download to stderr
259     """
260     return _resnet('resnet18', BasicBlock, [2, 2, 2, 2], pretrained, progress,
261                    **kwargs)
264 def resnet34(pretrained=False, progress=True, **kwargs):
265     r"""ResNet-34 model from
266     `"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_
268     Args:
269         pretrained (bool): If True, returns a model pre-trained on ImageNet
270         progress (bool): If True, displays a progress bar of the download to stderr
271     """
272     return _resnet('resnet34', BasicBlock, [3, 4, 6, 3], pretrained, progress,
273                    **kwargs)
276 def resnet50(pretrained=False, progress=True, **kwargs):
277     r"""ResNet-50 model from
278     `"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_
280     Args:
281         pretrained (bool): If True, returns a model pre-trained on ImageNet
282         progress (bool): If True, displays a progress bar of the download to stderr
283     """
284     return _resnet('resnet50', Bottleneck, [3, 4, 6, 3], pretrained, progress,
285                    **kwargs)
288 def resnet101(pretrained=False, progress=True, **kwargs):
289     r"""ResNet-101 model from
290     `"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_
292     Args:
293         pretrained (bool): If True, returns a model pre-trained on ImageNet
294         progress (bool): If True, displays a progress bar of the download to stderr
295     """
296     return _resnet('resnet101', Bottleneck, [3, 4, 23, 3], pretrained, progress,
297                    **kwargs)
300 def resnet152(pretrained=False, progress=True, **kwargs):
301     r"""ResNet-152 model from
302     `"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_
304     Args:
305         pretrained (bool): If True, returns a model pre-trained on ImageNet
306         progress (bool): If True, displays a progress bar of the download to stderr
307     """
308     return _resnet('resnet152', Bottleneck, [3, 8, 36, 3], pretrained, progress,
309                    **kwargs)
312 def resnext50_32x4d(pretrained=False, progress=True, **kwargs):
313     r"""ResNeXt-50 32x4d model from
314     `"Aggregated Residual Transformation for Deep Neural Networks" <https://arxiv.org/pdf/1611.05431.pdf>`_
316     Args:
317         pretrained (bool): If True, returns a model pre-trained on ImageNet
318         progress (bool): If True, displays a progress bar of the download to stderr
319     """
320     kwargs['groups'] = 32
321     kwargs['width_per_group'] = 4
322     return _resnet('resnext50_32x4d', Bottleneck, [3, 4, 6, 3],
323                    pretrained, progress, **kwargs)
326 def resnext101_32x8d(pretrained=False, progress=True, **kwargs):
327     r"""ResNeXt-101 32x8d model from
328     `"Aggregated Residual Transformation for Deep Neural Networks" <https://arxiv.org/pdf/1611.05431.pdf>`_
330     Args:
331         pretrained (bool): If True, returns a model pre-trained on ImageNet
332         progress (bool): If True, displays a progress bar of the download to stderr
333     """
334     kwargs['groups'] = 32
335     kwargs['width_per_group'] = 8
336     return _resnet('resnext101_32x8d', Bottleneck, [3, 4, 23, 3],
337                    pretrained, progress, **kwargs)
340 def wide_resnet50_2(pretrained=False, progress=True, **kwargs):
341     r"""Wide ResNet-50-2 model from
342     `"Wide Residual Networks" <https://arxiv.org/pdf/1605.07146.pdf>`_
344     The model is the same as ResNet except for the bottleneck number of channels
345     which is twice larger in every block. The number of channels in outer 1x1
346     convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048
347     channels, and in Wide ResNet-50-2 has 2048-1024-2048.
349     Args:
350         pretrained (bool): If True, returns a model pre-trained on ImageNet
351         progress (bool): If True, displays a progress bar of the download to stderr
352     """
353     kwargs['width_per_group'] = 64 * 2
354     return _resnet('wide_resnet50_2', Bottleneck, [3, 4, 6, 3],
355                    pretrained, progress, **kwargs)
358 def wide_resnet101_2(pretrained=False, progress=True, **kwargs):
359     r"""Wide ResNet-101-2 model from
360     `"Wide Residual Networks" <https://arxiv.org/pdf/1605.07146.pdf>`_
362     The model is the same as ResNet except for the bottleneck number of channels
363     which is twice larger in every block. The number of channels in outer 1x1
364     convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048
365     channels, and in Wide ResNet-50-2 has 2048-1024-2048.
367     Args:
368         pretrained (bool): If True, returns a model pre-trained on ImageNet
369         progress (bool): If True, displays a progress bar of the download to stderr
370     """
371     kwargs['width_per_group'] = 64 * 2
372     return _resnet('wide_resnet101_2', Bottleneck, [3, 4, 23, 3],
373                    pretrained, progress, **kwargs)
376 ###################################################
377 def get_config():
378     model_config = xnn.utils.ConfigNode()
379     model_config.input_channels = 3
380     model_config.num_classes = 1000
381     model_config.strides = None #(2,2,2,2,2)
382     model_config.fastdown = False
383     return model_config
386 def resnet50_with_model_config(model_config, pretrained=None):
387     model_config = get_config().merge_from(model_config)
388     model = resnet50(input_channels=model_config.input_channels, strides=model_config.strides,
389                      num_classes=model_config.num_classes, pretrained=pretrained,
390                      width_mult=model_config.width_mult, fastdown=model_config.fastdown)
391     return model