[jacinto-ai/pytorch-jacinto-ai-devkit.git] / modules / pytorch_jacinto_ai / xvision / models / pixel2pixel / deeplabv3lite.py
diff --git a/modules/pytorch_jacinto_ai/xvision/models/pixel2pixel/deeplabv3lite.py b/modules/pytorch_jacinto_ai/xvision/models/pixel2pixel/deeplabv3lite.py
index dda072e2383353af5455e9d90325df435e5cc4fa..88e3ff1891e994b3d63c2fbf6d831da3eae3b487 100644 (file)
try: from .pixel2pixelnet_internal import *
except: pass
-from ..multi_input_net import MobileNetV2TVMI4, MobileNetV2EricsunMI4, ResNet50MI4
+from ..multi_input_net import MobileNetV2TVMI4, MobileNetV2EricsunMI4, \
+ ResNet50MI4, RegNetX800MFMI4
###########################################
__all__ = ['DeepLabV3Lite', 'DeepLabV3LiteDecoder',
'deeplabv3lite_mobilenetv2_tv', 'deeplabv3lite_mobilenetv2_tv_fd',
'deeplabv3lite_mobilenetv2_ericsun',
- 'deeplabv3lite_resnet50', 'deeplabv3lite_resnet50_p5', 'deeplabv3lite_resnet50_p5_fd']
+ 'deeplabv3lite_resnet50', 'deeplabv3lite_resnet50_p5', 'deeplabv3lite_resnet50_p5_fd',
+ 'deeplabv3lite_regnetx800mf']
###########################################
aspp_channels = round(model_config.aspp_chan*model_config.decoder_factor)
if model_config.use_aspp:
+ group_size_dw = model_config.group_size_dw if hasattr(model_config, 'group_size_dw') else None
ASPPBlock = xnn.layers.GWASPPLiteBlock if model_config.groupwise_sep else xnn.layers.DWASPPLiteBlock
self.aspp = ASPPBlock(current_channels, aspp_channels, decoder_channels, dilation=model_config.aspp_dil,
- activation=model_config.activation, linear_dw=model_config.linear_dw)
+ activation=model_config.activation, linear_dw=model_config.linear_dw,
+ group_size_dw=group_size_dw)
else:
self.aspp = None
if (not self.training) and (self.model_config.output_type == 'segmentation'):
x = torch.argmax(x, dim=1, keepdim=True)
- assert int(in_shape[2]) == int(x.shape[2]*self.model_config.target_input_ratio) and int(in_shape[3]) == int(x.shape[3]*self.model_config.target_input_ratio), 'incorrect output shape'
+ assert int(in_shape[2]) == int(x.shape[2]*self.model_config.target_input_ratio) and \
+ int(in_shape[3]) == int(x.shape[3]*self.model_config.target_input_ratio), 'incorrect output shape'
if self.model_config.freeze_decoder:
x = x.detach()
model_config.fastdown = True
model_config.shortcut_channels = (128,1024)
model_config.shortcut_strides = (8,64)
- return deeplabv3lite_resnet50(model_config, pretrained=pretrained)
\ No newline at end of file
+ return deeplabv3lite_resnet50(model_config, pretrained=pretrained)
+
+
+###########################################
+# config settings for mobilenetv2 backbone
+def get_config_deeplav3lite_regnetx800mf():
+ # only the delta compared to the one defined for mobilenetv2
+ model_config = get_config_deeplav3lite_mnv2()
+ model_config.shortcut_channels = (64,672)
+ model_config.group_size_dw = 16
+ return model_config
+
+
+# here this is nothing specific about bgr in this model
+# but is just a reminder that regnet models are typically trained with bgr input
+def deeplabv3lite_regnetx800mf(model_config, pretrained=None):
+ model_config = get_config_deeplav3lite_regnetx800mf().merge_from(model_config)
+ # encoder setup
+ model_config_e = model_config.clone()
+ base_model = RegNetX800MFMI4(model_config_e)
+ # decoder setup
+ model = DeepLabV3Lite(base_model, model_config)
+
+ # the pretrained model provided by torchvision and what is defined here differs slightly
+ # note: that this change_names_dict will take effect only if the direct load fails
+ # finally take care of the change for deeplabv3lite (features->encoder.features)
+ num_inputs = len(model_config.input_channels)
+ num_decoders = len(model_config.output_channels) if (model_config.num_decoders is None) else model_config.num_decoders
+ if num_inputs > 1:
+ change_names_dict = {'^features.': ['encoder.features.stream{}.'.format(stream) for stream in range(num_inputs)],
+ '^encoder.features.': ['encoder.features.stream{}.'.format(stream) for stream in range(num_inputs)],
+ '^decoders.0.': ['decoders.{}.'.format(d) for d in range(num_decoders)]}
+ else:
+ change_names_dict = {'^stem.': 'encoder.features.stem.',
+ '^s1': 'encoder.features.s1', '^s2': 'encoder.features.s2',
+ '^s3': 'encoder.features.s3', '^s4': 'encoder.features.s4',
+ '^features.': 'encoder.features.',
+ '^decoders.0.': ['decoders.{}.'.format(d) for d in range(num_decoders)]}
+ #
+
+ if pretrained:
+ model = xnn.utils.load_weights(model, pretrained, change_names_dict, ignore_size=True, verbose=True,
+ state_dict_name=['state_dict','model_state'])
+ else:
+ # need to use state_dict_name as the checkpoint uses a different name for state_dict
+ # provide a custom load_weighs for the model
+ def load_weights_func(pretrained, change_names_dict, ignore_size=True, verbose=True,
+ state_dict_name=['state_dict','model_state']):
+ xnn.utils.load_weights(model, pretrained, change_names_dict=change_names_dict, ignore_size=ignore_size, verbose=verbose,
+ state_dict_name=state_dict_name)
+ #
+ model.load_weights = load_weights_func
+
+ return model, change_names_dict
+