]> Gitweb @ Texas Instruments - Open Source Git Repositories - git.TI.com/gitweb - jacinto-ai/pytorch-jacinto-ai-devkit.git/blob - docs/Quantization.md
188f7a443dd09ea18b6c3ac7d464fea533fda393
[jacinto-ai/pytorch-jacinto-ai-devkit.git] / docs / Quantization.md
1 # Quantization
3 Quantization of a CNN model is the process of converting floating point data & operations to fixed point (integer). This includes quantization of weights, feature maps and all operations (including convolution). The quantization style used in this code is **Power-Of-2, Symmetric, Per-Tensor Quantization** for both **Weights and Activations**. There is also an option to use Per-Channel Weight Quantization for Depthwise Convolution Layers.
5 Accuracy of inference can degrade if the CNN model is quantized to 8bits using simple methods and steps have to be taken to minimize this accuracy loss. The parameters of the model need to be adjusted to suit quantization. This includes adjusting of weights, biases and activation ranges. This adjustment can be done as part of the Calibration or as part of Quantization Aware Training.
7 ## Overview
8 - Inference engines use fixed point arithmetic to implement neural networks. For example TI Deep Learning Library (TIDL) for TI’s Jacinto7 TDA4x Devices (eg. TDA4VM) supports 16-bit and 8-bit fixed point inference modes.
9 - Fixed point mode, especially the 8-bit mode can have accuracy degradation. The tools and guidelines provided here help to avoid accuracy degradation with quantization.
10 - If you are getting accuracy degradation with 8-bit inference, the first thing to check is 16-bit inference. If 16-bit inference provides accuracy close to floating point and 8-bit has an accuracy degradation, there it is likely that the degradation si due to quantization. However, if there is substantial accuracy degradation with 16-bit inference itself, then it is likely that there there is some issue other than quantization.  
12 #### Quantization Schemes
13 - Post Training Calibration & Quantization (Calibration): Calibration often involves range estimation for weights and activations and also minor tweaks to the model (such as bias adjustments). Fixed point inference engines such as TIDL can accept a floating point model and Calibrate it using a few sample images. The Calibration is done during the import of the model, in the case of TIDL.<br>
14 - Quantization Aware Training (QAT): This is needed if accuracy obtained with Calibration is not satisfactory (eg. Quantization Accuracy Drop >2%). QAT operates as a second phase after the initial training in floating point is done. 
16 ## Guidelines For Training To Get Best Accuracy With Quantization
17 - **These are important** - we are listing these guidelines upfront because it is important to follow these.
18 - We recommend that the training uses sufficient amount of regularization / weight decay. Regularization / weight decay ensures that the weights, biases and other parameters (if any) are small and compact - this is good for quantization. These features are supported in most of the popular training framework.<br>
19 - We have noticed that some training code bases do not use weight decay for biases. Some other code bases do not use weight decay for the parameters in Depthwise convolution layers. All these are bad strategies for quantization. These poor choices done (probably to get a 0.1% accuracy lift with floating point) will result in a huge degradation in fixed point - sometimes several percentage points. The weight decay factor should not be too small. We have used a weight decay factor of 1e-4 for training several networks and we highly recommend a similar value. Please do no use small values such as 1e-5.<br>
20 - We also highly recommend to use Batch Normalization immediately after every Convolution layer. This helps the feature map to be properly regularized/normalized. If this is not done, there can be accuracy degradation with quantization. This especially true for Depthwise Convolution layers. However applying Batch Normalization to the very last Convolution layer (for example, the prediction layer in segmentation/object detection network) may hurt accuracy and can be avoided.<br>
22 To get best accuracy at the quantization stage, it is important that the model is trained carefully, following the guidelines (even during the floating point training). Having spent a lot of time solving quantization issues, we would like to highlight that following these guidelines are of at most importance. Otherwise, there is a high risk that the tools and techniques described here may not be completely effective in solving the accuracy drop due to quantization. To summarize, if you are getting poor accuracy with quntization, please check the following:<br>
23 - Weight decay is applied to all layers / parameters and that weight decay factor is good.<br>
24 - Ensure that the Convolution layers in the network have Batch Normalization layers immediately after that. The only exception allowed to this rule is for the very last Convolution layer in the network (for example the prediction layer in a segmentation network or detection network, where adding Batch normalization might hurt the floating point accuracy).<br>
26 ## Implementation Notes, Limitations & Recommendations
27 - **Please read carefully** - closely following these recommendations can save hours or days of debug related to quantization accuracy issues.
28 - **Use Modules instead of functions** (by Module we mean classes derived from torch.nn.Module). We make use of Modules heavily in our quantization tools - in order to do range collection, in order to merge Convolution/BN/ReLU in order to decide whether to quantize a certain tensor and so on. For example use torch.nn.ReLU instead of torch.nn.functional.relu(), torch.nn.AdaptiveAvgPool2d() instead of torch.nn.functional.adaptive_avg_pool2d(), torch.nn.Flatten() instead of torch.nn.functional.flatten() etc.<br>
29 - **The same module should not be re-used multiple times within the module** in order that the feature map range estimation is correct. Unfortunately, in the torchvision ResNet models, the ReLU module in the BasicBlock and BottleneckBlock are re-used multiple times. We have corrected this by defining separate ReLU modules. This change is minor and **does not** affect the loading of existing pretrained weights. See the [our modified ResNet model definition here](./modules/pytorch_jacinto_ai/vision/models/resnet.py).<br>
30 - If you have done QAT and is getting poor accuracy either in the Python code or during inference in the platform, please inspect your model carefully to see if the above recommendations have been followed - some of these can be easily missed by oversight - and can result in painful debugging that could have been avoided.<br>
31 - However, if a function does not change the range of feature map, it is not critical to use it in Module form. An example of this is torch.nn.functional.interpolate<br>
32 - **Multi-GPU training/calibration/validation with DataParallel is not yet working with our quantization modules** QuantTrainModule/QuantCalibrateModule/QuantTestModule. We recommend not to wrap the modules in DataParallel if you are training/calibrating/testing with quantization - i.e. if your model is wrapped in QuantTrainModule/QuantCalibrateModule/QuantTestModule.<br>
33 - If you get an error during training related to weights and input not being in the same GPU, please check and ensure that you are not using DataParallel with QuantTrainModule/QuantCalibrateModule/QuantTestModule. This may not be such a problem as calibration and quantization may not take as much time as the original floating point training. The original floating point training (without quantization) can use Multi-GPU as usual and we do not have any restrictions on that.<br>
34 - If your calibration/training crashes with insufficient GPU memory, reduce the batch size and try again.
35 - This repository has several useful functions and Modules as part of the xnn python module. Most notable ones are: [xnn.layers.resize_with, xnn.layers.ResizeWith](../modules/pytorch_jacinto_ai/xnn/resize_blocks.py) to export a clean resize/interpolate/upsamle graph, [xnn.layers.AddBlock, xnn.layers.CatBlock](../modules/pytorch_jacinto_ai/xnn/common_blocks.py) to do elementwise addition & concatenation in a torch.nn.Module form.
36 - If you are using TIDL to infer a model trained using QAT (or calibratied using PTQ) tools provided in this repository, please set the following in the import config file:<br>
37 **quantizationStyle = 3** to use power of 2 quantization.<br>
38 **foldPreBnConv2D = 0** to avoid an issue in folding of BatchNormalization that comes before Convolution.<br>
40 ## Post Training Calibration For Quantization (PTQ a.k.a. Calibration)
41 **Note: this is not our recommended method in PyTorch.**<br>
42 Post Training Calibration or simply Calibration is a method to reduce the accuracy loss with quantization. This is an approximate method and does not require ground truth or back-propagation - hence it is suitable for implementation in an Import/Calibration tool. We have simulated this in PyTorch and can be used as fast method to improve the accuracy with Quantization. If you are interested, you can take a look at the [documentation of Calibration here](Calibration.md).<br>
43 However, in a training frame work such as PyTorch, it is possible to get better accuracy with Quantization Aware Training and we recommend to use that (next section).
45 ## Quantization Aware Training (QAT)
46 Quantization Aware Training (QAT) is easy to incorporate into an existing PyTorch training code. We provide a wrapper module called QuantTrainModule to automate all the tasks required for QAT. The user simply needs to wrap his model in QuantTrainModule and do the training.
48 The overall flow of training is as follows:
49 - Step 1:Train your model in floating point as usual.
50 - Step 2: Starting from the floating point model as pretrained weights, do Quantization Aware Training. In order to do this wrap your model in the wrapper module called  pytorch_jacinto_ai.xnn.quantize.QuantTrainModule and perform training with a small learning rate. About 25 to 50 epochs of training may be required to get the best accuracy.
52 QuantTrainModule does the following operations to the model. Note that QuantTrainModule that will handle these tasks - the only thing that is required is to wrap the user's module in QuantTrainModule as explained in the section "How to use  QuantTrainModule".
53 - Replace all the ReLU, ReLU6 layers in the model by PACT2. Insert PACT2 after Convolution+BatchNorm if a ReLU/ReLU6 is missing after that.  Insert PACT2 anywhere else required - where activation range clipping and range collection is required. For example it can be after the Fully Connected Layer. We use forward post hooks of PyTorch nn.Modules to call these extra activation functions. Thus we are able to add these extra activations without disturbing the loading of existing pre-trained weights.
54 - Clip the weights to an appropriate range if the weight range is very high.
55 - Quantize the weights during the forward pass. Merging Convolution layers with the adjacent Batch Normalization layers (on-the-fly) during the weight quantization is required - if this merging is not correctly done, Quantization Aware Training may not improve accuracy.
56 - Quantize activations during the forward pass.
57 - Other modifications to help the learning process. For example, we use Straight-Through Estimation (STE) [[2,3]] to improve the gradient flow in back-propagation.
59 A block diagram of Quantization Aware Training with QuantTrainModule is shown below:
60 <p float="left"> <img src="quantization/trained_quant_ste.png" width="640" hspace="5"/> </p>
62 #### PACT2 activation
63 In order to make the activations quantization friendly, it is important to clip them during Quantization Aware Training. PACT2 activation module has been developed to clip the activations to a power-of-two value. PACT2 is used in the place of commonly used activation functions such as ReLU or ReLU6. Our Quantization Aware Training modules/scripts will automatically insert PACT2 activation functions wherever necessary to constraint the ranges of activations. The following is a block diagram of the PACT2:
64 <p float="left"> <img src="quantization/pact2_activation.png" width="640" hspace="5"/> </p>
65 We use statistical range clipping in PACT2 to improve the Quantized Accuracy (compared to simple min-max range clipping).
67 #### What happens during Quantization Aware Training?
68 - For each iteration perform a forward in floating point using the original weights and biases. During this pass PACT2 layers will collect output ranges using histogram and running average.
69 - In addition, perform Convolution+BatchNorm merging and quantization of the resulting weights. These quantized and de-quantized weights are used in a forward pass. Ranges collected by PACT2 is used for activation quantization (and de-quantization) to generate quantized output.
70 - Back-propagation with STE will update the parameters of the model to reduce the loss with quantization.
71 - Within a few epochs, we should get reasonable quantization accuracy.
73 #### How to use  QuantTrainModule
74 In order to enable quantized training, we have developed the wrapper class pytorch_jacinto_ai.xnn.quantize.QuantTrainModule. A simple example for using this module is given in the script [examples/quantization_example.py](../examples/quantization_example.py) and calling this is demonstrated in [run_quantization_example.sh](../run_quantization_example.sh). The usage of this module can also be seen in pytorch_jacinto_ai.engine.train_classification.py and pytorch_jacinto_ai.engine.train_pixel2pixel.py. The following is a brief description of how to use this wrapper module:
75 ```
76 from pytorch_jacinto_ai import xnn
78 # create your model here:
79 model = ...
81 # create a dummy input - this is required to analyze the model - fill in the input image size expected by your model.
82 dummy_input = torch.rand((1,3,384,768))
84 # wrap your model in xnn.quantize.QuantTrainModule. 
85 # once it is wrapped, the actual model is in model.module
86 model = xnn.quantize.QuantTrainModule(model, dummy_input=dummy_input)
88 # load your pretrained weights here into model.module
89 pretrained_data = torch.load(pretrained_path)
90 model.module.load_state_dict(pretrained_data)
92 # your training loop here with with loss, backward, optimizer and scheduler. 
93 # this is the usual training loop - but use a lower learning rate such as 5e-5
94 ....
95 ....
97 # save the model - the trained module is in model.module
98 torch.save(model.module.state_dict(), os.path.join(save_path,'model.pth'))
99 torch.onnx.export(model.module, dummy_input, os.path.join(save_path,'model.onnx'), export_params=True, verbose=False, do_constant_folding=True, opset_version=9)
100 ```
102 As can be seen, it is easy to incorporate QuantTrainModule in your existing training code as the only thing required is to wrap your original model in QuantTrainModule. Careful attention needs to be given to how the parameters of the pretrained model is loaded and trained model is saved as shown in the above code snippet.
104 Optional: We have provided a utility function called pytorch_jacinto_ai.xnn.utils.load_weights() that prints which parameters are loaded correctly and which are not - you can use this load function if needed to ensure that your parameters are loaded correctly.
106 ####  Example commands for QAT
107 ImageNet Classification: *In this example, only a fraction of the training samples are used in each training epoch to speedup training. Remove the argument --epoch_size to use all the training samples.*
108 ```
109 python ./scripts/train_classification_main.py --dataset_name image_folder_classification --model_name mobilenetv2_tv_x1 --data_path ./data/datasets/image_folder_classification --pretrained https://download.pytorch.org/models/mobilenet_v2-b0353104.pth --batch_size 64 --quantize True --epochs 50 --epoch_size 1000 --lr 1e-5 --evaluate_start False
110 ```
112 Cityscapes Semantic Segmentation:<br>
113 ```
114 python ./scripts/train_segmentation_main.py --dataset_name cityscapes_segmentation --model_name deeplabv3lite_mobilenetv2_tv --data_path ./data/datasets/cityscapes/data --img_resize 384 768 --output_size 1024 2048 --gpus 0 1 --pretrained ./data/modelzoo/pytorch/semantic_segmentation/cityscapes/jacinto_ai/deeplabv3lite_mobilenetv2_tv_resize768x384_best.pth.tar --batch_size 8 --quantize True --epochs 50 --lr 1e-5 --evaluate_start False
115 ```
117 For more examples, please see the files run_qunatization_example.sh and examples/quantization_example.py
119 ## Results
121 The table below shows the Quantized Accuracy with various Calibration and methods and also QAT. Some of the commands used to generate these results are summarized in the file **run_quantization.sh** for convenience.
123 ###### Dataset: ImageNet Classification (Image Classification)
125 |Mode Name          |Backbone   |Stride|Resolution|Float Acc%|Simple Calib Acc%|Adv Calib Acc%|Adv DW Calib Acc%|QAT Acc% |Acc Drop-Adv Calib|Acc Drop - QAT|
126 |----------         |-----------|------|----------|--------- |---              |---           |---              |---      |---               |---          |
127 |ResNet50(TV)       |ResNet50   |32    |224x224   |**76.15** |75.56            |**75.56**     |75.56            |**76.05**|-0.59             |-0.10        |
128 |MobileNetV2(TV)    |MobileNetV2|32    |224x224   |**71.89** |67.77            |**68.39**     |69.34            |**70.74**|-3.50             |-1.34        |
129 |MobileNetV2(Shicai)|MobileNetV2|32    |224x224   |**71.44** |0.0              |**68.81**     |70.65            |**70.54**|-2.63             |-0.9         |
131 Notes:
132 - For Image Classification, the accuracy measure used is % Top-1 Classification Accuracy. 'Top-1 Classification Accuracy' is abbreviated by Acc in the above table.<br>
133 - (TV) Stands for TochVision: https://github.com/pytorch/vision
134 - MobileNetV2(Shicai) model is from https://github.com/shicai/MobileNet-Caffe (converted from caffe to PyTorch) - this model was selected as this is a tough case for quantization.<br><br>
136 ###### Dataset: Cityscapes Segmentation (Semantic Segmentation)
138 |Mode Name    |Backbone   |Stride|Resolution|Float Acc%|Simple Calib Acc%|Adv Calib Acc%|Adv DW Calib Acc%|QAT Acc% |Acc Drop-Advanced Calib|Acc Drop - QAT|
139 |----------   |-----------|------|----------|----------|---              |---           |---              |---      |---                    |---           |
140 |DeepLabV3Lite|MobileNetV2|16    |768x384   |**69.13** |61.71            |**67.95**     |68.47            |**68.44**|-1.18                  |-0.69         |
142 Note: For Semantic Segmentation, the accuracy measure used in MeanIoU Accuracy. 'MeanIoU Accuracy' is abbreviated by Acc in the above table.
144 **Terminology:**<br>
145 All of these are variants of Power-Of-2, Symmetric, Per-Tensor Quantization, depending on how the parameters are adjusted for Quantization.<br>
146 - Simple Calib: Calibration based on min/max ranges
147 - Adv Calib: Includes histogram based ranges, calibration of weight/bias parameters to compensate for quantization accuracy loss.
148 - Adv DW Calib: Also includes Per-Channel Weight Quantization for Depthwise layers
149 - QAT: Quantization Aware Training with PyTorch Jacinto AI DevKit (Does not use Per-Channel Weight Quantization)
151 **Conclusion based on Simulation Results:**<br>
152 - Advanced Calibration Methods may have >2% Accuracy Drop in some cases.
153 - Quantization Aware Training (QAT) is consistently able to produce <2% Accuracy drop.
156 ## References 
157 [1] PACT: Parameterized Clipping Activation for Quantized Neural Networks, Jungwook Choi, Zhuo Wang, Swagath Venkataramani, Pierce I-Jen Chuang, Vijayalakshmi Srinivasan, Kailash Gopalakrishnan, arXiv preprint, arXiv:1805.06085, 2018
159 [2] Estimating or propagating gradients through stochastic neurons for conditional computation. Y. Bengio, N. Léonard, and A. Courville. arXiv preprint arXiv:1308.3432, 2013.
161 [3] Understanding Straight-Through Estimator in training activation quantized neural nets, Penghang Yin, Jiancheng Lyu, Shuai Zhang, Stanley Osher, Yingyong Qi, Jack Xin, ICLR 2019
163 [4] Quantization and Training of Neural Networks for Efficient Integer-Arithmetic-Only Inference, Benoit Jacob Skirmantas Kligys Bo Chen Menglong Zhu, Matthew Tang Andrew Howard Hartwig Adam Dmitry Kalenichenko, arXiv preprint, arXiv:1712.05877
165 [5] Trained quantization thresholds for accurate and efficient fixed-point inference of Deep Learning Neural Networks, Sambhav R. Jain, Albert Gural, Michael Wu, Chris H. Dick, arXiv preprint, arXiv:1903.08066 
167 [6] Quantizing deep convolutional networks for efficient inference: A whitepaper, Raghuraman Krishnamoorthi, arXiv preprint, arXiv:1806.08342
169 [7] TensorFlow / Learn / For Mobile & IoT / Guide / Post-training quantization, https://www.tensorflow.org/lite/performance/post_training_quantization
171 [8] QUANTIZATION / Introduction to Quantization, https://pytorch.org/docs/stable/quantization.html