default of model_surgery_quantize is now True for QuantTestModule.
authorManu Mathew <a0393608@ti.com>
Sat, 23 May 2020 12:52:30 +0000 (18:22 +0530)
committerManu Mathew <a0393608@ti.com>
Sat, 23 May 2020 12:58:11 +0000 (18:28 +0530)
model_surgery_quantize must be True if the pretrained is a QAT or Calib module.
To test accuracy for a purely float model, set this flag to zero.

docs/Quantization.md
modules/pytorch_jacinto_ai/vision/models/classification/__init__.py
modules/pytorch_jacinto_ai/vision/models/resnet.py
modules/pytorch_jacinto_ai/xnn/quantize/quant_base_module.py
modules/pytorch_jacinto_ai/xnn/quantize/quant_calib_module.py
modules/pytorch_jacinto_ai/xnn/quantize/quant_graph_module.py
modules/pytorch_jacinto_ai/xnn/quantize/quant_test_module.py
modules/pytorch_jacinto_ai/xnn/quantize/quant_utils.py
run_quantization.sh

index 5ae553150c33e181505131148c9f6c14e8a41be2..30350506dabac9eb41f9a30dfe22ceb95fb7d415 100644 (file)
@@ -31,7 +31,7 @@ To get best accuracy at the quantization stage, it is important that the model i
 - However, if a function does not change the range of feature map, it is not critical to use it in Module form. An example of this is torch.nn.functional.interpolate<br>
 - **Multi-GPU training/calibration/validation with DataParallel is supported with our QAT module** QuantTrainModule. This takes care of a major concern that was earlier there in doing QAT with QuantTrainModule. (However it is not supported for QuantCalibrateModule/QuantTestModule - these calibration/test phases take much less time - so hopefully this is not a big issue. In our example training scripts train_classification.py and train_pixel2pixel.py in pytorch_jacinto_ai/engine, we do not wrap the model in DataParallel if the model is QuantCalibrateModule or QuantTestModule, but we do that for QuantTrainModule).<br>
 - If your training/calibration crashes because of insufficient GPU memory, reduce the batch size and try again.
-- This repository has several useful functions and Modules as part of the xnn python module. Most notable ones are: [xnn.layers.resize_with, xnn.layers.ResizeWith](../modules/pytorch_jacinto_ai/xnn/resize_blocks.py) to export a clean resize/interpolate/upsamle graph, [xnn.layers.AddBlock, xnn.layers.CatBlock](../modules/pytorch_jacinto_ai/xnn/common_blocks.py) to do elementwise addition & concatenation in a torch.nn.Module form.
+- This repository has several useful functions and Modules as part of the xnn python module. Most notable ones are: [xnn.layers.resize_with, xnn.layers.ResizeWith](../modules/pytorch_jacinto_ai/xnn/layers/resize_blocks.py) to export a clean resize/interpolate/upsamle graph, [xnn.layers.AddBlock, xnn.layers.CatBlock](../modules/pytorch_jacinto_ai/xnn/layers/common_blocks.py) to do elementwise addition & concatenation in a torch.nn.Module form.
 - If you are using TIDL to infer a model trained using QAT (or calibratied using PTQ) tools provided in this repository, please set the following in the import config file for best accuracy: **quantizationStyle = 3** to use power of 2 quantization. **foldPreBnConv2D = 0** to avoid a slight accuracy degradation due to incorrect folding of BatchNormalization that comes before Convolution (input mean/scale is implemented in TIDL as a PreBN - so this affects most networks).
 
 ## Post Training Calibration For Quantization (PTQ a.k.a. Calibration)
index 690def51a4cf023dd4e4b3476441a2daf2ed7c36..8331aa19b7443420c0cbbf3ce060e68643704f9b 100644 (file)
@@ -21,7 +21,7 @@ except: pass
 from .... import xnn
 
 __all__ = ['mobilenetv1_x1', 'mobilenetv2_tv_x1', 'mobilenetv2_x1', 'mobilenetv2_tv_x2_t2',
-           'resnet50_x1', 'resnet50_xp5',
+           'resnet50_x1', 'resnet50_xp5', 'resnet18_x1',
            # experimental
            'mobilenetv2_ericsun_x1', 'mobilenetv2_shicai_x1',
            'mobilenetv2_tv_gws_x1', 'flownetslite_base_x1', 'mobilenetv1_multi_label_x1']
@@ -31,7 +31,6 @@ __all__ = ['mobilenetv1_x1', 'mobilenetv2_tv_x1', 'mobilenetv2_x1', 'mobilenetv2
 def resnet50_x1(model_config, pretrained=None):
     model_config = resnet.get_config().merge_from(model_config)
     model = resnet.resnet50_with_model_config(model_config)
-
     # the pretrained model provided by torchvision and what is defined here differs slightly
     # note: that this change_names_dict  will take effect only if the direct load fails
     change_names_dict = {'^conv1.': 'features.conv1.', '^bn1.': 'features.bn1.',
@@ -47,6 +46,20 @@ def resnet50_xp5(model_config, pretrained=None):
     return resnet50_x1(model_config=model_config, pretrained=pretrained)
 
 
+#####################################################################
+def resnet18_x1(model_config, pretrained=None):
+    model_config = resnet.get_config().merge_from(model_config)
+    model = resnet.resnet18_with_model_config(model_config)
+    # the pretrained model provided by torchvision and what is defined here differs slightly
+    # note: that this change_names_dict  will take effect only if the direct load fails
+    change_names_dict = {'^conv1.': 'features.conv1.', '^bn1.': 'features.bn1.',
+                         '^relu.': 'features.relu.', '^maxpool.': 'features.maxpool.',
+                         '^layer': 'features.layer', '^fc.': 'classifier.'}
+    if pretrained:
+        model = xnn.utils.load_weights(model, pretrained, change_names_dict=change_names_dict)
+    return model, change_names_dict
+
+
 #####################################################################
 def mobilenetv1_x1(model_config, pretrained=None):
     model_config = mobilenetv1.get_config().merge_from(model_config)
index 333789f66860dd3e292f59c199e56368fbc9e1df..3a523741c951198ee815ca2975ee473e5c27014c 100644 (file)
@@ -8,7 +8,7 @@ from ... import xnn
 __all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101',
            'resnet152', 'resnext50_32x4d', 'resnext101_32x8d',
            'wide_resnet50_2', 'wide_resnet101_2',
-           'resnet50_with_model_config']
+           'resnet50_with_model_config', 'resnet18_with_model_config']
 
 
 model_urls = {
@@ -391,3 +391,11 @@ def resnet50_with_model_config(model_config, pretrained=None):
                      width_mult=model_config.width_mult, fastdown=model_config.fastdown)
     return model
 
+
+def resnet18_with_model_config(model_config, pretrained=None):
+    model_config = get_config().merge_from(model_config)
+    model = resnet18(input_channels=model_config.input_channels, strides=model_config.strides,
+                     num_classes=model_config.num_classes, pretrained=pretrained,
+                     width_mult=model_config.width_mult, fastdown=model_config.fastdown)
+    return model
+
index c53c3e4a7e61cc88f0f13c5d05ebcca29ebc3049..30a8f91d58116c0670b30d975a2968cd567948d4 100644 (file)
@@ -52,7 +52,7 @@ class QuantBaseModule(QuantGraphModule):
 
 
     def train(self, mode=True):
-        self.iter_in_epoch.fill_(-1.0)
+        self.iter_in_epoch = -1
         super().train(mode)
 
 
index c7f8cfb369db636bf1a95c729ad4f84caf8a60e1..266b0dabc0a3aed066194f9cf84f70a8899b457b 100644 (file)
@@ -35,6 +35,9 @@ class QuantCalibrateModule(QuantTrainModule):
     def forward(self, inputs):
         # calibration doesn't need gradients
         with torch.no_grad():
+            # counters such as num_batches_tracked are used. update them.
+            self.update_counters()
+
             # backup the current state
             training = self.training
 
@@ -44,9 +47,6 @@ class QuantCalibrateModule(QuantTrainModule):
             # for example if in a segmentation model argmax is done instead of softmax in eval mode).
             utils.freeze_bn(self)
 
-            # counters such as num_batches_tracked are used. update them.
-            self.update_counters()
-
             # actual forward call
             if self.training and (self.bias_calibration or self.weights_calibration):
                 # calibration
index 3f40d38e49dfe90ebafe1df7bfb0cb7cf9425ad7..1086cb559fa50063a7d95bd99e3adff9a9c95d4c 100644 (file)
@@ -11,9 +11,9 @@ class QuantGraphModule(HookedModule):
         super().__init__()
         self.module = module
         self.init_qstate()
-        self.register_buffer('num_batches_tracked', torch.tensor(-1.0))
-        self.register_buffer('iter_in_epoch', torch.tensor(-1.0))
-        self.register_buffer('epoch', torch.tensor(-1.0))
+        self.num_batches_tracked = -1
+        self.iter_in_epoch = -1
+        self.epoch = -1
 
         # TBD: is this required
         # # if the original module has load_weights, add it to the quant module also
@@ -54,13 +54,13 @@ class QuantGraphModule(HookedModule):
 
 
     def update_counters(self, force_update=False):
+        self.iter_in_epoch += 1
         if self.training or force_update:
             self.num_batches_tracked += 1
-            if self.num_batches_tracked == 0:
+            if self.iter_in_epoch == 0:
                 self.epoch += 1.0
             #
         #
-        self.iter_in_epoch += 1
     #
 
     # force_update is used to increment inte counters even in non training
@@ -134,7 +134,7 @@ class QuantGraphModule(HookedModule):
 
 
     def train(self, mode=True):
-        self.iter_in_epoch.fill_(-1.0)
+        self.iter_in_epoch = -1
         super().train(mode)
 
 
index 64bce323f6a8635059ee5aac1beb572a2ad5fb55..f0d47898ac97938f398e08a032bc23ca0dc7d092 100644 (file)
@@ -10,7 +10,7 @@ from .quant_utils import *
 
 class QuantTestModule(QuantBaseModule):
     def __init__(self, module, dummy_input, bitwidth_weights=8, bitwidth_activations=8, per_channel_q=False, histogram_range=True,
-                 range_calibration_online=False, model_surgery_quantize=False):
+                 range_calibration_online=False, model_surgery_quantize=True):
         super().__init__(module, dummy_input=dummy_input, bitwidth_weights=bitwidth_weights, bitwidth_activations=bitwidth_activations,
                          per_channel_q=per_channel_q, histogram_range=histogram_range, bias_calibration=False,
                          constrain_weights=False, model_surgery_quantize=model_surgery_quantize)
index a2c437e8e5ca9af63a4b1d58addc6667c34c41f0..f76f86258fda044cc7f2756ebbce41d06ba62db0 100644 (file)
@@ -50,23 +50,27 @@ def compute_tensor_scale(tensor, mn, mx, bitwidth, power2_scaling):
     return tensor_scale, clamp_limits
 
 
-# constants --------------------
-weights_clamp_value = 15.0 #31
-weights_clamp_ratio = 16.0 #32
-# -----------------------------
-
-def clamp_weight_simple(merged_weight):
+def clamp_weight_simple(merged_weight, clamp_ratio, clamp_value):
     # a simple clamp - this may not be suitable for all models
-    clamped_weight = merged_weight.clamp(-weights_clamp_value, weights_clamp_value)
+    clamped_weight = merged_weight.clamp(-clamp_value, clamp_value)
     return clamped_weight
 
 
-def clamp_weight_ratio(merged_weight):
+def clamp_weight_soft(weight, clamp_ratio, clamp_value):
+    weight_max = weight.abs().max()
+    weight_median = weight.abs().median()
+    if (weight_max > clamp_value) and (weight_max > (weight_median*clamp_ratio)):
+        weight = torch.tanh(weight/clamp_value)*(clamp_value)
+    #
+    return weight
+
+
+def clamp_weight_ratio(merged_weight, clamp_ratio, clamp_value):
     # an intlligent clamp - look at the statistics and then clamp
     weight_max = merged_weight.abs().max()
     weight_median = merged_weight.abs().median()
-    if (weight_max > weights_clamp_value) and (weight_max > (weight_median*weights_clamp_ratio)):
-        weight_max = torch.min(weight_max, weight_median*weights_clamp_ratio)
+    if (weight_max > clamp_value) and (weight_max > (weight_median*clamp_ratio)):
+        weight_max = torch.min(weight_max, weight_median*clamp_ratio)
         weight_max2 = layers.ceil2_g(weight_max)
         scale_max2 = 128.0 / weight_max2
         # minimum 1 - using slightly higher margin to ensure quantization aware training
@@ -80,18 +84,12 @@ def clamp_weight_ratio(merged_weight):
     return clamped_weight
 
 
-def clamp_weight_soft(weight):
-    weight_max = weight.abs().max()
-    weight_median = weight.abs().median()
-    if (weight_max > weights_clamp_value) and (weight_max > (weight_median*weights_clamp_ratio)):
-        weight = torch.tanh(weight/weights_clamp_value)*(weights_clamp_value)
-    #
-    return weight
-
-
-def constrain_weight(weight, per_channel=True):
-    weight = clamp_weight_ratio(weight)
-    # weight = layers.standardize_weight(weight, per_channel)
-    # weight = clamp_weight_soft(weight)
-    return weight
-
+def constrain_weight(weight, clamp_ratio=16.0, clamp_value=15.0):
+    '''
+    for a mild constraining: use clamp_ratio=32.0, clamp_value=31.0
+    for aggressive constraining use: clamp_ratio=16.0, clamp_value=15.0
+    '''
+    # weight = clamp_weight_simple(weight, clamp_ratio, clamp_value)
+    # weight = clamp_weight_soft(weight, clamp_ratio, clamp_value)
+    weight = clamp_weight_ratio(weight, clamp_ratio, clamp_value)
+    return weight
\ No newline at end of file
index e8372cf90aab49236da69d0c4d69b599f06f3a9d..3f80ed99cf69ffbbdd738e8ae6b84155e2e0d0f8 100755 (executable)
 #--batch_size 64 --quantize True --epochs 1 --epoch_size 0.1 --evaluate_start False
 #
 #
+#### Image Classification - Post Training Calibration & Quantization - ResNet18
+#python ./scripts/train_classification_main.py --phase calibration --dataset_name image_folder_classification --model_name resnet18_x1 --data_path ./data/datasets/image_folder_classification \
+#--pretrained https://download.pytorch.org/models/resnet18-5c106cde.pth \
+#--batch_size 64 --quantize True --epochs 1 --epoch_size 0.1 --evaluate_start False
+#
+#
 #### Image Classification - Post Training Calibration & Quantization - MobileNetV2
 #python ./scripts/train_classification_main.py --phase calibration --dataset_name image_folder_classification --model_name mobilenetv2_tv_x1 --data_path ./data/datasets/image_folder_classification \
 #--pretrained https://download.pytorch.org/models/mobilenet_v2-b0353104.pth \
 #--pretrained https://download.pytorch.org/models/resnet50-19c8e357.pth \
 #--batch_size 64 --quantize True
 
+#### Image Classification - Accuracy Estimation with Post Training Quantization - ResNet18
+#python ./scripts/train_classification_main.py --phase validation --dataset_name image_folder_classification --model_name resnet18_x1 --data_path ./data/datasets/image_folder_classification \
+#--pretrained https://download.pytorch.org/models/resnet18-5c106cde.pth \
+#--batch_size 64 --quantize True
+
 #### Image Classification - Accuracy Estimation with Post Training Quantization - A TOUGH MobileNetV2 pretrained model
 #python ./scripts/train_classification_main.py --phase validation --dataset_name image_folder_classification --model_name mobilenetv2_shicai_x1 --data_path ./data/datasets/image_folder_classification \
 #--pretrained ./data/modelzoo/pytorch/image_classification/imagenet1k/shicai/mobilenetv2_shicai_rgb.pth \