constrain_weights - in default scenario, this is not used with per_channel_q.
[jacinto-ai/pytorch-jacinto-ai-devkit.git] / modules / pytorch_jacinto_ai / xnn / quantize / quant_train_module.py
index f940fdf785575338f60015aaeba953d6300bb608..7fe340e705ebf742f4a1ed8f9abed618f7c89c18 100644 (file)
@@ -19,7 +19,8 @@ warnings.filterwarnings('ignore', category=torch.jit.TracerWarning)
 ###########################################################
 class QuantTrainModule(QuantBaseModule):
     def __init__(self, module, dummy_input, bitwidth_weights=8, bitwidth_activations=8, per_channel_q=False,
 ###########################################################
 class QuantTrainModule(QuantBaseModule):
     def __init__(self, module, dummy_input, bitwidth_weights=8, bitwidth_activations=8, per_channel_q=False,
-                 histogram_range=True, bias_calibration=False, constrain_weights=True):
+                 histogram_range=True, bias_calibration=False, constrain_weights=None):
+        constrain_weights = (not per_channel_q) if constrain_weights is None else constrain_weights
         super().__init__(module, dummy_input=dummy_input, bitwidth_weights=bitwidth_weights, bitwidth_activations=bitwidth_activations,
                          per_channel_q=per_channel_q, histogram_range=histogram_range, bias_calibration=bias_calibration,
                          constrain_weights=constrain_weights, model_surgery_quantize=True)
         super().__init__(module, dummy_input=dummy_input, bitwidth_weights=bitwidth_weights, bitwidth_activations=bitwidth_activations,
                          per_channel_q=per_channel_q, histogram_range=histogram_range, bias_calibration=bias_calibration,
                          constrain_weights=constrain_weights, model_surgery_quantize=True)