]> Gitweb @ Texas Instruments - Open Source Git Repositories - git.TI.com/gitweb - jacinto-ai/pytorch-jacinto-ai-devkit.git/blob - docs/Calibration.md
release commit
[jacinto-ai/pytorch-jacinto-ai-devkit.git] / docs / Calibration.md
2 ## Post Training Calibration For Quantization (not recommended now)
3 Note: Although Calibration methods are described in detail here, this is not our recommended method as this does not yield the best results in some cases. There are specific scenarios where Quantization Aware is not possible (eg. to import a floating point model for embedded inference) and Calibration is mainly meant for such scenarios. However, Calibration methods are improving at a rapid pace and it is possible that Calibration methods will be good enough for all cases in the future - but for now, you can skip to the section on Quantization Aware as that is the recommended method.
5 Post Training Calibration & Quantization can take a model trained in floating point and with a few steps convert it to a model that is friendly for quantized inference. Compared to the alternative (Trained Quantization), the advantages of this method are:
6 - Calibration is fast - a typical calibration finishes in a few minutes.
7 - Ground truth is not required - just input images are sufficient.
8 - Loss function or backward (back-propagation) are not required.
10 The disadvantage is:
11 - This is a fast and approximate method of quantizing the model and may not always yield the best accuracy.
13 In Post Training Calibration, the training happens entirely in floating point. The inference (possibly in an embedded device) happens in fixed point. In between training and fixed point inference, the model goes through the step called Calibration with some sample images. The Calibration happens in PC and Quantized Inference happens in the embedded device. Calibration basically tries to make the quantized output similar to the floating point output - by choosing appropriate activation ranges, weights and biases.
15 A block diagram of Post Training Calibration is shown below:
16 <p float="left"> <img src="quantization/bias_calibration.png" width="640" hspace="5"/> </p>
18 Depending on how the activation range is collected and Quantization is done, we have a few variants of this basic scheme.
19 - Simple Calib: Calibration includes PACT2 for activation clipping, running average and range collection. In this method we use min-max for activation range collection (no histogram).
20 - **Advanced Calib**: Calibration includes PACT2 with histogram based ranges, Weight clipping, Bias correction.
21 - Advanced DW Calib: Calibration includes Per-Channel Quantization of Weights for Depthwise layers, PACT2 with histogram based ranges, Weight clipping, Bias correction. One of the earliest papers that clearly explained the benefits of Per-Channel Quantization for weights only (while the activations are quantized as Per-Tensor) is [[6]]
23 Out of these methods, **Advanced Calib** is our recommended Calibration method as of now, as it has the best trade-off between the Accuracy and the features required during fixed point inference. All the Calibration scripts that we have in this page uses "Advanced Calib" by default. Other Calibration methods described here are for information only.
25 In order to do Calibration easily we have a developed a wrapper module called QuantCalibrateModule, which is located in pytorch_jacinto_ai.xnn.quantize.QuantCalibrateModule. We make use of a kind of Parametric Activation called **PACT2** in order to store the calibrated ranges of activations. PACT2 is a improved form of PACT [[1]]. **PACT2 uses power of 2 activation ranges** for activation clipping. PACT2 can learn ranges very quickly (using a statistic method) without back propagation - this feature makes it quite attractive for Calibration. Our wrapper module replaces all the ReLUs in the model with PACT2. It also inserts PACT2 in other places where activation ranges need to be collected.  Statistical range clipping in PACT2 improves the Quantized Accuracy over simple min-max range clipping.
27 #### What happens during Calibration?
28 - For each iteration perform a forward in floating point using the original weights and biases. During this pass PACT2 layers will collect output ranges using histogram and running average.
29 - In addition, perform Convolution+BatchNorm merging and quantization of the resulting weights. These quantized and de-quantized weights are used in a forward pass. Ranges collected by PACT2 is used for activation quantization (and de-quantization) to generate quantized output.
30 - The floating point output and quantized output are compared using statistic measures. Using such statistic measures, we can adjust the weights and biases of Convolutions and Batch Normalization layers - so that the quantized output becomes closer to the floating point output.
31 - Within a few iterations, we should get reasonable quantization accuracy.
33 #### How to use  QuantCalibrateModule
34 As explained, the method of **Calibration does not need ground truth, loss function or back propagation.** However in the calibration script, we make use of ground truth to measure the loss/accuracy even in the Calibration stage - although that is not necessary.
36 The section briefly explains how to make use of our helper/wrapper module to do the calibration of your model. For further details, please see pytorch_jacinto_ai.engine.train_classification.py and pytorch_jacinto_ai.engine.train_pixel2pixel.py. The step by step process is as follows:
38 ```
39 from pytorch_jacinto_ai.xnn.quantize import QuantCalibrateModule
41 # create your model here:
42 model = ...
44 # create a dummy input - this is required to analyze the model - fill in the input image size expected by your model.
45 dummy_input = torch.rand((1,3,384,768))
47 #wrap your model in QuantCalibrateModule. Once it is wrapped, the actual model is in model.module
48 model = QuantCalibrateModule(model, dummy_input=dummy_input)
50 # load your pretrained weights here into model.module
51 pretrained_data = torch.load(pretrained_path)
52 model.module.load_state_dict(pretrained_data)
54 # create your dataset here - the ground-truth/target that you provide in the dataset can be dummy and does not affect calibration.
55 my_dataset_train, my_dataset_val = ...
57 # do one epoch of calibration - in practice about 1000 iterations are sufficient.
58 for images, target in my_dataset_train:
59     output = model(images)
61 # save the model - the calibrated module is in model.module
62 torch.save(model.module.state_dict(), os.path.join(save_path,'model.pth'))
63 torch.onnx.export(model.module, dummy_input, os.path.join(save_path,'model.onnx'), export_params=True, verbose=False)
65 ```
66 Careful attention needs to be given to how the pretrained model is loaded and trained model is saved as shown in the above code snippet.
68 Few examples of calibration are provided below. These commands are also listed in the file **run_quantization.sh** for convenience.<br>
70 - Calibration of ImageNet Classification MobileNetV2 model
71 ```
72 python ./scripts/train_classification_main.py --phase calibration --dataset_name image_folder_classification --model_name mobilenetv2_tv_x1 --data_path ./data/datasets/image_folder_classification --pretrained https://download.pytorch.org/models/mobilenet_v2-b0353104.pth --batch_size 64 --quantize True --epochs 1 --epoch_size 100
73 ```
75 - Calibration of ImageNet Classification ResNet50 model
76 ```
77 python ./scripts/train_classification_main.py --phase calibration --dataset_name image_folder_classification --model_name mobilenetv2_tv_x1 --data_path ./data/datasets/image_folder_classification --pretrained https://download.pytorch.org/models/resnet50-19c8e357.pth --batch_size 64 --quantize True --epochs 1 --epoch_size 100
78 ```
80 - Calibration of Cityscapes Semantic Segmentation model
81 ```
82 python ./scripts/train_segmentation_main.py --phase calibration --dataset_name cityscapes_segmentation --model_name deeplabv3lite_mobilenetv2_tv --data_path ./data/datasets/cityscapes/data --img_resize 384 768 --output_size 1024 2048 --gpus 0 1 
83 --pretrained ./data/modelzoo/semantic_segmentation/cityscapes/deeplabv3lite-mobilenetv2/cityscapes_segmentation_deeplabv3lite-mobilenetv2_2019-06-26-08-59-32.pth 
84 --batch_size 12 --quantize True --epochs 1 --epoch_size 100
85 ```