constrain_weights - in default scenario, this is not used with per_channel_q.
[jacinto-ai/pytorch-jacinto-ai-devkit.git] / modules / pytorch_jacinto_ai / xnn / quantize / quant_calib_module.py
index dba9fa91ea19a6d4bdc6d704b6f8c4a21ec257d2..c7f8cfb369db636bf1a95c729ad4f84caf8a60e1 100644 (file)
@@ -18,7 +18,7 @@ warnings.filterwarnings('ignore', category=torch.jit.TracerWarning)
 ###########################################################
 class QuantCalibrateModule(QuantTrainModule):
     def __init__(self, module, dummy_input, bitwidth_weights=8, bitwidth_activations=8, per_channel_q=False,
-                 histogram_range=True, bias_calibration=True, constrain_weights=True, lr_calib=0.05):
+                 histogram_range=True, bias_calibration=True, constrain_weights=None, lr_calib=0.05):
         self.weights_calibration = False
         self.lr_calib = lr_calib
         self.calibration_factor = lr_calib
@@ -26,7 +26,7 @@ class QuantCalibrateModule(QuantTrainModule):
         self.calibrate_repeats = 1
         self.quantize_enable = True
         self.update_range = True
-        constrain_weights = (constrain_weights and bias_calibration)
+        constrain_weights = (bias_calibration and (not per_channel_q)) if constrain_weights is None else constrain_weights
         super().__init__(module, dummy_input=dummy_input, bitwidth_weights=bitwidth_weights, bitwidth_activations=bitwidth_activations,
                          per_channel_q=per_channel_q, histogram_range=histogram_range, bias_calibration=bias_calibration,
                          constrain_weights=constrain_weights)