4bf0c0b23b6e76908e77bd51c7c5cc149218334a
[jacinto-ai/pytorch-jacinto-ai-devkit.git] / modules / pytorch_jacinto_ai / vision / models / pixel2pixel / deeplabv3lite.py
1 import torch
2 import numpy as np
4 from .... import xnn
5 from .pixel2pixelnet import *
7 try: from .pixel2pixelnet_internal import *
8 except: pass
10 from ..multi_input_net import MobileNetV2TVMI4, MobileNetV2EricsunMI4, ResNet50MI4
12 ###########################################
13 __all__ = ['DeepLabV3Lite', 'DeepLabV3LiteDecoder',
14 'deeplabv3lite_mobilenetv2_tv', 'deeplabv3lite_mobilenetv2_tv_fd',
15 'deeplabv3lite_mobilenetv2_ericsun',
16 'deeplabv3lite_resnet50', 'deeplabv3lite_resnet50_p5', 'deeplabv3lite_resnet50_p5_fd']
19 ###########################################
20 class DeepLabV3LiteDecoder(torch.nn.Module):
21 def __init__(self, model_config):
22 super().__init__()
24 self.model_config = model_config
26 current_channels = model_config.shortcut_channels[-1]
27 decoder_channels = round(model_config.decoder_chan*model_config.decoder_factor)
28 aspp_channels = round(model_config.aspp_chan*model_config.decoder_factor)
30 if model_config.use_aspp:
31 ASPPBlock = xnn.layers.GWASPPLiteBlock if model_config.groupwise_sep else xnn.layers.DWASPPLiteBlock
32 self.aspp = ASPPBlock(current_channels, aspp_channels, decoder_channels, dilation=model_config.aspp_dil,
33 activation=model_config.activation, linear_dw=model_config.linear_dw)
34 else:
35 self.aspp = None
37 current_channels = decoder_channels if model_config.use_aspp else current_channels
39 short_chan = model_config.shortcut_channels[0]
40 self.shortcut = xnn.layers.ConvNormAct2d(short_chan, model_config.shortcut_out, kernel_size=1, activation=model_config.activation)
42 self.decoder_channels = merged_channels = (current_channels+model_config.shortcut_out)
44 upstride1 = model_config.shortcut_strides[-1]//model_config.shortcut_strides[0]
45 # use UpsampleGenericTo() instead of UpsampleTo() to break down large upsampling factors to multiples of 4 and 2 -
46 # useful if upsampling factors other than 4 and 2 are not supported.
47 self.upsample1 = xnn.layers.UpsampleTo(decoder_channels, decoder_channels, upstride1, model_config.interpolation_type, model_config.interpolation_mode)
49 self.cat = xnn.layers.CatBlock()
51 # prediction
52 if self.model_config.final_prediction:
53 ConvXWSepBlock = xnn.layers.ConvGWSepNormAct2d if model_config.groupwise_sep else xnn.layers.ConvDWSepNormAct2d
54 final_activation = xnn.layers.get_fixed_pact2(output_range = model_config.output_range) if (model_config.output_range is not None) else False
55 self.pred = ConvXWSepBlock(merged_channels, model_config.output_channels, kernel_size=3, normalization=((not model_config.linear_dw),False), activation=(False,final_activation), groups=1)
56 if self.model_config.final_upsample:
57 upstride2 = model_config.shortcut_strides[0]
58 # use UpsampleGenericTo() instead of UpsampleTo() to break down large upsampling factors to multiples of 4 and 2 -
59 # useful if upsampling factors other than 4 and 2 are not supported.
60 self.upsample2 = xnn.layers.UpsampleTo(model_config.output_channels, model_config.output_channels, upstride2, model_config.interpolation_type, model_config.interpolation_mode)
61 #
63 # the upsampling is using functional form to support size based upsampling for odd sizes
64 # that are not a perfect ratio (eg. 257x513), which seem to be popular for segmentation
65 def forward(self, x, x_features, x_list):
66 assert isinstance(x, (list,tuple)) and len(x)<=2, 'incorrect input'
68 x_input = x[0]
69 in_shape = x_input.shape
71 # high res shortcut
72 shape_s = xnn.utils.get_shape_with_stride(in_shape, self.model_config.shortcut_strides[0])
73 shape_s[1] = self.model_config.shortcut_channels[0]
74 x_s = xnn.utils.get_blob_from_list(x_list, shape_s)
75 x_s = self.shortcut(x_s)
77 if self.model_config.freeze_encoder:
78 x_s = x_s.detach()
79 x_features = x_features.detach()
81 # aspp/scse blocks at output stride
82 x = self.aspp(x_features) if self.model_config.use_aspp else x_features
84 # upsample low res features to match with shortcut
85 x = self.upsample1((x, x_s))
87 # combine and do high res prediction
88 x = self.cat((x,x_s))
90 if self.model_config.final_prediction:
91 x = self.pred(x)
93 if self.model_config.final_upsample:
94 # final prediction is the upsampled one
95 #scale_factor = (in_shape[2]/x.shape[2], in_shape[3]/x.shape[3])
96 x = self.upsample2((x, x_input))
98 if (not self.training) and (self.model_config.output_type == 'segmentation'):
99 x = torch.argmax(x, dim=1, keepdim=True)
101 assert int(in_shape[2]) == int(x.shape[2]) and int(in_shape[3]) == int(x.shape[3]), 'incorrect output shape'
103 if self.model_config.freeze_decoder:
104 x = x.detach()
106 return x
109 class DeepLabV3Lite(Pixel2PixelNet):
110 def __init__(self, base_model, model_config):
111 super().__init__(base_model, DeepLabV3LiteDecoder, model_config)
114 ###########################################
115 # config settings
116 def get_config_deeplav3lite_mnv2():
117 # use list for entries that are different for different decoders.
118 # and are expected to be passed from the main script.
119 model_config = xnn.utils.ConfigNode()
120 model_config.num_classes = None
121 model_config.num_decoders = None
122 model_config.input_channels = (3,)
123 model_config.output_channels = [19]
124 model_config.intermediate_outputs = True
125 model_config.normalize_input = False
126 model_config.split_outputs = False
127 model_config.use_aspp = True
128 model_config.fastdown = False
130 model_config.strides = (2,2,2,2,1)
131 model_config.fastdown = False
132 model_config.groupwise_sep = False
133 encoder_stride = np.prod(model_config.strides)
134 model_config.shortcut_strides = (4,encoder_stride)
135 model_config.shortcut_channels = (24,320) # this is for mobilenetv2 - change for other networks
136 model_config.shortcut_out = 48
137 model_config.decoder_chan = 256
138 model_config.aspp_chan = 256
139 model_config.aspp_dil = (6,12,18)
140 model_config.final_prediction = True
141 model_config.final_upsample = True
142 model_config.output_range = None
143 model_config.decoder_factor = 1.0
144 model_config.output_type = None
145 model_config.activation = xnn.layers.DefaultAct2d
146 model_config.interpolation_type = 'upsample'
147 model_config.interpolation_mode = 'bilinear'
148 model_config.linear_dw = False
149 model_config.normalize_gradients = False
150 model_config.freeze_encoder = False
151 model_config.freeze_decoder = False
152 model_config.multi_task = False
153 return model_config
156 def deeplabv3lite_mobilenetv2_tv(model_config, pretrained=None):
157 model_config = get_config_deeplav3lite_mnv2().merge_from(model_config)
158 # encoder setup
159 model_config_e = model_config.clone()
160 base_model = MobileNetV2TVMI4(model_config_e)
161 # decoder setup
162 model = DeepLabV3Lite(base_model, model_config)
164 num_inputs = len(model_config.input_channels)
165 num_decoders = len(model_config.output_channels) if (
166 model_config.num_decoders is None) else model_config.num_decoders
167 if num_inputs > 1:
168 change_names_dict = {'^features.': ['encoder.features.stream{}.'.format(stream) for stream in range(num_inputs)],
169 '^encoder.features.': ['encoder.features.stream{}.'.format(stream) for stream in range(num_inputs)],
170 '^decoders.0.': ['decoders.{}.'.format(d) for d in range(num_decoders)]}
171 else:
172 change_names_dict = {'^features.': 'encoder.features.',
173 '^decoders.0.': ['decoders.{}.'.format(d) for d in range(num_decoders)]}
174 #
176 if pretrained:
177 model = xnn.utils.load_weights(model, pretrained, change_names_dict, ignore_size=True, verbose=True)
179 return model, change_names_dict
182 def deeplabv3lite_mobilenetv2_tv_fd(model_config, pretrained=None):
183 model_config = get_config_deeplav3lite_mnv2().merge_from(model_config)
184 model_config.fastdown = True
185 model_config.strides = (2,2,2,2,1)
186 model_config.shortcut_strides = (8,32)
187 model_config.shortcut_channels = (24,320)
188 model_config.decoder_chan = 256
189 model_config.aspp_chan = 256
190 return deeplabv3lite_mobilenetv2_tv(model_config, pretrained=pretrained)
193 def deeplabv3lite_mobilenetv2_ericsun(model_config, pretrained=None):
194 model_config = get_config_deeplav3lite_mnv2().merge_from(model_config)
195 # encoder setup
196 model_config_e = model_config.clone()
197 base_model = MobileNetV2EricsunMI4(model_config_e)
198 # decoder setup
199 model = DeepLabV3Lite(base_model, model_config)
201 num_inputs = len(model_config.input_channels)
202 num_decoders = len(model_config.output_channels) if (model_config.num_decoders is None) else model_config.num_decoders
203 if num_inputs > 1:
204 change_names_dict = {
205 '^features.': ['encoder.features.stream{}.'.format(stream) for stream in range(num_inputs)],
206 '^encoder.features.': ['encoder.features.stream{}.'.format(stream) for stream in range(num_inputs)],
207 '^decoders.0.': ['decoders.{}.'.format(d) for d in range(num_decoders)]}
208 else:
209 change_names_dict = {'^features.': 'encoder.features.',
210 '^decoders.0.': ['decoders.{}.'.format(d) for d in range(num_decoders)]}
211 #
213 if pretrained:
214 model = xnn.utils.load_weights(model, pretrained, change_names_dict, ignore_size=True, verbose=True)
216 return model, change_names_dict
220 ###########################################
221 # config settings for mobilenetv2 backbone
222 def get_config_deeplav3lite_resnet50():
223 # only the delta compared to the one defined for mobilenetv2
224 model_config = get_config_deeplav3lite_mnv2()
225 model_config.shortcut_channels = (256,2048)
226 return model_config
229 def deeplabv3lite_resnet50(model_config, pretrained=None):
230 model_config = get_config_deeplav3lite_resnet50().merge_from(model_config)
231 # encoder setup
232 model_config_e = model_config.clone()
233 base_model = ResNet50MI4(model_config_e)
234 # decoder setup
235 model = DeepLabV3Lite(base_model, model_config)
237 # the pretrained model provided by torchvision and what is defined here differs slightly
238 # note: that this change_names_dict will take effect only if the direct load fails
239 # finally take care of the change for deeplabv3lite (features->encoder.features)
240 num_inputs = len(model_config.input_channels)
241 num_decoders = len(model_config.output_channels) if (model_config.num_decoders is None) else model_config.num_decoders
242 if num_inputs > 1:
243 change_names_dict = {'^conv1.': ['encoder.features.stream{}.conv1.'.format(stream) for stream in range(num_inputs)],
244 '^bn1.': ['encoder.features.stream{}.bn1.'.format(stream) for stream in range(num_inputs)],
245 '^relu.': ['encoder.features.stream{}.relu.'.format(stream) for stream in range(num_inputs)],
246 '^maxpool.': ['encoder.features.stream{}.maxpool.'.format(stream) for stream in range(num_inputs)],
247 '^layer': ['encoder.features.stream{}.layer'.format(stream) for stream in range(num_inputs)],
248 '^features.': ['encoder.features.stream{}.'.format(stream) for stream in range(num_inputs)],
249 '^encoder.features.': ['encoder.features.stream{}.'.format(stream) for stream in range(num_inputs)],
250 '^decoders.0.': ['decoders.{}.'.format(d) for d in range(num_decoders)]}
251 else:
252 change_names_dict = {'^conv1.': 'encoder.features.conv1.',
253 '^bn1.': 'encoder.features.bn1.',
254 '^relu.': 'encoder.features.relu.',
255 '^maxpool.': 'encoder.features.maxpool.',
256 '^layer': 'encoder.features.layer',
257 '^features.': 'encoder.features.',
258 '^decoders.0.': ['decoders.{}.'.format(d) for d in range(num_decoders)]}
259 #
261 if pretrained:
262 model = xnn.utils.load_weights(model, pretrained, change_names_dict, ignore_size=True, verbose=True)
264 return model, change_names_dict
267 def deeplabv3lite_resnet50_p5(model_config, pretrained=None):
268 model_config.width_mult = 0.5
269 model_config.shortcut_channels = (128,1024)
270 return deeplabv3lite_resnet50(model_config, pretrained=pretrained)
273 def deeplabv3lite_resnet50_p5_fd(model_config, pretrained=None):
274 model_config.width_mult = 0.5
275 model_config.fastdown = True
276 model_config.shortcut_channels = (128,1024)
277 model_config.shortcut_strides = (8,64)
278 return deeplabv3lite_resnet50(model_config, pretrained=pretrained)