]> Gitweb @ Texas Instruments - Open Source Git Repositories - git.TI.com/gitweb - jacinto-ai/pytorch-jacinto-ai-devkit.git/blob - modules/pytorch_jacinto_ai/xvision/models/pixel2pixel/unetlite_pixel2pixel.py
0d6c5b6da449d782fa00e7e724f2e939999ae099
[jacinto-ai/pytorch-jacinto-ai-devkit.git] / modules / pytorch_jacinto_ai / xvision / models / pixel2pixel / unetlite_pixel2pixel.py
1 import torch
2 import numpy as np
3 from .... import xnn
5 from .pixel2pixelnet import *
6 from ..multi_input_net import MobileNetV2TVMI4, ResNet50MI4
9 __all__ = ['UNetLitePixel2PixelASPP', 'UNetLitePixel2PixelDecoder',
10            'unetlite_pixel2pixel_aspp_mobilenetv2_tv', 'unetlite_pixel2pixel_aspp_mobilenetv2_tv_fd',
11            'unetlite_pixel2pixel_aspp_resnet50', 'unetlite_pixel2pixel_aspp_resnet50_fd',
12            ]
14 # config settings for mobilenetv2 backbone
15 def get_config_unetlitep2p_mnv2():
16     model_config = xnn.utils.ConfigNode()
17     model_config.num_classes = None
18     model_config.num_decoders = None
19     model_config.intermediate_outputs = True
20     model_config.use_aspp = True
21     model_config.use_extra_strides = False
22     model_config.groupwise_sep = False
23     model_config.fastdown = False
24     model_config.width_mult = 1.0
25     model_config.target_input_ratio = 1
27     model_config.strides = (2,2,2,2,2)
28     encoder_stride = np.prod(model_config.strides)
29     model_config.shortcut_strides = (2,4,8,16,encoder_stride)
30     model_config.shortcut_channels = (16,24,32,96,320) # this is for mobilenetv2 - change for other networks
31     model_config.decoder_chan = 256
32     model_config.aspp_chan = 256
33     model_config.aspp_dil = (6,12,18)
35     model_config.kernel_size_smooth = 3
36     model_config.interpolation_type = 'upsample'
37     model_config.interpolation_mode = 'bilinear'
39     model_config.final_prediction = True
40     model_config.final_upsample = True
41     model_config.output_range = None
43     model_config.normalize_input = False
44     model_config.split_outputs = False
45     model_config.decoder_factor = 1.0
46     model_config.activation = xnn.layers.DefaultAct2d
47     model_config.linear_dw = False
48     model_config.normalize_gradients = False
49     model_config.freeze_encoder = False
50     model_config.freeze_decoder = False
51     model_config.multi_task = False
52     return model_config
55 ###########################################
56 class UNetLitePyramid(torch.nn.Module):
57     def __init__(self, current_channels, minimum_channels, shortcut_strides, shortcut_channels, activation, kernel_size_smooth, interpolation_type, interpolation_mode):
58         super().__init__()
59         self.shortcut_strides = shortcut_strides
60         self.shortcut_channels = shortcut_channels
61         self.upsamples = torch.nn.ModuleList()
62         self.concats = torch.nn.ModuleList()
63         self.smooth_convs = torch.nn.ModuleList()
65         self.smooth_convs.append(None)
66         self.concats.append(None)
68         upstride = 2
69         activation2 = (activation, activation)
70         for idx, (s_stride, feat_chan) in enumerate(zip(shortcut_strides, shortcut_channels)):
71             self.upsamples.append(xnn.layers.UpsampleWith(current_channels, current_channels, upstride, interpolation_type, interpolation_mode))
72             self.concats.append(xnn.layers.CatBlock())
73             smooth_channels = max(minimum_channels, feat_chan)
74             self.smooth_convs.append( xnn.layers.ConvDWSepNormAct2d(current_channels+feat_chan, smooth_channels, kernel_size=kernel_size_smooth, activation=activation2))
75             current_channels = smooth_channels
76         #
77     #
80     def forward(self, x_input, x_list):
81         in_shape = x_input.shape
82         x = x_list[-1]
84         outputs = []
86         x = self.smooth_convs[0](x) if (self.smooth_convs[0] is not None) else x
87         outputs.append(x)
89         for idx, (concat, smooth_conv, s_stride, short_chan, upsample) in \
90                 enumerate(zip(self.concats[1:], self.smooth_convs[1:], self.shortcut_strides, self.shortcut_channels, self.upsamples)):
91             # get the feature of lower stride
92             shape_s = xnn.utils.get_shape_with_stride(in_shape, s_stride)
93             shape_s[1] = short_chan
94             x_s = xnn.utils.get_blob_from_list(x_list, shape_s)
95             # upsample current output and concat to that
96             x = upsample(x)
97             x = concat((x,x_s)) if (concat is not None) else x
98             # smooth conv
99             x = smooth_conv(x) if (smooth_conv is not None) else x
100             # output
101             outputs.append(x)
102         #
103         return outputs[::-1]
106 ###########################################
107 class UNetLitePixel2PixelDecoder(torch.nn.Module):
108     def __init__(self, model_config):
109         super().__init__()
110         self.model_config = model_config
111         activation = self.model_config.activation
112         self.output_type = model_config.output_type
113         self.decoder_channels = decoder_channels = round(self.model_config.decoder_chan*self.model_config.decoder_factor)
115         self.rfblock = None
116         if self.model_config.use_aspp:
117             current_channels = self.model_config.shortcut_channels[-1]
118             aspp_channels = round(self.model_config.aspp_chan * self.model_config.decoder_factor)
119             self.rfblock = xnn.layers.DWASPPLiteBlock(current_channels, aspp_channels, decoder_channels, dilation=self.model_config.aspp_dil, avg_pool=False, activation=activation)
120             current_channels = decoder_channels
121         elif self.model_config.use_extra_strides:
122             # a low complexity pyramid
123             current_channels = self.model_config.shortcut_channels[-3]
124             self.rfblock = torch.nn.Sequential(xnn.layers.ConvDWSepNormAct2d(current_channels, current_channels, kernel_size=3, stride=2, activation=(activation, activation)),
125                                                xnn.layers.ConvDWSepNormAct2d(current_channels, decoder_channels, kernel_size=3, stride=2, activation=(activation, activation)))
126             current_channels = decoder_channels
127         else:
128             current_channels = self.model_config.shortcut_channels[-1]
129             self.rfblock = xnn.layers.ConvNormAct2d(current_channels, decoder_channels, kernel_size=1, stride=1)
130             current_channels = decoder_channels
131         #
133         minimum_channels = max(self.model_config.output_channels*2, 32)
134         shortcut_strides = self.model_config.shortcut_strides[::-1][1:]
135         shortcut_channels = self.model_config.shortcut_channels[::-1][1:]
136         self.unet = UNetLitePyramid(current_channels, minimum_channels, shortcut_strides, shortcut_channels, self.model_config.activation, self.model_config.kernel_size_smooth,
137                            self.model_config.interpolation_type, self.model_config.interpolation_mode)
138         current_channels = max(minimum_channels, shortcut_channels[-1])
140         # add prediction & upsample modules
141         if self.model_config.final_prediction:
142             add_lite_prediction_modules(self, model_config, current_channels, module_names=('pred','upsample'))
143         #
146     def forward(self, x_input, x, x_list):
147         assert isinstance(x_input, (list,tuple)) and len(x_input)<=2, 'incorrect input'
148         assert x is x_list[-1], 'the features must the last one in x_list'
149         x_input = x_input[0]
150         in_shape = x_input.shape
152         if self.model_config.use_extra_strides:
153             for blk in self.rfblock:
154                 x = blk(x)
155                 x_list += [x]
156             #
157         elif self.rfblock is not None:
158             x = self.rfblock(x)
159             x_list[-1] = x
160         #
162         x_list = self.unet(x_input, x_list)
163         x = x_list[0]
165         if self.model_config.final_prediction:
166             # prediction
167             x = self.pred(x)
169             # final prediction is the upsampled one
170             if self.model_config.final_upsample:
171                 x = self.upsample(x)
173             if (not self.training) and (self.output_type == 'segmentation'):
174                 x = torch.argmax(x, dim=1, keepdim=True)
176             assert int(in_shape[2]) == int(x.shape[2]*self.model_config.target_input_ratio) and int(in_shape[3]) == int(x.shape[3]*self.model_config.target_input_ratio), 'incorrect output shape'
178         return x
181 ###########################################
182 class UNetLitePixel2PixelASPP(Pixel2PixelNet):
183     def __init__(self, base_model, model_config):
184         super().__init__(base_model, UNetLitePixel2PixelDecoder, model_config)
187 ###########################################
188 def unetlite_pixel2pixel_aspp_mobilenetv2_tv(model_config, pretrained=None):
189     model_config = get_config_unetlitep2p_mnv2().merge_from(model_config)
190     # encoder setup
191     model_config_e = model_config.clone()
192     base_model = MobileNetV2TVMI4(model_config_e)
193     # decoder setup
194     model = UNetLitePixel2PixelASPP(base_model, model_config)
196     num_inputs = len(model_config.input_channels)
197     num_decoders = len(model_config.output_channels) if (model_config.num_decoders is None) else model_config.num_decoders
198     if num_inputs > 1:
199         change_names_dict = {'^features.': ['encoder.features.stream{}.'.format(stream) for stream in range(num_inputs)],
200                             '^classifier.': 'encoder.classifier.',
201                             '^encoder.features.': ['encoder.features.stream{}.'.format(stream) for stream in range(num_inputs)],
202                             '^decoders.0.': ['decoders.{}.'.format(d) for d in range(num_decoders)]}
203     else:
204         change_names_dict = {'^features.': 'encoder.features.',
205                              '^classifier.': 'encoder.classifier.',
206                              '^decoders.0.': ['decoders.{}.'.format(d) for d in range(num_decoders)]}
207     #
209     if pretrained:
210         model = xnn.utils.load_weights(model, pretrained, change_names_dict, ignore_size=True, verbose=True)
211     #
212     return model, change_names_dict
215 # fast down sampling model (encoder stride 64 model)
216 def unetlite_pixel2pixel_aspp_mobilenetv2_tv_fd(model_config, pretrained=None):
217     model_config = get_config_unetlitep2p_mnv2().merge_from(model_config)
218     model_config.fastdown = True
219     model_config.strides = (2,2,2,2,2)
220     model_config.shortcut_strides = (4,8,16,32,64)
221     model_config.shortcut_channels = (16,24,32,96,320)
222     model_config.decoder_chan = 256
223     model_config.aspp_chan = 256
224     return unetlite_pixel2pixel_aspp_mobilenetv2_tv(model_config, pretrained=pretrained)
227 ###########################################
228 def get_config_unetlitep2p_resnet50():
229     # only the delta compared to the one defined for mobilenetv2
230     model_config = get_config_unetlitep2p_mnv2()
231     model_config.shortcut_strides = (2,4,8,16,32)
232     model_config.shortcut_channels = (64,256,512,1024,2048)
233     return model_config
236 def unetlite_pixel2pixel_aspp_resnet50(model_config, pretrained=None):
237     model_config = get_config_unetlitep2p_resnet50().merge_from(model_config)
238     # encoder setup
239     model_config_e = model_config.clone()
240     base_model = ResNet50MI4(model_config_e)
241     # decoder setup
242     model = UNetLitePixel2PixelASPP(base_model, model_config)
244     # the pretrained model provided by torchvision and what is defined here differs slightly
245     # note: that this change_names_dict  will take effect only if the direct load fails
246     # finally take care of the change for unet (features->encoder.features)
247     num_inputs = len(model_config.input_channels)
248     num_decoders = len(model_config.output_channels) if (model_config.num_decoders is None) else model_config.num_decoders
249     if num_inputs > 1:
250         change_names_dict = {'^conv1.': ['encoder.features.stream{}.conv1.'.format(stream) for stream in range(num_inputs)],
251                             '^bn1.': ['encoder.features.stream{}.bn1.'.format(stream) for stream in range(num_inputs)],
252                             '^relu.': ['encoder.features.stream{}.relu.'.format(stream) for stream in range(num_inputs)],
253                             '^maxpool.': ['encoder.features.stream{}.maxpool.'.format(stream) for stream in range(num_inputs)],
254                             '^layer': ['encoder.features.stream{}.layer'.format(stream) for stream in range(num_inputs)],
255                             '^features.': ['encoder.features.stream{}.'.format(stream) for stream in range(num_inputs)],
256                             '^encoder.features.': ['encoder.features.stream{}.'.format(stream) for stream in range(num_inputs)],
257                             '^decoders.0.': ['decoders.{}.'.format(d) for d in range(num_decoders)]}
258     else:
259         change_names_dict = {'^conv1.': 'encoder.features.conv1.',
260                              '^bn1.': 'encoder.features.bn1.',
261                              '^relu.': 'encoder.features.relu.',
262                              '^maxpool.': 'encoder.features.maxpool.',
263                              '^layer': 'encoder.features.layer',
264                              '^features.': 'encoder.features.',
265                              '^decoders.0.': ['decoders.{}.'.format(d) for d in range(num_decoders)]}
266     #
268     if pretrained:
269         model = xnn.utils.load_weights(model, pretrained, change_names_dict, ignore_size=True, verbose=True)
271     return model, change_names_dict
274 def unetlite_pixel2pixel_aspp_resnet50_fd(model_config, pretrained=None):
275     model_config = get_config_unetlitep2p_resnet50().merge_from(model_config)
276     model_config.fastdown = True
277     model_config.strides = (2,2,2,2,2)
278     model_config.shortcut_strides = (2,4,8,16,32,64) #(4,8,16,32,64)
279     model_config.shortcut_channels = (64,64,256,512,1024,2048) #(64,256,512,1024,2048)
280     model_config.decoder_chan = 256 #128
281     model_config.aspp_chan = 256 #128
282     return unetlite_pixel2pixel_aspp_resnet50(model_config, pretrained=pretrained)