[jacinto-ai/pytorch-jacinto-ai-devkit.git] / modules / pytorch_jacinto_ai / vision / models / resnet.py
diff --git a/modules/pytorch_jacinto_ai/vision/models/resnet.py b/modules/pytorch_jacinto_ai/vision/models/resnet.py
index 788fd8b46a51a47ea97cbd38734feb78f68e5457..f4f4026fffb3d47d93036696a633214cf730e17f 100644 (file)
def _resnet(arch, block, layers, pretrained, progress, **kwargs):
model = ResNet(block, layers, **kwargs)
- if pretrained:
+ if pretrained is True:
+ change_names_dict = kwargs.get('change_names_dict', None)
+ state_dict = load_state_dict_from_url(model_urls[arch],
+ progress=progress)
+ model.load_weights(state_dict, change_names_dict=change_names_dict)
+ elif pretrained:
change_names_dict = kwargs.get('change_names_dict', None)
download_root = kwargs.get('download_root', None)
model.load_weights(pretrained, change_names_dict=change_names_dict, download_root=download_root)