]> Gitweb @ Texas Instruments - Open Source Git Repositories - git.TI.com/gitweb - jacinto-ai/pytorch-jacinto-ai-devkit.git/blobdiff - modules/pytorch_jacinto_ai/engine/test_classification.py
release commit
[jacinto-ai/pytorch-jacinto-ai-devkit.git] / modules / pytorch_jacinto_ai / engine / test_classification.py
index 554cc5e51f95a0a7d67fc712c342c8a6e8aeb783..909f0f5411f786eb6391ab82ad21c02c79666e61 100644 (file)
@@ -146,16 +146,16 @@ def main(args):
         is_cuda = next(model.parameters()).is_cuda
         dummy_input = create_rand_inputs(args, is_cuda=is_cuda)
         #
-        if args.phase == 'training':
+        if 'training' in args.phase:
             model = xnn.quantize.QuantTrainModule(model, per_channel_q=args.per_channel_q,
                         bitwidth_weights=args.bitwidth_weights, bitwidth_activations=args.bitwidth_activations,
                         dummy_input=dummy_input)
-        elif args.phase == 'calibration':
+        elif 'calibration' in args.phase:
             model = xnn.quantize.QuantCalibrateModule(model, per_channel_q=args.per_channel_q,
                         bitwidth_weights=args.bitwidth_weights, bitwidth_activations=args.bitwidth_activations,
                         bias_calibration=args.bias_calibration, lr_calib=args.lr_calib,
                         dummy_input=dummy_input)
-        elif args.phase == 'validation':
+        elif 'validation' in args.phase:
             # Note: bias_calibration is not enabled in test
             model = xnn.quantize.QuantTestModule(model, per_channel_q=args.per_channel_q,
                         bitwidth_weights=args.bitwidth_weights, bitwidth_activations=args.bitwidth_activations,
@@ -165,7 +165,6 @@ def main(args):
             assert False, f'invalid phase {args.phase}'
     #
 
-
     # load pretrained
     xnn.utils.load_weights_check(get_model_orig(model), pretrained=pretrained_data, change_names_dict=change_names_dict)
 
@@ -191,7 +190,7 @@ def main(args):
 
     #################################################
     # multi gpu mode is not yet supported with quantization in evaluate
-    if args.parallel_model and (args.phase=='training'):
+    if args.parallel_model and ('training' in args.phase):
         model = torch.nn.DataParallel(model)
 
     #################################################