]> Gitweb @ Texas Instruments - Open Source Git Repositories - git.TI.com/gitweb - jacinto-ai/pytorch-jacinto-ai-devkit.git/blob - docs/Calibration.md
docs update and minor fixes
[jacinto-ai/pytorch-jacinto-ai-devkit.git] / docs / Calibration.md
2 ## Post Training Calibration For Quantization (a.k.a. Calibration)
3 **Note: this is not our recommended method in PyTorch.**<br>
4 Note: Post Training Calibration or simply Calibration is a method to reduce the accuracy loss with quantization. This is an approximate method and does not require ground truth or back-propagation - hence it is suitable for implementation in an Import/Calibration tool. We have simulated this Calibration scheme in PyTorch and can be used as fast method to improve the accuracy of Quantization. However, in a training frame work such as PyTorch, it is possible to get better accuracy with Quantization Aware Training - hence we suggest to use [Quantization Aware Training](./Quantization.md).
6 Post Training Calibration & Quantization can take a model trained in floating point and with a few steps convert it to a model that is friendly for quantized inference. Compared to the alternative (Trained Quantization), the advantages of this method are:
7 - Calibration is fast - a typical calibration finishes in a few minutes.
8 - Ground truth is not required - just input images are sufficient.
9 - Loss function or backward (back-propagation) are not required.
11 The disadvantage is:
12 - This is a fast and approximate method of quantizing the model and may not always yield the best accuracy.
14 In Post Training Calibration, the training happens entirely in floating point. The inference (possibly in an embedded device) happens in fixed point. In between training and fixed point inference, the model goes through the step called Calibration with some sample images. The Calibration happens in PC and Quantized Inference happens in the embedded device. Calibration basically tries to make the quantized output similar to the floating point output - by choosing appropriate activation ranges, weights and biases.
16 A block diagram of Post Training Calibration is shown below:
17 <p float="left"> <img src="quantization/bias_calibration.png" width="640" hspace="5"/> </p>
19 Depending on how the activation range is collected and Quantization is done, we have a few variants of this basic scheme.
20 - Simple Calib: Calibration includes PACT2 for activation clipping, running average and range collection. In this method we use min-max for activation range collection (no histogram).
21 - **Advanced Calib**: Calibration includes PACT2 with histogram based ranges, Weight clipping, Bias correction.
22 - Advanced DW Calib: Calibration includes Per-Channel Quantization of Weights for Depthwise layers, PACT2 with histogram based ranges, Weight clipping, Bias correction. One of the earliest papers that clearly explained the benefits of Per-Channel Quantization for weights only (while the activations are quantized as Per-Tensor) is [[6]]
24 Out of these methods, **Advanced Calib** is our recommended Calibration method as of now, as it has the best trade-off between the Accuracy and the features required during fixed point inference. All the Calibration scripts that we have in this page uses "Advanced Calib" by default. Other Calibration methods described here are for information only.
26 In order to do Calibration easily we have a developed a wrapper module called QuantCalibrateModule, which is located in pytorch_jacinto_ai.xnn.quantize.QuantCalibrateModule. We make use of a kind of Parametric Activation called **PACT2** in order to store the calibrated ranges of activations. PACT2 is a improved form of PACT [[1]]. **PACT2 uses power of 2 activation ranges** for activation clipping. PACT2 can learn ranges very quickly (using a statistic method) without back propagation - this feature makes it quite attractive for Calibration. Our wrapper module replaces all the ReLUs in the model with PACT2. It also inserts PACT2 in other places where activation ranges need to be collected.  Statistical range clipping in PACT2 improves the Quantized Accuracy over simple min-max range clipping.
28 #### What happens during Calibration?
29 - For each iteration perform a forward in floating point using the original weights and biases. During this pass PACT2 layers will collect output ranges using histogram and running average.
30 - In addition, perform Convolution+BatchNorm merging and quantization of the resulting weights. These quantized and de-quantized weights are used in a forward pass. Ranges collected by PACT2 is used for activation quantization (and de-quantization) to generate quantized output.
31 - The floating point output and quantized output are compared using statistic measures. Using such statistic measures, we can adjust the weights and biases of Convolutions and Batch Normalization layers - so that the quantized output becomes closer to the floating point output.
32 - Within a few iterations, we should get reasonable quantization accuracy.
34 #### How to use  QuantCalibrateModule
35 As explained, the method of **Calibration does not need ground truth, loss function or back propagation.** However in the calibration script, we make use of ground truth to measure the loss/accuracy even in the Calibration stage - although that is not necessary.
37 The section briefly explains how to make use of our helper/wrapper module to do the calibration of your model. For further details, please see pytorch_jacinto_ai.engine.train_classification.py and pytorch_jacinto_ai.engine.train_pixel2pixel.py. The step by step process is as follows:
39 ```
40 from pytorch_jacinto_ai import xnn
42 # create your model here:
43 model = ...
45 # create a dummy input - this is required to analyze the model - fill in the input image size expected by your model.
46 dummy_input = torch.rand((1,3,384,768))
48 #wrap your model in xnn.quantize.QuantCalibrateModule. Once it is wrapped, the actual model is in model.module
49 model = xnn.quantize.QuantCalibrateModule(model, dummy_input=dummy_input)
51 # load your pretrained weights here into model.module
52 pretrained_data = torch.load(pretrained_path)
53 model.module.load_state_dict(pretrained_data)
55 # create your dataset here - the ground-truth/target that you provide in the dataset can be dummy and does not affect calibration.
56 my_dataset_train, my_dataset_val = ...
58 # do one epoch of calibration - in practice about 1000 iterations are sufficient.
59 for images, target in my_dataset_train:
60     output = model(images)
62 # save the model - the calibrated module is in model.module
63 torch.save(model.module.state_dict(), os.path.join(save_path,'model.pth'))
64 torch.onnx.export(model.module, dummy_input, os.path.join(save_path,'model.onnx'), export_params=True, verbose=False, do_constant_folding=True, opset_version=9)
66 ```
68 Careful attention needs to be given to how the pretrained model is loaded and trained model is saved as shown in the above code snippet.
70 Few examples of calibration are provided below. These commands are also listed in the file **run_quantization.sh** for convenience.<br>
72 - Calibration of ImageNet Classification MobileNetV2 model
73 ```
74 python ./scripts/train_classification_main.py --phase calibration --dataset_name image_folder_classification --model_name mobilenetv2_tv_x1 --data_path ./data/datasets/image_folder_classification --pretrained https://download.pytorch.org/models/mobilenet_v2-b0353104.pth --batch_size 64 --quantize True --epochs 1 --epoch_size 100
75 ```
77 - Calibration of ImageNet Classification ResNet50 model
78 ```
79 python ./scripts/train_classification_main.py --phase calibration --dataset_name image_folder_classification --model_name mobilenetv2_tv_x1 --data_path ./data/datasets/image_folder_classification --pretrained https://download.pytorch.org/models/resnet50-19c8e357.pth --batch_size 64 --quantize True --epochs 1 --epoch_size 100
80 ```
82 - Calibration of Cityscapes Semantic Segmentation model
83 ```
84 python ./scripts/train_segmentation_main.py --phase calibration --dataset_name cityscapes_segmentation --model_name deeplabv3lite_mobilenetv2_tv --data_path ./data/datasets/cityscapes/data --img_resize 384 768 --output_size 1024 2048 --gpus 0 1 
85 --pretrained ./data/modelzoo/pytorch/semantic_segmentation/cityscapes/jacinto_ai/deeplabv3lite_mobilenetv2_tv_resize768x384_best.pth.tar 
86 --batch_size 12 --quantize True --epochs 1 --epoch_size 100
87 ```