]> Gitweb @ Texas Instruments - Open Source Git Repositories - git.TI.com/gitweb - jacinto-ai/pytorch-jacinto-ai-devkit.git/blob - modules/pytorch_jacinto_ai/vision/models/pixel2pixel/deeplabv3lite.py
updated python package requirements (don't need tensorflow for tensorboard). not...
[jacinto-ai/pytorch-jacinto-ai-devkit.git] / modules / pytorch_jacinto_ai / vision / models / pixel2pixel / deeplabv3lite.py
1 import torch
2 import numpy as np
4 from .... import xnn
5 from .pixel2pixelnet import *
7 try: from .pixel2pixelnet_internal import *
8 except: pass
10 from ..multi_input_net import MobileNetV2TVMI4, MobileNetV2EricsunMI4, ResNet50MI4
12 ###########################################
13 __all__ = ['DeepLabV3Lite', 'DeepLabV3LiteDecoder',
14            'deeplabv3lite_mobilenetv2_tv', 'deeplabv3lite_mobilenetv2_tv_fd',
15            'deeplabv3lite_mobilenetv2_ericsun',
16            'deeplabv3lite_resnet50', 'deeplabv3lite_resnet50_p5', 'deeplabv3lite_resnet50_p5_fd']
19 ###########################################
20 class DeepLabV3LiteDecoder(torch.nn.Module):
21     def __init__(self, model_config):
22         super().__init__()
24         self.model_config = model_config
26         current_channels = model_config.shortcut_channels[-1]
27         decoder_channels = round(model_config.decoder_chan*model_config.decoder_factor)
28         aspp_channels = round(model_config.aspp_chan*model_config.decoder_factor)
30         if model_config.use_aspp:
31             ASPPBlock = xnn.layers.GWASPPLiteBlock if model_config.groupwise_sep else xnn.layers.DWASPPLiteBlock
32             self.aspp = ASPPBlock(current_channels, aspp_channels, decoder_channels, dilation=model_config.aspp_dil,
33                                               activation=model_config.activation, linear_dw=model_config.linear_dw)
34         else:
35             self.aspp = None
37         current_channels = decoder_channels if model_config.use_aspp else current_channels
39         short_chan = model_config.shortcut_channels[0]
40         self.shortcut = xnn.layers.ConvNormAct2d(short_chan, model_config.shortcut_out, kernel_size=1, activation=model_config.activation)
42         self.decoder_channels = merged_channels = (current_channels+model_config.shortcut_out)
44         upstride1 = model_config.shortcut_strides[-1]//model_config.shortcut_strides[0]
45         # use UpsampleWithGeneric() instead of UpsampleWith() to break down large upsampling factors to multiples of 4 and 2 -
46         # useful if upsampling factors other than 4 and 2 are not supported.
47         self.upsample1 = xnn.layers.UpsampleWith(decoder_channels, decoder_channels, upstride1, model_config.interpolation_type, model_config.interpolation_mode)
49         self.cat = xnn.layers.CatBlock()
51         # add prediction & upsample modules
52         if self.model_config.final_prediction:
53             add_lite_prediction_modules(self, model_config, merged_channels, module_names=('pred','upsample2'))
54         #
57     # the upsampling is using functional form to support size based upsampling for odd sizes
58     # that are not a perfect ratio (eg. 257x513), which seem to be popular for segmentation
59     def forward(self, x, x_features, x_list):
60         assert isinstance(x, (list,tuple)) and len(x)<=2, 'incorrect input'
62         x_input = x[0]
63         in_shape = x_input[0].shape if isinstance(x_input, (list,tuple)) else x_input.shape
65         # high res shortcut
66         shape_s = xnn.utils.get_shape_with_stride(in_shape, self.model_config.shortcut_strides[0])
67         shape_s[1] = self.model_config.shortcut_channels[0]
68         x_s = xnn.utils.get_blob_from_list(x_list, shape_s)
69         x_s = self.shortcut(x_s)
71         if self.model_config.freeze_encoder:
72             x_s = x_s.detach()
73             x_features = x_features.detach()
75         # aspp/scse blocks at output stride
76         x = self.aspp(x_features) if self.model_config.use_aspp else x_features
78         # upsample low res features to match with shortcut
79         x = self.upsample1(x)
81         # combine and do high res prediction
82         x = self.cat((x,x_s))
84         if self.model_config.final_prediction:
85             x = self.pred(x)
87             if self.model_config.final_upsample:
88                 x = self.upsample2(x)
90             if (not self.training) and (self.model_config.output_type == 'segmentation'):
91                 x = torch.argmax(x, dim=1, keepdim=True)
93             assert int(in_shape[2]) == int(x.shape[2]*self.model_config.target_input_ratio) and int(in_shape[3]) == int(x.shape[3]*self.model_config.target_input_ratio), 'incorrect output shape'
95         if self.model_config.freeze_decoder:
96             x = x.detach()
98         return x
101 class DeepLabV3Lite(Pixel2PixelNet):
102     def __init__(self, base_model, model_config):
103         super().__init__(base_model, DeepLabV3LiteDecoder, model_config)
106 ###########################################
107 # config settings
108 def get_config_deeplav3lite_mnv2():
109     # use list for entries that are different for different decoders.
110     # and are expected to be passed from the main script.
111     model_config = xnn.utils.ConfigNode()
112     model_config.num_classes = None
113     model_config.num_decoders = None
114     model_config.input_channels = (3,)
115     model_config.output_channels = [19]
116     model_config.intermediate_outputs = True
117     model_config.normalize_input = False
118     model_config.split_outputs = False
119     model_config.use_aspp = True
120     model_config.fastdown = False
121     model_config.target_input_ratio = 1
123     model_config.strides = (2,2,2,2,1)
124     model_config.fastdown = False
125     model_config.groupwise_sep = False
126     encoder_stride = np.prod(model_config.strides)
127     model_config.shortcut_strides = (4,encoder_stride)
128     model_config.shortcut_channels = (24,320) # this is for mobilenetv2 - change for other networks
129     model_config.shortcut_out = 48
130     model_config.decoder_chan = 256
131     model_config.aspp_chan = 256
132     model_config.aspp_dil = (6,12,18)
133     model_config.final_prediction = True
134     model_config.final_upsample = True
135     model_config.output_range = None
136     model_config.decoder_factor = 1.0
137     model_config.output_type = None
138     model_config.activation = xnn.layers.DefaultAct2d
139     model_config.interpolation_type = 'upsample'
140     model_config.interpolation_mode = 'bilinear'
141     model_config.linear_dw = False
142     model_config.normalize_gradients = False
143     model_config.freeze_encoder = False
144     model_config.freeze_decoder = False
145     model_config.multi_task = False
146     return model_config
149 def deeplabv3lite_mobilenetv2_tv(model_config, pretrained=None):
150     model_config = get_config_deeplav3lite_mnv2().merge_from(model_config)
151     # encoder setup
152     model_config_e = model_config.clone()
153     base_model = MobileNetV2TVMI4(model_config_e)
154     # decoder setup
155     model = DeepLabV3Lite(base_model, model_config)
157     num_inputs = len(model_config.input_channels)
158     num_decoders = len(model_config.output_channels) if (
159                 model_config.num_decoders is None) else model_config.num_decoders
160     if num_inputs > 1:
161         change_names_dict = {'^features.': ['encoder.features.stream{}.'.format(stream) for stream in range(num_inputs)],
162                             '^encoder.features.': ['encoder.features.stream{}.'.format(stream) for stream in range(num_inputs)],
163                             '^decoders.0.': ['decoders.{}.'.format(d) for d in range(num_decoders)]}
164     else:
165         change_names_dict = {'^features.': 'encoder.features.',
166                              '^decoders.0.': ['decoders.{}.'.format(d) for d in range(num_decoders)]}
167     #
169     if pretrained:
170         model = xnn.utils.load_weights(model, pretrained, change_names_dict, ignore_size=True, verbose=True)
172     return model, change_names_dict
175 def deeplabv3lite_mobilenetv2_tv_fd(model_config, pretrained=None):
176     model_config = get_config_deeplav3lite_mnv2().merge_from(model_config)
177     model_config.fastdown = True
178     model_config.strides = (2,2,2,2,1)
179     model_config.shortcut_strides = (8,32)
180     model_config.shortcut_channels = (24,320)
181     model_config.decoder_chan = 256
182     model_config.aspp_chan = 256
183     return deeplabv3lite_mobilenetv2_tv(model_config, pretrained=pretrained)
186 def deeplabv3lite_mobilenetv2_ericsun(model_config, pretrained=None):
187     model_config = get_config_deeplav3lite_mnv2().merge_from(model_config)
188     # encoder setup
189     model_config_e = model_config.clone()
190     base_model = MobileNetV2EricsunMI4(model_config_e)
191     # decoder setup
192     model = DeepLabV3Lite(base_model, model_config)
194     num_inputs = len(model_config.input_channels)
195     num_decoders = len(model_config.output_channels) if (model_config.num_decoders is None) else model_config.num_decoders
196     if num_inputs > 1:
197         change_names_dict = {
198             '^features.': ['encoder.features.stream{}.'.format(stream) for stream in range(num_inputs)],
199             '^encoder.features.': ['encoder.features.stream{}.'.format(stream) for stream in range(num_inputs)],
200             '^decoders.0.': ['decoders.{}.'.format(d) for d in range(num_decoders)]}
201     else:
202         change_names_dict = {'^features.': 'encoder.features.',
203                              '^decoders.0.': ['decoders.{}.'.format(d) for d in range(num_decoders)]}
204     #
206     if pretrained:
207         model = xnn.utils.load_weights(model, pretrained, change_names_dict, ignore_size=True, verbose=True)
209     return model, change_names_dict
213 ###########################################
214 # config settings for mobilenetv2 backbone
215 def get_config_deeplav3lite_resnet50():
216     # only the delta compared to the one defined for mobilenetv2
217     model_config = get_config_deeplav3lite_mnv2()
218     model_config.shortcut_channels = (256,2048)
219     return model_config
222 def deeplabv3lite_resnet50(model_config, pretrained=None):
223     model_config = get_config_deeplav3lite_resnet50().merge_from(model_config)
224     # encoder setup
225     model_config_e = model_config.clone()
226     base_model = ResNet50MI4(model_config_e)
227     # decoder setup
228     model = DeepLabV3Lite(base_model, model_config)
230     # the pretrained model provided by torchvision and what is defined here differs slightly
231     # note: that this change_names_dict  will take effect only if the direct load fails
232     # finally take care of the change for deeplabv3lite (features->encoder.features)
233     num_inputs = len(model_config.input_channels)
234     num_decoders = len(model_config.output_channels) if (model_config.num_decoders is None) else model_config.num_decoders
235     if num_inputs > 1:
236         change_names_dict = {'^conv1.': ['encoder.features.stream{}.conv1.'.format(stream) for stream in range(num_inputs)],
237                             '^bn1.': ['encoder.features.stream{}.bn1.'.format(stream) for stream in range(num_inputs)],
238                             '^relu.': ['encoder.features.stream{}.relu.'.format(stream) for stream in range(num_inputs)],
239                             '^maxpool.': ['encoder.features.stream{}.maxpool.'.format(stream) for stream in range(num_inputs)],
240                             '^layer': ['encoder.features.stream{}.layer'.format(stream) for stream in range(num_inputs)],
241                             '^features.': ['encoder.features.stream{}.'.format(stream) for stream in range(num_inputs)],
242                             '^encoder.features.': ['encoder.features.stream{}.'.format(stream) for stream in range(num_inputs)],
243                             '^decoders.0.': ['decoders.{}.'.format(d) for d in range(num_decoders)]}
244     else:
245         change_names_dict = {'^conv1.': 'encoder.features.conv1.',
246                              '^bn1.': 'encoder.features.bn1.',
247                              '^relu.': 'encoder.features.relu.',
248                              '^maxpool.': 'encoder.features.maxpool.',
249                              '^layer': 'encoder.features.layer',
250                              '^features.': 'encoder.features.',
251                              '^decoders.0.': ['decoders.{}.'.format(d) for d in range(num_decoders)]}
252     #
254     if pretrained:
255         model = xnn.utils.load_weights(model, pretrained, change_names_dict, ignore_size=True, verbose=True)
257     return model, change_names_dict
260 def deeplabv3lite_resnet50_p5(model_config, pretrained=None):
261     model_config.width_mult = 0.5
262     model_config.shortcut_channels = (128,1024)
263     return deeplabv3lite_resnet50(model_config, pretrained=pretrained)
266 def deeplabv3lite_resnet50_p5_fd(model_config, pretrained=None):
267     model_config.width_mult = 0.5
268     model_config.fastdown = True
269     model_config.shortcut_channels = (128,1024)
270     model_config.shortcut_strides = (8,64)
271     return deeplabv3lite_resnet50(model_config, pretrained=pretrained)