release commit
[jacinto-ai/pytorch-jacinto-ai-devkit.git] / modules / pytorch_jacinto_ai / vision / models / resnet.py
index 788fd8b46a51a47ea97cbd38734feb78f68e5457..f4f4026fffb3d47d93036696a633214cf730e17f 100644 (file)
@@ -237,7 +237,12 @@ class ResNet(nn.Module):
 
 def _resnet(arch, block, layers, pretrained, progress, **kwargs):
     model = ResNet(block, layers, **kwargs)
 
 def _resnet(arch, block, layers, pretrained, progress, **kwargs):
     model = ResNet(block, layers, **kwargs)
-    if pretrained:
+    if pretrained is True:
+        change_names_dict = kwargs.get('change_names_dict', None)
+        state_dict = load_state_dict_from_url(model_urls[arch],
+                                              progress=progress)
+        model.load_weights(state_dict, change_names_dict=change_names_dict)
+    elif pretrained:
         change_names_dict = kwargs.get('change_names_dict', None)
         download_root = kwargs.get('download_root', None)
         model.load_weights(pretrained, change_names_dict=change_names_dict, download_root=download_root)
         change_names_dict = kwargs.get('change_names_dict', None)
         download_root = kwargs.get('download_root', None)
         model.load_weights(pretrained, change_names_dict=change_names_dict, download_root=download_root)