]> Gitweb @ Texas Instruments - Open Source Git Repositories - git.TI.com/gitweb - jacinto-ai/pytorch-jacinto-ai-devkit.git/blob - modules/pytorch_jacinto_ai/vision/models/pixel2pixel/fpn_pixel2pixel.py
release commit
[jacinto-ai/pytorch-jacinto-ai-devkit.git] / modules / pytorch_jacinto_ai / vision / models / pixel2pixel / fpn_pixel2pixel.py
1 import torch
2 import numpy as np
3 from .... import xnn
5 from .pixel2pixelnet import *
6 from ..multi_input_net import MobileNetV2TVMI4, ResNet50MI4
9 __all__ = ['FPNPixel2PixelASPP', 'FPNPixel2PixelDecoder',
10            'fpn_pixel2pixel_aspp_mobilenetv2_tv', 'fpn_pixel2pixel_aspp_mobilenetv2_tv_fd', 'fpn128_pixel2pixel_aspp_mobilenetv2_tv_fd',
11            # no aspp models
12            'fpn_pixel2pixel_mobilenetv2_tv', 'fpn_pixel2pixel_mobilenetv2_tv_fd',
13            # resnet models
14            'fpn_pixel2pixel_aspp_resnet50', 'fpn_pixel2pixel_aspp_resnet50_fd',
15            ]
17 # config settings for mobilenetv2 backbone
18 def get_config_fpnp2p_mnv2():
19     model_config = xnn.utils.ConfigNode()
20     model_config.num_classes = None
21     model_config.num_decoders = None
22     model_config.intermediate_outputs = True
23     model_config.use_aspp = True
24     model_config.use_extra_strides = False
25     model_config.groupwise_sep = False
26     model_config.fastdown = False
28     model_config.strides = (2,2,2,2,2)
29     encoder_stride = np.prod(model_config.strides)
30     model_config.shortcut_strides = (4,8,16,encoder_stride)
31     model_config.shortcut_channels = (24,32,96,320) # this is for mobilenetv2 - change for other networks
32     model_config.decoder_chan = 256
33     model_config.aspp_chan = 256
34     model_config.aspp_dil = (6,12,18)
36     model_config.inloop_fpn = True #False # inloop_fpn means the smooth convs are in the loop, after upsample
38     model_config.kernel_size_smooth = 3
39     model_config.interpolation_type = 'upsample'
40     model_config.interpolation_mode = 'bilinear'
42     model_config.final_prediction = True
43     model_config.final_upsample = True
45     model_config.normalize_input = False
46     model_config.split_outputs = False
47     model_config.decoder_factor = 1.0
48     model_config.activation = xnn.layers.DefaultAct2d
49     model_config.linear_dw = False
50     model_config.normalize_gradients = False
51     model_config.freeze_encoder = False
52     model_config.freeze_decoder = False
53     model_config.multi_task = False
54     return model_config
57 ###########################################
58 class FPNPyramid(torch.nn.Module):
59     def __init__(self, current_channels, decoder_channels, shortcut_strides, shortcut_channels, activation, kernel_size_smooth, interpolation_type, interpolation_mode, inloop_fpn=False, all_outputs=False):
60         super().__init__()
61         self.inloop_fpn = inloop_fpn
62         self.shortcut_strides = shortcut_strides
63         self.shortcut_channels = shortcut_channels
64         self.smooth_convs = torch.nn.ModuleList()
65         self.shortcuts = torch.nn.ModuleList()
66         self.upsamples = torch.nn.ModuleList()
68         shortcut0 = self.create_shortcut(current_channels, decoder_channels, activation) if (current_channels != decoder_channels) else None
69         self.shortcuts.append(shortcut0)
71         smooth_conv0 = None #xnn.layers.ConvDWSepNormAct2d(decoder_channels, decoder_channels, kernel_size=kernel_size_smooth, activation=(activation, activation)) if all_outputs else None
72         self.smooth_convs.append(smooth_conv0)
74         upstride = 2
75         for idx, (s_stride, feat_chan) in enumerate(zip(shortcut_strides, shortcut_channels)):
76             shortcut = self.create_shortcut(feat_chan, decoder_channels, activation)
77             self.shortcuts.append(shortcut)
78             is_last = (idx == len(shortcut_channels)-1)
79             smooth_conv = xnn.layers.ConvDWSepNormAct2d(decoder_channels, decoder_channels, kernel_size=kernel_size_smooth, activation=(activation,activation)) \
80                         if (inloop_fpn or all_outputs or is_last) else None
81             self.smooth_convs.append(smooth_conv)
82             upsample = xnn.layers.UpsampleTo(decoder_channels, decoder_channels, upstride, interpolation_type, interpolation_mode)
83             self.upsamples.append(upsample)
84         #
85     #
87     def create_shortcut(self, inch, outch, activation):
88         shortcut = xnn.layers.ConvNormAct2d(inch, outch, kernel_size=1, activation=activation)
89         return shortcut
90     #
92     def forward(self, x_input, x_list):
93         in_shape = x_input.shape
94         x = x_list[-1]
96         outputs = []
97         x = self.shortcuts[0](x) if (self.shortcuts[0] is not None) else x
98         y = self.smooth_convs[0](x) if (self.smooth_convs[0] is not None) else x
99         x = y if self.inloop_fpn else x
100         outputs.append(y)
102         for idx, (shortcut, smooth_conv, s_stride, short_chan, upsample) in enumerate(zip(self.shortcuts[1:], self.smooth_convs[1:], self.shortcut_strides, self.shortcut_channels, self.upsamples)):
103             # get the feature of lower stride
104             shape_s = xnn.utils.get_shape_with_stride(in_shape, s_stride)
105             shape_s[1] = short_chan
106             x_s = xnn.utils.get_blob_from_list(x_list, shape_s)
107             x_s = shortcut(x_s)
108             # updample current output and add to that
109             x = upsample((x,x_s))
110             x = x + x_s
111             # smooth conv
112             y = smooth_conv(x) if (smooth_conv is not None) else x
113             # use smooth output for next level in inloop_fpn
114             x = y if self.inloop_fpn else x
115             # output
116             outputs.append(y)
117         #
118         return outputs[::-1]
121 class InLoopFPNPyramid(FPNPyramid):
122     def __init__(self, input_channels, decoder_channels, shortcut_strides, shortcut_channels, activation, kernel_size_smooth, interpolation_type, interpolation_mode, inloop_fpn=True, all_outputs=False):
123         super().__init__(input_channels, decoder_channels, shortcut_strides, shortcut_channels, activation, kernel_size_smooth, interpolation_type, interpolation_mode, inloop_fpn=inloop_fpn, all_outputs=all_outputs)
126 ###########################################
127 class FPNPixel2PixelDecoder(torch.nn.Module):
128     def __init__(self, model_config):
129         super().__init__()
130         self.model_config = model_config
131         activation = self.model_config.activation
132         self.output_type = model_config.output_type
133         self.decoder_channels = decoder_channels = round(self.model_config.decoder_chan*self.model_config.decoder_factor)
135         upstride_final = self.model_config.shortcut_strides[0]
136         self.upsample = xnn.layers.UpsampleTo(model_config.output_channels, model_config.output_channels, upstride_final, model_config.interpolation_type, model_config.interpolation_mode)
138         self.rfblock = None
139         if self.model_config.use_aspp:
140             current_channels = self.model_config.shortcut_channels[-1]
141             aspp_channels = round(self.model_config.aspp_chan * self.model_config.decoder_factor)
142             self.rfblock = xnn.layers.DWASPPLiteBlock(current_channels, aspp_channels, decoder_channels, dilation=self.model_config.aspp_dil, avg_pool=False, activation=activation)
143             current_channels = decoder_channels
144         elif self.model_config.use_extra_strides:
145             # a low complexity pyramid
146             current_channels = self.model_config.shortcut_channels[-3]
147             self.rfblock = torch.nn.Sequential(xnn.layers.ConvDWSepNormAct2d(current_channels, current_channels, kernel_size=3, stride=2, activation=(activation, activation)),
148                                                xnn.layers.ConvDWSepNormAct2d(current_channels, decoder_channels, kernel_size=3, stride=2, activation=(activation, activation)))
149             current_channels = decoder_channels
150         else:
151             current_channels = self.model_config.shortcut_channels[-1]
152             self.rfblock = xnn.layers.ConvNormAct2d(current_channels, decoder_channels, kernel_size=1, stride=1)
153             current_channels = decoder_channels
154         #
156         shortcut_strides = self.model_config.shortcut_strides[::-1][1:]
157         shortcut_channels = self.model_config.shortcut_channels[::-1][1:]
158         FPNType = InLoopFPNPyramid if model_config.inloop_fpn else FPNPyramid
159         self.fpn = FPNType(current_channels, decoder_channels, shortcut_strides, shortcut_channels, self.model_config.activation, self.model_config.kernel_size_smooth,
160                            self.model_config.interpolation_type, self.model_config.interpolation_mode)
162         # prediction
163         if self.model_config.final_prediction:
164             self.pred = xnn.layers.ConvDWSepNormAct2d(current_channels, self.model_config.output_channels, kernel_size=3, normalization=(True,False), activation=(False,False))
165         #
167     def forward(self, x_input, x, x_list):
168         assert isinstance(x_input, (list,tuple)) and len(x_input)<=2, 'incorrect input'
169         assert x is x_list[-1], 'the features must the last one in x_list'
170         x_input = x_input[0]
171         in_shape = x_input.shape
173         if self.model_config.use_extra_strides:
174             for blk in self.rfblock:
175                 x = blk(x)
176                 x_list += [x]
177             #
178         elif self.rfblock is not None:
179             x = self.rfblock(x)
180             x_list[-1] = x
181         #
183         x_list = self.fpn(x_input, x_list)
184         x = x_list[0]
186         if self.model_config.final_prediction:
187             # prediction
188             x = self.pred(x)
190             # final prediction is the upsampled one
191             if self.model_config.final_upsample:
192                 x = self.upsample((x,x_input))
194             if (not self.training) and (self.output_type == 'segmentation'):
195                 x = torch.argmax(x, dim=1, keepdim=True)
197             assert int(in_shape[2]) == int(x.shape[2]) and int(in_shape[3]) == int(x.shape[3]), 'incorrect output shape'
199         return x
202 ###########################################
203 class FPNPixel2PixelASPP(Pixel2PixelNet):
204     def __init__(self, base_model, model_config):
205         super().__init__(base_model, FPNPixel2PixelDecoder, model_config)
208 ###########################################
209 def fpn_pixel2pixel_aspp_mobilenetv2_tv(model_config, pretrained=None):
210     model_config = get_config_fpnp2p_mnv2().merge_from(model_config)
211     # encoder setup
212     model_config_e = model_config.clone()
213     base_model = MobileNetV2TVMI4(model_config_e)
214     # decoder setup
215     model = FPNPixel2PixelASPP(base_model, model_config)
217     num_inputs = len(model_config.input_channels)
218     num_decoders = len(model_config.output_channels) if (model_config.num_decoders is None) else model_config.num_decoders
219     if num_inputs > 1:
220         change_names_dict = {'^features.': ['encoder.features.stream{}.'.format(stream) for stream in range(num_inputs)],
221                             '^classifier.': 'encoder.classifier.',
222                             '^encoder.features.': ['encoder.features.stream{}.'.format(stream) for stream in range(num_inputs)],
223                             '^decoders.0.': ['decoders.{}.'.format(d) for d in range(num_decoders)]}
224     else:
225         change_names_dict = {'^features.': 'encoder.features.',
226                              '^classifier.': 'encoder.classifier.',
227                              '^decoders.0.': ['decoders.{}.'.format(d) for d in range(num_decoders)]}
228     #
230     if pretrained:
231         model = xnn.utils.load_weights(model, pretrained, change_names_dict, ignore_size=True, verbose=True)
232     #
233     return model, change_names_dict
236 # fast down sampling model (encoder stride 64 model)
237 def fpn_pixel2pixel_aspp_mobilenetv2_tv_fd(model_config, pretrained=None):
238     model_config = get_config_fpnp2p_mnv2().merge_from(model_config)
239     model_config.fastdown = True
240     model_config.strides = (2,2,2,2,2)
241     model_config.shortcut_strides = (8,16,32,64)
242     model_config.shortcut_channels = (24,32,96,320)
243     model_config.decoder_chan = 256
244     model_config.aspp_chan = 256
245     return fpn_pixel2pixel_aspp_mobilenetv2_tv(model_config, pretrained=pretrained)
248 # fast down sampling model (encoder stride 64 model) with fpn decoder channels 128
249 def fpn128_pixel2pixel_aspp_mobilenetv2_tv_fd(model_config, pretrained=None):
250     model_config = get_config_fpnp2p_mnv2().merge_from(model_config)
251     model_config.fastdown = True
252     model_config.strides = (2,2,2,2,2)
253     model_config.shortcut_strides = (4,8,16,32,64)
254     model_config.shortcut_channels = (16,24,32,96,320)
255     model_config.decoder_chan = 128
256     model_config.aspp_chan = 128
257     return fpn_pixel2pixel_aspp_mobilenetv2_tv(model_config, pretrained=pretrained)
260 ##################
261 # similar to the original fpn model with extra convolutions with strides (no aspp)
262 def fpn_pixel2pixel_mobilenetv2_tv(model_config, pretrained=None):
263     model_config = get_config_fpnp2p_mnv2().merge_from(model_config)
264     model_config.use_aspp = False
265     model_config.use_extra_strides = True
266     model_config.shortcut_strides = (4, 8, 16, 32, 64, 128)
267     model_config.shortcut_channels = (24, 32, 96, 320, 320, 256)
268     return fpn_pixel2pixel_aspp_mobilenetv2_tv(model_config, pretrained=pretrained)
271 # similar to the original fpn model with extra convolutions with strides (no aspp) - fast down sampling model (encoder stride 64 model)
272 def fpn_pixel2pixel_mobilenetv2_tv_fd(model_config, pretrained=None):
273     model_config = get_config_fpnp2p_mnv2().merge_from(model_config)
274     model_config.use_aspp = False
275     model_config.use_extra_strides = True
276     model_config.fastdown = True
277     model_config.strides = (2,2,2,2,2)
278     model_config.shortcut_strides = (8, 16, 32, 64, 128, 256)
279     model_config.shortcut_channels = (24, 32, 96, 320, 320, 256)
280     model_config.decoder_chan = 256
281     model_config.aspp_chan = 256
282     return fpn_pixel2pixel_aspp_mobilenetv2_tv(model_config, pretrained=pretrained)
285 ###########################################
286 def get_config_fpnp2p_resnet50():
287     # only the delta compared to the one defined for mobilenetv2
288     model_config = get_config_fpnp2p_mnv2()
289     model_config.shortcut_channels = (256,512,1024,2048)
290     return model_config
293 def fpn_pixel2pixel_aspp_resnet50(model_config, pretrained=None):
294     model_config = get_config_fpnp2p_resnet50().merge_from(model_config)
295     # encoder setup
296     model_config_e = model_config.clone()
297     base_model = ResNet50MI4(model_config_e)
298     # decoder setup
299     model = FPNPixel2PixelASPP(base_model, model_config)
301     # the pretrained model provided by torchvision and what is defined here differs slightly
302     # note: that this change_names_dict  will take effect only if the direct load fails
303     # finally take care of the change for fpn (features->encoder.features)
304     num_inputs = len(model_config.input_channels)
305     num_decoders = len(model_config.output_channels) if (model_config.num_decoders is None) else model_config.num_decoders
306     if num_inputs > 1:
307         change_names_dict = {'^conv1.': ['encoder.features.stream{}.conv1.'.format(stream) for stream in range(num_inputs)],
308                             '^bn1.': ['encoder.features.stream{}.bn1.'.format(stream) for stream in range(num_inputs)],
309                             '^relu.': ['encoder.features.stream{}.relu.'.format(stream) for stream in range(num_inputs)],
310                             '^maxpool.': ['encoder.features.stream{}.maxpool.'.format(stream) for stream in range(num_inputs)],
311                             '^layer': ['encoder.features.stream{}.layer'.format(stream) for stream in range(num_inputs)],
312                             '^features.': ['encoder.features.stream{}.'.format(stream) for stream in range(num_inputs)],
313                             '^encoder.features.': ['encoder.features.stream{}.'.format(stream) for stream in range(num_inputs)],
314                             '^decoders.0.': ['decoders.{}.'.format(d) for d in range(num_decoders)]}
315     else:
316         change_names_dict = {'^conv1.': 'encoder.features.conv1.',
317                              '^bn1.': 'encoder.features.bn1.',
318                              '^relu.': 'encoder.features.relu.',
319                              '^maxpool.': 'encoder.features.maxpool.',
320                              '^layer': 'encoder.features.layer',
321                              '^features.': 'encoder.features.',
322                              '^decoders.0.': ['decoders.{}.'.format(d) for d in range(num_decoders)]}
323     #
325     if pretrained:
326         model = xnn.utils.load_weights(model, pretrained, change_names_dict, ignore_size=True, verbose=True)
328     return model, change_names_dict
331 def fpn_pixel2pixel_aspp_resnet50_fd(model_config, pretrained=None):
332     model_config = get_config_fpnp2p_resnet50().merge_from(model_config)
333     model_config.fastdown = True
334     model_config.strides = (2,2,2,2,2)
335     model_config.shortcut_strides = (8,16,32,64) #(4,8,16,32,64)
336     model_config.shortcut_channels = (256,512,1024,2048) #(64,256,512,1024,2048)
337     model_config.decoder_chan = 256 #128
338     model_config.aspp_chan = 256 #128
339     return fpn_pixel2pixel_aspp_resnet50(model_config, pretrained=pretrained)