quantization - fix for ConvTranspose2d and BatchNorm merge. Do not merge weights...
authorManu Mathew <a0393608@ti.com>
Tue, 12 May 2020 16:03:18 +0000 (21:33 +0530)
committerManu Mathew <a0393608@ti.com>
Wed, 13 May 2020 03:00:58 +0000 (08:30 +0530)
modules/pytorch_jacinto_ai/xnn/quantize/quant_calib_module.py
modules/pytorch_jacinto_ai/xnn/quantize/quant_graph_module.py
modules/pytorch_jacinto_ai/xnn/quantize/quant_test_module.py
modules/pytorch_jacinto_ai/xnn/quantize/quant_train_module.py

index b56cd41638d94f43e58561da3c64b4e469b555d6..dba9fa91ea19a6d4bdc6d704b6f8c4a21ec257d2 100644 (file)
@@ -33,27 +33,30 @@ class QuantCalibrateModule(QuantTrainModule):
 
 
     def forward(self, inputs):
-        # backup the current state
-        training = self.training
-
-        # Set all bns to eval so that they do not change. We can't set the whole model to eval because,
-        # we need the pact to learn the ranges - which will happen only in training mode.
-        # Also the model output itself may be different in eval mode (in certain cases -
-        # for example if in a segmentation model argmax is done instead of softmax in eval mode).
-        utils.freeze_bn(self)
-
-        # since this does not involve training, we can set merge_weights=True
-        self.analyze_graph(inputs=inputs, cleanup_states=True, merge_weights=True)
-
-        # actual forward call
-        if self.training and (self.bias_calibration or self.weights_calibration):
-            # calibration
-            outputs = self.forward_calibrate(inputs)
-        else:
-            outputs = self.module(inputs)
+        # calibration doesn't need gradients
+        with torch.no_grad():
+            # backup the current state
+            training = self.training
+
+            # Set all bns to eval so that they do not change. We can't set the whole model to eval because,
+            # we need the pact to learn the ranges - which will happen only in training mode.
+            # Also the model output itself may be different in eval mode (in certain cases -
+            # for example if in a segmentation model argmax is done instead of softmax in eval mode).
+            utils.freeze_bn(self)
+
+            # counters such as num_batches_tracked are used. update them.
+            self.update_counters()
+
+            # actual forward call
+            if self.training and (self.bias_calibration or self.weights_calibration):
+                # calibration
+                outputs = self.forward_calibrate(inputs)
+            else:
+                outputs = self.module(inputs)
+            #
+
+            self.train(training)
         #
-
-        self.train(training)
         return outputs
 
 
@@ -69,12 +72,11 @@ class QuantCalibrateModule(QuantTrainModule):
             self._backup_weights_quant()
         #
 
-        with torch.no_grad():
-            # Compute the mean output in float first.
-            outputs = self.forward_float(inputs)
-            # Then adjust weights/bias so that the quantized output matches float output
-            outputs = self.forward_quantized(inputs)
-        #
+        # Compute the mean output in float first.
+        outputs = self.forward_float(inputs)
+        # Then adjust weights/bias so that the quantized output matches float output
+        outputs = self.forward_quantized(inputs)
+
         return outputs
 
 
index d4f5f7d9d89d651e5197599cb77f7933b028d682..39c6482841f761a9263693bc6d6746ae590c183a 100644 (file)
@@ -89,18 +89,23 @@ class QuantGraphModule(HookedModule):
     def forward(self, inputs):
         assert False, 'forward is not defined'
 
+
+    def update_counters(self, force_update=False):
+        if self.training or force_update:
+            self.num_batches_tracked += 1
+            if self.num_batches_tracked == 0:
+                self.epoch += 1.0
+            #
+        #
+        self.iter_in_epoch += 1
+    #
+
     # force_update is used to increment inte counters even in non training
     # used for validation in QuantTestModule
     def analyze_graph(self, inputs, force_update=False, merge_weights=False, cleanup_states=False):
         with torch.no_grad():
             self.init_states()
-            if self.training or force_update:
-                self.num_batches_tracked += 1
-                if self.num_batches_tracked == 0:
-                    self.epoch += 1.0
-                #
-            #
-            self.iter_in_epoch += 1
+            self.update_counters(force_update=force_update)
             if (self.get_state().analyzed_graph == False):
                 # forward and analyze
                 self.forward_analyze_modules(inputs)
@@ -332,8 +337,15 @@ class QuantGraphModule(HookedModule):
 
             # merged weight and offset
             merged_scale = torch.rsqrt(bn.running_var.data + bn.eps) * bn_weight
-            merged_weight = conv.weight.data * merged_scale.view(-1, 1, 1, 1)
-            merged_bias = (conv_bias - bn.running_mean.data) * merged_scale + bn_bias
+            if utils.is_conv(conv):
+                merged_scale = merged_scale.view(-1, 1, 1, 1)
+            elif utils.is_deconv(conv):
+                merged_scale = merged_scale.view(1, -1, 1, 1)
+            else:
+                assert False, 'unable to merge convolution and BN'
+            #
+            merged_weight = conv.weight.data * merged_scale
+            merged_bias = (conv_bias - bn.running_mean.data) * merged_scale.view(-1) + bn_bias
 
             # bn is set to unity
             bn.running_mean.data.fill_(0.0)
index 557a504dd90350a161b548d1f2d1bdfab9f8ab73..55f2301e9d720b41bc69ea8859442107d60e22a4 100644 (file)
@@ -71,7 +71,7 @@ class QuantTestModule(QuantBaseModule):
 
 
     def forward(self, inputs):
-        # analyze
+        # analyze - need to merge_weights - so call analyze_graph() instead of just update_counters()
         self.analyze_graph(inputs=inputs, force_update=True, merge_weights=True, cleanup_states=True)
 
         # batch_size = inputs[0].size(0) if utils.is_list(inputs) else inputs.size(0)
index ca4a01220be3b740572902c906ca13269855b3b6..f940fdf785575338f60015aaeba953d6300bb608 100644 (file)
@@ -32,10 +32,8 @@ class QuantTrainModule(QuantBaseModule):
                             update_range=True, quantize_enable=True, quantize_weights=True, quantize_bias=True, quantize_activations=True)
 
     def forward(self, inputs):
-        # analyze
-        # since this involves training, we cannot set merge_weights=True
-        # merging weights modifies bn and training after that unstable.
-        self.analyze_graph(inputs=inputs, cleanup_states=True)
+        # counters such as num_batches_tracked are used. update them.
+        self.update_counters()
         # outputs
         outputs = self.module(inputs)
         return outputs
@@ -339,8 +337,15 @@ class QuantTrainPAct2(layers.PAct2):
             bn_bias = bn.bias if (bn.bias is not None) else torch.tensor(0.0).to(bn.running_mean.device)
             #
             merged_scale = bn_weight / torch.sqrt(bn.running_var + bn.eps)
-            merged_bias = (conv_bias - bn.running_mean) * merged_scale + bn_bias
-            merged_weight = conv.weight * merged_scale.view(-1, 1, 1, 1)
+            if utils.is_conv(conv):
+                merged_scale = merged_scale.view(-1, 1, 1, 1)
+            elif utils.is_deconv(conv):
+                merged_scale = merged_scale.view(1, -1, 1, 1)
+            else:
+                assert False, 'unable to merge convolution and BN'
+            #
+            merged_bias = (conv_bias - bn.running_mean) * merged_scale.view(-1) + bn_bias
+            merged_weight = conv.weight * merged_scale
             #
             merged_scale_sign = merged_scale.sign()
             merged_scale_sign = merged_scale_sign + (merged_scale_sign == 0) # make the 0s in sign to 1
@@ -350,8 +355,8 @@ class QuantTrainPAct2(layers.PAct2):
         elif conv is not None:
             merged_weight = conv.weight
             merged_bias = conv.bias if (conv.bias is not None) else torch.zeros(conv.out_channels).to(conv.weight.device)
-            merged_scale = torch.ones(conv.out_channels).to(conv.weight.device)
-            merged_scale_inv = torch.ones(conv.out_channels).to(conv.weight.device)
+            merged_scale = 1.0
+            merged_scale_inv = 1.0
         elif bn is not None:
             merged_weight = bn.weight if (bn.weight is not None) else torch.zeros(bn.num_features).to(bn.running_mean.device)
             merged_bias = bn.bias if (bn.bias is not None) else torch.zeros(bn.num_features).to(bn.running_mean.device)
@@ -370,7 +375,7 @@ class QuantTrainPAct2(layers.PAct2):
                         constrained_weight = self.apply_constrain_weights(merged_weight.data)
                         merged_weight.data.copy_(constrained_weight.data)
                         # store clipped weight after inverting bn - not really needed as there is a saving below as well
-                        conv.weight.data.copy_(merged_weight.data * merged_scale_inv.view(-1, 1, 1, 1))
+                        # conv.weight.data.copy_(merged_weight.data * merged_scale_inv)
                     #
                 #
 
@@ -404,7 +409,7 @@ class QuantTrainPAct2(layers.PAct2):
             if first_training_iter or (self.training and is_store_weight_bias_iter):
                 with torch.no_grad():
                     if self.quantize_enable and self.quantize_weights:
-                        conv.weight.data.copy_(merged_weight.data * merged_scale_inv.view(-1, 1, 1, 1))
+                        conv.weight.data.copy_(merged_weight.data * merged_scale_inv)
                     #
                     if self.quantize_enable and self.quantize_bias:
                         if conv.bias is not None: