]> Gitweb @ Texas Instruments - Open Source Git Repositories - git.TI.com/gitweb - jacinto-ai/pytorch-jacinto-ai-devkit.git/commitdiff
re-implemented QuantTestModule using QuantTrainModule. constrain_bias added.
authorManu Mathew <a0393608@ti.com>
Wed, 27 May 2020 12:42:48 +0000 (18:12 +0530)
committerManu Mathew <a0393608@ti.com>
Wed, 27 May 2020 13:11:54 +0000 (18:41 +0530)
release commit

modules/pytorch_jacinto_ai/vision/models/classification/__init__.py
modules/pytorch_jacinto_ai/xnn/layers/activation.py
modules/pytorch_jacinto_ai/xnn/quantize/quant_base_module.py
modules/pytorch_jacinto_ai/xnn/quantize/quant_calib_module.py
modules/pytorch_jacinto_ai/xnn/quantize/quant_graph_module.py
modules/pytorch_jacinto_ai/xnn/quantize/quant_test_module.py
modules/pytorch_jacinto_ai/xnn/quantize/quant_train_module.py
modules/pytorch_jacinto_ai/xnn/utils/load_weights.py
modules/pytorch_jacinto_ai/xnn/utils/module_utils.py
run_quantization.sh

index 8331aa19b7443420c0cbbf3ce060e68643704f9b..cf4437a81434979eeed1239f9c63c55bade216df 100644 (file)
@@ -28,7 +28,7 @@ __all__ = ['mobilenetv1_x1', 'mobilenetv2_tv_x1', 'mobilenetv2_x1', 'mobilenetv2
 
 
 #####################################################################
 
 
 #####################################################################
-def resnet50_x1(model_config, pretrained=None):
+def resnet50_x1(model_config=None, pretrained=None):
     model_config = resnet.get_config().merge_from(model_config)
     model = resnet.resnet50_with_model_config(model_config)
     # the pretrained model provided by torchvision and what is defined here differs slightly
     model_config = resnet.get_config().merge_from(model_config)
     model = resnet.resnet50_with_model_config(model_config)
     # the pretrained model provided by torchvision and what is defined here differs slightly
@@ -41,13 +41,13 @@ def resnet50_x1(model_config, pretrained=None):
     return model, change_names_dict
 
 
     return model, change_names_dict
 
 
-def resnet50_xp5(model_config, pretrained=None):
+def resnet50_xp5(model_config=None, pretrained=None):
     model_config.width_mult = 0.5
     return resnet50_x1(model_config=model_config, pretrained=pretrained)
 
 
 #####################################################################
     model_config.width_mult = 0.5
     return resnet50_x1(model_config=model_config, pretrained=pretrained)
 
 
 #####################################################################
-def resnet18_x1(model_config, pretrained=None):
+def resnet18_x1(model_config=None, pretrained=None):
     model_config = resnet.get_config().merge_from(model_config)
     model = resnet.resnet18_with_model_config(model_config)
     # the pretrained model provided by torchvision and what is defined here differs slightly
     model_config = resnet.get_config().merge_from(model_config)
     model = resnet.resnet18_with_model_config(model_config)
     # the pretrained model provided by torchvision and what is defined here differs slightly
@@ -61,14 +61,14 @@ def resnet18_x1(model_config, pretrained=None):
 
 
 #####################################################################
 
 
 #####################################################################
-def mobilenetv1_x1(model_config, pretrained=None):
+def mobilenetv1_x1(model_config=None, pretrained=None):
     model_config = mobilenetv1.get_config().merge_from(model_config)
     model = mobilenetv1.MobileNetV1(model_config=model_config)
     if pretrained:
         model = xnn.utils.load_weights(model, pretrained)
     return model
 
     model_config = mobilenetv1.get_config().merge_from(model_config)
     model = mobilenetv1.MobileNetV1(model_config=model_config)
     if pretrained:
         model = xnn.utils.load_weights(model, pretrained)
     return model
 
-def mobilenetv1_multi_label_x1(model_config, pretrained=None):
+def mobilenetv1_multi_label_x1(model_config=None, pretrained=None):
     model_config = mobilenetv1.get_config().merge_from(model_config)
     model = mobilenetv1_internal.MobileNetV1MultiLabel(model_config=model_config)
     if pretrained:
     model_config = mobilenetv1.get_config().merge_from(model_config)
     model = mobilenetv1_internal.MobileNetV1MultiLabel(model_config=model_config)
     if pretrained:
@@ -77,7 +77,7 @@ def mobilenetv1_multi_label_x1(model_config, pretrained=None):
 
 
 #####################################################################
 
 
 #####################################################################
-def mobilenetv2_tv_x1(model_config, pretrained=None):
+def mobilenetv2_tv_x1(model_config=None, pretrained=None):
     model_config = mobilenetv2.get_config().merge_from(model_config)
     model = mobilenetv2.MobileNetV2TV(model_config=model_config)
     if pretrained:
     model_config = mobilenetv2.get_config().merge_from(model_config)
     model = mobilenetv2.MobileNetV2TV(model_config=model_config)
     if pretrained:
@@ -88,7 +88,7 @@ def mobilenetv2_tv_x1(model_config, pretrained=None):
 mobilenetv2_x1 = mobilenetv2_tv_x1
 
 
 mobilenetv2_x1 = mobilenetv2_tv_x1
 
 
-def mobilenetv2_tv_x2_t2(model_config, pretrained=None):
+def mobilenetv2_tv_x2_t2(model_config=None, pretrained=None):
     model_config = mobilenetv2.get_config().merge_from(model_config)
     model_config.width_mult = 2.0
     model_config.expand_ratio = 2.0
     model_config = mobilenetv2.get_config().merge_from(model_config)
     model_config.width_mult = 2.0
     model_config.expand_ratio = 2.0
@@ -99,7 +99,7 @@ def mobilenetv2_tv_x2_t2(model_config, pretrained=None):
 
 
 #####################################################################
 
 
 #####################################################################
-def mobilenetv2_tv_gws_x1(model_config, pretrained=None):
+def mobilenetv2_tv_gws_x1(model_config=None, pretrained=None):
     model_config = mobilenetv2_internal.get_config_mnetv2_gws().merge_from(model_config)
     model = mobilenetv2_internal.MobileNetV2TVGWS(model_config=model_config)
     if pretrained:
     model_config = mobilenetv2_internal.get_config_mnetv2_gws().merge_from(model_config)
     model = mobilenetv2_internal.MobileNetV2TVGWS(model_config=model_config)
     if pretrained:
@@ -108,7 +108,7 @@ def mobilenetv2_tv_gws_x1(model_config, pretrained=None):
 
 
 #####################################################################
 
 
 #####################################################################
-def mobilenetv2_ericsun_x1(model_config, pretrained=None):
+def mobilenetv2_ericsun_x1(model_config=None, pretrained=None):
     model_config = mobilenetv2_ericsun_internal.get_config().merge_from(model_config)
     model = mobilenetv2_ericsun_internal.MobileNetV2Ericsun(model_config=model_config)
     if pretrained:
     model_config = mobilenetv2_ericsun_internal.get_config().merge_from(model_config)
     model = mobilenetv2_ericsun_internal.MobileNetV2Ericsun(model_config=model_config)
     if pretrained:
@@ -116,17 +116,10 @@ def mobilenetv2_ericsun_x1(model_config, pretrained=None):
     return model
 
 
     return model
 
 
-def mobilenetv2_shicai_x1(model_config, pretrained=None):
+def mobilenetv2_shicai_x1(model_config=None, pretrained=None):
     model_config = mobilenetv2_shicai_internal.get_config().merge_from(model_config)
     model = mobilenetv2_shicai_internal.mobilenetv2_shicai(model_config=model_config)
     if pretrained:
         model = xnn.utils.load_weights(model, pretrained)
     return model
 
     model_config = mobilenetv2_shicai_internal.get_config().merge_from(model_config)
     model = mobilenetv2_shicai_internal.mobilenetv2_shicai(model_config=model_config)
     if pretrained:
         model = xnn.utils.load_weights(model, pretrained)
     return model
 
-
-def flownetslite_base_x1(model_config, pretrained=None):
-    model_config = flownetbase_internal.get_config().merge_from(model_config)
-    model = flownetbase_internal.flownetslite_base(model_config, pretrained=pretrained)
-    if pretrained:
-        model = xnn.utils.load_weights(model, pretrained)
-    return model
\ No newline at end of file
index 325764e0caaea4cc4fb3dbe87ba9283f709fb41a..a1e6ee68af0dab1a6891e821eea07959b237b5a8 100644 (file)
@@ -15,7 +15,7 @@ class PAct2(torch.nn.Module):
     PACT2_RANGE_INIT = 8.0      # this is the starting range
     PACT2_RANGE_EXPANSION = 1.1 # expand the calculated range for margin
 
     PACT2_RANGE_INIT = 8.0      # this is the starting range
     PACT2_RANGE_EXPANSION = 1.1 # expand the calculated range for margin
 
-    def __init__(self, inplace=False, signed=None, percentile_range_shrink=PACT2_RANGE_SHRINK, clip_range=None, **kwargs):
+    def __init__(self, inplace=False, signed=None, percentile_range_shrink=PACT2_RANGE_SHRINK, clip_range=None, power2_activation_range=True, **kwargs):
         super().__init__()
         if (clip_range is not None) and (signed is not None):
             assert signed == (clip_range[0]<0.0), 'the signed flag provided did not match the clip_range provided'
         super().__init__()
         if (clip_range is not None) and (signed is not None):
             assert signed == (clip_range[0]<0.0), 'the signed flag provided did not match the clip_range provided'
@@ -27,7 +27,7 @@ class PAct2(torch.nn.Module):
         self.fixed_range = (clip_range is not None)
         self.learn_range = (self.PACT2_RANGE_LEARN and (not self.fixed_range))
         self.eps = np.power(2.0, -16.0)
         self.fixed_range = (clip_range is not None)
         self.learn_range = (self.PACT2_RANGE_LEARN and (not self.fixed_range))
         self.eps = np.power(2.0, -16.0)
-        self.power2 = True   # power of 2 ranges
+        self.power2_activation_range = power2_activation_range   # power of 2 ranges
         self.log_base = None # 2.0  # log is used only in learned mode if log_base is not None
 
         # any validation before at-least one iteration of training wll use default clip values.
         self.log_base = None # 2.0  # log is used only in learned mode if log_base is not None
 
         # any validation before at-least one iteration of training wll use default clip values.
@@ -54,8 +54,8 @@ class PAct2(torch.nn.Module):
         #
 
 
         #
 
 
-    def forward(self, x, update_range=True, enable=True):
-        if (self.training and update_range):
+    def forward(self, x, update_activation_range=True, enable=True):
+        if (self.training and update_activation_range):
             self.num_batches_tracked += 1
             # even in learn_range mode - do this for a few iterations to get a good starting point
             if (not self.fixed_range) and ((not self.learn_range) or (self.num_batches_tracked < 100)):
             self.num_batches_tracked += 1
             # even in learn_range mode - do this for a few iterations to get a good starting point
             if (not self.fixed_range) and ((not self.learn_range) or (self.num_batches_tracked < 100)):
@@ -66,10 +66,11 @@ class PAct2(torch.nn.Module):
         #
         if not enable:
             signed = self.clips_act[0] < 0.0 if (self.signed is None) else self.signed
         #
         if not enable:
             signed = self.clips_act[0] < 0.0 if (self.signed is None) else self.signed
-            return x if signed else torch.nn.functional.relu(x)
+            y = x if signed else torch.nn.functional.relu(x)
+        else:
+            clips = self.get_clips_act()
+            y = clamp_g(x, clips[0], clips[1], self.training, self.inplace, requires_grad=self.learn_range)
         #
         #
-        clips = self.get_clips_act()
-        y = clamp_g(x, clips[0], clips[1], self.training, self.inplace, requires_grad=self.learn_range)
         return y
 
 
         return y
 
 
@@ -110,9 +111,10 @@ class PAct2(torch.nn.Module):
         clip_max = torch.max(torch.abs(self.clips_act)) if signed else torch.abs(self.clips_act[1])
         clip_max = torch.clamp(clip_max, min=self.eps)
         clip_max = self.convert_to_linear(clip_max)
         clip_max = torch.max(torch.abs(self.clips_act)) if signed else torch.abs(self.clips_act[1])
         clip_max = torch.clamp(clip_max, min=self.eps)
         clip_max = self.convert_to_linear(clip_max)
-        # in range learning mode + training - this power2 is taken care in the quantize function
-        use_power2 = (self.power2 and (not (self.PACT2_RANGE_LEARN and self.training)))
-        clip_max2 = ceil2_g(clip_max) if use_power2 else clip_max
+        # in range learning mode + training - this power2_activation_range is taken care in the quantize function
+        is_learning_range = (self.PACT2_RANGE_LEARN and self.training)
+        use_power2_activation_range = (self.power2_activation_range and (not is_learning_range))
+        clip_max2 = ceil2_g(clip_max) if use_power2_activation_range else clip_max
         clip_min2 = (-clip_max2 if signed else clip_max2*0.0)
         return (clip_min2, clip_max2)
 
         clip_min2 = (-clip_max2 if signed else clip_max2*0.0)
         return (clip_min2, clip_max2)
 
@@ -136,7 +138,22 @@ class ReLU1(torch.nn.Hardtanh):
 
 
 ###############################################################
 
 
 ###############################################################
-class NoAct(torch.nn.Module):
+# Always quantized activation function.
+# Inserting this activation function is a simple way to ensure quantization happens at certain places.
+class QAct(torch.nn.Module):
+    def __init__(self, inplace=False, signed=True, **kwargs):
+        super().__init__()
+        self.inplace = inplace
+        self.signed = signed
+
+    def forward(self, x):
+        return x
+
+
+# Never quantized activation function.
+# Also if the next block is this, the previous block output is also not quantized.
+# Inserting this activation function is a simple way to avoid quantization at certain places.
+class NoQAct(torch.nn.Module):
     def __init__(self, inplace=False, signed=True, **kwargs):
         super().__init__()
         self.inplace = inplace
     def __init__(self, inplace=False, signed=True, **kwargs):
         super().__init__()
         self.inplace = inplace
index 30a8f91d58116c0670b30d975a2968cd567948d4..fa6b75396e10e1ac25de878c5c6a2b02aae217fe 100644 (file)
@@ -11,15 +11,20 @@ class QuantEstimationType:
 # base module to be use for all quantization modules
 class QuantBaseModule(QuantGraphModule):
     def __init__(self, module, dummy_input, bitwidth_weights=8, bitwidth_activations=8, per_channel_q=False,
 # base module to be use for all quantization modules
 class QuantBaseModule(QuantGraphModule):
     def __init__(self, module, dummy_input, bitwidth_weights=8, bitwidth_activations=8, per_channel_q=False,
-                 histogram_range=True, bias_calibration=False, constrain_weights=False,
-                 model_surgery_quantize=False):
+                 histogram_range=True, bias_calibration=False, constrain_weights=False, constrain_bias=None,
+                 model_surgery_quantize=True, power2_weight_range=None, power2_activation_range=None):
         super().__init__(module)
         self.bitwidth_weights = bitwidth_weights
         self.bitwidth_activations = bitwidth_activations
         self.per_channel_q = per_channel_q
         self.histogram_range = histogram_range
         self.constrain_weights = constrain_weights
         super().__init__(module)
         self.bitwidth_weights = bitwidth_weights
         self.bitwidth_activations = bitwidth_activations
         self.per_channel_q = per_channel_q
         self.histogram_range = histogram_range
         self.constrain_weights = constrain_weights
+        self.constrain_bias = True if (constrain_bias is None) else constrain_bias
         self.bias_calibration = bias_calibration
         self.bias_calibration = bias_calibration
+        self.power2_weight_range = True if (power2_weight_range is None) else power2_weight_range
+        self.power2_activation_range = True if (power2_activation_range is None) else power2_activation_range
+        # range shrink - 0.0 indicates no shrink
+        self.percentile_range_shrink = (layers.PAct2.PACT2_RANGE_SHRINK if histogram_range else 0.0)
         # for help in debug/print
         utils.add_module_names(self)
         # put in eval mode before analyze
         # for help in debug/print
         utils.add_module_names(self)
         # put in eval mode before analyze
@@ -37,17 +42,38 @@ class QuantBaseModule(QuantGraphModule):
         # for help in debug/print
         utils.add_module_names(self)
 
         # for help in debug/print
         utils.add_module_names(self)
 
+        # set attributes to all modules - can control the behaviour from here
+        utils.apply_setattr(self, bitwidth_weights=self.bitwidth_weights, bitwidth_activations=self.bitwidth_activations,
+                            histogram_range=histogram_range, bias_calibration=self.bias_calibration, per_channel_q=self.per_channel_q,
+                            percentile_range_shrink=self.percentile_range_shrink, constrain_weights=self.constrain_weights,
+                            power2_weight_range=self.power2_weight_range, power2_activation_range=self.power2_activation_range,
+                            constrain_bias=self.constrain_bias)
 
     def add_activation_hooks(self):
         # add a forward hook to call the extra activation that we added
 
     def add_activation_hooks(self):
         # add a forward hook to call the extra activation that we added
-        def _forward_activation(op, inputs, outputs):
+        def _forward_input_activation(op, inputs):
+            if hasattr(op, 'activation_in'):
+                # hook passes the input as tuple - expand it
+                to_squeeze = isinstance(inputs, tuple) and len(inputs) == 1
+                inputs = inputs[0] if to_squeeze else inputs
+                inputs = op.activation_in(inputs)
+                inputs = (inputs,) if to_squeeze else inputs
+            #
+            return inputs
+        #
+        def _forward_output_activation(op, inputs, outputs):
             if hasattr(op, 'activation_q'):
             if hasattr(op, 'activation_q'):
+                # hook passes the input as tuple - expand it
+                to_squeeze = isinstance(outputs, tuple) and len(outputs) == 1
+                outputs = outputs[0] if to_squeeze else outputs
                 outputs = op.activation_q(outputs)
                 outputs = op.activation_q(outputs)
+                outputs = (outputs,) if to_squeeze else outputs
             #
             return outputs
         #
         for m in self.modules():
             #
             return outputs
         #
         for m in self.modules():
-            m.register_forward_hook(_forward_activation)
+            m.register_forward_pre_hook(_forward_input_activation)
+            m.register_forward_hook(_forward_output_activation)
         #
 
 
         #
 
 
index 266b0dabc0a3aed066194f9cf84f70a8899b457b..954ea45368fa25a72eeb16f59e28c63f545e1d58 100644 (file)
@@ -18,18 +18,19 @@ warnings.filterwarnings('ignore', category=torch.jit.TracerWarning)
 ###########################################################
 class QuantCalibrateModule(QuantTrainModule):
     def __init__(self, module, dummy_input, bitwidth_weights=8, bitwidth_activations=8, per_channel_q=False,
 ###########################################################
 class QuantCalibrateModule(QuantTrainModule):
     def __init__(self, module, dummy_input, bitwidth_weights=8, bitwidth_activations=8, per_channel_q=False,
-                 histogram_range=True, bias_calibration=True, constrain_weights=None, lr_calib=0.05):
+                 histogram_range=True, bias_calibration=True, constrain_weights=None,
+                 power2_weight_range=None, power2_activation_range=None, constrain_bias=None, lr_calib=0.05):
         self.weights_calibration = False
         self.lr_calib = lr_calib
         self.calibration_factor = lr_calib
         self.calibration_gamma = 0.5
         self.calibrate_repeats = 1
         self.quantize_enable = True
         self.weights_calibration = False
         self.lr_calib = lr_calib
         self.calibration_factor = lr_calib
         self.calibration_gamma = 0.5
         self.calibrate_repeats = 1
         self.quantize_enable = True
-        self.update_range = True
+        self.update_activation_range = True
         constrain_weights = (bias_calibration and (not per_channel_q)) if constrain_weights is None else constrain_weights
         super().__init__(module, dummy_input=dummy_input, bitwidth_weights=bitwidth_weights, bitwidth_activations=bitwidth_activations,
         constrain_weights = (bias_calibration and (not per_channel_q)) if constrain_weights is None else constrain_weights
         super().__init__(module, dummy_input=dummy_input, bitwidth_weights=bitwidth_weights, bitwidth_activations=bitwidth_activations,
-                         per_channel_q=per_channel_q, histogram_range=histogram_range, bias_calibration=bias_calibration,
-                         constrain_weights=constrain_weights)
+                         per_channel_q=per_channel_q, histogram_range=histogram_range, bias_calibration=bias_calibration, constrain_weights=constrain_weights,
+                         power2_weight_range=power2_weight_range, power2_activation_range=power2_activation_range, constrain_bias=constrain_bias)
 
 
     def forward(self, inputs):
 
 
     def forward(self, inputs):
@@ -83,15 +84,15 @@ class QuantCalibrateModule(QuantTrainModule):
     def forward_float(self, inputs):
         self._restore_weights_orig()
         # disable quantization for a moment
     def forward_float(self, inputs):
         self._restore_weights_orig()
         # disable quantization for a moment
-        quantize_enable_backup_value, update_range_backup_value = self.quantize_enable, self.update_range
-        utils.apply_setattr(self, quantize_enable=False, update_range=False)
+        quantize_enable_backup_value, update_activation_range_backup_value = self.quantize_enable, self.update_activation_range
+        utils.apply_setattr(self, quantize_enable=False, update_activation_range=False)
 
         self.add_call_hook(self.module, self.forward_float_hook)
         outputs = self.module(inputs)
         self.remove_call_hook(self.module)
 
         # turn quantization back on - not a clean method
 
         self.add_call_hook(self.module, self.forward_float_hook)
         outputs = self.module(inputs)
         self.remove_call_hook(self.module)
 
         # turn quantization back on - not a clean method
-        utils.apply_setattr(self, quantize_enable=quantize_enable_backup_value, update_range=update_range_backup_value)
+        utils.apply_setattr(self, quantize_enable=quantize_enable_backup_value, update_activation_range=update_activation_range_backup_value)
         self._backup_weights_orig()
         return outputs
     #
         self._backup_weights_orig()
         return outputs
     #
index 1086cb559fa50063a7d95bd99e3adff9a9c95d4c..cbcb3c1ebb6fd8c4a465ab5d1eb80d0b87b53cec 100644 (file)
@@ -14,6 +14,13 @@ class QuantGraphModule(HookedModule):
         self.num_batches_tracked = -1
         self.iter_in_epoch = -1
         self.epoch = -1
         self.num_batches_tracked = -1
         self.iter_in_epoch = -1
         self.epoch = -1
+        # these are the blocks whose output we quantize for sure.
+        # outputs of other clocks such as Conv2d, ConvTranspose2d, BatchNorm2d, Lindear are quantized conditionally
+        self.quantize_out_blocks = (torch.nn.ReLU, torch.nn.ReLU6, torch.nn.Hardtanh, layers.QAct, layers.PAct2,
+                                    layers.AddBlock, layers.CatBlock, layers.MultBlock, torch.nn.MaxPool2d, torch.nn.AvgPool2d)
+
+        # this block is not quantized. Also if the next block is this, current block is not quantized
+        self.ignore_out_blocks = (layers.NoQAct,torch.nn.Dropout2d)
 
         # TBD: is this required
         # # if the original module has load_weights, add it to the quant module also
 
         # TBD: is this required
         # # if the original module has load_weights, add it to the quant module also
@@ -93,7 +100,7 @@ class QuantGraphModule(HookedModule):
         self.clear_qstate()
         # analyze
         self.analyze_graph(dummy_input)
         self.clear_qstate()
         # analyze
         self.analyze_graph(dummy_input)
-        # insert NoAct wherever range clipping needs to be done
+        # insert QAct wherever range clipping needs to be done
         self.model_surgery_activations()
         # since we might have added new activations, clear the sates as they may not be valid
         self.clear_qstate()
         self.model_surgery_activations()
         # since we might have added new activations, clear the sates as they may not be valid
         self.clear_qstate()
@@ -111,7 +118,7 @@ class QuantGraphModule(HookedModule):
                         activation_q = layers.PAct2(signed=False)
                     elif isinstance(module, torch.nn.Hardtanh):
                         activation_q = layers.PAct2(clip_range=(module.min_val, module.max_val))
                         activation_q = layers.PAct2(signed=False)
                     elif isinstance(module, torch.nn.Hardtanh):
                         activation_q = layers.PAct2(clip_range=(module.min_val, module.max_val))
-                    elif isinstance(module, layers.NoAct):
+                    elif isinstance(module, layers.QAct):
                         activation_q = layers.PAct2(signed=None)
                     else:
                         activation_q = layers.PAct2(signed=None)
                         activation_q = layers.PAct2(signed=None)
                     else:
                         activation_q = layers.PAct2(signed=None)
@@ -126,6 +133,12 @@ class QuantGraphModule(HookedModule):
                     activation_q.train(self.training)
                     module.activation_q = activation_q
                 #
                     activation_q.train(self.training)
                     module.activation_q = activation_q
                 #
+            elif qparams.quantize_in:
+                if not hasattr(module, 'activation_in'):
+                    activation_in = layers.PAct2(signed=None)
+                    activation_in.train(self.training)
+                    module.activation_in = activation_in
+                #
             else:
                 pass
             #
             else:
                 pass
             #
@@ -150,14 +163,14 @@ class QuantGraphModule(HookedModule):
 
     def _forward_analyze_modules_impl(self, inputs):
         self.start_call()
 
     def _forward_analyze_modules_impl(self, inputs):
         self.start_call()
-        self.add_call_hook(self.module, self._analyze_modules_op)
+        self.add_call_hook(self, self._analyze_modules_op)
         output = self.module(inputs)
         self.remove_call_hook(self.module)
         self.finish_call()
         return output
 
     def _analyze_modules_op(self, op, *inputs_orig):
         output = self.module(inputs)
         self.remove_call_hook(self.module)
         self.finish_call()
         return output
 
     def _analyze_modules_op(self, op, *inputs_orig):
-        inputs = utils.squeeze_list(inputs_orig)
+        inputs = utils.squeeze_list2(inputs_orig)
         self.start_node(op)
         self.add_node(op, inputs)
         outputs = op.__forward_orig__(*inputs_orig)
         self.start_node(op)
         self.add_node(op, inputs)
         outputs = op.__forward_orig__(*inputs_orig)
@@ -208,50 +221,55 @@ class QuantGraphModule(HookedModule):
 
     ################################################################
     def analyze_connections(self):
 
     ################################################################
     def analyze_connections(self):
+        first_module = None
         prediction_module = None
         for module_hash, qparams in self.get_qstate().qparams.items():
             module = self.get_module(module_hash)
         prediction_module = None
         for module_hash, qparams in self.get_qstate().qparams.items():
             module = self.get_module(module_hash)
-            if utils.is_conv(module) or utils.is_linear(module) or utils.is_normalization(module) or utils.is_activation(module):
+            if utils.is_conv_deconv_linear(module) or utils.is_normalization(module) or utils.is_activation(module):
+                first_module = module if first_module is None else first_module
                 prediction_module = module
             #
         #
         for module_hash, qparams in self.get_qstate().qparams.items():
             module = self.get_module(module_hash)
                 prediction_module = module
             #
         #
         for module_hash, qparams in self.get_qstate().qparams.items():
             module = self.get_module(module_hash)
-            is_prediction = (prediction_module is module)
-            self._analyse_connections_op(module_hash, module, qparams, is_prediction)
+            is_first_module = (first_module is module)
+            is_prediction_module = (prediction_module is module)
+            self._analyse_connections_op(module_hash, module, qparams, is_first_module, is_prediction_module)
         #
 
         #
 
-    def _analyse_connections_op(self, module_hash, module, qparams, is_prediction):
-        previous_module = [self.get_module(p) for p in qparams.previous_node] if len(qparams.previous_node)>0 else []
-        next_module = [self.get_module(n) for n in qparams.next_node] if len(qparams.next_node)>0 else []
+    def _analyse_connections_op(self, module_hash, module, qparams, is_first_module, is_prediction_module):
+        previous_modules = [self.get_module(p) for p in qparams.previous_node] if len(qparams.previous_node)>0 else []
+        next_modules = [self.get_module(n) for n in qparams.next_node] if len(qparams.next_node)>0 else []
 
         quantize_out = False
 
         quantize_out = False
-        if utils.is_activation(module):
-            if len(next_module)==1 and utils.is_activation(next_module[0]):
+        if isinstance(module, self.ignore_out_blocks):
+            quantize_out = False
+        elif utils.is_activation(module):
+            if len(next_modules)==1 and utils.is_activation(next_modules[0]):
                 quantize_out = False
             else:
                 quantize_out = True
             #
                 quantize_out = False
             else:
                 quantize_out = True
             #
-        elif isinstance(module, (layers.AddBlock, layers.CatBlock, layers.MultBlock)):
-            if len(next_module)==1 and utils.is_activation(next_module[0]):
+        elif isinstance(module, self.quantize_out_blocks):
+            if len(next_modules)==1 and utils.is_activation(next_modules[0]):
                 quantize_out = False
             else:
                 quantize_out = True
             #
         elif utils.is_normalization(module):
                 quantize_out = False
             else:
                 quantize_out = True
             #
         elif utils.is_normalization(module):
-            if len(next_module)==1 and utils.is_activation(next_module[0]):
+            if len(next_modules)==1 and utils.is_activation(next_modules[0]):
                 quantize_out = False
             else:
                 quantize_out = True
             #
         elif utils.is_conv(module) or utils.is_deconv(module):
                 quantize_out = False
             else:
                 quantize_out = True
             #
         elif utils.is_conv(module) or utils.is_deconv(module):
-            if len(next_module)==1 and (utils.is_normalization(next_module[0]) or utils.is_activation(next_module[0])):
+            if len(next_modules)==1 and (utils.is_normalization(next_modules[0]) or utils.is_activation(next_modules[0])):
                 quantize_out = False
             else:
                 quantize_out = True
             #
         elif utils.is_linear(module):
                 quantize_out = False
             else:
                 quantize_out = True
             #
         elif utils.is_linear(module):
-            if len(next_module)==1 and (utils.is_normalization(next_module[0]) or utils.is_activation(next_module[0])):
+            if len(next_modules)==1 and (utils.is_normalization(next_modules[0]) or utils.is_activation(next_modules[0])):
                 quantize_out = False
             else:
                 quantize_out = True
                 quantize_out = False
             else:
                 quantize_out = True
@@ -260,16 +278,32 @@ class QuantGraphModule(HookedModule):
         #     quantize_out = True
         # #
 
         #     quantize_out = True
         # #
 
-        qparams.quantize_w = isinstance(module, (torch.nn.Conv2d,torch.nn.ConvTranspose2d,torch.nn.Linear))     # all conv/deconv layers will be quantized
-        qparams.quantize_b = isinstance(module, (torch.nn.Conv2d,torch.nn.ConvTranspose2d,torch.nn.Linear))     # all conv/deconv layers will be quantized
-        qparams.quantize_out = quantize_out                                            # selectively quantize output
-        qparams.quantize_in = qparams.is_input                                         # only top modules's input need to be quantized
-        qparams.align_in = isinstance(module, (layers.AddBlock, layers.CatBlock,torch.nn.AdaptiveAvgPool2d))# all tensors to be made same q at the input
-        qparams.scale_in = 64.0 if isinstance(module, torch.nn.AdaptiveAvgPool2d) else 1.0  #additional scaleup to simulate fixed point
-        qparams.unquantize_out = qparams.is_input                                      # only top modules's output need to be unquantized
+        if len(qparams.previous_node) > 0:
+            previous_module_hash = qparams.previous_node[-1]
+            previous_module = self.get_module(previous_module_hash)
+            previous_module_qparams = self.get_qstate().qparams[previous_module_hash]
+            is_input_ignored = isinstance(previous_module, self.ignore_out_blocks)
+            is_input_quantized = previous_module_qparams.quantize_out if \
+                hasattr(previous_module_qparams, 'quantize_out') else False
+        else:
+            is_input_ignored = False
+            is_input_quantized = False
+        #
+
+        quantize_in = utils.is_conv_deconv_linear(module) and not is_input_quantized and \
+                      not is_input_ignored and is_first_module
+        qparams.quantize_w = utils.is_conv_deconv_linear(module)                                # all conv/deconv layers will be quantized
+        qparams.quantize_b = utils.is_conv_deconv_linear(module)                                # all conv/deconv layers will be quantized
+        qparams.quantize_out = quantize_out                                                     # selectively quantize output
+        qparams.quantize_in = quantize_in                                                       # only top modules's input need to be quantized
+        multi_input_blocks = (layers.AddBlock, layers.CatBlock, torch.nn.AdaptiveAvgPool2d)
+        qparams.align_in = isinstance(module, multi_input_blocks)                               # all tensors to be made same q at the input
+        qparams.scale_in = 64.0 if isinstance(module, torch.nn.AdaptiveAvgPool2d) else 1.0      # additional scaleup to simulate fixed point
+        qparams.unquantize_out = qparams.is_input                                               # only top modules's output need to be unquantized
         qparams.is_dwconv = utils.is_dwconv(module)
         qparams.is_dwconv = utils.is_dwconv(module)
-        qparams.next_module = next_module
-        qparams.is_prediction = is_prediction
+        qparams.next_modules = next_modules
+        qparams.is_first_module = is_first_module
+        qparams.is_prediction_module = is_prediction_module
 
 
     ################################################################
 
 
     ################################################################
@@ -286,7 +320,7 @@ class QuantGraphModule(HookedModule):
         is_conv = utils.is_conv_deconv(module)
 
         # note: we consider merging only if there is a single next node
         is_conv = utils.is_conv_deconv(module)
 
         # note: we consider merging only if there is a single next node
-        next_module = qparams.next_module[0] if len(qparams.next_module) == 1 else None
+        next_module = qparams.next_modules[0] if len(qparams.next_modules) == 1 else None
         next_bn = isinstance(next_module, torch.nn.BatchNorm2d) if (next_module is not None) else None
 
         # if the next module is a bn, appy bn merging step
         next_bn = isinstance(next_module, torch.nn.BatchNorm2d) if (next_module is not None) else None
 
         # if the next module is a bn, appy bn merging step
@@ -404,7 +438,7 @@ class QuantGraphModule(HookedModule):
 
     def format_tensors(self, inputs):
         # make a list/tuple if inputs is not. if it is a double list, remove the extra one
 
     def format_tensors(self, inputs):
         # make a list/tuple if inputs is not. if it is a double list, remove the extra one
-        inputs = utils.squeeze_list(utils.make_list(inputs))
+        inputs = utils.squeeze_list2(utils.make_list(inputs))
         # remove lists/tuple
         inputs = [ipt for ipt in inputs if utils.is_tensor(ipt)]
         return inputs
         # remove lists/tuple
         inputs = [ipt for ipt in inputs if utils.is_tensor(ipt)]
         return inputs
index f0d47898ac97938f398e08a032bc23ca0dc7d092..76a139160718e44658dfca591a108eb12e5b4c6f 100644 (file)
@@ -4,18 +4,46 @@ import copy
 import warnings
 import numpy as np
 
 import warnings
 import numpy as np
 
+
+########################################################################
+from .quant_train_module import *
+
+class QuantTestModule(QuantTrainModule):
+    def __init__(self, module, dummy_input, bitwidth_weights=8, bitwidth_activations=8, per_channel_q=False,
+                 histogram_range=True, bias_calibration=False, constrain_weights=None, model_surgery_quantize=True,
+                 power2_weight_range=None, power2_activation_range=None, constrain_bias=None):
+        constrain_weights = (not per_channel_q) if constrain_weights is None else constrain_weights
+        super().__init__(module, dummy_input=dummy_input, bitwidth_weights=bitwidth_weights, bitwidth_activations=bitwidth_activations,
+                         per_channel_q=per_channel_q, histogram_range=histogram_range, bias_calibration=bias_calibration,
+                         constrain_weights=constrain_weights,
+                         power2_weight_range=power2_weight_range, power2_activation_range=power2_activation_range, constrain_bias=constrain_bias)
+        assert model_surgery_quantize == True, f'{self.__class__.__name__} does not support model_surgery_quantize=False. please use a qat or calibrated module.'
+        self.eval()
+
+
+    def train(self, mode=True):
+        assert mode == False, 'QuantTestModule cannot be used in train mode'
+        super().train(mode)
+
+
+########################################################################
 from .quant_base_module import *
 from .quant_utils import *
 
 
 from .quant_base_module import *
 from .quant_utils import *
 
 
-class QuantTestModule(QuantBaseModule):
-    def __init__(self, module, dummy_input, bitwidth_weights=8, bitwidth_activations=8, per_channel_q=False, histogram_range=True,
-                 range_calibration_online=False, model_surgery_quantize=True):
+class QuantEstimateModule(QuantBaseModule):
+    '''
+    QuantEstimateModule  can be used to estimate the quantization accuracy of a float model
+    that has not gone through QAT or Calibration. However, this is an approximate method.
+    '''
+    def __init__(self, module, dummy_input, bitwidth_weights=8, bitwidth_activations=8, per_channel_q=False,
+                 histogram_range=True, range_calibration_online=False, model_surgery_quantize=True,
+                 power2_weight_range=None, power2_activation_range=None, constrain_bias=None):
         super().__init__(module, dummy_input=dummy_input, bitwidth_weights=bitwidth_weights, bitwidth_activations=bitwidth_activations,
                          per_channel_q=per_channel_q, histogram_range=histogram_range, bias_calibration=False,
         super().__init__(module, dummy_input=dummy_input, bitwidth_weights=bitwidth_weights, bitwidth_activations=bitwidth_activations,
                          per_channel_q=per_channel_q, histogram_range=histogram_range, bias_calibration=False,
-                         constrain_weights=False, model_surgery_quantize=model_surgery_quantize)
-        # use power2_weights for now
-        self.power2_weights = True
+                         constrain_weights=False, model_surgery_quantize=model_surgery_quantize,
+                         power2_weight_range=power2_weight_range, power2_activation_range=power2_activation_range, constrain_bias=constrain_bias)
+        assert False, 'recommend to use QuantTestModule instead'
         # whether to do online adjustment of calibration using previous frame range
         self.range_calibration_online = range_calibration_online
         # number of offline calibration iters. during offline calibration, current frame range is used
         # whether to do online adjustment of calibration using previous frame range
         self.range_calibration_online = range_calibration_online
         # number of offline calibration iters. during offline calibration, current frame range is used
@@ -41,7 +69,7 @@ class QuantTestModule(QuantBaseModule):
 
         def replace_func(op):
             for name, m in op._modules.items():
 
         def replace_func(op):
             for name, m in op._modules.items():
-                if isinstance(m, layers.NoAct):
+                if isinstance(m, layers.QAct):
                     new_m = layers.PAct2(signed=None)
                 else:
                     new_m = None
                     new_m = layers.PAct2(signed=None)
                 else:
                     new_m = None
@@ -91,7 +119,7 @@ class QuantTestModule(QuantBaseModule):
 
 
     def _forward_quantize_hook(self, op, *inputs_orig):
 
 
     def _forward_quantize_hook(self, op, *inputs_orig):
-        inputs = utils.squeeze_list(inputs_orig)
+        inputs = utils.squeeze_list2(inputs_orig)
         self.start_node(op)
         self.start_quantize(op)
 
         self.start_node(op)
         self.start_quantize(op)
 
@@ -143,13 +171,13 @@ class QuantTestModule(QuantBaseModule):
 
         if qparams.quantize_w and weight is not None:
             qparams.qrange_w = Dict()
 
         if qparams.quantize_w and weight is not None:
             qparams.qrange_w = Dict()
-            self.quantize_weights(module, weight, qparams.qrange_w)
+            self.quantize_weights_tensor(module, weight, qparams.qrange_w)
         else:
             qparams.qrange_w = None
 
         if qparams.quantize_b and bias is not None:
             qparams.qparams_b = Dict()
         else:
             qparams.qrange_w = None
 
         if qparams.quantize_b and bias is not None:
             qparams.qparams_b = Dict()
-            self.quantize_bias(module, bias, qparams.qparams_b)
+            self.quantize_bias_tensor(module, bias, qparams.qparams_b)
         else:
             qparams.qparams_b = None
 
         else:
             qparams.qparams_b = None
 
@@ -167,12 +195,12 @@ class QuantTestModule(QuantBaseModule):
         for inp in inputs:
             inp.scale = inp.scale  if hasattr(inp,'scale') else self.current_scale
 
         for inp in inputs:
             inp.scale = inp.scale  if hasattr(inp,'scale') else self.current_scale
 
-        qrange_cur = self.quantize_inputs(module, inputs, outputs, qparams_prev, qparams)
+        qrange_cur = self.quantize_input_tensors(module, inputs, outputs, qparams_prev, qparams)
 
         # create the current scale in proccess_inputs instead of process_outputs.
         # otherwise exit condition for aggregate modules (eg. torch.nn.Sequential, Bottleneck in ResNet) will cause trouble.
 
         # create the current scale in proccess_inputs instead of process_outputs.
         # otherwise exit condition for aggregate modules (eg. torch.nn.Sequential, Bottleneck in ResNet) will cause trouble.
-        # all the inputs scales are assumed to be aligned at this point (see align_inputs)
-        # any module that needs special handling needs to be considered in quantize_inputs / align_inputs.
+        # all the inputs scales are assumed to be aligned at this point (see align_input_tensors)
+        # any module that needs special handling needs to be considered in quantize_input_tensors / align_input_tensors.
         has_weight_scale = (hasattr(module,'weight') and (module.weight is not None) and hasattr(module.weight,'scale'))
         if has_weight_scale:
             is_dw = utils.is_dwconv(module)
         has_weight_scale = (hasattr(module,'weight') and (module.weight is not None) and hasattr(module.weight,'scale'))
         if has_weight_scale:
             is_dw = utils.is_dwconv(module)
@@ -210,7 +238,7 @@ class QuantTestModule(QuantBaseModule):
         for idx, opt in enumerate(output):
             opt.scale = self.current_scale
 
         for idx, opt in enumerate(output):
             opt.scale = self.current_scale
 
-        qrange_cur = self.quantize_outputs(module, inputs, output, qparams_prev, qparams)
+        qrange_cur = self.quantize_output_tensors(module, inputs, output, qparams_prev, qparams)
         self.viz_act(en=False, opt_q=output[0], opt_fl=inputs[0])
         self.current_scale = output[0].scale
 
         self.viz_act(en=False, opt_q=output[0], opt_fl=inputs[0])
         self.current_scale = output[0].scale
 
@@ -251,8 +279,8 @@ class QuantTestModule(QuantBaseModule):
 
     def _update_activation_ranges(self, module, tensor_in, running_update, qrange_cur, qrange_prev):
         is_calibration = (self.iter_in_epoch < self.range_calibration_offline_iters)
 
     def _update_activation_ranges(self, module, tensor_in, running_update, qrange_cur, qrange_prev):
         is_calibration = (self.iter_in_epoch < self.range_calibration_offline_iters)
-        update_range = (is_calibration or self.range_calibration_online)
-        if update_range:
+        update_activation_range = (is_calibration or self.range_calibration_online)
+        if update_activation_range:
             # in the case of fixed range module, we do not expand the ranges
             fixed_range_module = utils.is_fixed_range(module)
             if fixed_range_module:
             # in the case of fixed range module, we do not expand the ranges
             fixed_range_module = utils.is_fixed_range(module)
             if fixed_range_module:
@@ -293,7 +321,7 @@ class QuantTestModule(QuantBaseModule):
         return bitwidth_activations
 
 
         return bitwidth_activations
 
 
-    def quantize_weights(self, module, tensor_in, qrange):
+    def quantize_weights_tensor(self, module, tensor_in, qrange):
         self.apply_constrain_weights(module)
 
         bitwidth_weights = self.get_bitwidth_weights(module)
         self.apply_constrain_weights(module)
 
         bitwidth_weights = self.get_bitwidth_weights(module)
@@ -307,7 +335,7 @@ class QuantTestModule(QuantBaseModule):
                 for chan in range(tensor_in.shape[0]):
                     # Range
                     mn, mx = self.compute_tensor_range(module, tensor_in[chan], percentile_range_shrink=self.percentile_range_shrink_weights)
                 for chan in range(tensor_in.shape[0]):
                     # Range
                     mn, mx = self.compute_tensor_range(module, tensor_in[chan], percentile_range_shrink=self.percentile_range_shrink_weights)
-                    tensor_scale, clamp_limits = compute_tensor_scale(tensor_in[chan], mn, mx, bitwidth_weights, self.power2_weights)
+                    tensor_scale, clamp_limits = compute_tensor_scale(tensor_in[chan], mn, mx, bitwidth_weights, self.power2_weight_range)
                     qrange.min.append(mn)
                     qrange.max.append(mx)
                     # Quantize
                     qrange.min.append(mn)
                     qrange.max.append(mx)
                     # Quantize
@@ -320,7 +348,7 @@ class QuantTestModule(QuantBaseModule):
             else:
                 # Range
                 mn, mx = self.compute_tensor_range(module, tensor_in, percentile_range_shrink=self.percentile_range_shrink_weights)
             else:
                 # Range
                 mn, mx = self.compute_tensor_range(module, tensor_in, percentile_range_shrink=self.percentile_range_shrink_weights)
-                tensor_scale, clamp_limits = compute_tensor_scale(tensor_in, mn, mx, bitwidth_weights, self.power2_weights)
+                tensor_scale, clamp_limits = compute_tensor_scale(tensor_in, mn, mx, bitwidth_weights, self.power2_weight_range)
                 qrange.min = mn
                 qrange.max = mx
                 # Quantize
                 qrange.min = mn
                 qrange.max = mx
                 # Quantize
@@ -331,7 +359,7 @@ class QuantTestModule(QuantBaseModule):
                 tensor_in.scale = 1.0
 
 
                 tensor_in.scale = 1.0
 
 
-    def quantize_bias(self, module, tensor_in, qparams):
+    def quantize_bias_tensor(self, module, tensor_in, qparams):
         quant_for_bias = True
         if quant_for_bias:
             bitwidth_weights = self.get_bitwidth_weights(module)
         quant_for_bias = True
         if quant_for_bias:
             bitwidth_weights = self.get_bitwidth_weights(module)
@@ -340,7 +368,7 @@ class QuantTestModule(QuantBaseModule):
             bitwidth_bias = bitwidth_weights
             
             mn, mx = self.compute_tensor_range(module, tensor_in, percentile_range_shrink=0.0)
             bitwidth_bias = bitwidth_weights
             
             mn, mx = self.compute_tensor_range(module, tensor_in, percentile_range_shrink=0.0)
-            tensor_scale, clamp_limits = compute_tensor_scale(tensor_in, mn, mx, bitwidth_bias, self.power2_weights)
+            tensor_scale, clamp_limits = compute_tensor_scale(tensor_in, mn, mx, bitwidth_bias, self.power2_weight_range)
 
             # --
             tensor = symmetric_round_tensor(tensor_in * tensor_scale)
 
             # --
             tensor = symmetric_round_tensor(tensor_in * tensor_scale)
@@ -353,7 +381,7 @@ class QuantTestModule(QuantBaseModule):
             tensor_in.scale = 1.0
 
 
             tensor_in.scale = 1.0
 
 
-    def quantize_inputs(self, module, input, output, qparams_prev, qparams):
+    def quantize_input_tensors(self, module, input, output, qparams_prev, qparams):
         qrange_cur = []
         use_current_range = (not self.range_calibration_offline_iters) or (self.iter_in_epoch < self.range_calibration_offline_iters)
         for idx, inp in enumerate(input):
         qrange_cur = []
         use_current_range = (not self.range_calibration_offline_iters) or (self.iter_in_epoch < self.range_calibration_offline_iters)
         for idx, inp in enumerate(input):
@@ -365,7 +393,7 @@ class QuantTestModule(QuantBaseModule):
         return qrange_cur
 
 
         return qrange_cur
 
 
-    def quantize_outputs(self, module, input, output, qparams_prev, qparams):
+    def quantize_output_tensors(self, module, input, output, qparams_prev, qparams):
         qrange_cur = []
         use_current_range = (not self.range_calibration_offline_iters) or (self.iter_in_epoch < self.range_calibration_offline_iters)
         for idx, opt in enumerate(output):
         qrange_cur = []
         use_current_range = (not self.range_calibration_offline_iters) or (self.iter_in_epoch < self.range_calibration_offline_iters)
         for idx, opt in enumerate(output):
@@ -444,7 +472,7 @@ class QuantTestModule(QuantBaseModule):
 
 
     def compute_mse_for_quant(self, tensor_wt_fl, mn=0, mx=0, bitwidth_weights=8):
 
 
     def compute_mse_for_quant(self, tensor_wt_fl, mn=0, mx=0, bitwidth_weights=8):
-        tensor_scale, clamp_limits = compute_tensor_scale(tensor_wt_fl, mn, mx, bitwidth_weights, self.power2_weights)
+        tensor_scale, clamp_limits = compute_tensor_scale(tensor_wt_fl, mn, mx, bitwidth_weights, self.power2_weight_range)
 
         # print("mn : mx  {} {}".format(mn, mx))
 
 
         # print("mn : mx  {} {}".format(mn, mx))
 
index e7601cdc40a60c95981beb51079c38197a37124f..f0b4529f1235068fad816207fa914b1a89195288 100644 (file)
@@ -19,18 +19,13 @@ warnings.filterwarnings('ignore', category=torch.jit.TracerWarning)
 ###########################################################
 class QuantTrainModule(QuantBaseModule):
     def __init__(self, module, dummy_input, bitwidth_weights=8, bitwidth_activations=8, per_channel_q=False,
 ###########################################################
 class QuantTrainModule(QuantBaseModule):
     def __init__(self, module, dummy_input, bitwidth_weights=8, bitwidth_activations=8, per_channel_q=False,
-                 histogram_range=True, bias_calibration=False, constrain_weights=None):
+                 histogram_range=True, bias_calibration=False, constrain_weights=None,
+                 power2_weight_range=None, power2_activation_range=None, constrain_bias=None):
         constrain_weights = (not per_channel_q) if constrain_weights is None else constrain_weights
         super().__init__(module, dummy_input=dummy_input, bitwidth_weights=bitwidth_weights, bitwidth_activations=bitwidth_activations,
                          per_channel_q=per_channel_q, histogram_range=histogram_range, bias_calibration=bias_calibration,
         constrain_weights = (not per_channel_q) if constrain_weights is None else constrain_weights
         super().__init__(module, dummy_input=dummy_input, bitwidth_weights=bitwidth_weights, bitwidth_activations=bitwidth_activations,
                          per_channel_q=per_channel_q, histogram_range=histogram_range, bias_calibration=bias_calibration,
-                         constrain_weights=constrain_weights, model_surgery_quantize=True)
-        # range shrink - 0.0 indicates no shrink
-        percentile_range_shrink = (layers.PAct2.PACT2_RANGE_SHRINK if histogram_range else 0.0)
-        # set attributes to all modules - can control the behaviour from here
-        utils.apply_setattr(self, bitwidth_weights=self.bitwidth_weights, bitwidth_activations=self.bitwidth_activations,
-                            per_channel_q=self.per_channel_q, bias_calibration=self.bias_calibration,
-                            percentile_range_shrink=percentile_range_shrink, constrain_weights=self.constrain_weights,
-                            update_range=True, quantize_enable=True, quantize_weights=True, quantize_bias=True, quantize_activations=True)
+                         constrain_weights=constrain_weights, model_surgery_quantize=True,
+                         power2_weight_range=power2_weight_range, power2_activation_range=power2_activation_range, constrain_bias=constrain_bias)
 
     def forward(self, inputs):
         # counters such as num_batches_tracked are used. update them.
 
     def forward(self, inputs):
         # counters such as num_batches_tracked are used. update them.
@@ -55,30 +50,38 @@ class QuantTrainModule(QuantBaseModule):
                     padding_mode = m.padding_mode
                     new_m = QuantTrainConvTranspose2d(in_channels=m.in_channels, out_channels=m.out_channels, kernel_size=m.kernel_size, stride=m.stride,
                                             padding=m.padding, output_padding=m.output_padding, groups=m.groups, bias=bias, dilation=m.dilation, padding_mode=padding_mode)
                     padding_mode = m.padding_mode
                     new_m = QuantTrainConvTranspose2d(in_channels=m.in_channels, out_channels=m.out_channels, kernel_size=m.kernel_size, stride=m.stride,
                                             padding=m.padding, output_padding=m.output_padding, groups=m.groups, bias=bias, dilation=m.dilation, padding_mode=padding_mode)
+                elif utils.is_linear(m):
+                    bias = (m.bias is not None)
+                    new_m = QuantTrainLinear(in_features=m.in_features, out_features=m.out_features, bias=bias)
                 elif utils.is_bn(m):
                     new_m = QuantTrainBatchNorm2d(num_features=m.num_features, eps=m.eps, momentum=m.momentum, affine=m.affine,
                                             track_running_stats=m.track_running_stats)
                 elif isinstance(m, layers.PAct2):
                     new_m = QuantTrainPAct2(inplace=m.inplace, signed=m.signed, clip_range=m.clip_range,
                                              bitwidth_weights=self.bitwidth_weights, bitwidth_activations=self.bitwidth_activations,
                 elif utils.is_bn(m):
                     new_m = QuantTrainBatchNorm2d(num_features=m.num_features, eps=m.eps, momentum=m.momentum, affine=m.affine,
                                             track_running_stats=m.track_running_stats)
                 elif isinstance(m, layers.PAct2):
                     new_m = QuantTrainPAct2(inplace=m.inplace, signed=m.signed, clip_range=m.clip_range,
                                              bitwidth_weights=self.bitwidth_weights, bitwidth_activations=self.bitwidth_activations,
-                                             per_channel_q=self.per_channel_q)
-                elif isinstance(m, layers.NoAct):
+                                            per_channel_q=self.per_channel_q, percentile_range_shrink=self.percentile_range_shrink,
+                                            power2_weight_range=self.power2_weight_range, power2_activation_range=self.power2_activation_range)
+                elif isinstance(m, layers.QAct):
                     new_m = QuantTrainPAct2(signed=None, bitwidth_weights=self.bitwidth_weights, bitwidth_activations=self.bitwidth_activations,
                     new_m = QuantTrainPAct2(signed=None, bitwidth_weights=self.bitwidth_weights, bitwidth_activations=self.bitwidth_activations,
-                                             per_channel_q=self.per_channel_q)
+                                             per_channel_q=self.per_channel_q, percentile_range_shrink=self.percentile_range_shrink,
+                                            power2_weight_range=self.power2_weight_range, power2_activation_range=self.power2_activation_range)
                 elif isinstance(m, (torch.nn.ReLU, torch.nn.ReLU6)):
                     new_m = QuantTrainPAct2(signed=False, bitwidth_weights=self.bitwidth_weights, bitwidth_activations=self.bitwidth_activations,
                 elif isinstance(m, (torch.nn.ReLU, torch.nn.ReLU6)):
                     new_m = QuantTrainPAct2(signed=False, bitwidth_weights=self.bitwidth_weights, bitwidth_activations=self.bitwidth_activations,
-                                             per_channel_q=self.per_channel_q)
+                                             per_channel_q=self.per_channel_q, percentile_range_shrink=self.percentile_range_shrink,
+                                            power2_weight_range=self.power2_weight_range, power2_activation_range=self.power2_activation_range)
                 else:
                     new_m = None
                 #
                 if new_m is not None:
                 else:
                     new_m = None
                 #
                 if new_m is not None:
+                    copy_attr_list = ('weight', 'bias', 'eps', 'clips_act', 'clips_w')
                     for attr in dir(m):
                         value = getattr(m,attr)
                         if isinstance(value,torch.Tensor) and value is not None:
                             getattr(new_m,attr).data.copy_(value.data)
                     for attr in dir(m):
                         value = getattr(m,attr)
                         if isinstance(value,torch.Tensor) and value is not None:
                             getattr(new_m,attr).data.copy_(value.data)
-                        elif isinstance(value,torch.nn.Module) and value is not None:
+                        elif isinstance(value,torch.nn.Module):
                             setattr(new_m, attr, getattr(m,attr))
                             setattr(new_m, attr, getattr(m,attr))
-                        elif attr in ('weight','bias','eps','clips_act','clips_w'): # add more attribute name here
+                        elif attr in copy_attr_list:
+                            # copy attributes that need to be copied
                             setattr(new_m, attr, getattr(m, attr))
                         #
                     #
                             setattr(new_m, attr, getattr(m, attr))
                         #
                     #
@@ -139,8 +142,10 @@ class QuantTrainConv2d(torch.nn.Conv2d):
         qparams = get_qparams()
         qparams.inputs.append(x)
         qparams.modules.append(self)
         qparams = get_qparams()
         qparams.inputs.append(x)
         qparams.modules.append(self)
-        y.qparams = qparams
+        if hasattr(x, 'clips_act'):
+            qparams.clips_input = x.clips_act
         #
         #
+        y.qparams = qparams
         return y
     #
 
         return y
     #
 
@@ -170,12 +175,47 @@ class QuantTrainConvTranspose2d(torch.nn.ConvTranspose2d):
         qparams = get_qparams()
         qparams.inputs.append(x)
         qparams.modules.append(self)
         qparams = get_qparams()
         qparams.inputs.append(x)
         qparams.modules.append(self)
-        y.qparams = qparams
+        if hasattr(x, 'clips_act'):
+            qparams.clips_input = x.clips_act
         #
         #
+        y.qparams = qparams
         return y
     #
 
 
         return y
     #
 
 
+###########################################################
+class QuantTrainLinear(torch.nn.Linear):
+    def __init__(self, *args, **kwargs):
+        super().__init__(*args, **kwargs)
+        self.quantize_enable = True
+        self.bitwidth_weights = None
+        self.bitwidth_activations = None
+        self.per_channel_q = False
+
+    def forward(self, x):
+        is_merged = is_merged_layer(x)
+        if is_merged:
+           warnings.warn('please see if a PAct can be inserted before this module to collect ranges')
+        #
+
+        y = super().forward(x)
+
+        if not self.quantize_enable:
+            # if quantization is disabled - return
+            return y
+        #
+
+        qparams = get_qparams()
+        qparams.inputs.append(x)
+        qparams.modules.append(self)
+        if hasattr(x, 'clips_act'):
+            qparams.clips_input = x.clips_act
+        #
+        y.qparams = qparams
+        return y
+    #
+       
+       
 ###########################################################
 class QuantTrainBatchNorm2d(torch.nn.BatchNorm2d):
     def __init__(self, *args, **kwargs):
 ###########################################################
 class QuantTrainBatchNorm2d(torch.nn.BatchNorm2d):
     def __init__(self, *args, **kwargs):
@@ -195,6 +235,9 @@ class QuantTrainBatchNorm2d(torch.nn.BatchNorm2d):
             qparams = get_qparams()
             qparams.inputs = [x.qparams.inputs[0], x]
             qparams.modules = [x.qparams.modules[0], self]
             qparams = get_qparams()
             qparams.inputs = [x.qparams.inputs[0], x]
             qparams.modules = [x.qparams.modules[0], self]
+            if hasattr(x.qparams, 'clips_input'):
+                qparams.clips_input = x.qparams.clips_input
+            #
             y.qparams = qparams
         #
 
             y.qparams = qparams
         #
 
@@ -205,26 +248,31 @@ class QuantTrainBatchNorm2d(torch.nn.BatchNorm2d):
 ###########################################################
 # fake quantized PAct2 for training
 class QuantTrainPAct2(layers.PAct2):
 ###########################################################
 # fake quantized PAct2 for training
 class QuantTrainPAct2(layers.PAct2):
-    def __init__(self, inplace=False, signed=False, clip_range=None, bitwidth_weights=None, bitwidth_activations=None, per_channel_q=False):
-        super().__init__(inplace=inplace, signed=signed, clip_range=clip_range)
+    def __init__(self, inplace=False, signed=False, clip_range=None, bitwidth_weights=None, bitwidth_activations=None,
+                 per_channel_q=False, percentile_range_shrink=layers.PAct2.PACT2_RANGE_SHRINK, power2_weight_range=True, power2_activation_range=True):
+        super().__init__(inplace=inplace, signed=signed, clip_range=clip_range, percentile_range_shrink=percentile_range_shrink,
+                         power2_activation_range=power2_activation_range)
 
         self.bitwidth_weights = bitwidth_weights
         self.bitwidth_activations = bitwidth_activations
         self.per_channel_q = per_channel_q
 
         self.bitwidth_weights = bitwidth_weights
         self.bitwidth_activations = bitwidth_activations
         self.per_channel_q = per_channel_q
+        self.power2_weight_range = power2_weight_range
+
         # weight shrinking is done by clamp weights - set this factor to zero.
         # this must me zero - as in pact we do not have the actual weight param, but just a temporary tensor
         # so any clipping we do here is not stored int he weight params
         self.range_shrink_weights = 0.0
         self.round_dither = 0.0
         # weight shrinking is done by clamp weights - set this factor to zero.
         # this must me zero - as in pact we do not have the actual weight param, but just a temporary tensor
         # so any clipping we do here is not stored int he weight params
         self.range_shrink_weights = 0.0
         self.round_dither = 0.0
-        self.update_range = True
+        self.update_activation_range = True
         self.quantize_enable = True
         self.quantize_weights = True
         self.quantize_bias = True
         self.quantize_activations = True
         self.quantize_enable = True
         self.quantize_weights = True
         self.quantize_bias = True
         self.quantize_activations = True
+        self.constrain_bias = None
         self.constrain_weights = True
         self.bias_calibration = False
         self.constrain_weights = True
         self.bias_calibration = False
-        # save quantized weight/bias once in a while into the params - not needed
-        self.params_save_frequency = None #(10 if self.bias_calibration else None)
+        # do joint quantization only after the activation range has stabilized reasonably.
+        self.constrain_bias_start_iter = 75
 
         # set to STRAIGHT_THROUGH_ESTIMATION or ALPHA_BLENDING_ESTIMATION
         # For a comparison of STE and ABE, read:
 
         # set to STRAIGHT_THROUGH_ESTIMATION or ALPHA_BLENDING_ESTIMATION
         # For a comparison of STE and ABE, read:
@@ -242,9 +290,8 @@ class QuantTrainPAct2(layers.PAct2):
     def forward(self, x):
         assert (self.bitwidth_weights is not None) and (self.bitwidth_activations is not None), \
                         'bitwidth_weights and bitwidth_activations must not be None'
     def forward(self, x):
         assert (self.bitwidth_weights is not None) and (self.bitwidth_activations is not None), \
                         'bitwidth_weights and bitwidth_activations must not be None'
-
         # the pact range update happens here - but range clipping depends on quantize_enable
         # the pact range update happens here - but range clipping depends on quantize_enable
-        y = super().forward(x, update_range=self.update_range, enable=self.quantize_enable)
+        y = super().forward(x, update_activation_range=self.update_activation_range, enable=self.quantize_enable)
 
         if not self.quantize_enable:
             return y
 
         if not self.quantize_enable:
             return y
@@ -259,11 +306,11 @@ class QuantTrainPAct2(layers.PAct2):
 
             conv, bn = None, None
             # merge weight and bias (if possible) across layers
 
             conv, bn = None, None
             # merge weight and bias (if possible) across layers
-            if len(qparams.modules) == 2 and utils.is_conv_deconv(qparams.modules[-2]) and isinstance(
+            if len(qparams.modules) == 2 and utils.is_conv_deconv_linear(qparams.modules[-2]) and isinstance(
                     qparams.modules[-1], torch.nn.BatchNorm2d):
                 conv = qparams.modules[-2]
                 bn = qparams.modules[-1]
                     qparams.modules[-1], torch.nn.BatchNorm2d):
                 conv = qparams.modules[-2]
                 bn = qparams.modules[-1]
-            elif len(qparams.modules) == 1 and utils.is_conv_deconv(qparams.modules[-1]):
+            elif len(qparams.modules) == 1 and utils.is_conv_deconv_linear(qparams.modules[-1]):
                 conv = qparams.modules[-1]
             elif len(qparams.modules) == 1 and isinstance(qparams.modules[-1], torch.nn.BatchNorm2d):
                 assert False, f'quantization: previous layer is a BN without Conv {qparams.modules} - prease inspect the model carefully'
                 conv = qparams.modules[-1]
             elif len(qparams.modules) == 1 and isinstance(qparams.modules[-1], torch.nn.BatchNorm2d):
                 assert False, f'quantization: previous layer is a BN without Conv {qparams.modules} - prease inspect the model carefully'
@@ -272,7 +319,7 @@ class QuantTrainPAct2(layers.PAct2):
             else:
                 assert False, f'QuantTrainPAct2: both conv & bn layes cannot be None in a merged scenario - prease inspect the model carefully'
             #
             else:
                 assert False, f'QuantTrainPAct2: both conv & bn layes cannot be None in a merged scenario - prease inspect the model carefully'
             #
-            conv, weight, bias = self.merge_quantize_weights(conv, bn)
+            conv, weight, bias = self.merge_quantize_weights(qparams, conv, bn)
         else:
             conv, weight, bias = None, None, None
         #
         else:
             conv, weight, bias = None, None, None
         #
@@ -282,6 +329,8 @@ class QuantTrainPAct2(layers.PAct2):
         elif is_merged and utils.is_deconv(conv):
             xq = torch.nn.functional.conv_transpose2d(xorg, weight, bias, stride=conv.stride, padding=conv.padding, output_padding=conv.output_padding,
                                                       dilation=conv.dilation, groups=conv.groups)
         elif is_merged and utils.is_deconv(conv):
             xq = torch.nn.functional.conv_transpose2d(xorg, weight, bias, stride=conv.stride, padding=conv.padding, output_padding=conv.output_padding,
                                                       dilation=conv.dilation, groups=conv.groups)
+        elif is_merged and utils.is_linear(conv):
+            xq = torch.nn.functional.linear(xorg, weight, bias)
         else:
             xq = x
         #
         else:
             xq = x
         #
@@ -293,9 +342,9 @@ class QuantTrainPAct2(layers.PAct2):
             # currently the gradient is set to 1 in round_up_g. shouldn't the gradient be (y-x)?
             # see eqn 6 in https://arxiv.org/pdf/1903.08066v2.pdf
             # yq = layers.clamp_g(layers.round_up_g(xq * scale), width_min, width_max-1, self.training) * scale_inv
             # currently the gradient is set to 1 in round_up_g. shouldn't the gradient be (y-x)?
             # see eqn 6 in https://arxiv.org/pdf/1903.08066v2.pdf
             # yq = layers.clamp_g(layers.round_up_g(xq * scale), width_min, width_max-1, self.training) * scale_inv
-            yq = layers.quantize_dequantize_g(xq, scale, width_min, width_max-1, self.power2, 'round_up')
+            yq = layers.quantize_dequantize_g(xq, scale, width_min, width_max-1, self.power2_activation_range, 'round_up')
         else:
         else:
-            yq = super().forward(xq, update_range=False, enable=True)
+            yq = super().forward(xq, update_activation_range=False, enable=True)
         #
 
         if (self.quantized_estimation_type == QuantEstimationType.STRAIGHT_THROUGH_ESTIMATION):
         #
 
         if (self.quantized_estimation_type == QuantEstimationType.STRAIGHT_THROUGH_ESTIMATION):
@@ -316,6 +365,8 @@ class QuantTrainPAct2(layers.PAct2):
             assert False, f'unsupported value for quantized_estimation_type: {self.quantized_estimation_type}'
         #
 
             assert False, f'unsupported value for quantized_estimation_type: {self.quantized_estimation_type}'
         #
 
+        # pass on the clips to be used in the next quantization
+        y.clips_act = self.get_clips_act()
         return y
     #
 
         return y
     #
 
@@ -324,11 +375,12 @@ class QuantTrainPAct2(layers.PAct2):
         return quant_utils.constrain_weight(merged_weight)
 
 
         return quant_utils.constrain_weight(merged_weight)
 
 
-    def merge_quantize_weights(self, conv, bn):
-        # store the quantized weights and biases in-frequently - otherwise learning will be poor
-        # since this may not be done at the of the epoch, there can be a slight mismatch in validation accuracy
-        first_training_iter = self.training and (self.num_batches_tracked == 0)
-        is_store_weight_bias_iter = (self.params_save_frequency is not None) and (torch.remainder(self.num_batches_tracked, self.params_save_frequency) == 0)
+    def merge_quantize_weights(self, qparams, conv, bn):
+        num_batches_tracked = int(self.num_batches_tracked)
+        is_constrain_weights_iter = self.training and (num_batches_tracked == 0)
+        is_store_weights_iter = self.training and (num_batches_tracked == 0)
+        is_constrain_bias_iter = self.training and (num_batches_tracked>=self.constrain_bias_start_iter)
+        is_store_bias_iter = self.training and (num_batches_tracked==self.constrain_bias_start_iter)
 
         # merge weight and bias (if possible) across layers
         if conv is not None and bn is not None:
 
         # merge weight and bias (if possible) across layers
         if conv is not None and bn is not None:
@@ -370,7 +422,7 @@ class QuantTrainPAct2(layers.PAct2):
         # quantize weight and bias
         if (conv is not None):
             if (self.quantize_enable and self.quantize_weights):
         # quantize weight and bias
         if (conv is not None):
             if (self.quantize_enable and self.quantize_weights):
-                if self.constrain_weights and first_training_iter:
+                if self.constrain_weights and is_constrain_weights_iter:
                     with torch.no_grad():
                         # clamp merged weights, invert the bn and copy to conv weight
                         constrained_weight = self.apply_constrain_weights(merged_weight.data)
                     with torch.no_grad():
                         # clamp merged weights, invert the bn and copy to conv weight
                         constrained_weight = self.apply_constrain_weights(merged_weight.data)
@@ -396,58 +448,73 @@ class QuantTrainPAct2(layers.PAct2):
                 #
                 width_min, width_max = self.get_widths_w()
                 # merged_weight = layers.clamp_g(layers.round_sym_g(merged_weight * scale2), width_min, width_max-1, self.training) * scale_inv2
                 #
                 width_min, width_max = self.get_widths_w()
                 # merged_weight = layers.clamp_g(layers.round_sym_g(merged_weight * scale2), width_min, width_max-1, self.training) * scale_inv2
-                merged_weight = layers.quantize_dequantize_g(merged_weight, scale2, width_min, width_max-1, self.power2, 'round_sym')
+                merged_weight = layers.quantize_dequantize_g(merged_weight, scale2, width_min, width_max-1, self.power2_weight_range, 'round_sym')
             #
 
             if (self.quantize_enable and self.quantize_bias):
                 bias_width_min, bias_width_max = self.get_widths_bias()
                 bias_clip_min, bias_clip_max, bias_scale2, bias_scale_inv2 = self.get_clips_scale_bias(merged_bias)
             #
 
             if (self.quantize_enable and self.quantize_bias):
                 bias_width_min, bias_width_max = self.get_widths_bias()
                 bias_clip_min, bias_clip_max, bias_scale2, bias_scale_inv2 = self.get_clips_scale_bias(merged_bias)
-                # merged_bias = layers.clamp_g(layers.round_sym_g(merged_bias * bias_scale2), bias_width_min, bias_width_max-1, self.training) * bias_scale_inv2
-                merged_bias = layers.quantize_dequantize_g(merged_bias, bias_scale2, bias_width_min, bias_width_max - 1, self.power2, 'round_sym')
+                power2_bias_range = (self.power2_weight_range and self.power2_activation_range)
+                merged_bias = layers.quantize_dequantize_g(merged_bias, bias_scale2, bias_width_min, bias_width_max - 1, power2_bias_range, 'round_sym')
+            #
+
+            # in some cases, bias quantization can have additional restrictions if for example,
+            # bias that is being added to accumulator is limited to 16bit.
+            # scale factor to be used for bias is the product of scale factors of weight and input
+            if self.quantize_enable and self.constrain_bias and is_constrain_bias_iter:
+                clips_scale_joint = self.get_clips_scale_joint(qparams, merged_weight, merged_bias)
+                if clips_scale_joint is not None:
+                    bias_width_min, bias_width_max = self.get_widths_joint()
+                    bias_clip_min, bias_clip_max, bias_scale2, bias_scale_inv2 = clips_scale_joint
+                    power2_bias_range = (self.power2_weight_range and self.power2_activation_range)
+                    merged_bias = layers.quantize_dequantize_g(merged_bias, bias_scale2, bias_width_min, bias_width_max - 1, power2_bias_range, 'round_sym')
+                #
             #
 
             # invert the bn operation and store weights/bias
             #
 
             # invert the bn operation and store weights/bias
-            if first_training_iter or (self.training and is_store_weight_bias_iter):
-                with torch.no_grad():
-                    if self.quantize_enable and self.quantize_weights:
-                        conv.weight.data.copy_(merged_weight.data * merged_scale_inv)
-                    #
-                    if self.quantize_enable and self.quantize_bias:
-                        if conv.bias is not None:
-                            if bn is not None:
-                                conv_bias = (merged_bias - bn_bias) * merged_scale_inv.view(-1) + bn.running_mean
-                                conv.bias.data.copy_(conv_bias.data)
-                            else:
-                                conv.bias.data.copy_(merged_bias.data)
-                            #
-                        elif bn is not None and bn.bias is not None:
-                            bn_bias = merged_bias + bn.running_mean * merged_scale.view(-1)
-                            bn.bias.data.copy_(bn_bias.data)
-                        #
+            if self.quantize_enable and self.quantize_weights and is_store_weights_iter:
+                conv.weight.data.copy_(merged_weight.data * merged_scale_inv)
+            #
+            if self.quantize_enable and self.quantize_bias and is_store_bias_iter:
+                if conv.bias is not None:
+                    if bn is not None:
+                        conv_bias = (merged_bias - bn_bias) * merged_scale_inv.view(-1) + bn.running_mean
+                        conv.bias.data.copy_(conv_bias.data)
+                    else:
+                        conv.bias.data.copy_(merged_bias.data)
                     #
                     #
+                elif bn is not None and bn.bias is not None:
+                    bn_bias = merged_bias + bn.running_mean * merged_scale.view(-1)
+                    bn.bias.data.copy_(bn_bias.data)
                 #
             #
         #
         return conv, merged_weight, merged_bias
 
 
                 #
             #
         #
         return conv, merged_weight, merged_bias
 
 
+    ###########################################################
+    def get_widths_w(self):
+        # weights
+        bw = (self.bitwidth_weights - 1)
+        width_max = np.power(2.0, bw)
+        width_min = -width_max
+        # return
+        return (width_min, width_max)
+
+
     def get_clips_w(self, tensor):
         # find the clip values
         w_min, w_max = utils.extrema_fast(tensor.data, percentile_range_shrink=self.range_shrink_weights)
         clip_max = torch.max(torch.abs(w_min), torch.abs(w_max))
         clip_max = torch.clamp(clip_max, min=self.eps)
     def get_clips_w(self, tensor):
         # find the clip values
         w_min, w_max = utils.extrema_fast(tensor.data, percentile_range_shrink=self.range_shrink_weights)
         clip_max = torch.max(torch.abs(w_min), torch.abs(w_max))
         clip_max = torch.clamp(clip_max, min=self.eps)
-        # in range learning mode + training - this power2 is taken care in the quantize function
-        use_power2 = (self.power2 and (not (self.PACT2_RANGE_LEARN and self.training)))
+        # in range learning mode + training - this power2_weight_range is taken care in the quantize function
+        use_power2 = (self.power2_weight_range and (not (self.PACT2_RANGE_LEARN and self.training)))
         clip_max2 = layers.functional.ceil2_g(clip_max) if use_power2 else clip_max
         clip_min2 = -clip_max2
         return (clip_min2, clip_max2)
 
         clip_max2 = layers.functional.ceil2_g(clip_max) if use_power2 else clip_max
         clip_min2 = -clip_max2
         return (clip_min2, clip_max2)
 
-    # bias uses the same kind of clips
-    get_clips_bias = get_clips_w
-
 
     def get_clips_scale_w(self, weight):
 
     def get_clips_scale_w(self, weight):
-        # convert to scale
         clip_min, clip_max = self.get_clips_w(weight)
         width_min, width_max = self.get_widths_w()
         scale2 = (width_max / clip_max)
         clip_min, clip_max = self.get_clips_w(weight)
         width_min, width_max = self.get_widths_w()
         scale2 = (width_max / clip_max)
@@ -455,57 +522,91 @@ class QuantTrainPAct2(layers.PAct2):
         scale_inv2 = scale2.pow(-1.0)
         return (clip_min, clip_max, scale2, scale_inv2)
 
         scale_inv2 = scale2.pow(-1.0)
         return (clip_min, clip_max, scale2, scale_inv2)
 
+    ###########################################################
+    def get_widths_act(self):
+        if self.signed is None:
+            clip_min, clip_max = self.get_clips_act()
+            signed = (clip_min < 0.0)
+        else:
+            signed = self.signed
+        #
+        bw = (self.bitwidth_activations - 1) if signed else self.bitwidth_activations
+        width_max = np.power(2.0, bw)
+        width_min = -width_max if signed else 0.0
+        return width_min, width_max
 
 
-    # in reality, bias quantization will also depend on the activation scale
-    # this is not perfect - just a quick and dirty quantization for bias
-    def get_clips_scale_bias(self, bias):
-        # convert to scale
-        clip_min, clip_max = self.get_clips_bias(bias)
-        width_min, width_max = self.get_widths_bias()
-        scale2 = (width_max / clip_max)
+
+    def get_clips_scale_act(self):
+        clip_min, clip_max = self.get_clips_act()
+        width_min, width_max = self.get_widths_act()
+        scale2 = width_max / clip_max
         scale2 = torch.clamp(scale2, min=self.eps)
         scale_inv2 = scale2.pow(-1.0)
         return (clip_min, clip_max, scale2, scale_inv2)
 
 
         scale2 = torch.clamp(scale2, min=self.eps)
         scale_inv2 = scale2.pow(-1.0)
         return (clip_min, clip_max, scale2, scale_inv2)
 
 
-    def get_widths_w(self):
-        # weights
-        bw = (self.bitwidth_weights - 1)
-        width_max = np.power(2.0, bw)
-        width_min = -width_max
-        # return
-        return (width_min, width_max)
+    ###########################################################
+    # bias uses the same kind of widths
+    get_widths_bias = get_widths_w
 
 
 
 
-    def get_widths_bias(self):
-        # bias
-        bitwidth_bias = (2*self.bitwidth_activations)
-        bias_width_max = np.power(2.0, bitwidth_bias-1)
-        bias_width_min = -bias_width_max
-        # return
-        return (bias_width_min, bias_width_max)
+    # bias uses the same kind of clips
+    get_clips_bias = get_clips_w
 
 
 
 
-    # activation utility functions
-    def get_clips_scale_act(self):
-        # convert to scale
-        clip_min, clip_max = self.get_clips_act()
-        width_min, width_max = self.get_widths_act()
-        scale2 = width_max / clip_max
+    def get_clips_scale_bias(self, bias):
+        clip_min, clip_max = self.get_clips_bias(bias)
+        width_min, width_max = self.get_widths_bias()
+        scale2 = (width_max / clip_max)
         scale2 = torch.clamp(scale2, min=self.eps)
         scale_inv2 = scale2.pow(-1.0)
         return (clip_min, clip_max, scale2, scale_inv2)
 
 
         scale2 = torch.clamp(scale2, min=self.eps)
         scale_inv2 = scale2.pow(-1.0)
         return (clip_min, clip_max, scale2, scale_inv2)
 
 
-    def get_widths_act(self):
-        if self.signed is None:
-            clip_min, clip_max = self.get_clips_act()
-            signed = (clip_min < 0.0)
+    ###########################################################
+    def get_widths_joint(self):
+        bw = (2*self.bitwidth_weights - 1)
+        width_max = np.power(2.0, bw)
+        width_min = -width_max
+        return (width_min, width_max)
+
+
+    def get_clips_input(self, qparams):
+        if hasattr(qparams, 'clips_input'):
+            return qparams.clips_input
         else:
         else:
-            signed = self.signed
+            return None
         #
         #
+
+    def get_widths_input(self, clip_min, clip_max):
+        signed = (clip_min < 0.0)
         bw = (self.bitwidth_activations - 1) if signed else self.bitwidth_activations
         width_max = np.power(2.0, bw)
         width_min = -width_max if signed else 0.0
         return width_min, width_max
 
         bw = (self.bitwidth_activations - 1) if signed else self.bitwidth_activations
         width_max = np.power(2.0, bw)
         width_min = -width_max if signed else 0.0
         return width_min, width_max
 
+
+    def get_clips_scale_input(self, qparams):
+        clips_input = self.get_clips_input(qparams)
+        if clips_input is not None:
+            clip_min, clip_max = clips_input
+            width_min, width_max = self.get_widths_input(clip_min, clip_max)
+            scale2 = (width_max / clip_max)
+            scale2 = torch.clamp(scale2, min=self.eps)
+            scale_inv2 = scale2.pow(-1.0)
+            return (clip_min, clip_max, scale2, scale_inv2)
+        else:
+            return None
+        #
+
+
+    def get_clips_scale_joint(self, qparams, weights, bias):
+        clips_scale_input = self.get_clips_scale_input(qparams)
+        if clips_scale_input is not None:
+            clip_min_input, clip_max_input, scale2_input, scale_inv2_input = clips_scale_input
+            clip_min_w, clip_max_w, scale2_w, scale_inv2_w = self.get_clips_scale_w(weights)
+            clip_min_bias, clip_max_bias, scale2_bias, scale_inv2_bias = self.get_clips_scale_bias(bias)
+            return (clip_min_bias, clip_max_bias, scale2_w*scale2_input, scale_inv2_w*scale_inv2_input)
+        else:
+            return None
+        #
\ No newline at end of file
index eab5d3f043bb1c2703b406323ed880be1a5c3102..3bc9743abb5204804cf0078bc3977b1fbfc52904 100644 (file)
@@ -12,6 +12,7 @@ from . import utils_data
 ######################################################
 def load_weights(model, pretrained, change_names_dict=None, keep_original_names=False, width_mult=1.0,
                        ignore_size=True, verbose=False, num_batches_tracked = None, download_root=None, **kwargs):
 ######################################################
 def load_weights(model, pretrained, change_names_dict=None, keep_original_names=False, width_mult=1.0,
                        ignore_size=True, verbose=False, num_batches_tracked = None, download_root=None, **kwargs):
+    download_root = './' if (download_root is None) else download_root
     if pretrained is None or pretrained is False:
         print_utils.print_yellow(f'=> weights could not be loaded. pretrained data given is {pretrained}')
         return model
     if pretrained is None or pretrained is False:
         print_utils.print_yellow(f'=> weights could not be loaded. pretrained data given is {pretrained}')
         return model
index c12a26fdf73109f70d14f55452aa8f797a43ca16..9fcd6e9198eae2c3079e6de22c7f38abc3f76d20 100644 (file)
@@ -10,7 +10,7 @@ def is_normalization(module):
 
 def is_activation(module):
     is_act = isinstance(module, (torch.nn.ReLU, torch.nn.ReLU6, torch.nn.Hardtanh,
 
 def is_activation(module):
     is_act = isinstance(module, (torch.nn.ReLU, torch.nn.ReLU6, torch.nn.Hardtanh,
-                                 layers.NoAct, layers.PAct2))
+                                 layers.PAct2, layers.QAct, layers.NoQAct))
     return is_act
 
 def is_pact2(module):
     return is_act
 
 def is_pact2(module):
@@ -26,6 +26,9 @@ def is_deconv(module):
 def is_conv_deconv(module):
     return isinstance(module, (torch.nn.Conv2d,torch.nn.ConvTranspose2d))
 
 def is_conv_deconv(module):
     return isinstance(module, (torch.nn.Conv2d,torch.nn.ConvTranspose2d))
 
+def is_conv_deconv_linear(module):
+    return isinstance(module, (torch.nn.Conv2d,torch.nn.ConvTranspose2d,torch.nn.Linear))
+
 def is_linear(module):
     return isinstance(module, torch.nn.Linear)
 
 def is_linear(module):
     return isinstance(module, torch.nn.Linear)
 
@@ -106,6 +109,10 @@ def add_module_names(model):
 
 
 def squeeze_list(inputs):
 
 
 def squeeze_list(inputs):
+    return inputs[0] if (is_list(inputs) and len(inputs)==1) else inputs
+
+
+def squeeze_list2(inputs):
     return inputs[0] if (is_list(inputs) and (len(inputs)==1) and is_list(inputs[0])) else inputs
 
 
     return inputs[0] if (is_list(inputs) and (len(inputs)==1) and is_list(inputs[0])) else inputs
 
 
index 3f80ed99cf69ffbbdd738e8ae6b84155e2e0d0f8..461dbc4bb0147bc9c5b506770bb411b351b9ee32 100755 (executable)
@@ -1,34 +1,5 @@
-# Quantization
-
-## =====================================================================================
-## Quantization Aware Training
-## =====================================================================================
-#
-#### Image Classification - Quantization Aware Training - MobileNetV2
-#python ./scripts/train_classification_main.py --dataset_name image_folder_classification --model_name mobilenetv2_tv_x1 --data_path ./data/datasets/image_folder_classification \
-#--pretrained https://download.pytorch.org/models/mobilenet_v2-b0353104.pth \
-#--batch_size 64 --quantize True --epochs 25 --epoch_size 0.1 --lr 1e-5 --evaluate_start False
+## Quantization
 #
 #
-#
-#### Image Classification - Quantization Aware Training - MobileNetV2(Shicai) - a TOUGH MobileNetV2 pretrained model
-#python ./scripts/train_classification_main.py --dataset_name image_folder_classification --model_name mobilenetv2_shicai_x1 --data_path ./data/datasets/image_folder_classification \
-#--pretrained ./data/modelzoo/pytorch/image_classification/imagenet1k/shicai/mobilenetv2_shicai_rgb.pth \
-#--batch_size 64 --quantize True --epochs 25 --epoch_size 0.1 --lr 1e-5 --evaluate_start False
-#
-#
-#### Semantic Segmentation - Quantization Aware Training for MobileNetV2+DeeplabV3Lite
-#python ./scripts/train_segmentation_main.py --dataset_name cityscapes_segmentation --model_name deeplabv3lite_mobilenetv2_tv --data_path ./data/datasets/cityscapes/data --img_resize 384 768 --output_size 1024 2048 --gpus 0 1 \
-#--pretrained ./data/modelzoo/pytorch/semantic_seg/cityscapes/jacinto_ai/deeplabv3lite_mobilenetv2_tv_768x384_best.pth \
-#--batch_size 6 --quantize True --epochs 150 --lr 1e-5 --evaluate_start False
-#
-#
-#### Semantic Segmentation - Quantization Aware Training for MobileNetV2+UNetLite
-#python ./scripts/train_segmentation_main.py --dataset_name cityscapes_segmentation --model_name unetlite_pixel2pixel_aspp_mobilenetv2_tv --data_path ./data/datasets/cityscapes/data --img_resize 384 768 --output_size 1024 2048 --gpus 0 1 \
-#--pretrained ./data/modelzoo/pytorch/semantic_seg/cityscapes/jacinto_ai/unet_aspp_mobilenetv2_tv_768x384_best.pth \
-#--batch_size 6 --quantize True --epochs 150 --lr 1e-5 --evaluate_start False
-
-
-
 ## =====================================================================================
 ## Post Training Calibration & Quantization - this is fast, but may not always yield best quantized accuracy (not recommended)
 ## =====================================================================================
 ## =====================================================================================
 ## Post Training Calibration & Quantization - this is fast, but may not always yield best quantized accuracy (not recommended)
 ## =====================================================================================
 #python ./scripts/train_segmentation_main.py --phase calibration --dataset_name cityscapes_segmentation --model_name unetlite_pixel2pixel_aspp_mobilenetv2_tv --data_path ./data/datasets/cityscapes/data --img_resize 384 768 --output_size 1024 2048 --gpus 0 1 \
 #--pretrained ./data/modelzoo/pytorch/semantic_seg/cityscapes/jacinto_ai/unet_aspp_mobilenetv2_tv_768x384_best.pth \
 #--batch_size 6 --quantize True --epochs 1 --evaluate_start False
 #python ./scripts/train_segmentation_main.py --phase calibration --dataset_name cityscapes_segmentation --model_name unetlite_pixel2pixel_aspp_mobilenetv2_tv --data_path ./data/datasets/cityscapes/data --img_resize 384 768 --output_size 1024 2048 --gpus 0 1 \
 #--pretrained ./data/modelzoo/pytorch/semantic_seg/cityscapes/jacinto_ai/unet_aspp_mobilenetv2_tv_768x384_best.pth \
 #--batch_size 6 --quantize True --epochs 1 --evaluate_start False
-
-
-
+#
+#
+## =====================================================================================
+## Quantization Aware Training
+## =====================================================================================
+#
+#### Image Classification - Quantization Aware Training - MobileNetV2
+#python ./scripts/train_classification_main.py --dataset_name image_folder_classification --model_name mobilenetv2_tv_x1 --data_path ./data/datasets/image_folder_classification \
+#--pretrained https://download.pytorch.org/models/mobilenet_v2-b0353104.pth \
+#--batch_size 64 --quantize True --epochs 25 --epoch_size 0.1 --lr 1e-5 --evaluate_start False
+#
+#
+#### Image Classification - Quantization Aware Training - MobileNetV2(Shicai) - a TOUGH MobileNetV2 pretrained model
+#python ./scripts/train_classification_main.py --dataset_name image_folder_classification --model_name mobilenetv2_shicai_x1 --data_path ./data/datasets/image_folder_classification \
+#--pretrained ./data/modelzoo/pytorch/image_classification/imagenet1k/shicai/mobilenetv2_shicai_rgb.pth \
+#--batch_size 64 --quantize True --epochs 25 --epoch_size 0.1 --lr 1e-5 --evaluate_start False
+#
+#
+#### Semantic Segmentation - Quantization Aware Training for MobileNetV2+DeeplabV3Lite
+#python ./scripts/train_segmentation_main.py --dataset_name cityscapes_segmentation --model_name deeplabv3lite_mobilenetv2_tv --data_path ./data/datasets/cityscapes/data --img_resize 384 768 --output_size 1024 2048 --gpus 0 1 \
+#--pretrained ./data/modelzoo/pytorch/semantic_seg/cityscapes/jacinto_ai/deeplabv3lite_mobilenetv2_tv_768x384_best.pth \
+#--batch_size 6 --quantize True --epochs 150 --lr 1e-5 --evaluate_start False
+#
+#
+#### Semantic Segmentation - Quantization Aware Training for MobileNetV2+UNetLite
+#python ./scripts/train_segmentation_main.py --dataset_name cityscapes_segmentation --model_name unetlite_pixel2pixel_aspp_mobilenetv2_tv --data_path ./data/datasets/cityscapes/data --img_resize 384 768 --output_size 1024 2048 --gpus 0 1 \
+#--pretrained ./data/modelzoo/pytorch/semantic_seg/cityscapes/jacinto_ai/unet_aspp_mobilenetv2_tv_768x384_best.pth \
+#--batch_size 6 --quantize True --epochs 150 --lr 1e-5 --evaluate_start False
+#
+#
+#
 ## =====================================================================================
 ## =====================================================================================
-## Acuracy Evaluation with Post Training Quantization - cannot save quantized model - only accuracy evaluation
+## Acuracy Evaluation with Post Training Quantization - this is not supported anymore.
+## Either Calibration or QAT has to be performed first, to get correct accuracy.
+## Please use one of the sections above.
 ## =====================================================================================
 ## =====================================================================================
-
+#
 #### Image Classification - Accuracy Estimation with Post Training Quantization - MobileNetV2
 #python ./scripts/train_classification_main.py --phase validation --dataset_name image_folder_classification --model_name mobilenetv2_tv_x1 --data_path ./data/datasets/image_folder_classification \
 #--pretrained https://download.pytorch.org/models/mobilenet_v2-b0353104.pth \
 #--batch_size 64 --quantize True
 #### Image Classification - Accuracy Estimation with Post Training Quantization - MobileNetV2
 #python ./scripts/train_classification_main.py --phase validation --dataset_name image_folder_classification --model_name mobilenetv2_tv_x1 --data_path ./data/datasets/image_folder_classification \
 #--pretrained https://download.pytorch.org/models/mobilenet_v2-b0353104.pth \
 #--batch_size 64 --quantize True
-
+#
 #### Image Classification - Accuracy Estimation with Post Training Quantization - ResNet50
 #python ./scripts/train_classification_main.py --phase validation --dataset_name image_folder_classification --model_name resnet50_x1 --data_path ./data/datasets/image_folder_classification \
 #--pretrained https://download.pytorch.org/models/resnet50-19c8e357.pth \
 #--batch_size 64 --quantize True
 #### Image Classification - Accuracy Estimation with Post Training Quantization - ResNet50
 #python ./scripts/train_classification_main.py --phase validation --dataset_name image_folder_classification --model_name resnet50_x1 --data_path ./data/datasets/image_folder_classification \
 #--pretrained https://download.pytorch.org/models/resnet50-19c8e357.pth \
 #--batch_size 64 --quantize True
-
+#
 #### Image Classification - Accuracy Estimation with Post Training Quantization - ResNet18
 #python ./scripts/train_classification_main.py --phase validation --dataset_name image_folder_classification --model_name resnet18_x1 --data_path ./data/datasets/image_folder_classification \
 #--pretrained https://download.pytorch.org/models/resnet18-5c106cde.pth \
 #--batch_size 64 --quantize True
 #### Image Classification - Accuracy Estimation with Post Training Quantization - ResNet18
 #python ./scripts/train_classification_main.py --phase validation --dataset_name image_folder_classification --model_name resnet18_x1 --data_path ./data/datasets/image_folder_classification \
 #--pretrained https://download.pytorch.org/models/resnet18-5c106cde.pth \
 #--batch_size 64 --quantize True
-
+#
 #### Image Classification - Accuracy Estimation with Post Training Quantization - A TOUGH MobileNetV2 pretrained model
 #python ./scripts/train_classification_main.py --phase validation --dataset_name image_folder_classification --model_name mobilenetv2_shicai_x1 --data_path ./data/datasets/image_folder_classification \
 #--pretrained ./data/modelzoo/pytorch/image_classification/imagenet1k/shicai/mobilenetv2_shicai_rgb.pth \
 #--batch_size 64 --quantize True
 #### Image Classification - Accuracy Estimation with Post Training Quantization - A TOUGH MobileNetV2 pretrained model
 #python ./scripts/train_classification_main.py --phase validation --dataset_name image_folder_classification --model_name mobilenetv2_shicai_x1 --data_path ./data/datasets/image_folder_classification \
 #--pretrained ./data/modelzoo/pytorch/image_classification/imagenet1k/shicai/mobilenetv2_shicai_rgb.pth \
 #--batch_size 64 --quantize True
-
+#
 #### Semantic Segmentation - Accuracy Estimation with Post Training Quantization - MobileNetV2+DeeplabV3Lite
 #python ./scripts/train_segmentation_main.py --phase validation --dataset_name cityscapes_segmentation --model_name deeplabv3lite_mobilenetv2_tv --data_path ./data/datasets/cityscapes/data --img_resize 384 768 --output_size 1024 2048 --gpus 0 1 \
 #--pretrained './data/modelzoo/pytorch/semantic_seg/cityscapes/jacinto_ai/deeplabv3lite_mobilenetv2_tv_768x384_best.pth' \
 #--batch_size 1 --quantize True
 #### Semantic Segmentation - Accuracy Estimation with Post Training Quantization - MobileNetV2+DeeplabV3Lite
 #python ./scripts/train_segmentation_main.py --phase validation --dataset_name cityscapes_segmentation --model_name deeplabv3lite_mobilenetv2_tv --data_path ./data/datasets/cityscapes/data --img_resize 384 768 --output_size 1024 2048 --gpus 0 1 \
 #--pretrained './data/modelzoo/pytorch/semantic_seg/cityscapes/jacinto_ai/deeplabv3lite_mobilenetv2_tv_768x384_best.pth' \
 #--batch_size 1 --quantize True
-
+#
 #### Semantic Segmentation - Accuracy Estimation with Post Training Quantization - MobileNetV2+UNetLite
 #python ./scripts/train_segmentation_main.py --phase validation --dataset_name cityscapes_segmentation --model_name unetlite_pixel2pixel_aspp_mobilenetv2_tv --data_path ./data/datasets/cityscapes/data --img_resize 384 768 --output_size 1024 2048 --gpus 0 1 \
 #--pretrained ./data/modelzoo/pytorch/semantic_seg/cityscapes/jacinto_ai/unet_aspp_mobilenetv2_tv_768x384_best.pth \
 #### Semantic Segmentation - Accuracy Estimation with Post Training Quantization - MobileNetV2+UNetLite
 #python ./scripts/train_segmentation_main.py --phase validation --dataset_name cityscapes_segmentation --model_name unetlite_pixel2pixel_aspp_mobilenetv2_tv --data_path ./data/datasets/cityscapes/data --img_resize 384 768 --output_size 1024 2048 --gpus 0 1 \
 #--pretrained ./data/modelzoo/pytorch/semantic_seg/cityscapes/jacinto_ai/unet_aspp_mobilenetv2_tv_768x384_best.pth \