quantization cleanup and minor fixes
authorManu Mathew <a0393608@ti.com>
Mon, 27 Apr 2020 18:34:44 +0000 (00:04 +0530)
committerManu Mathew <a0393608@ti.com>
Mon, 27 Apr 2020 18:58:59 +0000 (00:28 +0530)
minor doc update

release commit

release commit

release commit

release commit

release commit

14 files changed:
README.md
modules/pytorch_jacinto_ai/engine/infer_classification_onnx_rt.py
modules/pytorch_jacinto_ai/engine/infer_pixel2pixel.py
modules/pytorch_jacinto_ai/engine/test_classification.py
modules/pytorch_jacinto_ai/engine/train_classification.py
modules/pytorch_jacinto_ai/engine/train_pixel2pixel.py
modules/pytorch_jacinto_ai/vision/models/resnet.py
modules/pytorch_jacinto_ai/xnn/quantize/__init__.py
modules/pytorch_jacinto_ai/xnn/quantize/quant_base_module.py
modules/pytorch_jacinto_ai/xnn/quantize/quant_calib_module.py [new file with mode: 0644]
modules/pytorch_jacinto_ai/xnn/quantize/quant_graph_module.py
modules/pytorch_jacinto_ai/xnn/quantize/quant_test_module.py
modules/pytorch_jacinto_ai/xnn/quantize/quant_train_module.py
run_quantization.sh

index 96713c4e6d103c5bb3b81d8ed93b4557d2bb3871..abc93ff3d29eed0c44e602a5b27f771049a8acdc 100644 (file)
--- a/README.md
+++ b/README.md
@@ -27,19 +27,24 @@ This code also includes tools for **Quantization Aware Training** that can outpu
     ```
 
 ## Examples
-The following examples are currently available. Click on each of the links below to go into the full description of the example. 
-* Image Classification<br>
-    * [**Image Classification**](docs/Image_Classification.md)<br>
-* Pixel2Pixel Prediction<br>
-    * [**Semantic Segmentation**](docs/Semantic_Segmentation.md)<br>
-    * [Depth Estimation](docs/Depth_Estimation.md)<br>
-    * [Motion Segmentation](docs/Motion_Segmentation.md)<br>
-    * [**Multi Task Estimation**](docs/Multi_Task_Learning.md)<br>
-* Object Detection<br>
-    * Object Detection - coming soon..<br>
-    * Object Keypoint Estimation - coming soon..<br>
-* Quantization<br>
-    * [**Quantization Aware Training**](docs/Quantization.md)<br>
+The following examples are currently available. Click on each of the links below to go into the full description of the example.<br>
+    <br>
+* **Image Classification**<br>
+    - [**Image Classification**](docs/Image_Classification.md)<br>
+    <br>
+* **Pixel2Pixel Prediction**<br>
+    - [**Semantic Segmentation**](docs/Semantic_Segmentation.md)<br>
+    - [Depth Estimation](docs/Depth_Estimation.md)<br>
+    - [Motion Segmentation](docs/Motion_Segmentation.md)<br>
+    - [**Multi Task Estimation**](docs/Multi_Task_Learning.md)<br>
+    <br>
+* **Object Detection**<br>
+    - Object Detection - coming soon..<br>
+    - Object Keypoint Estimation - coming soon..<br>
+    <br>
+* **Quantization**<br>
+    - [**Quantization Aware Training**](docs/Quantization.md)<br>
+    <br>
 
 
 Some of the common training and validation commands are provided in shell scripts (.sh files) in the root folder.
index f1129cfd456a67a1e32a73d6eedf772770b8de21..b637955e8387176dcae1929777097211a3ac15f8 100644 (file)
@@ -54,7 +54,7 @@ def get_config():
     args.dataset_format = 'folder'                      # dataset format, choices=['folder','lmdb']
     args.count_flops = True                             # count flops and report
 
-    args.lr_calib = 0.                                # lr for bias calibration
+    args.lr_calib = 0.05                                # lr for bias calibration
 
     args.rand_seed = 1                                  # random seed
     args.generate_onnx = False                          # apply quantized inference or not
index 94c01b8f476260fc7ce0357cb37222ba2405eb98..d439c28b1bb9e73cfed13d415277fb9ed8fa127d 100644 (file)
@@ -283,8 +283,8 @@ def main(args):
         # Note: bias_calibration is not enabled in test
         model = xnn.quantize.QuantTestModule(model, per_channel_q=args.per_channel_q,
                         bitwidth_weights=args.bitwidth_weights, bitwidth_activations=args.bitwidth_activations,
-                        histogram_range=args.histogram_range, model_surgery_quantize=model_surgery_quantize,
-                        dummy_input=dummy_input)
+                        histogram_range=args.histogram_range, dummy_input=dummy_input,
+                        model_surgery_quantize=model_surgery_quantize)
 
     # load pretrained weights
     xnn.utils.load_weights(get_model_orig(model), pretrained=pretrained_data, change_names_dict=change_names_dict)
index 534c1110678945e487ba29f903e2a5dd4d4e454f..4674bfe908b39125f70ef9a8a527db3afe85ac52 100644 (file)
@@ -51,7 +51,7 @@ def get_config():
     args.dataset_format = 'folder'                      # dataset format, choices=['folder','lmdb']
     args.count_flops = True                             # count flops and report
 
-    args.lr_calib = 0.                                # lr for bias calibration
+    args.lr_calib = 0.05                                # lr for bias calibration
 
     args.rand_seed = 1                                  # random seed
     args.generate_onnx = False                          # apply quantized inference or not
@@ -153,14 +153,13 @@ def main(args):
         elif 'calibration' in args.phase:
             model = xnn.quantize.QuantCalibrateModule(model, per_channel_q=args.per_channel_q,
                         bitwidth_weights=args.bitwidth_weights, bitwidth_activations=args.bitwidth_activations,
-                        bias_calibration=args.bias_calibration, lr_calib=args.lr_calib,
-                        dummy_input=dummy_input)
+                        bias_calibration=args.bias_calibration, dummy_input=dummy_input, lr_calib=args.lr_calib)
         elif 'validation' in args.phase:
             # Note: bias_calibration is not enabled in test
             model = xnn.quantize.QuantTestModule(model, per_channel_q=args.per_channel_q,
                         bitwidth_weights=args.bitwidth_weights, bitwidth_activations=args.bitwidth_activations,
-                        histogram_range=args.histogram_range, model_surgery_quantize=model_surgery_quantize,
-                        dummy_input=dummy_input)
+                        histogram_range=args.histogram_range, dummy_input=dummy_input,
+                        model_surgery_quantize=model_surgery_quantize)
         else:
             assert False, f'invalid phase {args.phase}'
     #
index a2e14af09e58384962707be21548aa49e822f987..58e8be83a2af4c556ede9a712cb0e521130fe613 100644 (file)
@@ -56,8 +56,8 @@ def get_config():
     args.iter_size = 1                                  # iteration size. total_batch_size = batch_size*iter_size
 
     args.lr = 0.1                                       # initial learning rate
-    args.lr_clips = None                                 # use args.lr itself if it is None
-    args.lr_calib = 0.                                # lr for bias calibration
+    args.lr_clips = None                                # use args.lr itself if it is None
+    args.lr_calib = 0.05                                # lr for bias calibration
     args.momentum = 0.9                                 # momentum
     args.weight_decay = 1e-4                            # weight decay (default: 1e-4)
     args.bias_decay = None                              # bias decay (default: 0.0)
@@ -233,14 +233,14 @@ def main(args):
         elif 'calibration' in args.phase:
             model = xnn.quantize.QuantCalibrateModule(model, per_channel_q=args.per_channel_q,
                         bitwidth_weights=args.bitwidth_weights, bitwidth_activations=args.bitwidth_activations,
-                        histogram_range=args.histogram_range, bias_calibration=args.bias_calibration, lr_calib=args.lr_calib,
-                        dummy_input=dummy_input)
+                        histogram_range=args.histogram_range, bias_calibration=args.bias_calibration, dummy_input=dummy_input,
+                        lr_calib=args.lr_calib)
         elif 'validation' in args.phase:
             # Note: bias_calibration is not used in test
             model = xnn.quantize.QuantTestModule(model, per_channel_q=args.per_channel_q,
                         bitwidth_weights=args.bitwidth_weights, bitwidth_activations=args.bitwidth_activations,
-                        histogram_range=args.histogram_range, model_surgery_quantize=model_surgery_quantize,
-                        dummy_input=dummy_input)
+                        histogram_range=args.histogram_range, dummy_input=dummy_input,
+                        model_surgery_quantize=model_surgery_quantize)
         else:
             assert False, f'invalid phase {args.phase}'
     #
index db1b59ca3d748bfb65fe435661838e64d078995e..775bc484341477e0fe5bc4343f6a10edfcf88c70 100644 (file)
@@ -84,8 +84,8 @@ def get_config():
     args.iter_size = 1                                  # iteration size. total_batch_size = batch_size*iter_size
 
     args.lr = 1e-4                                      # initial learning rate
-    args.lr_clips = None                                 # use args.lr itself if it is None
-    args.lr_calib = 0.                                # lr for bias calibration
+    args.lr_clips = None                                # use args.lr itself if it is None
+    args.lr_calib = 0.05                                # lr for bias calibration
     args.warmup_epochs = 5                              # number of epochs to warmup
 
     args.momentum = 0.9                                 # momentum for sgd, alpha parameter for adam
@@ -356,13 +356,14 @@ def main(args):
         elif 'calibration' in args.phase:
             model = xnn.quantize.QuantCalibrateModule(model, per_channel_q=args.per_channel_q,
                         bitwidth_weights=args.bitwidth_weights, bitwidth_activations=args.bitwidth_activations,
-                        histogram_range=args.histogram_range, bias_calibration=args.bias_calibration, lr_calib=args.lr_calib, dummy_input=dummy_input)
+                        histogram_range=args.histogram_range, bias_calibration=args.bias_calibration,
+                        dummy_input=dummy_input, lr_calib=args.lr_calib)
         elif 'validation' in args.phase:
             # Note: bias_calibration is not emabled
             model = xnn.quantize.QuantTestModule(model, per_channel_q=args.per_channel_q,
                         bitwidth_weights=args.bitwidth_weights, bitwidth_activations=args.bitwidth_activations,
-                        histogram_range=args.histogram_range, model_surgery_quantize=model_surgery_quantize,
-                        dummy_input=dummy_input)
+                        histogram_range=args.histogram_range, dummy_input=dummy_input,
+                        model_surgery_quantize=model_surgery_quantize)
         else:
             assert False, f'invalid phase {args.phase}'
     #
index f4f4026fffb3d47d93036696a633214cf730e17f..333789f66860dd3e292f59c199e56368fbc9e1df 100644 (file)
@@ -378,6 +378,7 @@ def get_config():
     model_config = xnn.utils.ConfigNode()
     model_config.input_channels = 3
     model_config.num_classes = 1000
+    model_config.width_mult = 1.0
     model_config.strides = None #(2,2,2,2,2)
     model_config.fastdown = False
     return model_config
index 7a892a1a138abdfe1aa27088398c542256e6d9a7..b55c304f2634676a2a6c929b05bcaea7046089e6 100644 (file)
@@ -1,4 +1,5 @@
 from .quant_utils import *
 from .quant_train_module import *
+from .quant_calib_module import *
 from .quant_test_module import *
 
index 48a6fca4fa5673d93ad6e834e844356c67c84a29..6e50118d1c3ea9e9bc8e48cace6fedfc2d25fddb 100644 (file)
@@ -10,8 +10,50 @@ class QuantEstimationType:
 
 # base module to be use for all quantization modules
 class QuantBaseModule(QuantGraphModule):
-    def __init__(self, module):
+    def __init__(self, module, bitwidth_weights=8, bitwidth_activations=8, per_channel_q=False,
+                 histogram_range=True, bias_calibration=False, constrain_weights=False, dummy_input=None,
+                 model_surgery_quantize=False):
         super().__init__(module)
+        self.bitwidth_weights = bitwidth_weights
+        self.bitwidth_activations = bitwidth_activations
+        self.per_channel_q = per_channel_q
+        self.histogram_range = histogram_range
+        self.constrain_weights = constrain_weights
+        self.bias_calibration = bias_calibration
+        # for help in debug/print
+        utils.add_module_names(self)
+        # put in eval mode before analyze
+        self.eval()
+        # model surgery for quantization
+        if model_surgery_quantize:
+            with torch.no_grad():
+                utils.print_yellow("=> model surgery by '{}'".format(self.model_surgery_quantize.__name__))
+                assert dummy_input is not None, 'dummy input is needed by quantized models to analyze graph'
+                self.model_surgery_quantize(dummy_input)
+            #
+            # add hooks to execute the pact modules
+            self.add_activation_hooks()
+        #
+        # for help in debug/print
+        utils.add_module_names(self)
+
+
+    def add_activation_hooks(self):
+        # add a forward hook to call the extra activation that we added
+        def _forward_activation(op, inputs, outputs):
+            if hasattr(op, 'activation_q'):
+                outputs = op.activation_q(outputs)
+            #
+            return outputs
+        #
+        for m in self.modules():
+            m.register_forward_hook(_forward_activation)
+        #
+
+
+    def train(self, mode=True):
+        self.iter_in_epoch.fill_(-1.0)
+        super().train(mode)
 
 
     def _backup_weights_orig(self):
@@ -42,7 +84,6 @@ class QuantBaseModule(QuantGraphModule):
             self.__buffers_quant__[n] = copy.deepcopy(p.data)
         #
 
-
     def _restore_weights_quant(self):
         for n,p in self.named_parameters():
             p.data.copy_(self.__params_quant__[n].data)
@@ -70,4 +111,9 @@ class QuantBaseModule(QuantGraphModule):
             #
         #
         self.apply(_remove_output_means_op)
-    #
+
+
+
+
+
+
diff --git a/modules/pytorch_jacinto_ai/xnn/quantize/quant_calib_module.py b/modules/pytorch_jacinto_ai/xnn/quantize/quant_calib_module.py
new file mode 100644 (file)
index 0000000..2ab33d8
--- /dev/null
@@ -0,0 +1,159 @@
+###########################################################
+# Approximate quantized floating point simulation with gradients.
+# Can be used for quantized training of models.
+###########################################################
+
+import torch
+import numpy as np
+import copy
+import warnings
+
+from .. import layers
+from .. import utils
+from .quant_train_module import *
+
+warnings.filterwarnings('ignore', category=torch.jit.TracerWarning)
+
+
+###########################################################
+class QuantCalibrateModule(QuantTrainModule):
+    def __init__(self, module, bitwidth_weights=8, bitwidth_activations=8, per_channel_q=False,
+                 histogram_range=True, bias_calibration=True, constrain_weights=True, dummy_input=None, lr_calib=0.05):
+        self.bias_calibration = bias_calibration
+        self.weights_calibration = False
+        self.lr_calib = lr_calib
+        self.calibration_factor = lr_calib
+        self.calibration_gamma = 0.5
+        self.calibrate_repeats = 1
+        self.quantize_enable = True
+        self.update_range = True
+        super().__init__(module, bitwidth_weights=bitwidth_weights, bitwidth_activations=bitwidth_activations,
+                         per_channel_q=per_channel_q, histogram_range=histogram_range, bias_calibration=bias_calibration,
+                         constrain_weights=constrain_weights, dummy_input=dummy_input)
+
+
+    def forward(self, inputs):
+        # since this does not involve training, we can set merge_weights=True
+        self.analyze_graph(inputs=inputs, cleanup_states=True, merge_weights=True)
+
+        # actual forward call
+        if self.training and (self.bias_calibration or self.weights_calibration):
+            # calibration
+            outputs = self.forward_calibrate(inputs)
+        else:
+            outputs = self.module(inputs)
+        #
+        return outputs
+
+
+    def forward_calibrate(self, inputs):
+        # we don't need gradients for calibration
+        # prepare/backup weights
+        if self.num_batches_tracked == 0:
+            # lr_calib
+            self.calibration_factor = self.lr_calib * np.power(self.calibration_gamma, float(self.epoch))
+            # backup original weights
+            self._backup_weights_orig()
+            # backup quantized weights
+            self._backup_weights_quant()
+        #
+
+        # backup the current state
+        training = self.training
+
+        # Set all bns to eval so that they do not change. We can't set the whole model to eval because,
+        # we need the pact to learn the ranges - which will happen only in training mode.
+        # Also the model output itself may be different in eval mode (in certain cases -
+        # for example if in a segmentation model argmax is done instead of softmax in eval mode).
+        utils.freeze_bn(self)
+
+        # Compute the mean output in float first.
+        with torch.no_grad():
+            outputs = self.forward_float(inputs)
+        #
+
+        # Then adjust weights/bias so that the quantized output matches float output
+        if self.weights_calibration:
+            outputs = self.forward_quantized(inputs)
+        else:
+            with torch.no_grad():
+                outputs = self.forward_quantized(inputs)
+            #
+        #
+
+        self.train(training)
+        return outputs
+
+
+    def forward_float(self, inputs):
+        self._restore_weights_orig()
+        # disable quantization for a moment
+        quantize_enable_backup_value, update_range_backup_value = self.quantize_enable, self.update_range
+        utils.apply_setattr(self, quantize_enable=False, update_range=False)
+
+        self.add_call_hook(self.module, self.forward_float_hook)
+        outputs = self.module(inputs)
+        self.remove_call_hook(self.module)
+
+        # turn quantization back on - not a clean method
+        utils.apply_setattr(self, quantize_enable=quantize_enable_backup_value, update_range=update_range_backup_value)
+        self._backup_weights_orig()
+        return outputs
+    #
+    def forward_float_hook(self, op, *inputs_orig):
+        outputs = op.__forward_orig__(*inputs_orig)
+
+        # calibration at specific layers
+        output = outputs[0] if isinstance(outputs, (list, tuple)) else outputs
+        reduce_dims = [0, 2, 3] if len(output.size()) == 4 else ([0] if len(output.size()) == 2 else None)
+
+        bias = op.bias if hasattr(op, 'bias') else None
+        if (self.bias_calibration and bias is not None):
+            op.__output_mean_orig__ = torch.mean(output, dim=reduce_dims).data
+        #
+
+        if self.weights_calibration and utils.is_conv_deconv(op):
+            op.__output_std_orig__ = torch.std(output, dim=reduce_dims).data
+        #
+        return outputs
+    #
+
+
+    def forward_quantized(self, input):
+        self._restore_weights_quant()
+        self.add_call_hook(self.module, self.forward_quantized_hook)
+        for _ in range(self.calibrate_repeats):
+            output = self.module(input)
+        #
+        self.remove_call_hook(self.module)
+        self._backup_weights_quant()
+        return output
+    #
+    def forward_quantized_hook(self, op, *inputs_orig):
+        outputs = op.__forward_orig__(*inputs_orig)
+
+        # calibration at specific layers
+        output = outputs[0] if isinstance(outputs, (list, tuple)) else outputs
+
+        bias = op.bias if hasattr(op, 'bias') else None
+        if self.bias_calibration and bias is not None:
+            reduce_dims = [0,2,3] if len(output.size()) == 4 else ([0] if len(output.size()) == 2 else None)
+            output_mean = torch.mean(output, dim=reduce_dims).data
+            output_delta = op.__output_mean_orig__ - output_mean
+            output_delta = output_delta * self.calibration_factor
+            bias.data += (output_delta)
+        #
+
+        if self.weights_calibration and utils.is_conv_deconv(op):
+            eps = 1e-6
+            weight = op.weight
+            reduce_dims = [0, 2, 3] if len(output.size()) == 4 else ([0] if len(output.size()) == 2 else None)
+            output_std = torch.std(output, dim=reduce_dims).data
+            output_ratio = (op.__output_std_orig__ + eps) / (output_std + eps)
+            channels = output.size(1)
+            output_ratio = output_ratio.view(channels, 1, 1, 1) if len(weight.data.size()) > 1 else output_ratio
+            output_ratio = torch.pow(output_ratio, self.calibration_factor)
+            weight.data *= output_ratio
+        #
+        return outputs
+
index ba907754d42dc76c0e2ce80ce0b60f924d262f19..d4f5f7d9d89d651e5197599cb77f7933b028d682 100644 (file)
@@ -143,6 +143,8 @@ class QuantGraphModule(HookedModule):
                         activation_q = layers.PAct2(signed=False)
                     elif isinstance(module, torch.nn.Hardtanh):
                         activation_q = layers.PAct2(clip_range=(module.min_val, module.max_val))
+                    elif isinstance(module, layers.NoAct):
+                        activation_q = layers.PAct2(signed=None)
                     else:
                         activation_q = layers.PAct2(signed=None)
                     #
index 5814f025cb1e7836b0930616b86a707d224e2dfe..7166caade59c444511edbe57ec147cb33eba2fab 100644 (file)
@@ -9,17 +9,13 @@ from .quant_utils import *
 
 
 class QuantTestModule(QuantBaseModule):
-    def __init__(self, module, bitwidth_weights=8, bitwidth_activations=8, per_channel_q=False, power2_weights=True, histogram_range=True,
-                 range_calibration_online=False, bias_calibration=False, model_surgery_quantize=False, dummy_input=None):
-        super().__init__(module)
-        assert dummy_input is not None, 'dummy input is needed by quantized models to analyze graph'
-        self.bitwidth_weights = bitwidth_weights
-        self.bitwidth_activations = bitwidth_activations
-        self.per_channel_q = per_channel_q
-        self.power2_weights = power2_weights
-        # this is actually to indicate the bias calibration - indicates this was called from the derived class
-        self.bias_calibration = bias_calibration
-
+    def __init__(self, module, bitwidth_weights=8, bitwidth_activations=8, per_channel_q=False, histogram_range=True,
+                 range_calibration_online=False, bias_calibration=False, dummy_input=None, model_surgery_quantize=False):
+        super().__init__(module, bitwidth_weights=bitwidth_weights, bitwidth_activations=bitwidth_activations,
+                         per_channel_q=per_channel_q, histogram_range=histogram_range, bias_calibration=bias_calibration,
+                         constrain_weights=False, dummy_input=dummy_input, model_surgery_quantize=model_surgery_quantize)
+        # use power2_weights for now
+        self.power2_weights = True
         # whether to do online adjustment of calibration using previous frame range
         self.range_calibration_online = range_calibration_online
         # number of offline calibration iters. during offline calibration, current frame range is used
@@ -39,22 +35,6 @@ class QuantTestModule(QuantBaseModule):
 
         self.idx_large_mse_for_act = 0
 
-        # put in eval mode before analyze
-        self.eval()
-
-        with torch.no_grad():
-            # model surgery for quantization
-            if model_surgery_quantize:
-                utils.print_yellow("=> model surgery by '{}'".format(self.model_surgery_quantize.__name__))
-                self.model_surgery_quantize(dummy_input)
-            #
-        #
-
-        # for help in debug/print
-        for n, m in self.named_modules():
-            m.name = n
-        #
-
 
     def model_surgery_quantize(self, dummy_input):
         super().model_surgery_quantize(dummy_input)
@@ -85,17 +65,6 @@ class QuantTestModule(QuantBaseModule):
         # apply recursively
         self.apply(replace_func)
 
-        # add a forward hook to call the extra activation that we added
-        def _forward_activation(op, inputs, outputs):
-            if hasattr(op, 'activation_q'):
-                outputs = op.activation_q(outputs)
-            #
-            return outputs
-        #
-        for m in self.modules():
-            m.register_forward_hook(_forward_activation)
-        #
-
         # clear
         self.clear_states()
     #
@@ -155,7 +124,7 @@ class QuantTestModule(QuantBaseModule):
 
 
     # implement this in a derived class to clamp weights
-    def constrain_weights(self, module):
+    def apply_constrain_weights(self, module):
         pass
 
 
@@ -330,7 +299,7 @@ class QuantTestModule(QuantBaseModule):
 
 
     def quantize_weights(self, module, tensor_in, qrange):
-        self.constrain_weights(module)
+        self.apply_constrain_weights(module)
 
         bitwidth_weights = self.get_bitwidth_weights(module)
         with torch.no_grad():
@@ -568,11 +537,11 @@ class QuantTestModule(QuantBaseModule):
 #     #
 #
 #
-#     def constrain_weights(self, module):
+#     def apply_constrain_weights(self, module):
 #         if not self.constrain_weights_enabled:
 #             return
 #         #
-#         constrained_weight = constrain_weight(module.weight.data)
+#         constrained_weight = quant_utils.constrain_weight(module.weight.data)
 #         module.weight.data.copy_(constrained_weight.data)
 #     #
 #
index 35144a3b8b30e1ba58a09bca5db34f52605d925a..2344b653ed05a8ab7c2cc5739409ad59a8eab720 100644 (file)
@@ -10,7 +10,8 @@ import warnings
 
 from .. import layers
 from .. import utils
-from .quant_train_utils import *
+from . import quant_utils
+from .quant_base_module import *
 
 warnings.filterwarnings('ignore', category=torch.jit.TracerWarning)
 
@@ -18,27 +19,10 @@ warnings.filterwarnings('ignore', category=torch.jit.TracerWarning)
 ###########################################################
 class QuantTrainModule(QuantBaseModule):
     def __init__(self, module, bitwidth_weights=8, bitwidth_activations=8, per_channel_q=False, histogram_range=True,
-                 constrain_weights=True, bias_calibration=False, dummy_input=None):
-        super().__init__(module)
-        assert dummy_input is not None, 'dummy input is needed by quantized models to analyze graph'
-        self.bitwidth_weights = bitwidth_weights
-        self.bitwidth_activations = bitwidth_activations
-        self.per_channel_q = per_channel_q
-        self.constrain_weights = constrain_weights #and (not bool(self.per_channel_q))
-        self.bias_calibration = bias_calibration
-
-        # for help in debug/print
-        utils.add_module_names(self)
-
-        # put in eval mode before analyze
-        self.eval()
-
-        with torch.no_grad():
-            # model surgery for quantization
-            utils.print_yellow("=> model surgery by '{}'".format(self.model_surgery_quantize.__name__))
-            self.model_surgery_quantize(dummy_input)
-        #
-
+                 bias_calibration=False, constrain_weights=True, dummy_input=None):
+        super().__init__(module, bitwidth_weights=bitwidth_weights, bitwidth_activations=bitwidth_activations,
+                         per_channel_q=per_channel_q, histogram_range=histogram_range, bias_calibration=bias_calibration,
+                         constrain_weights=constrain_weights, dummy_input=dummy_input, model_surgery_quantize=True)
         # range shrink - 0.0 indicates no shrink
         percentile_range_shrink = (layers.PAct2.PACT2_RANGE_SHRINK if histogram_range else 0.0)
         # set attributes to all modules - can control the behaviour from here
@@ -47,33 +31,16 @@ class QuantTrainModule(QuantBaseModule):
                             percentile_range_shrink=percentile_range_shrink, constrain_weights=self.constrain_weights,
                             update_range=True, quantize_enable=True, quantize_weights=True, quantize_bias=True, quantize_activations=True)
 
-        # for help in debug/print
-        utils.add_module_names(self)
-
-
-    def train(self, mode=True):
-        self.iter_in_epoch.fill_(-1.0)
-        super().train(mode)
-
-
     def forward(self, inputs):
         # analyze
+        # since this involves training, we cannot set merge_weights=True
+        # merging weights modifies bn and training after that unstable.
         self.analyze_graph(inputs=inputs, cleanup_states=True)
-
-        # actual forward call
-        if self.training and self.bias_calibration:
-            # bias calibration
-            outputs = self.forward_calibrate_bias(inputs)
-        else:
-            outputs = self.module(inputs)
-        #
+        # outputs
+        outputs = self.module(inputs)
         return outputs
 
 
-    def forward_calibrate_bias(self, inputs):
-        assert False, 'forward_calibrate_bias is not implemented'
-
-
     def model_surgery_quantize(self, dummy_input):
         super().model_surgery_quantize(dummy_input)
 
@@ -124,154 +91,415 @@ class QuantTrainModule(QuantBaseModule):
         # apply recursively
         self.apply(replace_func)
 
-        # add a forward hook to call the extra activation that we added
-        def _forward_activation(op, inputs, outputs):
-            if hasattr(op, 'activation_q'):
-                outputs = op.activation_q(outputs)
-            #
-            return outputs
+        # clear
+        self.clear_states()
+    #
+
+
+
+
+###########################################################
+class QuantTrainParams:
+    pass
+
+
+def get_qparams():
+    qparams = QuantTrainParams()
+    qparams.inputs = []
+    qparams.modules = []
+    return qparams
+
+
+def is_merged_layer(x):
+    is_merged = (hasattr(x, 'qparams') and isinstance(x.qparams, QuantTrainParams) and len(x.qparams.modules)>0)
+    return is_merged
+
+
+###########################################################
+class QuantTrainConv2d(torch.nn.Conv2d):
+    def __init__(self, *args, **kwargs):
+        super().__init__(*args, **kwargs)
+        self.quantize_enable = True
+        self.bitwidth_weights = None
+        self.bitwidth_activations = None
+        self.per_channel_q = False
+
+    def forward(self, x):
+        is_merged = is_merged_layer(x)
+        if is_merged:
+           warnings.warn('please see if a PAct can be inserted before this module to collect ranges')
         #
-        for m in self.modules():
-            m.register_forward_hook(_forward_activation)
+
+        y = super().forward(x)
+
+        if not self.quantize_enable:
+            # if quantization is disabled - return
+            return y
         #
 
-        # clear
-        self.clear_states()
+        qparams = get_qparams()
+        qparams.inputs.append(x)
+        qparams.modules.append(self)
+        y.qparams = qparams
+        #
+        return y
     #
 
 
 ###########################################################
-class QuantCalibrateModule(QuantTrainModule):
-    def __init__(self, module, bitwidth_weights=8, bitwidth_activations=8, per_channel_q=False, bias_calibration=True,
-                 histogram_range=True, constrain_weights=True, lr_calib=0.1, dummy_input=None):
-        self.bias_calibration = bias_calibration
-        self.lr_calib = lr_calib
-        self.bias_calibration_factor = lr_calib
-        self.bias_calibration_gamma = 0.5
-        self.calibrate_weights = False
-        self.calibrate_repeats = 1
+class QuantTrainConvTranspose2d(torch.nn.ConvTranspose2d):
+    def __init__(self, *args, **kwargs):
+        super().__init__(*args, **kwargs)
         self.quantize_enable = True
-        self.update_range = True
-        # BNs can be adjusted based on the input provided - however this is not really required
-        self.calibrate_bn = False
-        super().__init__(module, bitwidth_weights=bitwidth_weights, bitwidth_activations=bitwidth_activations, per_channel_q=per_channel_q,
-                         histogram_range=histogram_range, constrain_weights=constrain_weights, bias_calibration=bias_calibration, dummy_input=dummy_input)
+        self.bitwidth_weights = None
+        self.bitwidth_activations = None
+        self.per_channel_q = False
+
+    def forward(self, x):
+        is_merged = is_merged_layer(x)
+        if is_merged:
+           warnings.warn('please see if a PAct can be inserted before this module to collect ranges')
+        #
+
+        y = super().forward(x)
+
+        if not self.quantize_enable:
+            # if quantization is disabled - return
+            return y
+        #
+
+        qparams = get_qparams()
+        qparams.inputs.append(x)
+        qparams.modules.append(self)
+        y.qparams = qparams
+        #
+        return y
     #
 
 
-    def forward_calibrate_bias(self, inputs):
-        # we don't need gradients for calibration
-        with torch.no_grad():
-            # prepare/backup weights
-            if self.num_batches_tracked == 0:
-                # lr_calib
-                self.bias_calibration_factor = self.lr_calib * np.power(self.bias_calibration_gamma, float(self.epoch))
-                # backup original weights
-                self._backup_weights_orig()
-                # backup quantized weights
-                self._backup_weights_quant()
-            #
+###########################################################
+class QuantTrainBatchNorm2d(torch.nn.BatchNorm2d):
+    def __init__(self, *args, **kwargs):
+        super().__init__(*args, **kwargs)
+        self.quantize_enable = True
 
-            # backup the current state
-            training = self.training
 
-            # compute the mean output in float
-            # also, set all bns to eval. we can't set the whole model to eval because
-            # we need the pact to learn the ranges - which will happen only in training mode.
-            # also the model output itself may be different in eval mode.
-            if self.calibrate_bn:
-                outputs = self.forward_compute_oputput_stats(inputs)
-                utils.freeze_bn(self)
-            else:
-                utils.freeze_bn(self)
-                outputs = self.forward_compute_oputput_stats(inputs)
-            #
+    def forward(self, x):
+        y = super().forward(x)
 
-            # adjust the quantized output to match the mean
-            outputs = self.forward_adjust_bias(inputs)
+        if not self.quantize_enable:
+            # if quantization is disabled - return
+            return y
+        #
 
-            self.train(training)
+        if is_merged_layer(x) and utils.is_conv_deconv(x.qparams.modules[-1]):
+            qparams = get_qparams()
+            qparams.inputs = [x.qparams.inputs[0], x]
+            qparams.modules = [x.qparams.modules[0], self]
+            y.qparams = qparams
+        #
+
+        return y
+    #
 
-            return outputs
+
+###########################################################
+# fake quantized PAct2 for training
+class QuantTrainPAct2(layers.PAct2):
+    def __init__(self, inplace=False, signed=False, clip_range=None, bitwidth_weights=None, bitwidth_activations=None, per_channel_q=False):
+        super().__init__(inplace=inplace, signed=signed, clip_range=clip_range)
+
+        self.bitwidth_weights = bitwidth_weights
+        self.bitwidth_activations = bitwidth_activations
+        self.per_channel_q = per_channel_q
+        # weight shrinking is done by clamp weights - set this factor to zero.
+        # this must me zero - as in pact we do not have the actual weight param, but just a temporary tensor
+        # so any clipping we do here is not stored int he weight params
+        self.range_shrink_weights = 0.0
+        self.round_dither = 0.0
+        self.update_range = True
+        self.quantize_enable = True
+        self.quantize_weights = True
+        self.quantize_bias = True
+        self.quantize_activations = True
+        self.constrain_weights = True
+        self.bias_calibration = False
+        # save quantized weight/bias once in a while into the params - not needed
+        self.params_save_frequency = None #(10 if self.bias_calibration else None)
+
+        # set to STRAIGHT_THROUGH_ESTIMATION or ALPHA_BLENDING_ESTIMATION
+        # For a comparison of STE and ABE, read:
+        # Learning low-precision neural networks without Straight-Through Estimator (STE):
+        # https://arxiv.org/pdf/1903.01061.pdf
+        self.quantized_estimation_type = QuantEstimationType.STRAIGHT_THROUGH_ESTIMATION
+        self.alpha_blending_estimation_factor = 0.5
+
+        if (layers.PAct2.PACT2_RANGE_LEARN):
+            assert self.quantized_estimation_type != QuantEstimationType.STRAIGHT_THROUGH_ESTIMATION, \
+                'straight through estimation should not used when PACT clip values are being learned as it doesnt backpropagate gradients though quantization'
         #
 
 
-    def forward_compute_oputput_stats(self, inputs):
-        self._restore_weights_orig()
-        # disable quantization for a moment
-        quantize_enable_backup_value, update_range_backup_value = self.quantize_enable, self.update_range
-        utils.apply_setattr(self, quantize_enable=False, update_range=False)
+    def forward(self, x):
+        assert (self.bitwidth_weights is not None) and (self.bitwidth_activations is not None), \
+                        'bitwidth_weights and bitwidth_activations must not be None'
 
-        self.add_call_hook(self.module, self._forward_compute_oputput_stats_hook)
-        outputs = self.module(inputs)
-        self.remove_call_hook(self.module)
+        # the pact range update happens here - but range clipping depends on quantize_enable
+        y = super().forward(x, update_range=self.update_range, enable=self.quantize_enable)
 
-        # turn quantization back on - not a clean method
-        utils.apply_setattr(self, quantize_enable=quantize_enable_backup_value, update_range=update_range_backup_value)
-        self._backup_weights_orig()
-        return outputs
-    #
-    def _forward_compute_oputput_stats_hook(self, op, *inputs_orig):
-        outputs = op.__forward_orig__(*inputs_orig)
-        # calibration at specific layers
-        bias = op.bias if (hasattr(op, 'bias') and (op.bias is not None)) else None
-        weight = op.weight if (hasattr(op, 'weight') and (op.weight is not None)) else None
-        if (bias is not None) or (self.calibrate_weights and weight is not None):
-            output = outputs[0] if isinstance(outputs, (list,tuple)) else outputs
-            reduce_dims = [0,2,3] if len(output.size()) == 4 else ([0] if len(output.size()) == 2 else None)
-            op.__output_mean_orig__ = torch.mean(output, dim=reduce_dims)
-            op.__output_std_orig__ = torch.std(output, dim=reduce_dims)
+        if not self.quantize_enable:
+            return y
+        #
+
+        # previous intermediate outputs and other infoirmation are avaliable
+        # for example - conv-bn-relu may need to be merged together.
+        is_merged = is_merged_layer(x)
+        if is_merged:
+            qparams = x.qparams
+            xorg = qparams.inputs[0]
+
+            conv, bn = None, None
+            # merge weight and bias (if possible) across layers
+            if len(qparams.modules) == 2 and utils.is_conv_deconv(qparams.modules[-2]) and isinstance(
+                    qparams.modules[-1], torch.nn.BatchNorm2d):
+                conv = qparams.modules[-2]
+                bn = qparams.modules[-1]
+            elif len(qparams.modules) == 1 and utils.is_conv_deconv(qparams.modules[-1]):
+                conv = qparams.modules[-1]
+            elif len(qparams.modules) == 1 and isinstance(qparams.modules[-1], torch.nn.BatchNorm2d):
+                assert False, f'quantization: previous layer is a BN without Conv {qparams.modules} - prease inspect the model carefully'
+                bn = qparams.modules[-1]
+            #
+            else:
+                assert False, f'QuantTrainPAct2: both conv & bn layes cannot be None in a merged scenario - prease inspect the model carefully'
+            #
+            conv, weight, bias = self.merge_quantize_weights(conv, bn)
+        else:
+            conv, weight, bias = None, None, None
+        #
+
+        if is_merged and utils.is_conv(conv):
+            xq = torch.nn.functional.conv2d(xorg, weight, bias, stride=conv.stride, padding=conv.padding, dilation=conv.dilation, groups=conv.groups)
+        elif is_merged and utils.is_deconv(conv):
+            xq = torch.nn.functional.conv_transpose2d(xorg, weight, bias, stride=conv.stride, padding=conv.padding, output_padding=conv.output_padding,
+                                                      dilation=conv.dilation, groups=conv.groups)
+        else:
+            xq = x
         #
-        return outputs
-    #
 
+        if (self.quantize_enable and self.quantize_activations):
+            clip_min, clip_max, scale, scale_inv = self.get_clips_scale_act()
+            width_min, width_max = self.get_widths_act()
+            # no need to call super().forward here as clipping with width_min/windth_max-1 after scaling has the same effect.
+            # currently the gradient is set to 1 in round_up_g. shouldn't the gradient be (y-x)?
+            # see eqn 6 in https://arxiv.org/pdf/1903.08066v2.pdf
+            # yq = layers.clamp_g(layers.round_up_g(xq * scale), width_min, width_max-1, self.training) * scale_inv
+            yq = layers.quantize_dequantize_g(xq, scale, width_min, width_max-1, self.power2, 'round_up')
+        else:
+            yq = super().forward(xq, update_range=False, enable=True)
+        #
 
-    def forward_adjust_bias(self, input):
-        self._restore_weights_quant()
-        self.add_call_hook(self.module, self._forward_adjust_bias_hook)
-        for _ in range(self.calibrate_repeats):
-            output = self.module(input)
+        if (self.quantized_estimation_type == QuantEstimationType.STRAIGHT_THROUGH_ESTIMATION):
+            # replace the float output with quantized version
+            # the entire weight merging and quantization process is bypassed in the forward pass
+            # however, the backward gradients flow through only the float path - this is called straight through estimation (STE)
+            with torch.no_grad():
+                y.data.copy_(yq.data)
+            #
+        elif self.training and (self.quantized_estimation_type == QuantEstimationType.ALPHA_BLENDING_ESTIMATION):
+            # TODO: vary the alpha blending factor over the epochs
+            y = y * (1.0-self.alpha_blending_estimation_factor) + yq * self.alpha_blending_estimation_factor
+        elif (self.quantized_estimation_type == QuantEstimationType.QUANTIZED_THROUGH_ESTIMATION):
+            # pass on the quantized output - the backward gradients also flow through quantization.
+            # however, note the gradients of round and ceil operators are forced to be unity (1.0).
+            y = yq
+        else:
+            assert False, f'unsupported value for quantized_estimation_type: {self.quantized_estimation_type}'
         #
-        self.remove_call_hook(self.module)
-        self._backup_weights_quant()
-        return output
+
+        return y
     #
-    def _forward_adjust_bias_hook(self, op, *inputs_orig):
-        outputs = op.__forward_orig__(*inputs_orig)
-        # calibration at specific layers
-        bias = op.bias if (hasattr(op, 'bias') and (op.bias is not None)) else None
-        if bias is not None:
-            output = outputs[0] if isinstance(outputs, (list,tuple)) else outputs
-            reduce_dims = [0,2,3] if len(output.size()) == 4 else ([0] if len(output.size()) == 2 else None)
-            output_mean = torch.mean(output, dim=reduce_dims)
-            output_delta = op.__output_mean_orig__ - output_mean
-            output_delta = output_delta * self.bias_calibration_factor
-            bias.data += (output_delta)
-            # # TODO: is this required?
-            # if len(output.size()) == 4:
-            #     output.data += output_delta.data.view(1,-1,1,1)
-            # elif len(output.size()) == 2:
-            #     output.data += output_delta.data.view(1,-1)
-            # else:
-            #     assert False, 'unknown dimensions'
-            # #
+
+
+    def apply_constrain_weights(self, merged_weight):
+        return quant_utils.constrain_weight(merged_weight)
+
+
+    def merge_quantize_weights(self, conv, bn):
+        # store the quantized weights and biases in-frequently - otherwise learning will be poor
+        # since this may not be done at the of the epoch, there can be a slight mismatch in validation accuracy
+        first_training_iter = self.training and (self.num_batches_tracked == 0)
+        is_store_weight_bias_iter = (self.params_save_frequency is not None) and (torch.remainder(self.num_batches_tracked, self.params_save_frequency) == 0)
+
+        # merge weight and bias (if possible) across layers
+        if conv is not None and bn is not None:
+            conv_bias = conv.bias if (conv.bias is not None) else torch.tensor(0.0).to(conv.weight.device)
+            #
+            bn_weight = bn.weight if (bn.weight is not None) else torch.tensor(0.0).to(bn.running_mean.device)
+            bn_bias = bn.bias if (bn.bias is not None) else torch.tensor(0.0).to(bn.running_mean.device)
+            #
+            merged_scale = bn_weight / torch.sqrt(bn.running_var + bn.eps)
+            merged_bias = (conv_bias - bn.running_mean) * merged_scale + bn_bias
+            merged_weight = conv.weight * merged_scale.view(-1, 1, 1, 1)
+            #
+            merged_scale_sign = merged_scale.sign()
+            merged_scale_sign = merged_scale_sign + (merged_scale_sign == 0) # make the 0s in sign to 1
+            merged_scale_eps = merged_scale.abs().clamp(min=bn.eps) * merged_scale_sign
+            merged_scale_inv = 1.0 / merged_scale_eps
+            #
+        elif conv is not None:
+            merged_weight = conv.weight
+            merged_bias = conv.bias if (conv.bias is not None) else torch.zeros(conv.out_channels).to(conv.weight.device)
+            merged_scale = torch.ones(conv.out_channels).to(conv.weight.device)
+            merged_scale_inv = torch.ones(conv.out_channels).to(conv.weight.device)
+        elif bn is not None:
+            merged_weight = bn.weight if (bn.weight is not None) else torch.zeros(bn.num_features).to(bn.running_mean.device)
+            merged_bias = bn.bias if (bn.bias is not None) else torch.zeros(bn.num_features).to(bn.running_mean.device)
+        else:
+            assert False, f'merge_quantize_weights(): both conv & bn layes cannot be None in a merged scenario - prease inspect the model carefully'
+            merged_weight = 0.0
+            merged_bias = 0.0
         #
 
-        # weight = op.weight if (hasattr(op, 'weight') and (op.weight is not None)) else None
-        # iter_threshold = (1.0/self.bias_calibration_factor)
-        # if self.calibrate_weights and (weight is not None) and (self.num_batches_tracked > iter_threshold):
-        #         eps = 1e-3
-        #         output = outputs[0] if isinstance(outputs, (list,tuple)) else outputs
-        #         reduce_dims = [0, 2, 3] if len(output.size()) == 4 else ([0] if len(output.size()) == 2 else None)
-        #         output_std = torch.std(output, dim=reduce_dims)
-        #         output_ratio = (op.__output_std_orig__ + eps) / (output_std + eps)
-        #         channels = output.size(1)
-        #         output_ratio = output_ratio.view(channels, 1, 1, 1) if len(weight.data.size()) > 1 else output_ratio
-        #         output_ratio = torch.pow(output_ratio, self.bias_calibration_factor)
-        #         output_ratio = torch.clamp(output_ratio, 1.0-self.bias_calibration_factor, 1.0+self.bias_calibration_factor)
-        #         weight.data *= output_ratio
-        #     #
-        # #
+        # quantize weight and bias
+        if (conv is not None):
+            if (self.quantize_enable and self.quantize_weights):
+                if self.constrain_weights and first_training_iter:
+                    with torch.no_grad():
+                        # clamp merged weights, invert the bn and copy to conv weight
+                        constrained_weight = self.apply_constrain_weights(merged_weight.data)
+                        merged_weight.data.copy_(constrained_weight.data)
+                        # store clipped weight after inverting bn
+                        conv.weight.data.copy_(merged_weight.data * merged_scale_inv.view(-1, 1, 1, 1))
+                    #
+                #
 
-        return outputs
+                is_dw = utils.is_dwconv(conv)
+                use_per_channel_q = (is_dw and self.per_channel_q is True) or (self.per_channel_q == 'all')
+                if use_per_channel_q:
+                    channels = int(merged_weight.size(0))
+                    scale2 = torch.zeros(channels,1,1,1).to(merged_weight.device)
+                    scale_inv2 = torch.zeros(channels,1,1,1).to(merged_weight.device)
+                    for chan_id in range(channels):
+                        clip_min, clip_max, scale2_value, scale_inv2_value = self.get_clips_scale_w(merged_weight[chan_id])
+                        scale2[chan_id,0,0,0] = scale2_value
+                        scale_inv2[chan_id,0,0,0] = scale_inv2_value
+                    #
+                else:
+                    clip_min, clip_max, scale2, scale_inv2 = self.get_clips_scale_w(merged_weight)
+                #
+                width_min, width_max = self.get_widths_w()
+                # merged_weight = layers.clamp_g(layers.round_sym_g(merged_weight * scale2), width_min, width_max-1, self.training) * scale_inv2
+                merged_weight = layers.quantize_dequantize_g(merged_weight, scale2, width_min, width_max-1, self.power2, 'round_sym')
+            #
+
+            if (self.quantize_enable and self.quantize_bias):
+                bias_width_min, bias_width_max = self.get_widths_bias()
+                bias_clip_min, bias_clip_max, bias_scale2, bias_scale_inv2 = self.get_clips_scale_bias(merged_bias)
+                # merged_bias = layers.clamp_g(layers.round_sym_g(merged_bias * bias_scale2), bias_width_min, bias_width_max-1, self.training) * bias_scale_inv2
+                merged_bias = layers.quantize_dequantize_g(merged_bias, bias_scale2, bias_width_min, bias_width_max - 1, self.power2, 'round_sym')
+            #
+
+            # invert the bn operation and store weights/bias
+            if self.training and is_store_weight_bias_iter:
+                with torch.no_grad():
+                    if self.quantize_enable and self.quantize_weights:
+                        conv.weight.data.copy_(merged_weight.data * merged_scale_inv.view(-1, 1, 1, 1))
+                    #
+                    if self.quantize_enable and self.quantize_bias:
+                        if conv.bias is not None:
+                            if bn is not None:
+                                conv_bias = (merged_bias - bn_bias) * merged_scale_inv.view(-1) + bn.running_mean
+                                conv.bias.data.copy_(conv_bias.data)
+                            else:
+                                conv.bias.data.copy_(merged_bias.data)
+                            #
+                        elif bn is not None and bn.bias is not None:
+                            bn_bias = merged_bias + bn.running_mean * merged_scale.view(-1)
+                            bn.bias.data.copy_(bn_bias.data)
+                        #
+                    #
+                #
+            #
+        #
+        return conv, merged_weight, merged_bias
+
+
+    def get_clips_w(self, tensor):
+        # find the clip values
+        w_min, w_max = utils.extrema_fast(tensor.data, percentile_range_shrink=self.range_shrink_weights)
+        clip_max = torch.max(torch.abs(w_min), torch.abs(w_max))
+        clip_max = torch.clamp(clip_max, min=self.eps)
+        # in range learning mode + training - this power2 is taken care in the quantize function
+        use_power2 = (self.power2 and (not (self.PACT2_RANGE_LEARN and self.training)))
+        clip_max2 = layers.functional.ceil2_g(clip_max) if use_power2 else clip_max
+        clip_min2 = -clip_max2
+        return (clip_min2, clip_max2)
+
+    # bias uses the same kind of clips
+    get_clips_bias = get_clips_w
+
+
+    def get_clips_scale_w(self, weight):
+        # convert to scale
+        clip_min, clip_max = self.get_clips_w(weight)
+        width_min, width_max = self.get_widths_w()
+        scale2 = (width_max / clip_max)
+        scale2 = torch.clamp(scale2, min=self.eps)
+        scale_inv2 = scale2.pow(-1.0)
+        return (clip_min, clip_max, scale2, scale_inv2)
+
+
+    # in reality, bias quantization will also depend on the activation scale
+    # this is not perfect - just a quick and dirty quantization for bias
+    def get_clips_scale_bias(self, bias):
+        # convert to scale
+        clip_min, clip_max = self.get_clips_bias(bias)
+        width_min, width_max = self.get_widths_bias()
+        scale2 = (width_max / clip_max)
+        scale2 = torch.clamp(scale2, min=self.eps)
+        scale_inv2 = scale2.pow(-1.0)
+        return (clip_min, clip_max, scale2, scale_inv2)
+
+
+    def get_widths_w(self):
+        # weights
+        bw = (self.bitwidth_weights - 1)
+        width_max = np.power(2.0, bw)
+        width_min = -width_max
+        # return
+        return (width_min, width_max)
+
+
+    def get_widths_bias(self):
+        # bias
+        bitwidth_bias = (2*self.bitwidth_activations)
+        bias_width_max = np.power(2.0, bitwidth_bias-1)
+        bias_width_min = -bias_width_max
+        # return
+        return (bias_width_min, bias_width_max)
+
+
+    # activation utility functions
+    def get_clips_scale_act(self):
+        # convert to scale
+        clip_min, clip_max = self.get_clips_act()
+        width_min, width_max = self.get_widths_act()
+        scale2 = width_max / clip_max
+        scale2 = torch.clamp(scale2, min=self.eps)
+        scale_inv2 = scale2.pow(-1.0)
+        return (clip_min, clip_max, scale2, scale_inv2)
+
+
+    def get_widths_act(self):
+        if self.signed is None:
+            clip_min, clip_max = self.get_clips_act()
+            signed = (clip_min < 0.0)
+        else:
+            signed = self.signed
+        #
+        bw = (self.bitwidth_activations - 1) if signed else self.bitwidth_activations
+        width_max = np.power(2.0, bw)
+        width_min = -width_max if signed else 0.0
+        return width_min, width_max
 
index ba9f80d616c026c799f50e1099ce95f0f1865897..9f914f3f2b9f0e7264b5462dda8fec2b61173e43 100755 (executable)
@@ -1,26 +1,31 @@
 # Quantization
 
 ## =====================================================================================
-## Trained Quantization
+## Quantization Aware Training
 ## =====================================================================================
 #
-#### Image Classification - Trained Quantization - MobileNetV2
+#### Image Classification - Quantization Aware Training - MobileNetV2
 #python ./scripts/train_classification_main.py --dataset_name image_folder_classification --model_name mobilenetv2_tv_x1 --data_path ./data/datasets/image_folder_classification \
 #--pretrained https://download.pytorch.org/models/mobilenet_v2-b0353104.pth \
 #--batch_size 64 --quantize True --epochs 25 --epoch_size 1000 --lr 1e-5 --evaluate_start False
 #
 #
-#### Image Classification - Trained Quantization - MobileNetV2(Shicai) - a TOUGH MobileNetV2 pretrained model
+#### Image Classification - Quantization Aware Training - MobileNetV2(Shicai) - a TOUGH MobileNetV2 pretrained model
 #python ./scripts/train_classification_main.py --dataset_name image_folder_classification --model_name mobilenetv2_shicai_x1 --data_path ./data/datasets/image_folder_classification \
 #--pretrained ./data/modelzoo/experimental/pytorch/others/shicai/MobileNet-Caffe/mobilenetv2_shicai_rgb.tar \
 #--batch_size 64 --quantize True --epochs 25 --epoch_size 1000 --lr 1e-5 --evaluate_start False
 #
 #
-#### Semantic Segmentation - Trained Quantization for MobileNetV2+DeeplabV3Lite
+#### Semantic Segmentation - Quantization Aware Training for MobileNetV2+DeeplabV3Lite
 #python ./scripts/train_segmentation_main.py --dataset_name cityscapes_segmentation --model_name deeplabv3lite_mobilenetv2_tv --data_path ./data/datasets/cityscapes/data --img_resize 384 768 --output_size 1024 2048 --gpus 0 1 \
-#--pretrained ./data/modelzoo/pytorch/semantic_segmentation/cityscapes/jacinto_ai/deeplabv3lite_mobilenetv2_tv_resize768x384_best.pth.tar \
-#--batch_size 12 --quantize True --epochs 150 --lr 1e-5 --evaluate_start False
-
+#--pretrained ./data/modelzoo/pytorch/semantic_seg/cityscapes/jacinto_ai/deeplabv3lite_mobilenetv2_tv_768x384_best.pth.tar \
+#--batch_size 6 --quantize True --epochs 150 --lr 1e-5 --evaluate_start False
+#
+#
+#### Semantic Segmentation - Quantization Aware Training for MobileNetV2+UNetLite
+#python ./scripts/train_segmentation_main.py --dataset_name cityscapes_segmentation --model_name unetlite_pixel2pixel_aspp_mobilenetv2_tv --data_path ./data/datasets/cityscapes/data --img_resize 384 768 --output_size 1024 2048 --gpus 0 1 \
+#--pretrained ./data/modelzoo/pytorch/semantic_seg/cityscapes/jacinto_ai/unet_aspp_mobilenetv2_tv_768x384_best.pth.tar \
+#--batch_size 6 --quantize True --epochs 150 --lr 1e-5 --evaluate_start False
 
 ## =====================================================================================
 ## Acuracy Evaluation with Post Training Quantization - cannot save quantized model - only accuracy evaluation
 
 #### Image Classification - Accuracy Estimation with Post Training Quantization - A TOUGH MobileNetV2 pretrained model
 #python ./scripts/train_classification_main.py --phase validation --dataset_name image_folder_classification --model_name mobilenetv2_shicai_x1 --data_path ./data/datasets/image_folder_classification \
-#--pretrained ./data/modelzoo/experimental/pytorch/others/shicai/MobileNet-Caffe/mobilenetv2_shicai_rgb.tar \
+#--pretrained ./data/modelzoo/pytorch/semantic_seg/cityscapes/jacinto_ai/deeplabv3lite_mobilenetv2_tv_768x384_best.pth.tar \
 #--batch_size 64 --quantize True
 
-#### Semantic Segmentation - Accuracy Estimation with Post Training Quantization
+#### Semantic Segmentation - Accuracy Estimation with Post Training Quantization - MobileNetV2+DeeplabV3Lite
 #python ./scripts/train_segmentation_main.py --phase validation --dataset_name cityscapes_segmentation --model_name deeplabv3lite_mobilenetv2_tv --data_path ./data/datasets/cityscapes/data --img_resize 384 768 --output_size 1024 2048 --gpus 0 1 \
-#--pretrained './data/modelzoo/pytorch/semantic_segmentation/cityscapes/jacinto_ai/deeplabv3lite_mobilenetv2_tv_resize768x384_best.pth.tar' \
+#--pretrained './data/modelzoo/pytorch/semantic_seg/cityscapes/jacinto_ai/deeplabv3lite_mobilenetv2_tv_768x384_best.pth.tar' \
+#--batch_size 1 --quantize True
+
+#### Semantic Segmentation - Accuracy Estimation with Post Training Quantization - MobileNetV2+UNetLite
+#python ./scripts/train_segmentation_main.py --phase validation --dataset_name cityscapes_segmentation --model_name unetlite_pixel2pixel_aspp_mobilenetv2_tv --data_path ./data/datasets/cityscapes/data --img_resize 384 768 --output_size 1024 2048 --gpus 0 1 \
+#--pretrained ./data/modelzoo/pytorch/semantic_seg/cityscapes/jacinto_ai/unet_aspp_mobilenetv2_tv_768x384_best.pth.tar \
 #--batch_size 1 --quantize True
 
 
 #### Image Classification - Post Training Calibration & Quantization - ResNet50
 #python ./scripts/train_classification_main.py --phase calibration --dataset_name image_folder_classification --model_name resnet50_x1 --data_path ./data/datasets/image_folder_classification \
 #--pretrained https://download.pytorch.org/models/resnet50-19c8e357.pth \
-#--batch_size 64 --quantize True --epochs 1 --epoch_size 100
+#--batch_size 64 --quantize True --epochs 1 --epoch_size 100 --evaluate_start False
 #
 #
 #### Image Classification - Post Training Calibration & Quantization - MobileNetV2
 #python ./scripts/train_classification_main.py --phase calibration --dataset_name image_folder_classification --model_name mobilenetv2_tv_x1 --data_path ./data/datasets/image_folder_classification \
 #--pretrained https://download.pytorch.org/models/mobilenet_v2-b0353104.pth \
-#--batch_size 64 --quantize True --epochs 1 --epoch_size 100
+#--batch_size 64 --quantize True --epochs 1 --epoch_size 100 --evaluate_start False
 #
 #
 #### Image Classification - Post Training Calibration & Quantization for a TOUGH MobileNetV2 pretrained model
 #python ./scripts/train_classification_main.py --phase calibration --dataset_name image_folder_classification --model_name mobilenetv2_shicai_x1 --data_path ./data/datasets/image_folder_classification \
-#--pretrained ./data/modelzoo/experimental/pytorch/others/shicai/MobileNet-Caffe/mobilenetv2_shicai_rgb.tar \
-#--batch_size 64 --quantize True --epochs 1 --epoch_size 100
+#--pretrained ./data/modelzoo/pytorch/image_classification/imagenet1k/shicai/mobilenetv2_shicai_rgb.pth.tar \
+#--batch_size 64 --quantize True --epochs 1 --epoch_size 100 --evaluate_start False
 #
 #
-#### Semantic Segmentation - Post Training Calibration &  Quantization for MobileNetV2+DeeplabV3Lite
+### Semantic Segmentation - Post Training Calibration &  Quantization for MobileNetV2+DeeplabV3Lite
 #python ./scripts/train_segmentation_main.py --phase calibration --dataset_name cityscapes_segmentation --model_name deeplabv3lite_mobilenetv2_tv --data_path ./data/datasets/cityscapes/data --img_resize 384 768 --output_size 1024 2048 --gpus 0 1 \
-#--pretrained ./data/modelzoo/pytorch/semantic_segmentation/cityscapes/jacinto_ai/deeplabv3lite_mobilenetv2_tv_resize768x384_best.pth.tar \
-#--batch_size 12 --quantize True --epochs 1 --epoch_size 100
-
-
+#--pretrained ./data/modelzoo/pytorch/semantic_seg/cityscapes/jacinto_ai/deeplabv3lite_mobilenetv2_tv_768x384_best.pth.tar \
+#--batch_size 12 --quantize True --epochs 1 --epoch_size 100 --evaluate_start False
+#
+#
+### Semantic Segmentation - Post Training Calibration &  Quantization for MobileNetV2+UNetLite
+#python ./scripts/train_segmentation_main.py --phase calibration --dataset_name cityscapes_segmentation --model_name unetlite_pixel2pixel_aspp_mobilenetv2_tv --data_path ./data/datasets/cityscapes/data --img_resize 384 768 --output_size 1024 2048 --gpus 0 1 \
+#--pretrained ./data/modelzoo/pytorch/semantic_seg/cityscapes/jacinto_ai/unet_aspp_mobilenetv2_tv_768x384_best.pth.tar \
+#--batch_size 12 --quantize True --epochs 1 --epoch_size 100 --evaluate_start False
\ No newline at end of file