]> Gitweb @ Texas Instruments - Open Source Git Repositories - git.TI.com/gitweb - jacinto-ai/pytorch-jacinto-ai-devkit.git/blob - modules/pytorch_jacinto_ai/vision/models/pixel2pixel/deeplabv3lite.py
2252e385ade719a7c839b8a3d3db75dbc2cb62f1
[jacinto-ai/pytorch-jacinto-ai-devkit.git] / modules / pytorch_jacinto_ai / vision / models / pixel2pixel / deeplabv3lite.py
1 import torch
2 import numpy as np
4 from .... import xnn
5 from .pixel2pixelnet import *
7 try: from .pixel2pixelnet_internal import *
8 except: pass
10 from ..multi_input_net import MobileNetV2TVMI4, MobileNetV2EricsunMI4, ResNet50MI4
12 ###########################################
13 __all__ = ['DeepLabV3Lite', 'DeepLabV3LiteDecoder',
14            'deeplabv3lite_mobilenetv2_tv', 'deeplabv3lite_mobilenetv2_ericsun',
15            'deeplabv3lite_resnet50']
18 ###########################################
19 class DeepLabV3LiteDecoder(torch.nn.Module):
20     def __init__(self, model_config):
21         super().__init__()
23         self.model_config = model_config
25         current_channels = model_config.shortcut_channels[-1]
26         decoder_channels = round(model_config.decoder_chan*model_config.decoder_factor)
27         aspp_channels = round(model_config.aspp_chan*model_config.decoder_factor)
29         if model_config.use_aspp:
30             ASPPBlock = xnn.layers.GWASPPLiteBlock if model_config.groupwise_sep else xnn.layers.DWASPPLiteBlock
31             self.aspp = ASPPBlock(current_channels, aspp_channels, decoder_channels, dilation=model_config.aspp_dil,
32                                               activation=model_config.activation, linear_dw=model_config.linear_dw)
33         else:
34             self.aspp = None
36         current_channels = decoder_channels if model_config.use_aspp else current_channels
38         short_chan = model_config.shortcut_channels[0]
39         self.shortcut = xnn.layers.ConvNormAct2d(short_chan, model_config.shortcut_out, kernel_size=1, activation=model_config.activation)
41         self.decoder_channels = merged_channels = (current_channels+model_config.shortcut_out)
43         # prediction
44         if self.model_config.final_prediction:
45             ConvXWSepBlock = xnn.layers.ConvGWSepNormAct2d if model_config.groupwise_sep else xnn.layers.ConvDWSepNormAct2d
46             final_activation = xnn.layers.get_fixed_pact2(output_range = model_config.output_range) \
47                     if (model_config.output_range is not None) else False
48             self.pred = ConvXWSepBlock(merged_channels, model_config.output_channels, kernel_size=3, normalization=((not model_config.linear_dw),False),
49                 activation=(False,final_activation), groups = 1)
51         self.cat = xnn.layers.CatBlock()
53         upstride1 = model_config.shortcut_strides[-1]//model_config.shortcut_strides[0]
54         upstride2 = model_config.shortcut_strides[0]
55         self.upsample1 = xnn.layers.UpsampleTo(decoder_channels, decoder_channels, upstride1, model_config.interpolation_type, model_config.interpolation_mode)
56         self.upsample2 = xnn.layers.UpsampleTo(model_config.output_channels, model_config.output_channels, upstride2, model_config.interpolation_type, model_config.interpolation_mode)
58     # the upsampling is using functional form to support size based upsampling for odd sizes
59     # that are not a perfect ratio (eg. 257x513), which seem to be popular for segmentation
60     def forward(self, x, x_features, x_list):
61         assert isinstance(x, (list,tuple)) and len(x)<=2, 'incorrect input'
63         x_input = x[0]
64         in_shape = x_input.shape
66         # high res shortcut
67         shape_s = xnn.utils.get_shape_with_stride(in_shape, self.model_config.shortcut_strides[0])
68         shape_s[1] = self.model_config.shortcut_channels[0]
69         x_s = xnn.utils.get_blob_from_list(x_list, shape_s)
70         x_s = self.shortcut(x_s)
72         if self.model_config.freeze_encoder:
73             x_s = x_s.detach()
74             x_features = x_features.detach()
76         # aspp/scse blocks at output stride
77         x = self.aspp(x_features) if self.model_config.use_aspp else x_features
79         # upsample low res features to match with shortcut
80         x = self.upsample1((x, x_s))
82         # combine and do high res prediction
83         x = self.cat((x,x_s))
85         if self.model_config.final_prediction:
86             x = self.pred(x)
88             if self.model_config.final_upsample:
89                 # final prediction is the upsampled one
90                 #scale_factor = (in_shape[2]/x.shape[2], in_shape[3]/x.shape[3])
91                 x = self.upsample2((x, x_input))
93             if (not self.training) and (self.model_config.output_type == 'segmentation'):
94                 x = torch.argmax(x, dim=1, keepdim=True)
96             assert int(in_shape[2]) == int(x.shape[2]) and int(in_shape[3]) == int(x.shape[3]), 'incorrect output shape'
98         if self.model_config.freeze_decoder:
99             x = x.detach()
101         return x
104 class DeepLabV3Lite(Pixel2PixelNet):
105     def __init__(self, base_model, model_config):
106         super().__init__(base_model, DeepLabV3LiteDecoder, model_config)
109 ###########################################
110 # config settings
111 def get_config_deeplav3lite_mnv2():
112     # use list for entries that are different for different decoders.
113     # and are expected to be passed from the main script.
114     model_config = xnn.utils.ConfigNode()
115     model_config.num_classes = None
116     model_config.num_decoders = None
117     model_config.input_channels = (3,)
118     model_config.output_channels = [19]
119     model_config.intermediate_outputs = True
120     model_config.normalize_input = False
121     model_config.split_outputs = False
122     model_config.use_aspp = True
123     model_config.strides = (2,2,2,2,1)
124     model_config.groupwise_sep = False
125     encoder_stride = np.prod(model_config.strides)
126     model_config.shortcut_strides = (4,encoder_stride)
127     model_config.shortcut_channels = (24,320) # this is for mobilenetv2 - change for other networks
128     model_config.shortcut_out = 48
129     model_config.decoder_chan = 256
130     model_config.aspp_chan = 256
131     model_config.aspp_dil = (6,12,18)
132     model_config.final_prediction = True
133     model_config.final_upsample = True
134     model_config.output_range = None
135     model_config.decoder_factor = 1.0
136     model_config.output_type = None
137     model_config.activation = xnn.layers.DefaultAct2d
138     model_config.interpolation_type = 'upsample'
139     model_config.interpolation_mode = 'bilinear'
140     model_config.linear_dw = False
141     model_config.normalize_gradients = False
142     model_config.freeze_encoder = False
143     model_config.freeze_decoder = False
144     model_config.multi_task = False
145     return model_config
148 def deeplabv3lite_mobilenetv2_tv(model_config, pretrained=None):
149     model_config = get_config_deeplav3lite_mnv2().merge_from(model_config)
150     # encoder setup
151     model_config_e = model_config.clone()
152     base_model = MobileNetV2TVMI4(model_config_e)
153     # decoder setup
154     model = DeepLabV3Lite(base_model, model_config)
156     num_inputs = len(model_config.input_channels)
157     num_decoders = len(model_config.output_channels) if (
158                 model_config.num_decoders is None) else model_config.num_decoders
159     if num_inputs > 1:
160         change_names_dict = {'^features.': ['encoder.features.stream{}.'.format(stream) for stream in range(num_inputs)],
161                             '^encoder.features.': ['encoder.features.stream{}.'.format(stream) for stream in range(num_inputs)],
162                             '^decoders.0.': ['decoders.{}.'.format(d) for d in range(num_decoders)]}
163     else:
164         change_names_dict = {'^features.': 'encoder.features.',
165                              '^decoders.0.': ['decoders.{}.'.format(d) for d in range(num_decoders)]}
166     #
168     if pretrained:
169         model = xnn.utils.load_weights(model, pretrained, change_names_dict, ignore_size=True, verbose=True)
171     return model, change_names_dict
174 def deeplabv3lite_mobilenetv2_ericsun(model_config, pretrained=None):
175     model_config = get_config_deeplav3lite_mnv2().merge_from(model_config)
176     # encoder setup
177     model_config_e = model_config.clone()
178     base_model = MobileNetV2EricsunMI4(model_config_e)
179     # decoder setup
180     model = DeepLabV3Lite(base_model, model_config)
182     num_inputs = len(model_config.input_channels)
183     num_decoders = len(model_config.output_channels) if (model_config.num_decoders is None) else model_config.num_decoders
184     if num_inputs > 1:
185         change_names_dict = {
186             '^features.': ['encoder.features.stream{}.'.format(stream) for stream in range(num_inputs)],
187             '^encoder.features.': ['encoder.features.stream{}.'.format(stream) for stream in range(num_inputs)],
188             '^decoders.0.': ['decoders.{}.'.format(d) for d in range(num_decoders)]}
189     else:
190         change_names_dict = {'^features.': 'encoder.features.',
191                              '^decoders.0.': ['decoders.{}.'.format(d) for d in range(num_decoders)]}
192     #
194     if pretrained:
195         model = xnn.utils.load_weights(model, pretrained, change_names_dict, ignore_size=True, verbose=True)
197     return model, change_names_dict
201 ###########################################
202 # config settings for mobilenetv2 backbone
203 def get_config_deeplav3lite_resnet50():
204     # only the delta compared to the one defined for mobilenetv2
205     model_config = get_config_deeplav3lite_mnv2()
206     model_config.shortcut_channels = (256,2048)
207     return model_config
210 def deeplabv3lite_resnet50(model_config, pretrained=None):
211     model_config = get_config_deeplav3lite_resnet50().merge_from(model_config)
212     # encoder setup
213     model_config_e = model_config.clone()
214     base_model = ResNet50MI4(model_config_e)
215     # decoder setup
216     model = DeepLabV3Lite(base_model, model_config)
218     # the pretrained model provided by torchvision and what is defined here differs slightly
219     # note: that this change_names_dict  will take effect only if the direct load fails
220     # finally take care of the change for deeplabv3lite (features->encoder.features)
221     num_inputs = len(model_config.input_channels)
222     num_decoders = len(model_config.output_channels) if (model_config.num_decoders is None) else model_config.num_decoders
223     if num_inputs > 1:
224         change_names_dict = {'^conv1.': ['encoder.features.stream{}.conv1.'.format(stream) for stream in range(num_inputs)],
225                             '^bn1.': ['encoder.features.stream{}.bn1.'.format(stream) for stream in range(num_inputs)],
226                             '^relu.': ['encoder.features.stream{}.relu.'.format(stream) for stream in range(num_inputs)],
227                             '^maxpool.': ['encoder.features.stream{}.maxpool.'.format(stream) for stream in range(num_inputs)],
228                             '^layer': ['encoder.features.stream{}.layer'.format(stream) for stream in range(num_inputs)],
229                             '^features.': ['encoder.features.stream{}.'.format(stream) for stream in range(num_inputs)],
230                             '^encoder.features.': ['encoder.features.stream{}.'.format(stream) for stream in range(num_inputs)],
231                             '^decoders.0.': ['decoders.{}.'.format(d) for d in range(num_decoders)]}
232     else:
233         change_names_dict = {'^conv1.': 'encoder.features.conv1.',
234                              '^bn1.': 'encoder.features.bn1.',
235                              '^relu.': 'encoder.features.relu.',
236                              '^maxpool.': 'encoder.features.maxpool.',
237                              '^layer': 'encoder.features.layer',
238                              '^features.': 'encoder.features.',
239                              '^decoders.0.': ['decoders.{}.'.format(d) for d in range(num_decoders)]}
240     #
242     if pretrained:
243         model = xnn.utils.load_weights(model, pretrained, change_names_dict, ignore_size=True, verbose=True)
245     return model, change_names_dict