constrain_weights - in default scenario, this is not used with per_channel_q.
authorManu Mathew <a0393608@ti.com>
Thu, 14 May 2020 03:46:27 +0000 (09:16 +0530)
committerManu Mathew <a0393608@ti.com>
Thu, 14 May 2020 07:56:46 +0000 (13:26 +0530)
modules/pytorch_jacinto_ai/xnn/quantize/quant_calib_module.py
modules/pytorch_jacinto_ai/xnn/quantize/quant_train_module.py

index dba9fa91ea19a6d4bdc6d704b6f8c4a21ec257d2..c7f8cfb369db636bf1a95c729ad4f84caf8a60e1 100644 (file)
@@ -18,7 +18,7 @@ warnings.filterwarnings('ignore', category=torch.jit.TracerWarning)
 ###########################################################
 class QuantCalibrateModule(QuantTrainModule):
     def __init__(self, module, dummy_input, bitwidth_weights=8, bitwidth_activations=8, per_channel_q=False,
-                 histogram_range=True, bias_calibration=True, constrain_weights=True, lr_calib=0.05):
+                 histogram_range=True, bias_calibration=True, constrain_weights=None, lr_calib=0.05):
         self.weights_calibration = False
         self.lr_calib = lr_calib
         self.calibration_factor = lr_calib
@@ -26,7 +26,7 @@ class QuantCalibrateModule(QuantTrainModule):
         self.calibrate_repeats = 1
         self.quantize_enable = True
         self.update_range = True
-        constrain_weights = (constrain_weights and bias_calibration)
+        constrain_weights = (bias_calibration and (not per_channel_q)) if constrain_weights is None else constrain_weights
         super().__init__(module, dummy_input=dummy_input, bitwidth_weights=bitwidth_weights, bitwidth_activations=bitwidth_activations,
                          per_channel_q=per_channel_q, histogram_range=histogram_range, bias_calibration=bias_calibration,
                          constrain_weights=constrain_weights)
index f940fdf785575338f60015aaeba953d6300bb608..7fe340e705ebf742f4a1ed8f9abed618f7c89c18 100644 (file)
@@ -19,7 +19,8 @@ warnings.filterwarnings('ignore', category=torch.jit.TracerWarning)
 ###########################################################
 class QuantTrainModule(QuantBaseModule):
     def __init__(self, module, dummy_input, bitwidth_weights=8, bitwidth_activations=8, per_channel_q=False,
-                 histogram_range=True, bias_calibration=False, constrain_weights=True):
+                 histogram_range=True, bias_calibration=False, constrain_weights=None):
+        constrain_weights = (not per_channel_q) if constrain_weights is None else constrain_weights
         super().__init__(module, dummy_input=dummy_input, bitwidth_weights=bitwidth_weights, bitwidth_activations=bitwidth_activations,
                          per_channel_q=per_channel_q, histogram_range=histogram_range, bias_calibration=bias_calibration,
                          constrain_weights=constrain_weights, model_surgery_quantize=True)