]> Gitweb @ Texas Instruments - Open Source Git Repositories - git.TI.com/gitweb - jacinto-ai/pytorch-jacinto-ai-devkit.git/blob - modules/pytorch_jacinto_ai/xvision/models/pixel2pixel/deeplabv3lite.py
cleanedup STE for QAT. Added RegNetX models
[jacinto-ai/pytorch-jacinto-ai-devkit.git] / modules / pytorch_jacinto_ai / xvision / models / pixel2pixel / deeplabv3lite.py
1 import torch
2 import numpy as np
4 from .... import xnn
5 from .pixel2pixelnet import *
7 try: from .pixel2pixelnet_internal import *
8 except: pass
10 from ..multi_input_net import MobileNetV2TVMI4, MobileNetV2EricsunMI4, \
11                               ResNet50MI4, RegNetX800MFMI4
13 ###########################################
14 __all__ = ['DeepLabV3Lite', 'DeepLabV3LiteDecoder',
15            'deeplabv3lite_mobilenetv2_tv', 'deeplabv3lite_mobilenetv2_tv_fd',
16            'deeplabv3lite_mobilenetv2_ericsun',
17            'deeplabv3lite_resnet50', 'deeplabv3lite_resnet50_p5', 'deeplabv3lite_resnet50_p5_fd',
18            'deeplabv3lite_regnetx800mf']
21 ###########################################
22 class DeepLabV3LiteDecoder(torch.nn.Module):
23     def __init__(self, model_config):
24         super().__init__()
26         self.model_config = model_config
28         current_channels = model_config.shortcut_channels[-1]
29         decoder_channels = round(model_config.decoder_chan*model_config.decoder_factor)
30         aspp_channels = round(model_config.aspp_chan*model_config.decoder_factor)
32         if model_config.use_aspp:
33             group_size_dw = model_config.group_size_dw if hasattr(model_config, 'group_size_dw') else None
34             ASPPBlock = xnn.layers.GWASPPLiteBlock if model_config.groupwise_sep else xnn.layers.DWASPPLiteBlock
35             self.aspp = ASPPBlock(current_channels, aspp_channels, decoder_channels, dilation=model_config.aspp_dil,
36                                   activation=model_config.activation, linear_dw=model_config.linear_dw,
37                                   group_size_dw=group_size_dw)
38         else:
39             self.aspp = None
41         current_channels = decoder_channels if model_config.use_aspp else current_channels
43         short_chan = model_config.shortcut_channels[0]
44         self.shortcut = xnn.layers.ConvNormAct2d(short_chan, model_config.shortcut_out, kernel_size=1, activation=model_config.activation)
46         self.decoder_channels = merged_channels = (current_channels+model_config.shortcut_out)
48         upstride1 = model_config.shortcut_strides[-1]//model_config.shortcut_strides[0]
49         # use UpsampleWithGeneric() instead of UpsampleWith() to break down large upsampling factors to multiples of 4 and 2 -
50         # useful if upsampling factors other than 4 and 2 are not supported.
51         self.upsample1 = xnn.layers.UpsampleWith(decoder_channels, decoder_channels, upstride1, model_config.interpolation_type, model_config.interpolation_mode)
53         self.cat = xnn.layers.CatBlock()
55         # add prediction & upsample modules
56         if self.model_config.final_prediction:
57             add_lite_prediction_modules(self, model_config, merged_channels, module_names=('pred','upsample2'))
58         #
61     # the upsampling is using functional form to support size based upsampling for odd sizes
62     # that are not a perfect ratio (eg. 257x513), which seem to be popular for segmentation
63     def forward(self, x, x_features, x_list):
64         assert isinstance(x, (list,tuple)) and len(x)<=2, 'incorrect input'
66         x_input = x[0]
67         in_shape = x_input[0].shape if isinstance(x_input, (list,tuple)) else x_input.shape
69         # high res shortcut
70         shape_s = xnn.utils.get_shape_with_stride(in_shape, self.model_config.shortcut_strides[0])
71         shape_s[1] = self.model_config.shortcut_channels[0]
72         x_s = xnn.utils.get_blob_from_list(x_list, shape_s)
73         x_s = self.shortcut(x_s)
75         if self.model_config.freeze_encoder:
76             x_s = x_s.detach()
77             x_features = x_features.detach()
79         # aspp/scse blocks at output stride
80         x = self.aspp(x_features) if self.model_config.use_aspp else x_features
82         # upsample low res features to match with shortcut
83         x = self.upsample1(x)
85         # combine and do high res prediction
86         x = self.cat((x,x_s))
88         if self.model_config.final_prediction:
89             x = self.pred(x)
91             if self.model_config.final_upsample:
92                 x = self.upsample2(x)
94             if (not self.training) and (self.model_config.output_type == 'segmentation'):
95                 x = torch.argmax(x, dim=1, keepdim=True)
97             assert int(in_shape[2]) == int(x.shape[2]*self.model_config.target_input_ratio) and \
98                    int(in_shape[3]) == int(x.shape[3]*self.model_config.target_input_ratio), 'incorrect output shape'
100         if self.model_config.freeze_decoder:
101             x = x.detach()
103         return x
106 class DeepLabV3Lite(Pixel2PixelNet):
107     def __init__(self, base_model, model_config):
108         super().__init__(base_model, DeepLabV3LiteDecoder, model_config)
111 ###########################################
112 # config settings
113 def get_config_deeplav3lite_mnv2():
114     # use list for entries that are different for different decoders.
115     # and are expected to be passed from the main script.
116     model_config = xnn.utils.ConfigNode()
117     model_config.num_classes = None
118     model_config.num_decoders = None
119     model_config.input_channels = (3,)
120     model_config.output_channels = [19]
121     model_config.intermediate_outputs = True
122     model_config.normalize_input = False
123     model_config.split_outputs = False
124     model_config.use_aspp = True
125     model_config.fastdown = False
126     model_config.target_input_ratio = 1
128     model_config.strides = (2,2,2,2,1)
129     model_config.fastdown = False
130     model_config.groupwise_sep = False
131     encoder_stride = np.prod(model_config.strides)
132     model_config.shortcut_strides = (4,encoder_stride)
133     model_config.shortcut_channels = (24,320) # this is for mobilenetv2 - change for other networks
134     model_config.shortcut_out = 48
135     model_config.decoder_chan = 256
136     model_config.aspp_chan = 256
137     model_config.aspp_dil = (6,12,18)
138     model_config.final_prediction = True
139     model_config.final_upsample = True
140     model_config.output_range = None
141     model_config.decoder_factor = 1.0
142     model_config.output_type = None
143     model_config.activation = xnn.layers.DefaultAct2d
144     model_config.interpolation_type = 'upsample'
145     model_config.interpolation_mode = 'bilinear'
146     model_config.linear_dw = False
147     model_config.normalize_gradients = False
148     model_config.freeze_encoder = False
149     model_config.freeze_decoder = False
150     model_config.multi_task = False
151     return model_config
154 def deeplabv3lite_mobilenetv2_tv(model_config, pretrained=None):
155     model_config = get_config_deeplav3lite_mnv2().merge_from(model_config)
156     # encoder setup
157     model_config_e = model_config.clone()
158     base_model = MobileNetV2TVMI4(model_config_e)
159     # decoder setup
160     model = DeepLabV3Lite(base_model, model_config)
162     num_inputs = len(model_config.input_channels)
163     num_decoders = len(model_config.output_channels) if (
164                 model_config.num_decoders is None) else model_config.num_decoders
165     if num_inputs > 1:
166         change_names_dict = {'^features.': ['encoder.features.stream{}.'.format(stream) for stream in range(num_inputs)],
167                             '^encoder.features.': ['encoder.features.stream{}.'.format(stream) for stream in range(num_inputs)],
168                             '^decoders.0.': ['decoders.{}.'.format(d) for d in range(num_decoders)]}
169     else:
170         change_names_dict = {'^features.': 'encoder.features.',
171                              '^decoders.0.': ['decoders.{}.'.format(d) for d in range(num_decoders)]}
172     #
174     if pretrained:
175         model = xnn.utils.load_weights(model, pretrained, change_names_dict, ignore_size=True, verbose=True)
177     return model, change_names_dict
180 def deeplabv3lite_mobilenetv2_tv_fd(model_config, pretrained=None):
181     model_config = get_config_deeplav3lite_mnv2().merge_from(model_config)
182     model_config.fastdown = True
183     model_config.strides = (2,2,2,2,1)
184     model_config.shortcut_strides = (8,32)
185     model_config.shortcut_channels = (24,320)
186     model_config.decoder_chan = 256
187     model_config.aspp_chan = 256
188     return deeplabv3lite_mobilenetv2_tv(model_config, pretrained=pretrained)
191 def deeplabv3lite_mobilenetv2_ericsun(model_config, pretrained=None):
192     model_config = get_config_deeplav3lite_mnv2().merge_from(model_config)
193     # encoder setup
194     model_config_e = model_config.clone()
195     base_model = MobileNetV2EricsunMI4(model_config_e)
196     # decoder setup
197     model = DeepLabV3Lite(base_model, model_config)
199     num_inputs = len(model_config.input_channels)
200     num_decoders = len(model_config.output_channels) if (model_config.num_decoders is None) else model_config.num_decoders
201     if num_inputs > 1:
202         change_names_dict = {
203             '^features.': ['encoder.features.stream{}.'.format(stream) for stream in range(num_inputs)],
204             '^encoder.features.': ['encoder.features.stream{}.'.format(stream) for stream in range(num_inputs)],
205             '^decoders.0.': ['decoders.{}.'.format(d) for d in range(num_decoders)]}
206     else:
207         change_names_dict = {'^features.': 'encoder.features.',
208                              '^decoders.0.': ['decoders.{}.'.format(d) for d in range(num_decoders)]}
209     #
211     if pretrained:
212         model = xnn.utils.load_weights(model, pretrained, change_names_dict, ignore_size=True, verbose=True)
214     return model, change_names_dict
218 ###########################################
219 # config settings for mobilenetv2 backbone
220 def get_config_deeplav3lite_resnet50():
221     # only the delta compared to the one defined for mobilenetv2
222     model_config = get_config_deeplav3lite_mnv2()
223     model_config.shortcut_channels = (256,2048)
224     return model_config
227 def deeplabv3lite_resnet50(model_config, pretrained=None):
228     model_config = get_config_deeplav3lite_resnet50().merge_from(model_config)
229     # encoder setup
230     model_config_e = model_config.clone()
231     base_model = ResNet50MI4(model_config_e)
232     # decoder setup
233     model = DeepLabV3Lite(base_model, model_config)
235     # the pretrained model provided by torchvision and what is defined here differs slightly
236     # note: that this change_names_dict  will take effect only if the direct load fails
237     # finally take care of the change for deeplabv3lite (features->encoder.features)
238     num_inputs = len(model_config.input_channels)
239     num_decoders = len(model_config.output_channels) if (model_config.num_decoders is None) else model_config.num_decoders
240     if num_inputs > 1:
241         change_names_dict = {'^conv1.': ['encoder.features.stream{}.conv1.'.format(stream) for stream in range(num_inputs)],
242                             '^bn1.': ['encoder.features.stream{}.bn1.'.format(stream) for stream in range(num_inputs)],
243                             '^relu.': ['encoder.features.stream{}.relu.'.format(stream) for stream in range(num_inputs)],
244                             '^maxpool.': ['encoder.features.stream{}.maxpool.'.format(stream) for stream in range(num_inputs)],
245                             '^layer': ['encoder.features.stream{}.layer'.format(stream) for stream in range(num_inputs)],
246                             '^features.': ['encoder.features.stream{}.'.format(stream) for stream in range(num_inputs)],
247                             '^encoder.features.': ['encoder.features.stream{}.'.format(stream) for stream in range(num_inputs)],
248                             '^decoders.0.': ['decoders.{}.'.format(d) for d in range(num_decoders)]}
249     else:
250         change_names_dict = {'^conv1.': 'encoder.features.conv1.',
251                              '^bn1.': 'encoder.features.bn1.',
252                              '^relu.': 'encoder.features.relu.',
253                              '^maxpool.': 'encoder.features.maxpool.',
254                              '^layer': 'encoder.features.layer',
255                              '^features.': 'encoder.features.',
256                              '^decoders.0.': ['decoders.{}.'.format(d) for d in range(num_decoders)]}
257     #
259     if pretrained:
260         model = xnn.utils.load_weights(model, pretrained, change_names_dict, ignore_size=True, verbose=True)
262     return model, change_names_dict
265 def deeplabv3lite_resnet50_p5(model_config, pretrained=None):
266     model_config.width_mult = 0.5
267     model_config.shortcut_channels = (128,1024)
268     return deeplabv3lite_resnet50(model_config, pretrained=pretrained)
271 def deeplabv3lite_resnet50_p5_fd(model_config, pretrained=None):
272     model_config.width_mult = 0.5
273     model_config.fastdown = True
274     model_config.shortcut_channels = (128,1024)
275     model_config.shortcut_strides = (8,64)
276     return deeplabv3lite_resnet50(model_config, pretrained=pretrained)
279 ###########################################
280 # config settings for mobilenetv2 backbone
281 def get_config_deeplav3lite_regnetx800mf():
282     # only the delta compared to the one defined for mobilenetv2
283     model_config = get_config_deeplav3lite_mnv2()
284     model_config.shortcut_channels = (64,672)
285     model_config.group_size_dw = 16
286     return model_config
289 # here this is nothing specific about bgr in this model
290 # but is just a reminder that regnet models are typically trained with bgr input
291 def deeplabv3lite_regnetx800mf(model_config, pretrained=None):
292     model_config = get_config_deeplav3lite_regnetx800mf().merge_from(model_config)
293     # encoder setup
294     model_config_e = model_config.clone()
295     base_model = RegNetX800MFMI4(model_config_e)
296     # decoder setup
297     model = DeepLabV3Lite(base_model, model_config)
299     # the pretrained model provided by torchvision and what is defined here differs slightly
300     # note: that this change_names_dict  will take effect only if the direct load fails
301     # finally take care of the change for deeplabv3lite (features->encoder.features)
302     num_inputs = len(model_config.input_channels)
303     num_decoders = len(model_config.output_channels) if (model_config.num_decoders is None) else model_config.num_decoders
304     if num_inputs > 1:
305         change_names_dict = {'^features.': ['encoder.features.stream{}.'.format(stream) for stream in range(num_inputs)],
306                              '^encoder.features.': ['encoder.features.stream{}.'.format(stream) for stream in range(num_inputs)],
307                              '^decoders.0.': ['decoders.{}.'.format(d) for d in range(num_decoders)]}
308     else:
309         change_names_dict = {'^stem.': 'encoder.features.stem.',
310                              '^s1': 'encoder.features.s1', '^s2': 'encoder.features.s2',
311                              '^s3': 'encoder.features.s3', '^s4': 'encoder.features.s4',
312                              '^features.': 'encoder.features.',
313                              '^decoders.0.': ['decoders.{}.'.format(d) for d in range(num_decoders)]}
314     #
316     if pretrained:
317         model = xnn.utils.load_weights(model, pretrained, change_names_dict, ignore_size=True, verbose=True,
318                                        state_dict_name=['state_dict','model_state'])
319     else:
320         # need to use state_dict_name as the checkpoint uses a different name for state_dict
321         # provide a custom load_weighs for the model
322         def load_weights_func(pretrained, change_names_dict, ignore_size=True, verbose=True,
323                                        state_dict_name=['state_dict','model_state']):
324             xnn.utils.load_weights(model, pretrained, change_names_dict=change_names_dict, ignore_size=ignore_size, verbose=verbose,
325                                            state_dict_name=state_dict_name)
326         #
327         model.load_weights = load_weights_func
329     return model, change_names_dict