X-Git-Url: https://git.ti.com/gitweb?p=jacinto-ai%2Fpytorch-jacinto-ai-devkit.git;a=blobdiff_plain;f=modules%2Fpytorch_jacinto_ai%2Fxnn%2Fquantize%2Fquant_calib_module.py;h=c7f8cfb369db636bf1a95c729ad4f84caf8a60e1;hp=dba9fa91ea19a6d4bdc6d704b6f8c4a21ec257d2;hb=b1387506d4f0dd058aab3c27224d9e985bd013ac;hpb=3f0c4a7a88cbfe0723558fefcd8c91053c727ba1 diff --git a/modules/pytorch_jacinto_ai/xnn/quantize/quant_calib_module.py b/modules/pytorch_jacinto_ai/xnn/quantize/quant_calib_module.py index dba9fa9..c7f8cfb 100644 --- a/modules/pytorch_jacinto_ai/xnn/quantize/quant_calib_module.py +++ b/modules/pytorch_jacinto_ai/xnn/quantize/quant_calib_module.py @@ -18,7 +18,7 @@ warnings.filterwarnings('ignore', category=torch.jit.TracerWarning) ########################################################### class QuantCalibrateModule(QuantTrainModule): def __init__(self, module, dummy_input, bitwidth_weights=8, bitwidth_activations=8, per_channel_q=False, - histogram_range=True, bias_calibration=True, constrain_weights=True, lr_calib=0.05): + histogram_range=True, bias_calibration=True, constrain_weights=None, lr_calib=0.05): self.weights_calibration = False self.lr_calib = lr_calib self.calibration_factor = lr_calib @@ -26,7 +26,7 @@ class QuantCalibrateModule(QuantTrainModule): self.calibrate_repeats = 1 self.quantize_enable = True self.update_range = True - constrain_weights = (constrain_weights and bias_calibration) + constrain_weights = (bias_calibration and (not per_channel_q)) if constrain_weights is None else constrain_weights super().__init__(module, dummy_input=dummy_input, bitwidth_weights=bitwidth_weights, bitwidth_activations=bitwidth_activations, per_channel_q=per_channel_q, histogram_range=histogram_range, bias_calibration=bias_calibration, constrain_weights=constrain_weights)