788fd8b46a51a47ea97cbd38734feb78f68e5457
[jacinto-ai/pytorch-jacinto-ai-devkit.git] / modules / pytorch_jacinto_ai / vision / models / resnet.py
1 import torch
2 import torch.nn as nn
3 import collections
4 from .utils import load_state_dict_from_url
5 from ... import xnn
8 __all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101',
9            'resnet152', 'resnext50_32x4d', 'resnext101_32x8d',
10            'wide_resnet50_2', 'wide_resnet101_2',
11            'resnet50_with_model_config']
14 model_urls = {
15     'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth',
16     'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth',
17     'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth',
18     'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth',
19     'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth',
20     'resnext50_32x4d': 'https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth',
21     'resnext101_32x8d': 'https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth',
22     'wide_resnet50_2': 'https://download.pytorch.org/models/wide_resnet50_2-95faca4d.pth',
23     'wide_resnet101_2': 'https://download.pytorch.org/models/wide_resnet101_2-32ee1156.pth',
24 }
27 def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):
28     """3x3 convolution with padding"""
29     return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
30                      padding=dilation, groups=groups, bias=False, dilation=dilation)
33 def conv1x1(in_planes, out_planes, stride=1):
34     """1x1 convolution"""
35     return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
38 class BasicBlock(nn.Module):
39     expansion = 1
41     def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
42                  base_width=64, dilation=1, norm_layer=None):
43         super(BasicBlock, self).__init__()
44         if norm_layer is None:
45             norm_layer = nn.BatchNorm2d
46         if groups != 1 or base_width != 64:
47             raise ValueError('BasicBlock only supports groups=1 and base_width=64')
48         if dilation > 1:
49             raise NotImplementedError("Dilation > 1 not supported in BasicBlock")
50         # Both self.conv1 and self.downsample layers downsample the input when stride != 1
51         self.conv1 = conv3x3(inplanes, planes, stride)
52         self.bn1 = norm_layer(planes)
53         self.relu1 = nn.ReLU(inplace=True)
54         self.conv2 = conv3x3(planes, planes)
55         self.bn2 = norm_layer(planes)
56         self.downsample = downsample
57         self.add = xnn.layers.AddBlock()
58         self.relu2 = nn.ReLU(inplace=True)
59         self.stride = stride
61     def forward(self, x):
62         identity = x
64         out = self.conv1(x)
65         out = self.bn1(out)
66         out = self.relu1(out)
68         out = self.conv2(out)
69         out = self.bn2(out)
71         if self.downsample is not None:
72             identity = self.downsample(x)
74         out = self.add((out,identity))
75         out = self.relu2(out)
77         return out
80 class Bottleneck(nn.Module):
81     expansion = 4
83     def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
84                  base_width=64, dilation=1, norm_layer=None):
85         super(Bottleneck, self).__init__()
86         if norm_layer is None:
87             norm_layer = nn.BatchNorm2d
88         width = int(planes * (base_width / 64.)) * groups
89         # Both self.conv2 and self.downsample layers downsample the input when stride != 1
90         self.conv1 = conv1x1(inplanes, width)
91         self.relu1 = nn.ReLU(inplace=True)
92         self.bn1 = norm_layer(width)
93         self.conv2 = conv3x3(width, width, stride, groups, dilation)
94         self.bn2 = norm_layer(width)
95         self.relu2 = nn.ReLU(inplace=True)
96         self.conv3 = conv1x1(width, planes * self.expansion)
97         self.bn3 = norm_layer(planes * self.expansion)
98         self.downsample = downsample
99         self.add = xnn.layers.AddBlock()
100         self.relu3 = nn.ReLU(inplace=True)
101         self.stride = stride
103     def forward(self, x):
104         identity = x
106         out = self.conv1(x)
107         out = self.bn1(out)
108         out = self.relu1(out)
110         out = self.conv2(out)
111         out = self.bn2(out)
112         out = self.relu2(out)
114         out = self.conv3(out)
115         out = self.bn3(out)
117         if self.downsample is not None:
118             identity = self.downsample(x)
120         out = self.add((out,identity))
121         out = self.relu3(out)
123         return out
126 class ResNet(nn.Module):
128     def __init__(self, block, layers, num_classes=1000, zero_init_residual=False,
129                  groups=1, width_per_group=64, replace_stride_with_dilation=None,
130                  norm_layer=None, input_channels=3, strides=None, width_mult=1.0, fastdown=False):
131         super(ResNet, self).__init__()
132         if norm_layer is None:
133             norm_layer = nn.BatchNorm2d
134         self._norm_layer = norm_layer
135         self.num_classes = num_classes
137         self.inplanes = int(64 * width_mult)
138         self.dilation = 1
139         if replace_stride_with_dilation is None:
140             # each element in the tuple indicates if we should replace
141             # the 2x2 stride with a dilated convolution instead
142             replace_stride_with_dilation = [False, False, False]
143         if len(replace_stride_with_dilation) != 3:
144             raise ValueError(f"replace_stride_with_dilation should be None or a 3-element tuple, got {replace_stride_with_dilation}")
146         # strides of various layers
147         strides = strides if (strides is not None) else (2,2,2,2,2)
148         s0 = strides[0]
149         s1 = strides[1]
150         sf = 2 if fastdown else 1 # additional stride if fast down is true
151         s2 = strides[2]
152         s3 = strides[3]
153         s4 = strides[4]
155         self.groups = groups
156         self.base_width = width_per_group
157         conv1 = nn.Conv2d(input_channels, self.inplanes, kernel_size=7, stride=s0, padding=3, bias=False)
158         bn1 = norm_layer(self.inplanes)
159         relu = nn.ReLU(inplace=True)
160         maxpool = nn.MaxPool2d(kernel_size=3, stride=s1, padding=1)
161         features = [('conv1',conv1), ('bn1',bn1), ('relu',relu), ('maxpool',maxpool)]
163         layer1 = self._make_layer(block, int(64*width_mult), layers[0], stride=sf)
164         layer2 = self._make_layer(block, int(128*width_mult), layers[1], stride=s2,
165                                        dilate=replace_stride_with_dilation[0])
166         layer3 = self._make_layer(block, int(256*width_mult), layers[2], stride=s3,
167                                        dilate=replace_stride_with_dilation[1])
168         layer4 = self._make_layer(block, int(512*width_mult), layers[3], stride=s4,
169                                        dilate=replace_stride_with_dilation[2])
170         features += [('layer1',layer1), ('layer2',layer2), ('layer3',layer3), ('layer4',layer4)]
171         self.features = torch.nn.Sequential(collections.OrderedDict(features))
173         if self.num_classes:
174             self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
175             self.classifier = nn.Linear(int(512*width_mult) * block.expansion, num_classes)
177         xnn.utils.module_weights_init(self)
179         # Zero-initialize the last BN in each residual branch,
180         # so that the residual branch starts with zeros, and each residual block behaves like an identity.
181         # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677
182         if zero_init_residual:
183             for m in self.modules():
184                 if isinstance(m, Bottleneck):
185                     nn.init.constant_(m.bn3.weight, 0)
186                 elif isinstance(m, BasicBlock):
187                     nn.init.constant_(m.bn2.weight, 0)
189     def _make_layer(self, block, planes, blocks, stride=1, dilate=False):
190         norm_layer = self._norm_layer
191         downsample = None
192         previous_dilation = self.dilation
193         if dilate:
194             self.dilation *= stride
195             stride = 1
196         if stride != 1 or self.inplanes != planes * block.expansion:
197             downsample = nn.Sequential(
198                 conv1x1(self.inplanes, planes * block.expansion, stride),
199                 norm_layer(planes * block.expansion),
200             )
202         layers = []
203         layers.append(block(self.inplanes, planes, stride, downsample, self.groups,
204                             self.base_width, previous_dilation, norm_layer))
205         self.inplanes = planes * block.expansion
206         for _ in range(1, blocks):
207             layers.append(block(self.inplanes, planes, groups=self.groups,
208                                 base_width=self.base_width, dilation=self.dilation,
209                                 norm_layer=norm_layer))
211         return nn.Sequential(*layers)
213     def forward(self, x):
214         x = self.features(x)
215         if self.num_classes:
216             x = self.avgpool(x)
217             x = torch.flatten(x, 1)
218             x = self.classifier(x)
220         return x
223     # define a load weights fuinction in the module since the module is changed w.r.t. to torchvision
224     # since we want to be able to laod the existing torchvision pretrained weights
225     def load_weights(self, pretrained, change_names_dict=None, download_root=None):
226         if change_names_dict is None:
227             # the pretrained model provided by torchvision and what is defined here differs slightly
228             # note: that this change_names_dict  will take effect only if the direct load fails
229             change_names_dict = {'^conv1.':'features.conv1.', '^bn1.':'features.bn1.',
230                                  '^relu.':'features.relu.', '^maxpool.':'features.maxpool.',
231                                  '^layer':'features.layer' , '^fc.':'classifier.'}
232         #
233         if pretrained is not None:
234             xnn.utils.load_weights(self, pretrained, change_names_dict=change_names_dict, download_root=download_root)
235         return self, change_names_dict
238 def _resnet(arch, block, layers, pretrained, progress, **kwargs):
239     model = ResNet(block, layers, **kwargs)
240     if pretrained:
241         change_names_dict = kwargs.get('change_names_dict', None)
242         download_root = kwargs.get('download_root', None)
243         model.load_weights(pretrained, change_names_dict=change_names_dict, download_root=download_root)
244     return model
247 def resnet18(pretrained=False, progress=True, **kwargs):
248     r"""ResNet-18 model from
249     `"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_
251     Args:
252         pretrained (bool): If True, returns a model pre-trained on ImageNet
253         progress (bool): If True, displays a progress bar of the download to stderr
254     """
255     return _resnet('resnet18', BasicBlock, [2, 2, 2, 2], pretrained, progress,
256                    **kwargs)
259 def resnet34(pretrained=False, progress=True, **kwargs):
260     r"""ResNet-34 model from
261     `"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_
263     Args:
264         pretrained (bool): If True, returns a model pre-trained on ImageNet
265         progress (bool): If True, displays a progress bar of the download to stderr
266     """
267     return _resnet('resnet34', BasicBlock, [3, 4, 6, 3], pretrained, progress,
268                    **kwargs)
271 def resnet50(pretrained=False, progress=True, **kwargs):
272     r"""ResNet-50 model from
273     `"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_
275     Args:
276         pretrained (bool): If True, returns a model pre-trained on ImageNet
277         progress (bool): If True, displays a progress bar of the download to stderr
278     """
279     return _resnet('resnet50', Bottleneck, [3, 4, 6, 3], pretrained, progress,
280                    **kwargs)
283 def resnet101(pretrained=False, progress=True, **kwargs):
284     r"""ResNet-101 model from
285     `"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_
287     Args:
288         pretrained (bool): If True, returns a model pre-trained on ImageNet
289         progress (bool): If True, displays a progress bar of the download to stderr
290     """
291     return _resnet('resnet101', Bottleneck, [3, 4, 23, 3], pretrained, progress,
292                    **kwargs)
295 def resnet152(pretrained=False, progress=True, **kwargs):
296     r"""ResNet-152 model from
297     `"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_
299     Args:
300         pretrained (bool): If True, returns a model pre-trained on ImageNet
301         progress (bool): If True, displays a progress bar of the download to stderr
302     """
303     return _resnet('resnet152', Bottleneck, [3, 8, 36, 3], pretrained, progress,
304                    **kwargs)
307 def resnext50_32x4d(pretrained=False, progress=True, **kwargs):
308     r"""ResNeXt-50 32x4d model from
309     `"Aggregated Residual Transformation for Deep Neural Networks" <https://arxiv.org/pdf/1611.05431.pdf>`_
311     Args:
312         pretrained (bool): If True, returns a model pre-trained on ImageNet
313         progress (bool): If True, displays a progress bar of the download to stderr
314     """
315     kwargs['groups'] = 32
316     kwargs['width_per_group'] = 4
317     return _resnet('resnext50_32x4d', Bottleneck, [3, 4, 6, 3],
318                    pretrained, progress, **kwargs)
321 def resnext101_32x8d(pretrained=False, progress=True, **kwargs):
322     r"""ResNeXt-101 32x8d model from
323     `"Aggregated Residual Transformation for Deep Neural Networks" <https://arxiv.org/pdf/1611.05431.pdf>`_
325     Args:
326         pretrained (bool): If True, returns a model pre-trained on ImageNet
327         progress (bool): If True, displays a progress bar of the download to stderr
328     """
329     kwargs['groups'] = 32
330     kwargs['width_per_group'] = 8
331     return _resnet('resnext101_32x8d', Bottleneck, [3, 4, 23, 3],
332                    pretrained, progress, **kwargs)
335 def wide_resnet50_2(pretrained=False, progress=True, **kwargs):
336     r"""Wide ResNet-50-2 model from
337     `"Wide Residual Networks" <https://arxiv.org/pdf/1605.07146.pdf>`_
339     The model is the same as ResNet except for the bottleneck number of channels
340     which is twice larger in every block. The number of channels in outer 1x1
341     convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048
342     channels, and in Wide ResNet-50-2 has 2048-1024-2048.
344     Args:
345         pretrained (bool): If True, returns a model pre-trained on ImageNet
346         progress (bool): If True, displays a progress bar of the download to stderr
347     """
348     kwargs['width_per_group'] = 64 * 2
349     return _resnet('wide_resnet50_2', Bottleneck, [3, 4, 6, 3],
350                    pretrained, progress, **kwargs)
353 def wide_resnet101_2(pretrained=False, progress=True, **kwargs):
354     r"""Wide ResNet-101-2 model from
355     `"Wide Residual Networks" <https://arxiv.org/pdf/1605.07146.pdf>`_
357     The model is the same as ResNet except for the bottleneck number of channels
358     which is twice larger in every block. The number of channels in outer 1x1
359     convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048
360     channels, and in Wide ResNet-50-2 has 2048-1024-2048.
362     Args:
363         pretrained (bool): If True, returns a model pre-trained on ImageNet
364         progress (bool): If True, displays a progress bar of the download to stderr
365     """
366     kwargs['width_per_group'] = 64 * 2
367     return _resnet('wide_resnet101_2', Bottleneck, [3, 4, 23, 3],
368                    pretrained, progress, **kwargs)
371 ###################################################
372 def get_config():
373     model_config = xnn.utils.ConfigNode()
374     model_config.input_channels = 3
375     model_config.num_classes = 1000
376     model_config.strides = None #(2,2,2,2,2)
377     model_config.fastdown = False
378     return model_config
381 def resnet50_with_model_config(model_config, pretrained=None):
382     model_config = get_config().merge_from(model_config)
383     model = resnet50(input_channels=model_config.input_channels, strides=model_config.strides,
384                      num_classes=model_config.num_classes, pretrained=pretrained,
385                      width_mult=model_config.width_mult, fastdown=model_config.fastdown)
386     return model