[jacinto-ai/pytorch-jacinto-ai-devkit.git] / modules / pytorch_jacinto_ai / vision / models / resnet.py
diff --git a/modules/pytorch_jacinto_ai/vision/models/resnet.py b/modules/pytorch_jacinto_ai/vision/models/resnet.py
index 0f8e3ee5399ef17bd86ecf37c71c0e4bfbce8a8d..18d0f8edd729098bc01b3c32333b760782b079a8 100644 (file)
return x
+ # define a load weights fuinction in the module since the module is changed w.r.t. to torchvision
+ # since we want to be able to laod the existing torchvision pretrained weights
+ def load_weights(self, pretrained, change_names_dict=None, download_root=None):
+ if change_names_dict is None:
+ # the pretrained model provided by torchvision and what is defined here differs slightly
+ # note: that this change_names_dict will take effect only if the direct load fails
+ change_names_dict = {'^conv1.':'features.conv1.', '^bn1.':'features.bn1.',
+ '^relu.':'features.relu.', '^maxpool.':'features.maxpool.',
+ '^layer':'features.layer' , '^fc.':'classifier.'}
+ #
+ if pretrained is not None:
+ xnn.utils.load_weights(self, pretrained, change_names_dict=change_names_dict, download_root=download_root)
+ return self, change_names_dict
+
+
def _resnet(arch, block, layers, pretrained, progress, **kwargs):
model = ResNet(block, layers, **kwargs)
if pretrained:
- state_dict = load_state_dict_from_url(model_urls[arch], progress=progress)
- # the pretrained model provided by torchvision and what is defined here differs slightly
- # note: that this change_names_dict will take effect only if the direct load fails
- change_names_dict = {'^conv1.':'features.conv1.', '^bn1.':'features.bn1.',
- '^relu.':'features.relu.', '^maxpool.':'features.maxpool.',
- '^layer':'features.layer' , '^fc.':'classifier.'}
- model = xnn.utils.load_weights_check(model, state_dict, change_names_dict=change_names_dict)
+ change_names_dict = kwargs.get('change_names_dict', None)
+ download_root = kwargs.get('download_root', None)
+ model.load_weights(pretrained, change_names_dict=change_names_dict, download_root=download_root)
return model