[jacinto-ai/pytorch-jacinto-ai-devkit.git] / modules / pytorch_jacinto_ai / engine / test_classification.py
diff --git a/modules/pytorch_jacinto_ai/engine/test_classification.py b/modules/pytorch_jacinto_ai/engine/test_classification.py
index 554cc5e51f95a0a7d67fc712c342c8a6e8aeb783..909f0f5411f786eb6391ab82ad21c02c79666e61 100644 (file)
is_cuda = next(model.parameters()).is_cuda
dummy_input = create_rand_inputs(args, is_cuda=is_cuda)
#
- if args.phase == 'training':
+ if 'training' in args.phase:
model = xnn.quantize.QuantTrainModule(model, per_channel_q=args.per_channel_q,
bitwidth_weights=args.bitwidth_weights, bitwidth_activations=args.bitwidth_activations,
dummy_input=dummy_input)
- elif args.phase == 'calibration':
+ elif 'calibration' in args.phase:
model = xnn.quantize.QuantCalibrateModule(model, per_channel_q=args.per_channel_q,
bitwidth_weights=args.bitwidth_weights, bitwidth_activations=args.bitwidth_activations,
bias_calibration=args.bias_calibration, lr_calib=args.lr_calib,
dummy_input=dummy_input)
- elif args.phase == 'validation':
+ elif 'validation' in args.phase:
# Note: bias_calibration is not enabled in test
model = xnn.quantize.QuantTestModule(model, per_channel_q=args.per_channel_q,
bitwidth_weights=args.bitwidth_weights, bitwidth_activations=args.bitwidth_activations,
assert False, f'invalid phase {args.phase}'
#
-
# load pretrained
xnn.utils.load_weights_check(get_model_orig(model), pretrained=pretrained_data, change_names_dict=change_names_dict)
#################################################
# multi gpu mode is not yet supported with quantization in evaluate
- if args.parallel_model and (args.phase=='training'):
+ if args.parallel_model and ('training' in args.phase):
model = torch.nn.DataParallel(model)
#################################################