calibration - fix bugs there were introduced during the recent restructure
authorManu Mathew <a0393608@ti.com>
Tue, 12 May 2020 07:51:18 +0000 (13:21 +0530)
committerManu Mathew <a0393608@ti.com>
Wed, 13 May 2020 03:00:29 +0000 (08:30 +0530)
modules/pytorch_jacinto_ai/xnn/quantize/quant_calib_module.py
modules/pytorch_jacinto_ai/xnn/quantize/quant_train_module.py
run_quantization.sh

index 00227a0918a9b2c8ca06b1b2e51234503866ab1e..b56cd41638d94f43e58561da3c64b4e469b555d6 100644 (file)
@@ -19,7 +19,6 @@ warnings.filterwarnings('ignore', category=torch.jit.TracerWarning)
 class QuantCalibrateModule(QuantTrainModule):
     def __init__(self, module, dummy_input, bitwidth_weights=8, bitwidth_activations=8, per_channel_q=False,
                  histogram_range=True, bias_calibration=True, constrain_weights=True, lr_calib=0.05):
-        self.bias_calibration = bias_calibration
         self.weights_calibration = False
         self.lr_calib = lr_calib
         self.calibration_factor = lr_calib
@@ -27,12 +26,22 @@ class QuantCalibrateModule(QuantTrainModule):
         self.calibrate_repeats = 1
         self.quantize_enable = True
         self.update_range = True
+        constrain_weights = (constrain_weights and bias_calibration)
         super().__init__(module, dummy_input=dummy_input, bitwidth_weights=bitwidth_weights, bitwidth_activations=bitwidth_activations,
                          per_channel_q=per_channel_q, histogram_range=histogram_range, bias_calibration=bias_calibration,
                          constrain_weights=constrain_weights)
 
 
     def forward(self, inputs):
+        # backup the current state
+        training = self.training
+
+        # Set all bns to eval so that they do not change. We can't set the whole model to eval because,
+        # we need the pact to learn the ranges - which will happen only in training mode.
+        # Also the model output itself may be different in eval mode (in certain cases -
+        # for example if in a segmentation model argmax is done instead of softmax in eval mode).
+        utils.freeze_bn(self)
+
         # since this does not involve training, we can set merge_weights=True
         self.analyze_graph(inputs=inputs, cleanup_states=True, merge_weights=True)
 
@@ -43,6 +52,8 @@ class QuantCalibrateModule(QuantTrainModule):
         else:
             outputs = self.module(inputs)
         #
+
+        self.train(training)
         return outputs
 
 
@@ -58,30 +69,12 @@ class QuantCalibrateModule(QuantTrainModule):
             self._backup_weights_quant()
         #
 
-        # backup the current state
-        training = self.training
-
-        # Set all bns to eval so that they do not change. We can't set the whole model to eval because,
-        # we need the pact to learn the ranges - which will happen only in training mode.
-        # Also the model output itself may be different in eval mode (in certain cases -
-        # for example if in a segmentation model argmax is done instead of softmax in eval mode).
-        utils.freeze_bn(self)
-
-        # Compute the mean output in float first.
         with torch.no_grad():
+            # Compute the mean output in float first.
             outputs = self.forward_float(inputs)
-        #
-
-        # Then adjust weights/bias so that the quantized output matches float output
-        if self.weights_calibration:
+            # Then adjust weights/bias so that the quantized output matches float output
             outputs = self.forward_quantized(inputs)
-        else:
-            with torch.no_grad():
-                outputs = self.forward_quantized(inputs)
-            #
         #
-
-        self.train(training)
         return outputs
 
 
index 85319f0212dde5d0d3b925921eb2982764db5aa9..ca4a01220be3b740572902c906ca13269855b3b6 100644 (file)
@@ -26,8 +26,8 @@ class QuantTrainModule(QuantBaseModule):
         # range shrink - 0.0 indicates no shrink
         percentile_range_shrink = (layers.PAct2.PACT2_RANGE_SHRINK if histogram_range else 0.0)
         # set attributes to all modules - can control the behaviour from here
-        utils.apply_setattr(self, bitwidth_weights=bitwidth_weights, bitwidth_activations=bitwidth_activations,
-                            per_channel_q=per_channel_q, bias_calibration=bias_calibration,
+        utils.apply_setattr(self, bitwidth_weights=self.bitwidth_weights, bitwidth_activations=self.bitwidth_activations,
+                            per_channel_q=self.per_channel_q, bias_calibration=self.bias_calibration,
                             percentile_range_shrink=percentile_range_shrink, constrain_weights=self.constrain_weights,
                             update_range=True, quantize_enable=True, quantize_weights=True, quantize_bias=True, quantize_activations=True)
 
@@ -369,7 +369,7 @@ class QuantTrainPAct2(layers.PAct2):
                         # clamp merged weights, invert the bn and copy to conv weight
                         constrained_weight = self.apply_constrain_weights(merged_weight.data)
                         merged_weight.data.copy_(constrained_weight.data)
-                        # store clipped weight after inverting bn
+                        # store clipped weight after inverting bn - not really needed as there is a saving below as well
                         conv.weight.data.copy_(merged_weight.data * merged_scale_inv.view(-1, 1, 1, 1))
                     #
                 #
@@ -401,7 +401,7 @@ class QuantTrainPAct2(layers.PAct2):
             #
 
             # invert the bn operation and store weights/bias
-            if self.training and is_store_weight_bias_iter:
+            if first_training_iter or (self.training and is_store_weight_bias_iter):
                 with torch.no_grad():
                     if self.quantize_enable and self.quantize_weights:
                         conv.weight.data.copy_(merged_weight.data * merged_scale_inv.view(-1, 1, 1, 1))
index 96876974a016f4d7ca25d83554e1391020a7387f..c36767a6c9a182b38496e6634027b1e27a3281f4 100755 (executable)
@@ -82,7 +82,7 @@
 ### Semantic Segmentation - Post Training Calibration &  Quantization for MobileNetV2+DeeplabV3Lite
 #python ./scripts/train_segmentation_main.py --phase calibration --dataset_name cityscapes_segmentation --model_name deeplabv3lite_mobilenetv2_tv --data_path ./data/datasets/cityscapes/data --img_resize 384 768 --output_size 1024 2048 --gpus 0 1 \
 #--pretrained ./data/modelzoo/pytorch/semantic_seg/cityscapes/jacinto_ai/deeplabv3lite_mobilenetv2_tv_768x384_best.pth \
-#--batch_size 12 --quantize True --epochs 1 --epoch_size 100 --evaluate_start False
+#--batch_size 6 --quantize True --epochs 1 --epoch_size 100 --evaluate_start False
 #
 #
 ### Semantic Segmentation - Post Training Calibration &  Quantization for MobileNetV2+UNetLite