]> Gitweb @ Texas Instruments - Open Source Git Repositories - git.TI.com/gitweb - jacinto-ai/pytorch-jacinto-ai-devkit.git/blob - docs/Quantization.md
14c48810c2eb0a933fa63c176b8d0d0728497589
[jacinto-ai/pytorch-jacinto-ai-devkit.git] / docs / Quantization.md
1 # Quantization
3 As we know Quantization is the process of converting floating point data & operations to fixed point (integer). CNNs can be quantized to 8-bits integer data/operations without significant accuracy loss. This includes quantization of weights, feature maps and all operations (including convolution of weights). **We use power-of-2, symmetric quantization for both weights and activations**.
5 There are two primary methods of quantization - Post Training Quantization and Trained Quantization. 
7 ## Post Training Calibration & Quantization
9 Post Training Calibration & Quantization can take a model trained in floating point and with a few steps convert it to a model that is friendly for quantized inference. Compared to the alternative (Trained Quantization), the advantages of this method are:
10 - Calibration is fast - a typical calibration finishes in a few minutes.  
11 - Ground truth is not required - just input images are sufficient.
12 - Loss function or backward (back-propagation) are not required. 
14 Thus, this is the preferred method of quantization from an ease of use point of view.  As explained earlier, in this method, the training happens entirely in floating point. The inference (possibly in an embedded device) happens in fixed point. In between training and fixed point inference, the model goes through the step called Calibration with some sample images. The Calibration happens in PC and Quantized Inference happens in the embedded device. Calibration basically tries to make the quantized output similar to the floating point output - by choosing appropriate activation ranges, weights and biases. The step by step process is as follows:
16 #### Model preparation:
17 - Replace all the ReLU, ReLU6 layers in the model by PACT2. Insert PACT2 after Convolution+BatchNorm if a ReLU is missing after that.  Insert PACT2 anywhere else required - where activation range clipping and range collection is required. For example it can ne after the Fully Connected Layer. We use forward post hooks of PyTorch nn.Modules to call these extra activation functions. Thus we are able to add these extra activations without disturbing the loading of existing pre-trained weights.
18 - Clip the weights to an appropriate range if the weight range is very high.
20 #### Forward iterations:
21 - For each iteration perform a forward in floating point using the original weights and biases. During this pass PACT2 layers will collect output ranges using histogram and running average.
22 - In addition, perform Convolution+BatchNorm merging and quantization of the resulting weights. These quantized and de-quantized weights are used in a forward pass. Ranges collected by PACT2 is used for activation quantization (and de-quantization) to generate quantized output.
23 - The floating point output and quantized output are compared using statistic measures. Using such statistic measures, we can adjust the weights and biases of Convolutions and Batch Normalization layers - so that the quantized output becomes closer to the floating point output.
24 - Within a few iterations, we could get reasonable quantization accuracy for several models that we tried this method on.
26 Depending on how the activation range is collected and Quantization is done, we have a few variants of this basic scheme.  
27 - Simple Calib: Calibration includes PACT2 for activation clipping, running average and range collection. In this method we use min-max for activation range collection (no histogram).
28 - **Advanced Calib**: Calibration includes PACT2 with histogram based ranges, Weight clipping, Bias correction. 
29 - Advanced DW Calib: Calibration includes Per-Channel Quantization of Weights for Depthwise layers, PACT2 with histogram based ranges, Weight clipping, Bias correction. One of the earliest papers that clearly explained the benefits of Per-Channel Quantization for weights only (while the activations are quantized as Per-Tensor) is [6] 
30 - Advanced Per-Chan Calib: Calibration includes Per-Channel Quantization for all layers, PACT2 with histogram based ranges, Weight clipping, Bias correction.
32 Out of these methods, **Advanced Calib** is our recommended Calibration method as of now, as it has the best trade-off between the Accuracy and the features required during fixed point inference. All the Calibration scripts that we have in this page uses "Advanced Calib" by default. Other Calibration methods described here are for information only. 
34 In order to do Calibration easily we have a developed a wrapper module called QuantCalibrateModule, which is located in pytorch_jacinto_ai.xnn.quantize.QuantCalibrateModule. We make use of a kind of Parametric Activation called **PACT2** in order to store the calibrated ranges of activations. PACT2 is a improved form of PACT [1]. **PACT2 uses power of 2 activation ranges** for activation clipping. PACT2 can learn ranges very quickly (using a statistic method) without back propagation - this feature makes it quite attractive for Calibration. Our wrapper module replaces all the ReLUs in the model with PACT2. It also inserts PACT2 in other places where activation ranges need to be collected.  Statistical range clipping in PACT2 improves the Quantized Accuracy over simple min-max range clipping. 
36 As explained, our method of **Calibration does not need ground truth, loss function or back propagation.** However in our script, we make use of ground truth to measure the loss/accuracy even in the Calibration stage - although that is not necessary. 
38 #### How to use  QuantCalibrateModule
39 The section briefly explains how to make use of our helper/wrapper module to do the calibration of your model. For further details, please see pytorch_jacinto_ai.engine.train_classification.py and pytorch_jacinto_ai.engine.train_pixel2pixel.py.
41 ```
42 # create your model here:
43 model = ...
45 # create a dummy input - this is required to analyze the model - fill in the input image size expected by your model.
46 dummy_input = torch.rand((1,3,384,768))
48 #wrap your model in QuantCalibrateModule. Once it is wrapped, the actual model is in model.module
49 model = pytorch_jacinto_ai.xnn.quantize.QuantCalibrateModule(model, dummy_input=dummy_input)
51 # load your pretrained weights here into model.module
52 pretrained_data = torch.load(pretrained_path)
53 model.module.load_state_dict(pretrained_data)
55 # create your dataset here - the ground-truth/target that you provide in the dataset can be dummy and does not affect calibration.
56 my_dataset_train, my_dataset_val = ...
58 # do one epoch of calibration - in practice about 1000 iterations are sufficient.
59 for images, target in my_dataset_train:
60     output = model(images)
62 # save the model - the calibrated module is in model.module
63 torch.save(model.module.state_dict(), os.path.join(save_path,'model.pth'))
64 torch.onnx.export(model.module, dummy_input, os.path.join(save_path,'model.onnx'), export_params=True, verbose=False)
66 ```
68 Few examples of calibration are provided below. These commands are also listed in the file **run_quantization.sh** for convenience.<br>
70 #### Calibration of ImageNet Classification MobileNetV2 model 
71 ```
72 python ./scripts/train_classification_main.py --phase calibration --dataset_name image_folder_classification --model_name mobilenetv2_tv_x1 --data_path ./data/datasets/image_folder_classification --pretrained https://download.pytorch.org/models/mobilenet_v2-b0353104.pth --batch_size 64 --quantize True --epochs 1 --epoch_size 100
73 ```
75 #### Calibration of ImageNet Classification ResNet50 model 
76 ```
77 python ./scripts/train_classification_main.py --phase calibration --dataset_name image_folder_classification --model_name mobilenetv2_tv_x1 --data_path ./data/datasets/image_folder_classification --pretrained https://download.pytorch.org/models/resnet50-19c8e357.pth --batch_size 64 --quantize True --epochs 1 --epoch_size 100
78 ```
80 #### Calibration of Cityscapes Semantic Segmentation model 
81 ```
82 python ./scripts/train_segmentation_main.py --phase calibration --dataset_name cityscapes_segmentation --model_name deeplabv3lite_mobilenetv2_tv --data_path ./data/datasets/cityscapes/data --img_resize 384 768 --output_size 1024 2048 --gpus 0 1 
83 --pretrained ./data/modelzoo/semantic_segmentation/cityscapes/deeplabv3lite-mobilenetv2/cityscapes_segmentation_deeplabv3lite-mobilenetv2_2019-06-26-08-59-32.pth 
84 --batch_size 12 --quantize True --epochs 1 --epoch_size 100
85 ```
87 ## Trained Quantization
89 As explained in the previous section, Calibration is our preferred method of making a model quantization friendly. However, in exceptional cases, it is possible that the drop in accuracy during calibration is more than acceptable. In this case, Trained Quantization can be used. 
91 Unlike Calibration, Trained Quantization involves ground truth, loss function and back propagation. The most popular method of trained quantization is [4]. It takes care of merging Convolution layers with the adjascent Batch Normalization layers (on-the-fly) during the quantized training (if this merging is not correctly done, quantized training may not improve the accuracy). In addition, we use Straight-Through Estimation (STE) [2,3] to improve the gradient flow in back-propagation. Also, the statistical range clipping in PACT2 improves the Quantized Accuracy over simple min-max range clipping. 
93 Note: Instead of STE and statistical ranges for PACT2, we also tried out approximate gradients for scale and trained quantization thresholds proposed in [5] (We did not use the gradient nomralization and log-domain training mentioned in the paper). We found that method to be able to learn the clipping thresholds for initial few epochs, but became unstable after a few epochs and loss became high. Compared to that learned thresholds method, our statistical PACT2 ranges/thresholds combined with STE is simple and stable. 
95 In order to enable quantized training, we have developed the wrapper class pytorch_jacinto_ai.xnn.quantize.QuantTrainModule. The usage of this module can be seen in pytorch_jacinto_ai.engine.train_classification.py and pytorch_jacinto_ai.engine.train_pixel2pixel.py. 
96 ```
97 model = pytorch_jacinto_ai.xnn.quantize.QuantTrainModule(model, dummy_input=dummy_input)
98 ```
99 The resultant model can then be used for training as usual and it will take care of quantization constraints during the training forward and backward passes.
101 One word of caution is that our current implementation of Trained Quantization is a bit slow. The reason for this slowdown is that our implementation is using the top-level python layer of PyTorch and not the underlying C++ layer. But with PyTorch natively supporting the functionality required for quantization under the hood - we hope that this speed issue can be resolved in a future update. 
103 Example commands for trained quantization: 
104 ```
105 python ./scripts/train_classification_main.py --dataset_name image_folder_classification --model_name mobilenetv2_tv_x1 --data_path ./data/datasets/image_folder_classification --pretrained https://download.pytorch.org/models/mobilenet_v2-b0353104.pth --batch_size 64 --quantize True --epochs 150 --epoch_size 1000 --lr 5e-5 --evaluate_start False
106 ```
108 ```
109 python ./scripts/train_segmentation_main.py --dataset_name cityscapes_segmentation --model_name deeplabv3lite_mobilenetv2_tv --data_path ./data/datasets/cityscapes/data --img_resize 384 768 --output_size 1024 2048 --gpus 0 1 --pretrained ./data/modelzoo/semantic_segmentation/cityscapes/deeplabv3lite-mobilenetv2/cityscapes_segmentation_deeplabv3lite-mobilenetv2_2019-06-26-08-59-32.pth --batch_size 8 --quantize True --epochs 150 --lr 5e-5 --evaluate_start False
110 ```
112 ## Important Notes
113 **Multi-GPU training/calibration/validation with DataParallel is not yet working with our quantization modules** QuantCalibrateModule, QuantTrainModule. 
114 - For now, we recommend not to wrap the modules in DataParallel if you are training/calibrating with quantization - i.e. if your model is wrapped in QuantCalibrateModule/QuantTrainModule/QuantTestModule. 
115 - This may not be such a problem as calibration and quantization may not take as much time as the original training. 
116 - If your calibration/training crashes with insufficient GPU memory, reduce the batch size and try again.
117 - The original training (without quantization) can use Multi-GPU as usual and we do not have any restrictions on that.
118 - Tools for Calibration and Trained Quantization have started appearing in mainstream Deep Learning training frameworks [7,8]. Using the tools natively provided by these frameworks may be faster compared to an implementation in the Python layer of these frameworks (like we have done) - but they may not be mature currently. 
119 - In order that the activation range estimation is correct, **the same module should not be re-used multiple times within the module**. Unfortunately, in the torchvision ResNet models, the ReLU module in the BasicBlock and BottleneckBlock are re-used multiple times. We have corrected this by defining separate ReLU modules. This change is minor and **does not** affect the loading of existing pretrained weights. See the [our modified ResNet model definition here](./modules/pytorch_jacinto_ai/vision/models/resnet.py).
120 - Use Modules instead of functions as much as possible (we make use of modules to decide whether to do activation range clipping or not). For example use torch.nn.AdaptiveAvgPool2d() instead of torch.nn.functional.adaptive_avg_pool2d(), torch.nn.Flatten() instead of torch.nn.functional.flatten() etc. If you are using functions in your model and is giving poor quantized accuracy, then consider replacing those functions by the corresponding modules.
123 ## Results
125 The table below shows the Quantized Accuracy with various Calibration and methods and also Trained Quantization. Some of the commands used to generate these results are summarized in the file **run_quantization.sh** for convenience.
127 ###### Dataset: ImageNet Classification (Image Classification)
129 |Mode Name               |Backbone   |Stride|Resolution|Float Acc%|Simple Calib Acc%|Advanced Calib Acc%|Advanced DW Calib Acc%|Advanced Per-Chan Calib Acc%|Trained Quant Acc%|
130 |----------              |-----------|------|----------|--------- |---              |---                |---                   |---                         |---               |
131 |ResNet50(TorchVision)   |ResNet50   |32    |224x224   |**76.15** |75.56            |**75.56**          |75.56                 |75.39                       |                  |
132 |MobileNetV2(TorchVision)|MobileNetV2|32    |224x224   |**71.89** |67.77            |**68.39**          |69.34                 |69.46                       |**70.55**         |
133 |MobileNetV2(Shicai)     |MobileNetV2|32    |224x224   |**71.44** |45.60            |**68.81**          |70.65                 |70.75                       |                  |
135 Notes:
136 - For Image Classification, the accuracy measure used is % Top-1 Classification Accuracy. 'Top-1 Classification Accuracy' is abbreviated by Acc in the above table.<br>
137 - MobileNetV2(Shicai) model is from https://github.com/shicai/MobileNet-Caffe (converted from caffe to PyTorch) - this is a tough model for quantization.<br><br>
139 ###### Dataset: Cityscapes Segmentation (Semantic Segmentation)
141 |Mode Name               |Backbone   |Stride|Resolution|Float Acc%|Simple Calib Acc%|Advanced Calib Acc%|Advanced DW Calib Acc%|Advanced Per-Chan Calib Acc%|Trained Quant Acc%|
142 |----------              |-----------|------|----------|----------|---              |---                |---                   |---                         |---               |
143 |DeepLabV3Lite           |MobileNetV2|16    |768x384   |**69.13** |61.71            |**67.95**          |68.47                 |68.56                       |**68.26**         |
145 Notes: 
146  - For Semantic Segmentation, the accuracy measure used in MeanIoU Accuracy. 'MeanIoU Accuracy' is abbreviated by Acc in the above table.
147 <br>
150 ## References 
151 [1] PACT: Parameterized Clipping Activation for Quantized Neural Networks, Jungwook Choi, Zhuo Wang, Swagath Venkataramani, Pierce I-Jen Chuang, Vijayalakshmi Srinivasan, Kailash Gopalakrishnan, arXiv preprint, arXiv:1805.06085, 2018
153 [2] Estimating or propagating gradients through stochastic neurons for conditional computation. Y. Bengio, N. LĂ©onard, and A. Courville. arXiv preprint arXiv:1308.3432, 2013.
155 [3] Understanding Straight-Through Estimator in training activation quantized neural nets, Penghang Yin, Jiancheng Lyu, Shuai Zhang, Stanley Osher, Yingyong Qi, Jack Xin, ICLR 2019
157 [4] Quantization and Training of Neural Networks for Efficient Integer-Arithmetic-Only Inference, Benoit Jacob Skirmantas Kligys Bo Chen Menglong Zhu, Matthew Tang Andrew Howard Hartwig Adam Dmitry Kalenichenko, arXiv preprint, arXiv:1712.05877
159 [5] Trained quantization thresholds for accurate and efficient fixed-point inference of Deep Learning Neural Networks, Sambhav R. Jain, Albert Gural, Michael Wu, Chris H. Dick, arXiv preprint, arXiv:1903.08066 
161 [6] Quantizing deep convolutional networks for efficient inference: A whitepaper, Raghuraman Krishnamoorthi, arXiv preprint, arXiv:1806.08342
163 [7] TensorFlow / Learn / For Mobile & IoT / Guide / Post-training quantization, https://www.tensorflow.org/lite/performance/post_training_quantization
165 [8] QUANTIZATION / Introduction to Quantization, https://pytorch.org/docs/stable/quantization.html