[jacinto-ai/pytorch-jacinto-ai-devkit.git] / modules / pytorch_jacinto_ai / vision / models / pixel2pixel / deeplabv3lite.py
1 import torch
2 import numpy as np
4 from .... import xnn
5 from .pixel2pixelnet import *
7 try: from .pixel2pixelnet_internal import *
8 except: pass
10 from ..multi_input_net import MobileNetV2TVMI4, MobileNetV2EricsunMI4, ResNet50MI4
12 ###########################################
13 __all__ = ['DeepLabV3Lite', 'DeepLabV3LiteDecoder',
14 'deeplabv3lite_mobilenetv2_tv', 'deeplabv3lite_mobilenetv2_tv_fd',
15 'deeplabv3lite_mobilenetv2_ericsun',
16 'deeplabv3lite_resnet50', 'deeplabv3lite_resnet50_p5', 'deeplabv3lite_resnet50_p5_fd']
19 ###########################################
20 class DeepLabV3LiteDecoder(torch.nn.Module):
21 def __init__(self, model_config):
22 super().__init__()
24 self.model_config = model_config
26 current_channels = model_config.shortcut_channels[-1]
27 decoder_channels = round(model_config.decoder_chan*model_config.decoder_factor)
28 aspp_channels = round(model_config.aspp_chan*model_config.decoder_factor)
30 if model_config.use_aspp:
31 ASPPBlock = xnn.layers.GWASPPLiteBlock if model_config.groupwise_sep else xnn.layers.DWASPPLiteBlock
32 self.aspp = ASPPBlock(current_channels, aspp_channels, decoder_channels, dilation=model_config.aspp_dil,
33 activation=model_config.activation, linear_dw=model_config.linear_dw)
34 else:
35 self.aspp = None
37 current_channels = decoder_channels if model_config.use_aspp else current_channels
39 short_chan = model_config.shortcut_channels[0]
40 self.shortcut = xnn.layers.ConvNormAct2d(short_chan, model_config.shortcut_out, kernel_size=1, activation=model_config.activation)
42 self.decoder_channels = merged_channels = (current_channels+model_config.shortcut_out)
44 upstride1 = model_config.shortcut_strides[-1]//model_config.shortcut_strides[0]
45 # use UpsampleWithGeneric() instead of UpsampleWith() to break down large upsampling factors to multiples of 4 and 2 -
46 # useful if upsampling factors other than 4 and 2 are not supported.
47 self.upsample1 = xnn.layers.UpsampleWith(decoder_channels, decoder_channels, upstride1, model_config.interpolation_type, model_config.interpolation_mode)
49 self.cat = xnn.layers.CatBlock()
51 # add prediction & upsample modules
52 if self.model_config.final_prediction:
53 add_lite_prediction_modules(self, model_config, merged_channels, module_names=('pred','upsample2'))
54 #
57 # the upsampling is using functional form to support size based upsampling for odd sizes
58 # that are not a perfect ratio (eg. 257x513), which seem to be popular for segmentation
59 def forward(self, x, x_features, x_list):
60 assert isinstance(x, (list,tuple)) and len(x)<=2, 'incorrect input'
62 x_input = x[0]
63 in_shape = x_input[0].shape if isinstance(x_input, (list,tuple)) else x_input.shape
65 # high res shortcut
66 shape_s = xnn.utils.get_shape_with_stride(in_shape, self.model_config.shortcut_strides[0])
67 shape_s[1] = self.model_config.shortcut_channels[0]
68 x_s = xnn.utils.get_blob_from_list(x_list, shape_s)
69 x_s = self.shortcut(x_s)
71 if self.model_config.freeze_encoder:
72 x_s = x_s.detach()
73 x_features = x_features.detach()
75 # aspp/scse blocks at output stride
76 x = self.aspp(x_features) if self.model_config.use_aspp else x_features
78 # upsample low res features to match with shortcut
79 x = self.upsample1(x)
81 # combine and do high res prediction
82 x = self.cat((x,x_s))
84 if self.model_config.final_prediction:
85 x = self.pred(x)
87 if self.model_config.final_upsample:
88 x = self.upsample2(x)
90 if (not self.training) and (self.model_config.output_type == 'segmentation'):
91 x = torch.argmax(x, dim=1, keepdim=True)
93 assert int(in_shape[2]) == int(x.shape[2]*self.model_config.target_input_ratio) and int(in_shape[3]) == int(x.shape[3]*self.model_config.target_input_ratio), 'incorrect output shape'
95 if self.model_config.freeze_decoder:
96 x = x.detach()
98 return x
101 class DeepLabV3Lite(Pixel2PixelNet):
102 def __init__(self, base_model, model_config):
103 super().__init__(base_model, DeepLabV3LiteDecoder, model_config)
106 ###########################################
107 # config settings
108 def get_config_deeplav3lite_mnv2():
109 # use list for entries that are different for different decoders.
110 # and are expected to be passed from the main script.
111 model_config = xnn.utils.ConfigNode()
112 model_config.num_classes = None
113 model_config.num_decoders = None
114 model_config.input_channels = (3,)
115 model_config.output_channels = [19]
116 model_config.intermediate_outputs = True
117 model_config.normalize_input = False
118 model_config.split_outputs = False
119 model_config.use_aspp = True
120 model_config.fastdown = False
121 model_config.target_input_ratio = 1
123 model_config.strides = (2,2,2,2,1)
124 model_config.fastdown = False
125 model_config.groupwise_sep = False
126 encoder_stride = np.prod(model_config.strides)
127 model_config.shortcut_strides = (4,encoder_stride)
128 model_config.shortcut_channels = (24,320) # this is for mobilenetv2 - change for other networks
129 model_config.shortcut_out = 48
130 model_config.decoder_chan = 256
131 model_config.aspp_chan = 256
132 model_config.aspp_dil = (6,12,18)
133 model_config.final_prediction = True
134 model_config.final_upsample = True
135 model_config.output_range = None
136 model_config.decoder_factor = 1.0
137 model_config.output_type = None
138 model_config.activation = xnn.layers.DefaultAct2d
139 model_config.interpolation_type = 'upsample'
140 model_config.interpolation_mode = 'bilinear'
141 model_config.linear_dw = False
142 model_config.normalize_gradients = False
143 model_config.freeze_encoder = False
144 model_config.freeze_decoder = False
145 model_config.multi_task = False
146 return model_config
149 def deeplabv3lite_mobilenetv2_tv(model_config, pretrained=None):
150 model_config = get_config_deeplav3lite_mnv2().merge_from(model_config)
151 # encoder setup
152 model_config_e = model_config.clone()
153 base_model = MobileNetV2TVMI4(model_config_e)
154 # decoder setup
155 model = DeepLabV3Lite(base_model, model_config)
157 num_inputs = len(model_config.input_channels)
158 num_decoders = len(model_config.output_channels) if (
159 model_config.num_decoders is None) else model_config.num_decoders
160 if num_inputs > 1:
161 change_names_dict = {'^features.': ['encoder.features.stream{}.'.format(stream) for stream in range(num_inputs)],
162 '^encoder.features.': ['encoder.features.stream{}.'.format(stream) for stream in range(num_inputs)],
163 '^decoders.0.': ['decoders.{}.'.format(d) for d in range(num_decoders)]}
164 else:
165 change_names_dict = {'^features.': 'encoder.features.',
166 '^decoders.0.': ['decoders.{}.'.format(d) for d in range(num_decoders)]}
167 #
169 if pretrained:
170 model = xnn.utils.load_weights(model, pretrained, change_names_dict, ignore_size=True, verbose=True)
172 return model, change_names_dict
175 def deeplabv3lite_mobilenetv2_tv_fd(model_config, pretrained=None):
176 model_config = get_config_deeplav3lite_mnv2().merge_from(model_config)
177 model_config.fastdown = True
178 model_config.strides = (2,2,2,2,1)
179 model_config.shortcut_strides = (8,32)
180 model_config.shortcut_channels = (24,320)
181 model_config.decoder_chan = 256
182 model_config.aspp_chan = 256
183 return deeplabv3lite_mobilenetv2_tv(model_config, pretrained=pretrained)
186 def deeplabv3lite_mobilenetv2_ericsun(model_config, pretrained=None):
187 model_config = get_config_deeplav3lite_mnv2().merge_from(model_config)
188 # encoder setup
189 model_config_e = model_config.clone()
190 base_model = MobileNetV2EricsunMI4(model_config_e)
191 # decoder setup
192 model = DeepLabV3Lite(base_model, model_config)
194 num_inputs = len(model_config.input_channels)
195 num_decoders = len(model_config.output_channels) if (model_config.num_decoders is None) else model_config.num_decoders
196 if num_inputs > 1:
197 change_names_dict = {
198 '^features.': ['encoder.features.stream{}.'.format(stream) for stream in range(num_inputs)],
199 '^encoder.features.': ['encoder.features.stream{}.'.format(stream) for stream in range(num_inputs)],
200 '^decoders.0.': ['decoders.{}.'.format(d) for d in range(num_decoders)]}
201 else:
202 change_names_dict = {'^features.': 'encoder.features.',
203 '^decoders.0.': ['decoders.{}.'.format(d) for d in range(num_decoders)]}
204 #
206 if pretrained:
207 model = xnn.utils.load_weights(model, pretrained, change_names_dict, ignore_size=True, verbose=True)
209 return model, change_names_dict
213 ###########################################
214 # config settings for mobilenetv2 backbone
215 def get_config_deeplav3lite_resnet50():
216 # only the delta compared to the one defined for mobilenetv2
217 model_config = get_config_deeplav3lite_mnv2()
218 model_config.shortcut_channels = (256,2048)
219 return model_config
222 def deeplabv3lite_resnet50(model_config, pretrained=None):
223 model_config = get_config_deeplav3lite_resnet50().merge_from(model_config)
224 # encoder setup
225 model_config_e = model_config.clone()
226 base_model = ResNet50MI4(model_config_e)
227 # decoder setup
228 model = DeepLabV3Lite(base_model, model_config)
230 # the pretrained model provided by torchvision and what is defined here differs slightly
231 # note: that this change_names_dict will take effect only if the direct load fails
232 # finally take care of the change for deeplabv3lite (features->encoder.features)
233 num_inputs = len(model_config.input_channels)
234 num_decoders = len(model_config.output_channels) if (model_config.num_decoders is None) else model_config.num_decoders
235 if num_inputs > 1:
236 change_names_dict = {'^conv1.': ['encoder.features.stream{}.conv1.'.format(stream) for stream in range(num_inputs)],
237 '^bn1.': ['encoder.features.stream{}.bn1.'.format(stream) for stream in range(num_inputs)],
238 '^relu.': ['encoder.features.stream{}.relu.'.format(stream) for stream in range(num_inputs)],
239 '^maxpool.': ['encoder.features.stream{}.maxpool.'.format(stream) for stream in range(num_inputs)],
240 '^layer': ['encoder.features.stream{}.layer'.format(stream) for stream in range(num_inputs)],
241 '^features.': ['encoder.features.stream{}.'.format(stream) for stream in range(num_inputs)],
242 '^encoder.features.': ['encoder.features.stream{}.'.format(stream) for stream in range(num_inputs)],
243 '^decoders.0.': ['decoders.{}.'.format(d) for d in range(num_decoders)]}
244 else:
245 change_names_dict = {'^conv1.': 'encoder.features.conv1.',
246 '^bn1.': 'encoder.features.bn1.',
247 '^relu.': 'encoder.features.relu.',
248 '^maxpool.': 'encoder.features.maxpool.',
249 '^layer': 'encoder.features.layer',
250 '^features.': 'encoder.features.',
251 '^decoders.0.': ['decoders.{}.'.format(d) for d in range(num_decoders)]}
252 #
254 if pretrained:
255 model = xnn.utils.load_weights(model, pretrained, change_names_dict, ignore_size=True, verbose=True)
257 return model, change_names_dict
260 def deeplabv3lite_resnet50_p5(model_config, pretrained=None):
261 model_config.width_mult = 0.5
262 model_config.shortcut_channels = (128,1024)
263 return deeplabv3lite_resnet50(model_config, pretrained=pretrained)
266 def deeplabv3lite_resnet50_p5_fd(model_config, pretrained=None):
267 model_config.width_mult = 0.5
268 model_config.fastdown = True
269 model_config.shortcut_channels = (128,1024)
270 model_config.shortcut_strides = (8,64)
271 return deeplabv3lite_resnet50(model_config, pretrained=pretrained)