2252e385ade719a7c839b8a3d3db75dbc2cb62f1
[jacinto-ai/pytorch-jacinto-ai-devkit.git] / modules / pytorch_jacinto_ai / vision / models / pixel2pixel / deeplabv3lite.py
1 import torch
2 import numpy as np
4 from .... import xnn
5 from .pixel2pixelnet import *
7 try: from .pixel2pixelnet_internal import *
8 except: pass
10 from ..multi_input_net import MobileNetV2TVMI4, MobileNetV2EricsunMI4, ResNet50MI4
12 ###########################################
13 __all__ = ['DeepLabV3Lite', 'DeepLabV3LiteDecoder',
14 'deeplabv3lite_mobilenetv2_tv', 'deeplabv3lite_mobilenetv2_ericsun',
15 'deeplabv3lite_resnet50']
18 ###########################################
19 class DeepLabV3LiteDecoder(torch.nn.Module):
20 def __init__(self, model_config):
21 super().__init__()
23 self.model_config = model_config
25 current_channels = model_config.shortcut_channels[-1]
26 decoder_channels = round(model_config.decoder_chan*model_config.decoder_factor)
27 aspp_channels = round(model_config.aspp_chan*model_config.decoder_factor)
29 if model_config.use_aspp:
30 ASPPBlock = xnn.layers.GWASPPLiteBlock if model_config.groupwise_sep else xnn.layers.DWASPPLiteBlock
31 self.aspp = ASPPBlock(current_channels, aspp_channels, decoder_channels, dilation=model_config.aspp_dil,
32 activation=model_config.activation, linear_dw=model_config.linear_dw)
33 else:
34 self.aspp = None
36 current_channels = decoder_channels if model_config.use_aspp else current_channels
38 short_chan = model_config.shortcut_channels[0]
39 self.shortcut = xnn.layers.ConvNormAct2d(short_chan, model_config.shortcut_out, kernel_size=1, activation=model_config.activation)
41 self.decoder_channels = merged_channels = (current_channels+model_config.shortcut_out)
43 # prediction
44 if self.model_config.final_prediction:
45 ConvXWSepBlock = xnn.layers.ConvGWSepNormAct2d if model_config.groupwise_sep else xnn.layers.ConvDWSepNormAct2d
46 final_activation = xnn.layers.get_fixed_pact2(output_range = model_config.output_range) \
47 if (model_config.output_range is not None) else False
48 self.pred = ConvXWSepBlock(merged_channels, model_config.output_channels, kernel_size=3, normalization=((not model_config.linear_dw),False),
49 activation=(False,final_activation), groups = 1)
51 self.cat = xnn.layers.CatBlock()
53 upstride1 = model_config.shortcut_strides[-1]//model_config.shortcut_strides[0]
54 upstride2 = model_config.shortcut_strides[0]
55 self.upsample1 = xnn.layers.UpsampleTo(decoder_channels, decoder_channels, upstride1, model_config.interpolation_type, model_config.interpolation_mode)
56 self.upsample2 = xnn.layers.UpsampleTo(model_config.output_channels, model_config.output_channels, upstride2, model_config.interpolation_type, model_config.interpolation_mode)
58 # the upsampling is using functional form to support size based upsampling for odd sizes
59 # that are not a perfect ratio (eg. 257x513), which seem to be popular for segmentation
60 def forward(self, x, x_features, x_list):
61 assert isinstance(x, (list,tuple)) and len(x)<=2, 'incorrect input'
63 x_input = x[0]
64 in_shape = x_input.shape
66 # high res shortcut
67 shape_s = xnn.utils.get_shape_with_stride(in_shape, self.model_config.shortcut_strides[0])
68 shape_s[1] = self.model_config.shortcut_channels[0]
69 x_s = xnn.utils.get_blob_from_list(x_list, shape_s)
70 x_s = self.shortcut(x_s)
72 if self.model_config.freeze_encoder:
73 x_s = x_s.detach()
74 x_features = x_features.detach()
76 # aspp/scse blocks at output stride
77 x = self.aspp(x_features) if self.model_config.use_aspp else x_features
79 # upsample low res features to match with shortcut
80 x = self.upsample1((x, x_s))
82 # combine and do high res prediction
83 x = self.cat((x,x_s))
85 if self.model_config.final_prediction:
86 x = self.pred(x)
88 if self.model_config.final_upsample:
89 # final prediction is the upsampled one
90 #scale_factor = (in_shape[2]/x.shape[2], in_shape[3]/x.shape[3])
91 x = self.upsample2((x, x_input))
93 if (not self.training) and (self.model_config.output_type == 'segmentation'):
94 x = torch.argmax(x, dim=1, keepdim=True)
96 assert int(in_shape[2]) == int(x.shape[2]) and int(in_shape[3]) == int(x.shape[3]), 'incorrect output shape'
98 if self.model_config.freeze_decoder:
99 x = x.detach()
101 return x
104 class DeepLabV3Lite(Pixel2PixelNet):
105 def __init__(self, base_model, model_config):
106 super().__init__(base_model, DeepLabV3LiteDecoder, model_config)
109 ###########################################
110 # config settings
111 def get_config_deeplav3lite_mnv2():
112 # use list for entries that are different for different decoders.
113 # and are expected to be passed from the main script.
114 model_config = xnn.utils.ConfigNode()
115 model_config.num_classes = None
116 model_config.num_decoders = None
117 model_config.input_channels = (3,)
118 model_config.output_channels = [19]
119 model_config.intermediate_outputs = True
120 model_config.normalize_input = False
121 model_config.split_outputs = False
122 model_config.use_aspp = True
123 model_config.strides = (2,2,2,2,1)
124 model_config.groupwise_sep = False
125 encoder_stride = np.prod(model_config.strides)
126 model_config.shortcut_strides = (4,encoder_stride)
127 model_config.shortcut_channels = (24,320) # this is for mobilenetv2 - change for other networks
128 model_config.shortcut_out = 48
129 model_config.decoder_chan = 256
130 model_config.aspp_chan = 256
131 model_config.aspp_dil = (6,12,18)
132 model_config.final_prediction = True
133 model_config.final_upsample = True
134 model_config.output_range = None
135 model_config.decoder_factor = 1.0
136 model_config.output_type = None
137 model_config.activation = xnn.layers.DefaultAct2d
138 model_config.interpolation_type = 'upsample'
139 model_config.interpolation_mode = 'bilinear'
140 model_config.linear_dw = False
141 model_config.normalize_gradients = False
142 model_config.freeze_encoder = False
143 model_config.freeze_decoder = False
144 model_config.multi_task = False
145 return model_config
148 def deeplabv3lite_mobilenetv2_tv(model_config, pretrained=None):
149 model_config = get_config_deeplav3lite_mnv2().merge_from(model_config)
150 # encoder setup
151 model_config_e = model_config.clone()
152 base_model = MobileNetV2TVMI4(model_config_e)
153 # decoder setup
154 model = DeepLabV3Lite(base_model, model_config)
156 num_inputs = len(model_config.input_channels)
157 num_decoders = len(model_config.output_channels) if (
158 model_config.num_decoders is None) else model_config.num_decoders
159 if num_inputs > 1:
160 change_names_dict = {'^features.': ['encoder.features.stream{}.'.format(stream) for stream in range(num_inputs)],
161 '^encoder.features.': ['encoder.features.stream{}.'.format(stream) for stream in range(num_inputs)],
162 '^decoders.0.': ['decoders.{}.'.format(d) for d in range(num_decoders)]}
163 else:
164 change_names_dict = {'^features.': 'encoder.features.',
165 '^decoders.0.': ['decoders.{}.'.format(d) for d in range(num_decoders)]}
166 #
168 if pretrained:
169 model = xnn.utils.load_weights(model, pretrained, change_names_dict, ignore_size=True, verbose=True)
171 return model, change_names_dict
174 def deeplabv3lite_mobilenetv2_ericsun(model_config, pretrained=None):
175 model_config = get_config_deeplav3lite_mnv2().merge_from(model_config)
176 # encoder setup
177 model_config_e = model_config.clone()
178 base_model = MobileNetV2EricsunMI4(model_config_e)
179 # decoder setup
180 model = DeepLabV3Lite(base_model, model_config)
182 num_inputs = len(model_config.input_channels)
183 num_decoders = len(model_config.output_channels) if (model_config.num_decoders is None) else model_config.num_decoders
184 if num_inputs > 1:
185 change_names_dict = {
186 '^features.': ['encoder.features.stream{}.'.format(stream) for stream in range(num_inputs)],
187 '^encoder.features.': ['encoder.features.stream{}.'.format(stream) for stream in range(num_inputs)],
188 '^decoders.0.': ['decoders.{}.'.format(d) for d in range(num_decoders)]}
189 else:
190 change_names_dict = {'^features.': 'encoder.features.',
191 '^decoders.0.': ['decoders.{}.'.format(d) for d in range(num_decoders)]}
192 #
194 if pretrained:
195 model = xnn.utils.load_weights(model, pretrained, change_names_dict, ignore_size=True, verbose=True)
197 return model, change_names_dict
201 ###########################################
202 # config settings for mobilenetv2 backbone
203 def get_config_deeplav3lite_resnet50():
204 # only the delta compared to the one defined for mobilenetv2
205 model_config = get_config_deeplav3lite_mnv2()
206 model_config.shortcut_channels = (256,2048)
207 return model_config
210 def deeplabv3lite_resnet50(model_config, pretrained=None):
211 model_config = get_config_deeplav3lite_resnet50().merge_from(model_config)
212 # encoder setup
213 model_config_e = model_config.clone()
214 base_model = ResNet50MI4(model_config_e)
215 # decoder setup
216 model = DeepLabV3Lite(base_model, model_config)
218 # the pretrained model provided by torchvision and what is defined here differs slightly
219 # note: that this change_names_dict will take effect only if the direct load fails
220 # finally take care of the change for deeplabv3lite (features->encoder.features)
221 num_inputs = len(model_config.input_channels)
222 num_decoders = len(model_config.output_channels) if (model_config.num_decoders is None) else model_config.num_decoders
223 if num_inputs > 1:
224 change_names_dict = {'^conv1.': ['encoder.features.stream{}.conv1.'.format(stream) for stream in range(num_inputs)],
225 '^bn1.': ['encoder.features.stream{}.bn1.'.format(stream) for stream in range(num_inputs)],
226 '^relu.': ['encoder.features.stream{}.relu.'.format(stream) for stream in range(num_inputs)],
227 '^maxpool.': ['encoder.features.stream{}.maxpool.'.format(stream) for stream in range(num_inputs)],
228 '^layer': ['encoder.features.stream{}.layer'.format(stream) for stream in range(num_inputs)],
229 '^features.': ['encoder.features.stream{}.'.format(stream) for stream in range(num_inputs)],
230 '^encoder.features.': ['encoder.features.stream{}.'.format(stream) for stream in range(num_inputs)],
231 '^decoders.0.': ['decoders.{}.'.format(d) for d in range(num_decoders)]}
232 else:
233 change_names_dict = {'^conv1.': 'encoder.features.conv1.',
234 '^bn1.': 'encoder.features.bn1.',
235 '^relu.': 'encoder.features.relu.',
236 '^maxpool.': 'encoder.features.maxpool.',
237 '^layer': 'encoder.features.layer',
238 '^features.': 'encoder.features.',
239 '^decoders.0.': ['decoders.{}.'.format(d) for d in range(num_decoders)]}
240 #
242 if pretrained:
243 model = xnn.utils.load_weights(model, pretrained, change_names_dict, ignore_size=True, verbose=True)
245 return model, change_names_dict