[jacinto-ai/pytorch-jacinto-ai-devkit.git] / modules / pytorch_jacinto_ai / xnn / quantize / quant_calib_module.py
diff --git a/modules/pytorch_jacinto_ai/xnn/quantize/quant_calib_module.py b/modules/pytorch_jacinto_ai/xnn/quantize/quant_calib_module.py
index dba9fa91ea19a6d4bdc6d704b6f8c4a21ec257d2..c7f8cfb369db636bf1a95c729ad4f84caf8a60e1 100644 (file)
###########################################################
class QuantCalibrateModule(QuantTrainModule):
def __init__(self, module, dummy_input, bitwidth_weights=8, bitwidth_activations=8, per_channel_q=False,
- histogram_range=True, bias_calibration=True, constrain_weights=True, lr_calib=0.05):
+ histogram_range=True, bias_calibration=True, constrain_weights=None, lr_calib=0.05):
self.weights_calibration = False
self.lr_calib = lr_calib
self.calibration_factor = lr_calib
self.calibrate_repeats = 1
self.quantize_enable = True
self.update_range = True
- constrain_weights = (constrain_weights and bias_calibration)
+ constrain_weights = (bias_calibration and (not per_channel_q)) if constrain_weights is None else constrain_weights
super().__init__(module, dummy_input=dummy_input, bitwidth_weights=bitwidth_weights, bitwidth_activations=bitwidth_activations,
per_channel_q=per_channel_q, histogram_range=histogram_range, bias_calibration=bias_calibration,
constrain_weights=constrain_weights)