shufflenetv2 mnodel loading fix
authorManu Mathew <a0393608@ti.com>
Fri, 21 Feb 2020 08:24:14 +0000 (13:54 +0530)
committerManu Mathew <a0393608@ti.com>
Fri, 21 Feb 2020 08:25:22 +0000 (13:55 +0530)
modules/pytorch_jacinto_ai/vision/models/shufflenetv2.py
modules/pytorch_jacinto_ai/xnn/utils/load_weights.py
run_quantization_example.sh

index be529d6a775bcfef4fbee7ebc62933de95c12a5b..3a45fa198fdabf2a6257e380f82a886ca8116cc4 100644 (file)
@@ -138,21 +138,36 @@ class ShuffleNetV2(nn.Module):
         return x
 
 
+    # define a load weights fuinction in the module since the module is changed w.r.t. to torchvision
+    # since we want to be able to laod the existing torchvision pretrained weights
+    def load_weights(self, pretrained, change_names_dict=None, download_root=None):
+        if change_names_dict is None:
+            # note: that this change_names_dict  will take effect only if the direct load fails
+            change_names_dict = {'^conv': 'features.conv', '^maxpool.': 'features.maxpool.',
+                                 '^stage': 'features.stage', '^fc.': 'classifier.'}
+        #
+        if pretrained is not None:
+            xnn.utils.load_weights(self, pretrained, change_names_dict=change_names_dict, download_root=download_root)
+        #
+        return self, change_names_dict
+
+
 def _shufflenetv2(arch, pretrained, progress, *args, **kwargs):
     model = ShuffleNetV2(*args, **kwargs)
-
-    if pretrained:
+    if pretrained is True:
+        change_names_dict = kwargs.get('change_names_dict', None)
         model_url = model_urls[arch]
         if model_url is None:
             raise NotImplementedError('pretrained {} is not supported as of now'.format(arch))
         else:
             state_dict = load_state_dict_from_url(model_url, progress=progress)
-            # the pretrained model provided by torchvision and what is defined here differs slightly
-            # note: that this change_names_dict  will take effect only if the direct load fails
-            change_names_dict = {'^conv': 'features.conv', '^maxpool.': 'features.maxpool.',
-                                 '^stage': 'features.stage', '^fc.': 'classifier.'}
-            model = xnn.utils.load_weights(model, state_dict, change_names_dict=change_names_dict)
-
+            model.load_weights(state_dict, change_names_dict=change_names_dict)
+        #
+    elif pretrained:
+        change_names_dict = kwargs.get('change_names_dict', None)
+        download_root = kwargs.get('download_root', None)
+        model.load_weights(pretrained, change_names_dict=change_names_dict, download_root=download_root)
+    #
     return model
 
 
index acd39115cfb6dbbaec49bf3cc3c56bca18245e65..fc467be7646b0db4df1bd87c49e1380aa0fc99fd 100644 (file)
@@ -12,8 +12,8 @@ from . import utils_data
 ######################################################
 def load_weights(model, pretrained, change_names_dict=None, keep_original_names=False, width_mult=1.0,
                        ignore_size=True, verbose=False, num_batches_tracked = None, download_root=None, **kwargs):
-    if pretrained is None:
-        print_utils.print_yellow('=> weights could not be loaded. pretrained data given is None')
+    if pretrained is None or pretrained is False:
+        print_utils.print_yellow(f'=> weights could not be loaded. pretrained data given is {pretrained}')
         return model
 
     if isinstance(pretrained, str):
index a9e550edce7e6f0ba8713971220ec1f31f6edd16..1f189878c362c461cffc90eca2597852a0db3465 100755 (executable)
@@ -22,7 +22,7 @@ exec &> >(tee -a "$log_file")
 declare -A model_pretrained=(
   [mobilenet_v2]=https://download.pytorch.org/models/mobilenet_v2-b0353104.pth
   [resnet50]=https://download.pytorch.org/models/resnet50-19c8e357.pth
-  [shufflenetv2_x1.0]=https://download.pytorch.org/models/shufflenetv2_x1-5666bf0f80.pth
+  [shufflenet_v2_x1_0]=https://download.pytorch.org/models/shufflenetv2_x1-5666bf0f80.pth
 #  [mobilenetv2_shicai]='./data/modelzoo/pretrained/pytorch/others/shicai/MobileNet-Caffe/mobilenetv2_shicai_rgb.tar'
 )