]> Gitweb @ Texas Instruments - Open Source Git Repositories - git.TI.com/gitweb - jacinto-ai/pytorch-jacinto-ai-devkit.git/blob - modules/pytorch_jacinto_ai/vision/models/pixel2pixel/deeplabv3lite.py
8ba9b194140dfc7ab464e77c0c648259ec6b2e3d
[jacinto-ai/pytorch-jacinto-ai-devkit.git] / modules / pytorch_jacinto_ai / vision / models / pixel2pixel / deeplabv3lite.py
1 import torch
2 import numpy as np
4 from .... import xnn
5 from .pixel2pixelnet import *
7 try: from .pixel2pixelnet_internal import *
8 except: pass
10 from ..multi_input_net import MobileNetV2TVMI4, MobileNetV2EricsunMI4, ResNet50MI4
12 ###########################################
13 __all__ = ['DeepLabV3Lite', 'DeepLabV3LiteDecoder',
14            'deeplabv3lite_mobilenetv2_tv', 'deeplabv3lite_mobilenetv2_ericsun',
15            'deeplabv3lite_resnet50']
18 ###########################################
19 class DeepLabV3LiteDecoder(torch.nn.Module):
20     def __init__(self, model_config):
21         super().__init__()
23         self.model_config = model_config
25         current_channels = model_config.shortcut_channels[-1]
26         decoder_channels = round(model_config.decoder_chan*model_config.decoder_factor)
27         aspp_channels = round(model_config.aspp_chan*model_config.decoder_factor)
29         if model_config.use_aspp:
30             ASPPBlock = xnn.layers.GWASPPLiteBlock if model_config.groupwise_sep else xnn.layers.DWASPPLiteBlock
31             self.aspp = ASPPBlock(current_channels, aspp_channels, decoder_channels, dilation=model_config.aspp_dil,
32                                               activation=model_config.activation, linear_dw=model_config.linear_dw)
33         else:
34             self.aspp = None
36         current_channels = decoder_channels if model_config.use_aspp else current_channels
38         short_chan = model_config.shortcut_channels[0]
39         self.shortcut = xnn.layers.ConvNormAct2d(short_chan, model_config.shortcut_out, kernel_size=1, activation=model_config.activation)
41         self.decoder_channels = merged_channels = (current_channels+model_config.shortcut_out)
43         # prediction
44         if self.model_config.final_prediction:
45             ConvXWSepBlock = xnn.layers.ConvGWSepNormAct2d if model_config.groupwise_sep else xnn.layers.ConvDWSepNormAct2d
46             final_activation = xnn.layers.get_fixed_pact2(output_range = model_config.output_range) \
47                     if (model_config.output_range is not None) else False
48             self.pred = ConvXWSepBlock(merged_channels, model_config.output_channels, kernel_size=3, normalization=((not model_config.linear_dw),False),
49                 activation=(False,final_activation), groups = 1)
51         self.cat = xnn.layers.CatBlock()
53         upstride1 = model_config.shortcut_strides[-1]//model_config.shortcut_strides[0]
54         upstride2 = model_config.shortcut_strides[0]
55         self.upsample1 = xnn.layers.UpsampleTo(decoder_channels, decoder_channels, upstride1, model_config.interpolation_type, model_config.interpolation_mode)
56         self.upsample2 = xnn.layers.UpsampleTo(model_config.output_channels, model_config.output_channels, upstride2, model_config.interpolation_type, model_config.interpolation_mode)
58     # the upsampling is using functional form to support size based upsampling for odd sizes
59     # that are not a perfect ratio (eg. 257x513), which seem to be popular for segmentation
60     def forward(self, x, x_features, x_list):
61         assert isinstance(x, (list,tuple)) and len(x)<=2, 'incorrect input'
63         x_input = x[0]
64         in_shape = x_input.shape
66         # high res shortcut
67         shape_s = xnn.utils.get_shape_with_stride(in_shape, self.model_config.shortcut_strides[0])
68         shape_s[1] = self.model_config.shortcut_channels[0]
69         x_s = xnn.utils.get_blob_from_list(x_list, shape_s)
70         x_s = self.shortcut(x_s)
72         if self.model_config.freeze_encoder:
73             x_s = x_s.detach()
74             x_features = x_features.detach()
76         # aspp/scse blocks at output stride
77         x = self.aspp(x_features) if self.model_config.use_aspp else x_features
79         # upsample low res features to match with shortcut
80         x = self.upsample1((x, x_s))
82         # combine and do high res prediction
83         x = self.cat((x,x_s))
85         if self.model_config.final_prediction:
86             x = self.pred(x)
88             if self.model_config.final_upsample:
89                 # final prediction is the upsampled one
90                 #scale_factor = (in_shape[2]/x.shape[2], in_shape[3]/x.shape[3])
91                 x = self.upsample2((x, x_input))
93             if (not self.training) and (self.model_config.output_type == 'segmentation'):
94                 x = torch.argmax(x, dim=1, keepdim=True)
96             assert int(in_shape[2]) == int(x.shape[2]) and int(in_shape[3]) == int(x.shape[3]), 'incorrect output shape'
98         if self.model_config.freeze_decoder:
99             x = x.detach()
101         return x
104 class DeepLabV3Lite(Pixel2PixelNet):
105     def __init__(self, base_model, model_config):
106         super().__init__(base_model, DeepLabV3LiteDecoder, model_config)
109 ###########################################
110 # config settings
111 def get_config_deeplav3lite_mnv2():
112     # use list for entries that are different for different decoders.
113     # and are expected to be passed from the main script.
114     model_config = xnn.utils.ConfigNode()
115     model_config.num_classes = None
116     model_config.input_channels = (3,)
117     model_config.output_channels = [19]
118     model_config.intermediate_outputs = True
119     model_config.normalize_input = False
120     model_config.split_outputs = False
121     model_config.use_aspp = True
122     model_config.strides = (2,2,2,2,1)
123     model_config.groupwise_sep = False
124     encoder_stride = np.prod(model_config.strides)
125     model_config.shortcut_strides = (4,encoder_stride)
126     model_config.shortcut_channels = (24,320) # this is for mobilenetv2 - change for other networks
127     model_config.shortcut_out = 48
128     model_config.decoder_chan = 256
129     model_config.aspp_chan = 256
130     model_config.aspp_dil = (6,12,18)
131     model_config.final_prediction = True
132     model_config.final_upsample = True
133     model_config.output_range = None
134     model_config.decoder_factor = 1.0
135     model_config.output_type = None
136     model_config.activation = xnn.layers.DefaultAct2d
137     model_config.interpolation_type = 'upsample'
138     model_config.interpolation_mode = 'bilinear'
139     model_config.linear_dw = False
140     model_config.normalize_gradients = False
141     model_config.freeze_encoder = False
142     model_config.freeze_decoder = False
143     model_config.multi_task = False
144     return model_config
147 def deeplabv3lite_mobilenetv2_tv(model_config, pretrained=None):
148     model_config = get_config_deeplav3lite_mnv2().merge_from(model_config)
149     # encoder setup
150     model_config_e = model_config.clone()
151     base_model = MobileNetV2TVMI4(model_config_e)
152     # decoder setup
153     model = DeepLabV3Lite(base_model, model_config)
155     num_inputs = len(model_config.input_channels)
156     num_decoders = len(model_config.output_channels) if (
157                 model_config.num_decoders is None) else model_config.num_decoders
158     if num_inputs > 1:
159         change_names_dict = {'^features.': ['encoder.features.stream{}.'.format(stream) for stream in range(num_inputs)],
160                             '^encoder.features.': ['encoder.features.stream{}.'.format(stream) for stream in range(num_inputs)],
161                             '^decoders.0.': ['decoders.{}.'.format(d) for d in range(num_decoders)]}
162     else:
163         change_names_dict = {'^features.': 'encoder.features.',
164                              '^decoders.0.': ['decoders.{}.'.format(d) for d in range(num_decoders)]}
165     #
167     if pretrained:
168         model = xnn.utils.load_weights_check(model, pretrained, change_names_dict, ignore_size=True, verbose=True)
170     return model, change_names_dict
173 def deeplabv3lite_mobilenetv2_ericsun(model_config, pretrained=None):
174     model_config = get_config_deeplav3lite_mnv2().merge_from(model_config)
175     # encoder setup
176     model_config_e = model_config.clone()
177     base_model = MobileNetV2EricsunMI4(model_config_e)
178     # decoder setup
179     model = DeepLabV3Lite(base_model, model_config)
181     num_inputs = len(model_config.input_channels)
182     num_decoders = len(model_config.output_channels) if (model_config.num_decoders is None) else model_config.num_decoders
183     if num_inputs > 1:
184         change_names_dict = {
185             '^features.': ['encoder.features.stream{}.'.format(stream) for stream in range(num_inputs)],
186             '^encoder.features.': ['encoder.features.stream{}.'.format(stream) for stream in range(num_inputs)],
187             '^decoders.0.': ['decoders.{}.'.format(d) for d in range(num_decoders)]}
188     else:
189         change_names_dict = {'^features.': 'encoder.features.',
190                              '^decoders.0.': ['decoders.{}.'.format(d) for d in range(num_decoders)]}
191     #
193     if pretrained:
194         model = xnn.utils.load_weights_check(model, pretrained, change_names_dict, ignore_size=True, verbose=True)
196     return model, change_names_dict
200 ###########################################
201 # config settings for mobilenetv2 backbone
202 def get_config_deeplav3lite_resnet50():
203     # only the delta compared to the one defined for mobilenetv2
204     model_config = get_config_deeplav3lite_mnv2()
205     model_config.shortcut_channels = (256,2048)
206     return model_config
209 def deeplabv3lite_resnet50(model_config, pretrained=None):
210     model_config = get_config_deeplav3lite_resnet50().merge_from(model_config)
211     # encoder setup
212     model_config_e = model_config.clone()
213     base_model = ResNet50MI4(model_config_e)
214     # decoder setup
215     model = DeepLabV3Lite(base_model, model_config)
217     # the pretrained model provided by torchvision and what is defined here differs slightly
218     # note: that this change_names_dict  will take effect only if the direct load fails
219     # finally take care of the change for deeplabv3lite (features->encoder.features)
220     num_inputs = len(model_config.input_channels)
221     num_decoders = len(model_config.output_channels) if (model_config.num_decoders is None) else model_config.num_decoders
222     if num_inputs > 1:
223         change_names_dict = {'^conv1.': ['encoder.features.stream{}.conv1.'.format(stream) for stream in range(num_inputs)],
224                             '^bn1.': ['encoder.features.stream{}.bn1.'.format(stream) for stream in range(num_inputs)],
225                             '^relu.': ['encoder.features.stream{}.relu.'.format(stream) for stream in range(num_inputs)],
226                             '^maxpool.': ['encoder.features.stream{}.maxpool.'.format(stream) for stream in range(num_inputs)],
227                             '^layer': ['encoder.features.stream{}.layer'.format(stream) for stream in range(num_inputs)],
228                             '^features.': ['encoder.features.stream{}.'.format(stream) for stream in range(num_inputs)],
229                             '^encoder.features.': ['encoder.features.stream{}.'.format(stream) for stream in range(num_inputs)],
230                             '^decoders.0.': ['decoders.{}.'.format(d) for d in range(num_decoders)]}
231     else:
232         change_names_dict = {'^conv1.': 'encoder.features.conv1.',
233                              '^bn1.': 'encoder.features.bn1.',
234                              '^relu.': 'encoder.features.relu.',
235                              '^maxpool.': 'encoder.features.maxpool.',
236                              '^layer': 'encoder.features.layer',
237                              '^features.': 'encoder.features.',
238                              '^decoders.0.': ['decoders.{}.'.format(d) for d in range(num_decoders)]}
239     #
241     if pretrained:
242         model = xnn.utils.load_weights_check(model, pretrained, change_names_dict, ignore_size=True, verbose=True)
244     return model, change_names_dict