release commit
[jacinto-ai/pytorch-jacinto-ai-devkit.git] / docs / Quantization.md
1 # Quantization
3 Quantization of a CNN model is the process of converting floating point data & operations to fixed point (integer). This includes quantization of weights, feature maps and all operations (including convolution). The quantization style used in this code is **Power-Of-2, Symmetric, Per-Tensor Quantization** for both **Weights and Activations**. There is also an option to use Per-Channel Weight Quantization for Depthwise Convolution Layers.
5 Accuracy of inference can degrade if the CNN model is quantized to 8bits using simple methods and steps have to be taken to minimize this accuracy loss. The parameters of the model need to be adjusted to suit quantization. This includes adjusting of weights, biases and activation ranges. This adjustment can be done as part of the Calibration or as part of Quantization Aware Training.
7 ## Overview
8 TI Deep Learning Library (TIDL) is a highly optimized runtime for Deep Learning Models on TI’s Jacinto7 TDA4x Devices (eg. TDA4VM). TIDL supports two kinds of Quantization schemes:
9 - Post Training Calibration & Quantization (Calibration): TIDL can accept a floating point model and Calibrate it with a few sample images. The Calibration is done during the import of the model. The current Calibration scheme is fairly simple  but there are plans to improve it substantially.<br>
10 - Quantization Aware Training (QAT): This is needed if accuracy obtained with Calibration is not satisfactory (eg. Quantization Accuracy Drop >2%). QAT operates as a second phase after the initial training in floating point is done. We have provide this PyTorch Jacinto AI DevKit to enable QAT with PyTorch. There also a plan to make TensorFlow Jacinto AI DevKit available. Further Details are available at: [https://github.com/TexasInstruments/jacinto-ai-devkit](https://github.com/TexasInstruments/jacinto-ai-devkit)<br>
12 #### PACT2 activation
13 In order to make the activations quantization friendly, it is important to clip them during Quantization Aware Training. PACT2 activation module has been developed to clip the activations to a power-of-two value. PACT2 is used in the place of commonly used activation functions such as ReLU or ReLU6. Our Quantization Aware Training modules/scripts will automatically insert PACT2 activation functions wherever necessary to constraint the ranges of activations. The following is a block diagram of the PACT2:
14 <p float="left"> <img src="quantization/pact2_activation.png" width="640" hspace="5"/> </p>
15 We use statistical range clipping in PACT2 to improve the Quantized Accuracy (compared to simple min-max range clipping).
17 ## Post Training Calibration For Quantization (Calibration)
18 **Note: this is not our recommended method in PyTorch.**<br>
19 Post Training Calibration or simply Calibration is a method to reduce the accuracy loss with quantization. This is an approximate method and does not require ground truth or back-propagation - hence it is suitable for implementation in an Import/Calibration tool. We have simulated this in PyTorch and can be used as fast method to improve the accuracy with Quantization. If you are interested, you can take a look at the [documentation of Calibration here](Calibration.md).<br>
20 However, in a training frame work such as PyTorch, it is possible to get better accuracy with Quantization Aware Training and we recommend to use that (next section).
22 ## Quantization Aware Training (QAT)
23 Quantization Aware Training (QAT) is easy to incorporate into an existing PyTorch training code. We provide a wrapper module called QuantTrainModule to automate all the tasks required for QAT. The user simply needs to wrap his model in QuantTrainModule and do the training.
25 The overall flow of training is as follows:
26 - Step 1:Train your model in floating point as usual.
27 - Step 2: Starting from the floating point model as pretrained weights, do Quantization Aware Training. In order to do this wrap your model in the wrapper module called  pytorch_jacinto_ai.xnn.quantize.QuantTrainModule and perform training with a small learning rate. About 25 to 50 epochs of training may be required to get the best accuracy.
29 QuantTrainModule does the following operations to the model. Note that QuantTrainModule that will handle these tasks - the only thing that is required is to wrap the user's module in QuantTrainModule as explained in the section "How to use  QuantTrainModule".
30 - Replace all the ReLU, ReLU6 layers in the model by PACT2. Insert PACT2 after Convolution+BatchNorm if a ReLU/ReLU6 is missing after that.  Insert PACT2 anywhere else required - where activation range clipping and range collection is required. For example it can be after the Fully Connected Layer. We use forward post hooks of PyTorch nn.Modules to call these extra activation functions. Thus we are able to add these extra activations without disturbing the loading of existing pre-trained weights.
31 - Clip the weights to an appropriate range if the weight range is very high.
32 - Quantize the weights during the forward pass. Merging Convolution layers with the adjacent Batch Normalization layers (on-the-fly) during the weight quantization is required - if this merging is not correctly done, Quantization Aware Training may not improve accuracy.
33 - Quantize activations during the forward pass.
34 - Other modifications to help the learning process. For example, we use Straight-Through Estimation (STE) [[2,3]] to improve the gradient flow in back-propagation.
36 A block diagram of Quantization Aware Training with QuantTrainModule is shown below:
37 <p float="left"> <img src="quantization/trained_quant_ste.png" width="640" hspace="5"/> </p>
39 #### What happens during Quantization Aware Training?
40 - For each iteration perform a forward in floating point using the original weights and biases. During this pass PACT2 layers will collect output ranges using histogram and running average.
41 - In addition, perform Convolution+BatchNorm merging and quantization of the resulting weights. These quantized and de-quantized weights are used in a forward pass. Ranges collected by PACT2 is used for activation quantization (and de-quantization) to generate quantized output.
42 - Back-propagation with STE will update the parameters of the model to reduce the loss with quantization.
43 - Within a few epochs, we should get reasonable quantization accuracy.
45 #### How to use  QuantTrainModule
46 In order to enable quantized training, we have developed the wrapper class pytorch_jacinto_ai.xnn.quantize.QuantTrainModule. A simple example for using this module is given in the script [examples/quantization_example.py](../examples/quantization_example.py) and calling this is demonstrated in [run_quantization_example.sh](../run_quantization_example.sh). The usage of this module can also be seen in pytorch_jacinto_ai.engine.train_classification.py and pytorch_jacinto_ai.engine.train_pixel2pixel.py. The following is a brief description of how to use this wrapper module:
47 ```
48 from pytorch_jacinto_ai import xnn
50 # create your model here:
51 model = ...
53 # create a dummy input - this is required to analyze the model - fill in the input image size expected by your model.
54 dummy_input = torch.rand((1,3,384,768))
56 # wrap your model in xnn.quantize.QuantTrainModule. 
57 # once it is wrapped, the actual model is in model.module
58 model = xnn.quantize.QuantTrainModule(model, dummy_input=dummy_input)
60 # load your pretrained weights here into model.module
61 pretrained_data = torch.load(pretrained_path)
62 model.module.load_state_dict(pretrained_data)
64 # your training loop here with with loss, backward, optimizer and scheduler. 
65 # this is the usual training loop - but use a lower learning rate such as 5e-5
66 ....
67 ....
69 # save the model - the trained module is in model.module
70 torch.save(model.module.state_dict(), os.path.join(save_path,'model.pth'))
71 torch.onnx.export(model.module, dummy_input, os.path.join(save_path,'model.onnx'), export_params=True, verbose=False)
72 ```
74 As can be seen, it is easy to incorporate QuantTrainModule in your existing training code as the only thing required is to wrap your original model in QuantTrainModule. Careful attention needs to be given to how the parameters of the pretrained model is loaded and trained model is saved as shown in the above code snippet.
76 Optional: We have provided a utility function called pytorch_jacinto_ai.xnn.utils.load_weights() that prints which parameters are loaded correctly and which are not - you can use this load function if needed to ensure that your parameters are loaded correctly.
78 ####  Example commands for QAT
79 ImageNet Classification: *In this example, only a fraction of the training samples are used in each training epoch to speedup training. Remove the argument --epoch_size to use all the training samples.*
80 ```
81 python ./scripts/train_classification_main.py --dataset_name image_folder_classification --model_name mobilenetv2_tv_x1 --data_path ./data/datasets/image_folder_classification --pretrained https://download.pytorch.org/models/mobilenet_v2-b0353104.pth --batch_size 64 --quantize True --epochs 50 --epoch_size 1000 --lr 1e-5 --evaluate_start False
82 ```
84 Cityscapes Semantic Segmentation:<br>
85 ```
86 python ./scripts/train_segmentation_main.py --dataset_name cityscapes_segmentation --model_name deeplabv3lite_mobilenetv2_tv --data_path ./data/datasets/cityscapes/data --img_resize 384 768 --output_size 1024 2048 --gpus 0 1 --pretrained ./data/modelzoo/semantic_segmentation/cityscapes/deeplabv3lite-mobilenetv2/cityscapes_segmentation_deeplabv3lite-mobilenetv2_2019-06-26-08-59-32.pth --batch_size 8 --quantize True --epochs 50 --lr 1e-5 --evaluate_start False
87 ```
89 For more examples, please see the files run_qunatization_example.sh and examples/quantization_example.py
91 ## Important Notes - read carefully
92 - **Multi-GPU training/calibration/validation with DataParallel is not yet working with our quantization modules** QuantTrainModule/QuantCalibrateModule/QuantTestModule. We recommend not to wrap the modules in DataParallel if you are training/calibrating/testing with quantization - i.e. if your model is wrapped in QuantTrainModule/QuantCalibrateModule/QuantTestModule.<br>
93 - If you get an error during training related to weights and input not being in the same GPU, please check and ensure that you are not using DataParallel with QuantTrainModule/QuantCalibrateModule/QuantTestModule. This may not be such a problem as calibration and quantization may not take as much time as the original floating point training. The original floating point training (without quantization) can use Multi-GPU as usual and we do not have any restrictions on that.<br>
94 - If your calibration/training crashes with insufficient GPU memory, reduce the batch size and try again.
95 - **The same module should not be re-used multiple times within the module** in order that the activation range estimation is correct. Unfortunately, in the torchvision ResNet models, the ReLU module in the BasicBlock and BottleneckBlock are re-used multiple times. We have corrected this by defining separate ReLU modules. This change is minor and **does not** affect the loading of existing pretrained weights. See the [our modified ResNet model definition here](./modules/pytorch_jacinto_ai/vision/models/resnet.py).<br>
96 - **Use Modules instead of functions** (we make use of modules to decide whether to do activation range clipping or not). For example use torch.nn.reLU instead of torch.nn.functional.relu(), torch.nn.AdaptiveAvgPool2d() instead of torch.nn.functional.adaptive_avg_pool2d(), torch.nn.Flatten() instead of torch.nn.functional.flatten() etc. If you are using functions in your model and is giving poor quantized accuracy, then consider replacing those functions by the corresponding modules.<br>
99 ## Results
101 The table below shows the Quantized Accuracy with various Calibration and methods and also QAT. Some of the commands used to generate these results are summarized in the file **run_quantization.sh** for convenience.
103 ###### Dataset: ImageNet Classification (Image Classification)
105 |Mode Name          |Backbone   |Stride|Resolution|Float Acc%|Simple Calib Acc%|Adv Calib Acc%|Adv DW Calib Acc%|QAT Acc% |Acc Drop-Adv Calib|Acc Drop - QAT|
106 |----------         |-----------|------|----------|--------- |---              |---           |---              |---      |---               |---          |
107 |ResNet50(TV)       |ResNet50   |32    |224x224   |**76.15** |75.56            |**75.56**     |75.56            |**76.05**|-0.59             |-0.10        |
108 |MobileNetV2(TV)    |MobileNetV2|32    |224x224   |**71.89** |67.77            |**68.39**     |69.34            |**70.74**|-3.50             |-1.34        |
109 |MobileNetV2(Shicai)|MobileNetV2|32    |224x224   |**71.44** |0.0              |**68.81**     |70.65            |**70.54**|-2.63             |-0.9         |
111 Notes:
112 - For Image Classification, the accuracy measure used is % Top-1 Classification Accuracy. 'Top-1 Classification Accuracy' is abbreviated by Acc in the above table.<br>
113 - (TV) Stands for TochVision: https://github.com/pytorch/vision
114 - MobileNetV2(Shicai) model is from https://github.com/shicai/MobileNet-Caffe (converted from caffe to PyTorch) - this model was selected as this is a tough case for quantization.<br><br>
116 ###### Dataset: Cityscapes Segmentation (Semantic Segmentation)
118 |Mode Name    |Backbone   |Stride|Resolution|Float Acc%|Simple Calib Acc%|Adv Calib Acc%|Adv DW Calib Acc%|QAT Acc% |Acc Drop-Advanced Calib|Acc Drop - QAT|
119 |----------   |-----------|------|----------|----------|---              |---           |---              |---      |---                    |---           |
120 |DeepLabV3Lite|MobileNetV2|16    |768x384   |**69.13** |61.71            |**67.95**     |68.47            |**68.44**|-1.18                  |-0.69         |
122 Note: For Semantic Segmentation, the accuracy measure used in MeanIoU Accuracy. 'MeanIoU Accuracy' is abbreviated by Acc in the above table.
124 **Terminology:**<br>
125 All of these are variants of Power-Of-2, Symmetric, Per-Tensor Quantization, depending on how the parameters are adjusted for Quantization.<br>
126 - Simple Calib: Calibration based on min/max ranges
127 - Adv Calib: Includes histogram based ranges, calibration of weight/bias parameters to compensate for quantization accuracy loss.
128 - Adv DW Calib: Also includes Per-Channel Weight Quantization for Depthwise layers
129 - QAT: Quantization Aware Training with PyTorch Jacinto AI DevKit (Does not use Per-Channel Weight Quantization)
131 **Conclusion based on Simulation Results:**<br>
132 - Advanced Calibration Methods may have >2% Accuracy Drop in some cases.
133 - Quantization Aware Training (QAT) is consistently able to produce <2% Accuracy drop.
136 ## References 
137 [1] PACT: Parameterized Clipping Activation for Quantized Neural Networks, Jungwook Choi, Zhuo Wang, Swagath Venkataramani, Pierce I-Jen Chuang, Vijayalakshmi Srinivasan, Kailash Gopalakrishnan, arXiv preprint, arXiv:1805.06085, 2018
139 [2] Estimating or propagating gradients through stochastic neurons for conditional computation. Y. Bengio, N. Léonard, and A. Courville. arXiv preprint arXiv:1308.3432, 2013.
141 [3] Understanding Straight-Through Estimator in training activation quantized neural nets, Penghang Yin, Jiancheng Lyu, Shuai Zhang, Stanley Osher, Yingyong Qi, Jack Xin, ICLR 2019
143 [4] Quantization and Training of Neural Networks for Efficient Integer-Arithmetic-Only Inference, Benoit Jacob Skirmantas Kligys Bo Chen Menglong Zhu, Matthew Tang Andrew Howard Hartwig Adam Dmitry Kalenichenko, arXiv preprint, arXiv:1712.05877
145 [5] Trained quantization thresholds for accurate and efficient fixed-point inference of Deep Learning Neural Networks, Sambhav R. Jain, Albert Gural, Michael Wu, Chris H. Dick, arXiv preprint, arXiv:1903.08066 
147 [6] Quantizing deep convolutional networks for efficient inference: A whitepaper, Raghuraman Krishnamoorthi, arXiv preprint, arXiv:1806.08342
149 [7] TensorFlow / Learn / For Mobile & IoT / Guide / Post-training quantization, https://www.tensorflow.org/lite/performance/post_training_quantization
151 [8] QUANTIZATION / Introduction to Quantization, https://pytorch.org/docs/stable/quantization.html