release commit
authorManu Mathew <a0393608@ti.com>
Tue, 21 Jan 2020 07:18:07 +0000 (12:48 +0530)
committerManu Mathew <a0393608@ti.com>
Tue, 21 Jan 2020 07:18:07 +0000 (12:48 +0530)
22 files changed:
docs/Calibration.md [new file with mode: 0644]
docs/Quantization.md
examples/quantization_example.py [new file with mode: 0644]
modules/pytorch_jacinto_ai/engine/infer_pixel2pixel.py
modules/pytorch_jacinto_ai/engine/test_classification.py
modules/pytorch_jacinto_ai/engine/train_classification.py
modules/pytorch_jacinto_ai/engine/train_pixel2pixel.py
modules/pytorch_jacinto_ai/vision/models/__init__.py
modules/pytorch_jacinto_ai/vision/models/classification/__init__.py
modules/pytorch_jacinto_ai/vision/models/mnasnet.py
modules/pytorch_jacinto_ai/vision/models/mobilenetv2.py
modules/pytorch_jacinto_ai/vision/models/multi_input_net.py
modules/pytorch_jacinto_ai/vision/models/pixel2pixel/deeplabv3lite.py
modules/pytorch_jacinto_ai/vision/models/pixel2pixel/fpn_pixel2pixel.py
modules/pytorch_jacinto_ai/vision/models/resnet.py
modules/pytorch_jacinto_ai/vision/models/shufflenetv2.py
modules/pytorch_jacinto_ai/xnn/quantize/quant_graph_module.py
modules/pytorch_jacinto_ai/xnn/utils/__init__.py
modules/pytorch_jacinto_ai/xnn/utils/load_weights.py
modules/pytorch_jacinto_ai/xnn/utils/module_utils.py
modules/pytorch_jacinto_ai/xnn/utils/utils_data.py [new file with mode: 0644]
run_quantization_example.sh [new file with mode: 0755]

diff --git a/docs/Calibration.md b/docs/Calibration.md
new file mode 100644 (file)
index 0000000..c3aaa95
--- /dev/null
@@ -0,0 +1,86 @@
+
+## Post Training Calibration For Quantization (not recommended now)
+Note: Although Calibration methods are described in detail here, this is not our recommended method as this does not yield the best results in some cases. There are specific scenarios where Quantization Aware is not possible (eg. to import a floating point model for embedded inference) and Calibration is mainly meant for such scenarios. However, Calibration methods are improving at a rapid pace and it is possible that Calibration methods will be good enough for all cases in the future - but for now, you can skip to the section on Quantization Aware as that is the recommended method.
+
+Post Training Calibration & Quantization can take a model trained in floating point and with a few steps convert it to a model that is friendly for quantized inference. Compared to the alternative (Trained Quantization), the advantages of this method are:
+- Calibration is fast - a typical calibration finishes in a few minutes.
+- Ground truth is not required - just input images are sufficient.
+- Loss function or backward (back-propagation) are not required.
+
+The disadvantage is:
+- This is a fast and approximate method of quantizing the model and may not always yield the best accuracy.
+
+In Post Training Calibration, the training happens entirely in floating point. The inference (possibly in an embedded device) happens in fixed point. In between training and fixed point inference, the model goes through the step called Calibration with some sample images. The Calibration happens in PC and Quantized Inference happens in the embedded device. Calibration basically tries to make the quantized output similar to the floating point output - by choosing appropriate activation ranges, weights and biases.
+
+A block diagram of Post Training Calibration is shown below:
+<p float="left"> <img src="quantization/bias_calibration.png" width="640" hspace="5"/> </p>
+
+Depending on how the activation range is collected and Quantization is done, we have a few variants of this basic scheme.
+- Simple Calib: Calibration includes PACT2 for activation clipping, running average and range collection. In this method we use min-max for activation range collection (no histogram).
+- **Advanced Calib**: Calibration includes PACT2 with histogram based ranges, Weight clipping, Bias correction.
+- Advanced DW Calib: Calibration includes Per-Channel Quantization of Weights for Depthwise layers, PACT2 with histogram based ranges, Weight clipping, Bias correction. One of the earliest papers that clearly explained the benefits of Per-Channel Quantization for weights only (while the activations are quantized as Per-Tensor) is [[6]]
+
+Out of these methods, **Advanced Calib** is our recommended Calibration method as of now, as it has the best trade-off between the Accuracy and the features required during fixed point inference. All the Calibration scripts that we have in this page uses "Advanced Calib" by default. Other Calibration methods described here are for information only.
+
+In order to do Calibration easily we have a developed a wrapper module called QuantCalibrateModule, which is located in pytorch_jacinto_ai.xnn.quantize.QuantCalibrateModule. We make use of a kind of Parametric Activation called **PACT2** in order to store the calibrated ranges of activations. PACT2 is a improved form of PACT [[1]]. **PACT2 uses power of 2 activation ranges** for activation clipping. PACT2 can learn ranges very quickly (using a statistic method) without back propagation - this feature makes it quite attractive for Calibration. Our wrapper module replaces all the ReLUs in the model with PACT2. It also inserts PACT2 in other places where activation ranges need to be collected.  Statistical range clipping in PACT2 improves the Quantized Accuracy over simple min-max range clipping.
+
+#### What happens during Calibration?
+- For each iteration perform a forward in floating point using the original weights and biases. During this pass PACT2 layers will collect output ranges using histogram and running average.
+- In addition, perform Convolution+BatchNorm merging and quantization of the resulting weights. These quantized and de-quantized weights are used in a forward pass. Ranges collected by PACT2 is used for activation quantization (and de-quantization) to generate quantized output.
+- The floating point output and quantized output are compared using statistic measures. Using such statistic measures, we can adjust the weights and biases of Convolutions and Batch Normalization layers - so that the quantized output becomes closer to the floating point output.
+- Within a few iterations, we should get reasonable quantization accuracy.
+
+#### How to use  QuantCalibrateModule
+As explained, the method of **Calibration does not need ground truth, loss function or back propagation.** However in the calibration script, we make use of ground truth to measure the loss/accuracy even in the Calibration stage - although that is not necessary.
+
+The section briefly explains how to make use of our helper/wrapper module to do the calibration of your model. For further details, please see pytorch_jacinto_ai.engine.train_classification.py and pytorch_jacinto_ai.engine.train_pixel2pixel.py. The step by step process is as follows:
+
+```
+from pytorch_jacinto_ai.xnn.quantize import QuantCalibrateModule
+
+# create your model here:
+model = ...
+
+# create a dummy input - this is required to analyze the model - fill in the input image size expected by your model.
+dummy_input = torch.rand((1,3,384,768))
+
+#wrap your model in QuantCalibrateModule. Once it is wrapped, the actual model is in model.module
+model = QuantCalibrateModule(model, dummy_input=dummy_input)
+
+# load your pretrained weights here into model.module
+pretrained_data = torch.load(pretrained_path)
+model.module.load_state_dict(pretrained_data)
+
+# create your dataset here - the ground-truth/target that you provide in the dataset can be dummy and does not affect calibration.
+my_dataset_train, my_dataset_val = ...
+
+# do one epoch of calibration - in practice about 1000 iterations are sufficient.
+for images, target in my_dataset_train:
+    output = model(images)
+
+# save the model - the calibrated module is in model.module
+torch.save(model.module.state_dict(), os.path.join(save_path,'model.pth'))
+torch.onnx.export(model.module, dummy_input, os.path.join(save_path,'model.onnx'), export_params=True, verbose=False)
+
+```
+Careful attention needs to be given to how the pretrained model is loaded and trained model is saved as shown in the above code snippet.
+
+Few examples of calibration are provided below. These commands are also listed in the file **run_quantization.sh** for convenience.<br>
+
+- Calibration of ImageNet Classification MobileNetV2 model
+```
+python ./scripts/train_classification_main.py --phase calibration --dataset_name image_folder_classification --model_name mobilenetv2_tv_x1 --data_path ./data/datasets/image_folder_classification --pretrained https://download.pytorch.org/models/mobilenet_v2-b0353104.pth --batch_size 64 --quantize True --epochs 1 --epoch_size 100
+```
+
+- Calibration of ImageNet Classification ResNet50 model
+```
+python ./scripts/train_classification_main.py --phase calibration --dataset_name image_folder_classification --model_name mobilenetv2_tv_x1 --data_path ./data/datasets/image_folder_classification --pretrained https://download.pytorch.org/models/resnet50-19c8e357.pth --batch_size 64 --quantize True --epochs 1 --epoch_size 100
+```
+
+- Calibration of Cityscapes Semantic Segmentation model
+```
+python ./scripts/train_segmentation_main.py --phase calibration --dataset_name cityscapes_segmentation --model_name deeplabv3lite_mobilenetv2_tv --data_path ./data/datasets/cityscapes/data --img_resize 384 768 --output_size 1024 2048 --gpus 0 1 
+--pretrained ./data/modelzoo/semantic_segmentation/cityscapes/deeplabv3lite-mobilenetv2/cityscapes_segmentation_deeplabv3lite-mobilenetv2_2019-06-26-08-59-32.pth 
+--batch_size 12 --quantize True --epochs 1 --epoch_size 100
+```
+
index c4cffeb8695aebf043310f232a530e80f0ae227a..8a775f1d641c955fadea93d71f754578107a95ec 100644 (file)
 # Quantization
 
-Quantization is the process of converting floating point data & operations to fixed point (integer). CNNs can be quantized to 8-bits integer data/operations without significant accuracy loss. This includes quantization of weights, feature maps and all operations (including convolution of weights). The quantization style used in this code is **power-of-2, symmetric quantization for both weights and activations**.
+Quantization of a CNN model is the process of converting floating point data & operations to fixed point (integer). This includes quantization of weights, feature maps and all operations (including convolution). The quantization style used in this code is **power-of-2, symmetric quantization for both weights and activations**.
 
-In order to make the activations quantization friendly, it is important to clip them during Calibration or Trained Quantization. PACT2 activation function has been developed to clip the activations to a power-of-two value. PACT2 can be used in the place of commonly used activation functions such as ReLU. 
-<p float="left"> <img src="quantization/pact2_activation.png" width="640" hspace="5"/> </p>
-This code also contains two wrapper modules called QuantCalibrateModule and QuantTrainModule that will handle such tasks as replacing ReLUs with PACT2, quantizing weights, quantizing activations etc - so the user does not need to do any of these things - the only thing that is required is to wrap the user's module in these wapper modules as explained below.
-
-There are two primary methods of quantization - Post Training Calibration For Quantization and Trained Quantization (a.k.a Quantization Aware Training). These required the following (only for information - automatically done by QuantCalibrateModule and QuantTrainModule):
-- Replace all the ReLU, ReLU6 layers in the model by PACT2. Insert PACT2 after Convolution+BatchNorm if a ReLU is missing after that.  Insert PACT2 anywhere else required - where activation range clipping and range collection is required. For example it can ne after the Fully Connected Layer. We use forward post hooks of PyTorch nn.Modules to call these extra activation functions. Thus we are able to add these extra activations without disturbing the loading of existing pre-trained weights.
-- Clip the weights to an appropriate range if the weight range is very high.
-- Quantize the weights and activations during the forward pass.
-- Other modifications required for Calibration / Trained Quantization
-
-Trained Quantization typically provides better accuracy compared to Post Training Calibration. Trained Quantization is also easy to incorporate into your existing PyTorch training code - hence we recommend to use it. Here, we explain both methods for the sake of completeness.
-
-## Post Training Calibration For Quantization
-Note: Although Calibration methods are described in detail here, this is not our recommended method as this does not yield the best results in some cases. There are specific scenarios where Trained Quantization is not possible (eg. to import a floating point model for embedded inference) and Calibration is mainly meant for such scenarios. However, Calibration methods are improving at a rapid pace and it is possible that Calibration methods will be good enough for all cases in the future - but for now, you can skip to the section on Trained Quantization as it is the recommended method.
-Post Training Calibration & Quantization can take a model trained in floating point and with a few steps convert it to a model that is friendly for quantized inference. Compared to the alternative (Trained Quantization), the advantages of this method are:
-- Calibration is fast - a typical calibration finishes in a few minutes.  
-- Ground truth is not required - just input images are sufficient.
-- Loss function or backward (back-propagation) are not required. 
-
-The disadvantage is:
-- This is a fast and approximate method of quantizing the model and may not always yield the best accuracy. 
-
-In Post Training Calibration, the training happens entirely in floating point. The inference (possibly in an embedded device) happens in fixed point. In between training and fixed point inference, the model goes through the step called Calibration with some sample images. The Calibration happens in PC and Quantized Inference happens in the embedded device. Calibration basically tries to make the quantized output similar to the floating point output - by choosing appropriate activation ranges, weights and biases. 
-
-A block diagram of Post Training Calibration is shown below:
-<p float="left"> <img src="quantization/bias_calibration.png" width="640" hspace="5"/> </p>
-
-Depending on how the activation range is collected and Quantization is done, we have a few variants of this basic scheme.  
-- Simple Calib: Calibration includes PACT2 for activation clipping, running average and range collection. In this method we use min-max for activation range collection (no histogram).
-- **Advanced Calib**: Calibration includes PACT2 with histogram based ranges, Weight clipping, Bias correction. 
-- Advanced DW Calib: Calibration includes Per-Channel Quantization of Weights for Depthwise layers, PACT2 with histogram based ranges, Weight clipping, Bias correction. One of the earliest papers that clearly explained the benefits of Per-Channel Quantization for weights only (while the activations are quantized as Per-Tensor) is [[6]]
-
-Out of these methods, **Advanced Calib** is our recommended Calibration method as of now, as it has the best trade-off between the Accuracy and the features required during fixed point inference. All the Calibration scripts that we have in this page uses "Advanced Calib" by default. Other Calibration methods described here are for information only. 
-
-In order to do Calibration easily we have a developed a wrapper module called QuantCalibrateModule, which is located in pytorch_jacinto_ai.xnn.quantize.QuantCalibrateModule. We make use of a kind of Parametric Activation called **PACT2** in order to store the calibrated ranges of activations. PACT2 is a improved form of PACT [[1]]. **PACT2 uses power of 2 activation ranges** for activation clipping. PACT2 can learn ranges very quickly (using a statistic method) without back propagation - this feature makes it quite attractive for Calibration. Our wrapper module replaces all the ReLUs in the model with PACT2. It also inserts PACT2 in other places where activation ranges need to be collected.  Statistical range clipping in PACT2 improves the Quantized Accuracy over simple min-max range clipping. 
-
-#### What happens during Calibration?
-- For each iteration perform a forward in floating point using the original weights and biases. During this pass PACT2 layers will collect output ranges using histogram and running average.
-- In addition, perform Convolution+BatchNorm merging and quantization of the resulting weights. These quantized and de-quantized weights are used in a forward pass. Ranges collected by PACT2 is used for activation quantization (and de-quantization) to generate quantized output.
-- The floating point output and quantized output are compared using statistic measures. Using such statistic measures, we can adjust the weights and biases of Convolutions and Batch Normalization layers - so that the quantized output becomes closer to the floating point output.
-- Within a few iterations, we should get reasonable quantization accuracy.
-
-#### How to use  QuantCalibrateModule
-As explained, the method of **Calibration does not need ground truth, loss function or back propagation.** However in the calibration script, we make use of ground truth to measure the loss/accuracy even in the Calibration stage - although that is not necessary. 
-
-The section briefly explains how to make use of our helper/wrapper module to do the calibration of your model. For further details, please see pytorch_jacinto_ai.engine.train_classification.py and pytorch_jacinto_ai.engine.train_pixel2pixel.py. The step by step process is as follows:
-
-```
-# create your model here:
-model = ...
-
-# create a dummy input - this is required to analyze the model - fill in the input image size expected by your model.
-dummy_input = torch.rand((1,3,384,768))
-
-#wrap your model in QuantCalibrateModule. Once it is wrapped, the actual model is in model.module
-model = pytorch_jacinto_ai.xnn.quantize.QuantCalibrateModule(model, dummy_input=dummy_input)
-
-# load your pretrained weights here into model.module
-pretrained_data = torch.load(pretrained_path)
-model.module.load_state_dict(pretrained_data)
-
-# create your dataset here - the ground-truth/target that you provide in the dataset can be dummy and does not affect calibration.
-my_dataset_train, my_dataset_val = ...
-
-# do one epoch of calibration - in practice about 1000 iterations are sufficient.
-for images, target in my_dataset_train:
-    output = model(images)
-
-# save the model - the calibrated module is in model.module
-torch.save(model.module.state_dict(), os.path.join(save_path,'model.pth'))
-torch.onnx.export(model.module, dummy_input, os.path.join(save_path,'model.onnx'), export_params=True, verbose=False)
-
-```
-Careful attention needs to be given to how the pretrained model is loaded and trained model is saved as shown in the above code snippet.
-
-Few examples of calibration are provided below. These commands are also listed in the file **run_quantization.sh** for convenience.<br>
-
-- Calibration of ImageNet Classification MobileNetV2 model 
-```
-python ./scripts/train_classification_main.py --phase calibration --dataset_name image_folder_classification --model_name mobilenetv2_tv_x1 --data_path ./data/datasets/image_folder_classification --pretrained https://download.pytorch.org/models/mobilenet_v2-b0353104.pth --batch_size 64 --quantize True --epochs 1 --epoch_size 100
-```
-
-- Calibration of ImageNet Classification ResNet50 model 
-```
-python ./scripts/train_classification_main.py --phase calibration --dataset_name image_folder_classification --model_name mobilenetv2_tv_x1 --data_path ./data/datasets/image_folder_classification --pretrained https://download.pytorch.org/models/resnet50-19c8e357.pth --batch_size 64 --quantize True --epochs 1 --epoch_size 100
-```
-
-- Calibration of Cityscapes Semantic Segmentation model 
-```
-python ./scripts/train_segmentation_main.py --phase calibration --dataset_name cityscapes_segmentation --model_name deeplabv3lite_mobilenetv2_tv --data_path ./data/datasets/cityscapes/data --img_resize 384 768 --output_size 1024 2048 --gpus 0 1 
---pretrained ./data/modelzoo/semantic_segmentation/cityscapes/deeplabv3lite-mobilenetv2/cityscapes_segmentation_deeplabv3lite-mobilenetv2_2019-06-26-08-59-32.pth 
---batch_size 12 --quantize True --epochs 1 --epoch_size 100
-```
+Accuracy of inference can degrade if the CNN model is quantized to 8bits using simple methods and steps have to be taken to minimize this accuracy loss. One such method of reducing accuracy loss is "Quantization Aware Training". When Quantization Aware Training is incorporated into the training flow, the parameters of the model are adjusted to suit quantization. This includes adjusting of weights, biases and activation ranges.
 
-## Trained Quantization a.k.a Quantization Aware Training (recommended method)
-Trained Quantization typically provides better accuracy compared to Post Training Calibration.
+In order to make the activations quantization friendly, it is important to clip them during Quantization Aware Training. PACT2 activation function has been developed to clip the activations to a power-of-two value. PACT2 can be used in the place of commonly used activation functions such as ReLU or ReLU6.
+<p float="left"> <img src="quantization/pact2_activation.png" width="640" hspace="5"/> </p>
+We use statistical range clipping in PACT2 to improve the Quantized Accuracy (compared to simple min-max range clipping).
 
-Unlike Calibration, Trained Quantization involves ground truth, loss function and back propagation - hence it is very similar to the original floating point training.
+## Quantization Aware Training a.k.a Trained Quantization
+Quantization Aware Training (QAT) is easy to incorporate into an existing PyTorch training code. We provide a wrapper module called QuantTrainModule to automate all the tasks required for QAT. The user simply needs to wrap his model in QuantTrainModule and do the training.
 
-The most popular method of trained quantization is [[4]]. It takes care of merging Convolution layers with the adjascent Batch Normalization layers (on-the-fly) during the quantized training (if this merging is not correctly done, quantized training may not improve the accuracy). In addition, we use Straight-Through Estimation (STE) [[2,3]] to improve the gradient flow in back-propagation. Also, the statistical range clipping in PACT2 improves the Quantized Accuracy over simple min-max range clipping. 
+The overall flow of training is as follows:
+- Step 1:Train your model in floating point as usual.
+- Step 2: Starting from the floating point model as pretrained weights, do Quantization Aware Training. In order to do this wrap your model in the wrapper module called  pytorch_jacinto_ai.xnn.quantize.QuantTrainModule and perform training with a small learning rate. About 25 to 50 epochs of training may be required to get the best accuracy.
 
-Note: Instead of STE and statistical ranges for PACT2, we also tried out approximate gradients for scale and trained quantization thresholds proposed in [[5]] (We did not use the gradient nomralization and log-domain training mentioned in the paper). We found that method to be able to learn the clipping thresholds for initial few epochs, but became unstable after a few epochs and loss became high. Compared to that learned thresholds method, our statistical PACT2 ranges/thresholds combined with STE is simple and stable. 
+QuantTrainModule does the following operations to the model. Note that QuantTrainModule that will handle these tasks - the only thing that is required is to wrap the user's module in QuantTrainModule as explained in the section "How to use  QuantTrainModule".
+- Replace all the ReLU, ReLU6 layers in the model by PACT2. Insert PACT2 after Convolution+BatchNorm if a ReLU/ReLU6 is missing after that.  Insert PACT2 anywhere else required - where activation range clipping and range collection is required. For example it can be after the Fully Connected Layer. We use forward post hooks of PyTorch nn.Modules to call these extra activation functions. Thus we are able to add these extra activations without disturbing the loading of existing pre-trained weights.
+- Clip the weights to an appropriate range if the weight range is very high.
+- Quantize the weights during the forward pass. Merging Convolution layers with the adjacent Batch Normalization layers (on-the-fly) during the weight quantization is required - if this merging is not correctly done, Quantization Aware Training may not improve accuracy.
+- Quantize activations during the forward pass.
+- Other modifications to help the learning process. For example, we use Straight-Through Estimation (STE) [[2,3]] to improve the gradient flow in back-propagation.
 
-A block diagram of Trained Quantization is shown below:
+A block diagram of Quantization Aware Training with QuantTrainModule is shown below:
 <p float="left"> <img src="quantization/trained_quant_ste.png" width="640" hspace="5"/> </p>
 
-#### What happens during Trained Quantization?
+#### What happens during Quantization Aware Training?
 - For each iteration perform a forward in floating point using the original weights and biases. During this pass PACT2 layers will collect output ranges using histogram and running average.
 - In addition, perform Convolution+BatchNorm merging and quantization of the resulting weights. These quantized and de-quantized weights are used in a forward pass. Ranges collected by PACT2 is used for activation quantization (and de-quantization) to generate quantized output.
-- Backpropagation with STE will update the parameters of the model to reduce the loss with quantization.
+- Back-propagation with STE will update the parameters of the model to reduce the loss with quantization.
 - Within a few epochs, we should get reasonable quantization accuracy.
 
 #### How to use  QuantTrainModule
-In order to enable quantized training, we have developed the wrapper class pytorch_jacinto_ai.xnn.quantize.QuantTrainModule. The usage of this module can be seen in pytorch_jacinto_ai.engine.train_classification.py and pytorch_jacinto_ai.engine.train_pixel2pixel.py. 
+In order to enable quantized training, we have developed the wrapper class pytorch_jacinto_ai.xnn.quantize.QuantTrainModule. A simple example for using this module is given in the script [examples/quantization_example.py](../examples/quantization_example.py) and calling this is demonstrated in [run_quantization_example.sh](../run_quantization_example.sh). The usage of this module can also be seen in pytorch_jacinto_ai.engine.train_classification.py and pytorch_jacinto_ai.engine.train_pixel2pixel.py. The following is a brief description of how to use this wrapper module:
 ```
+from pytorch_jacinto_ai.xnn.quantize import QuantTrainModule
+
 # create your model here:
 model = ...
 
@@ -125,7 +43,7 @@ model = ...
 dummy_input = torch.rand((1,3,384,768))
 
 #wrap your model in QuantTrainModule. Once it is wrapped, the actual model is in model.module
-model = pytorch_jacinto_ai.xnn.quantize.QuantTrainModule(model, dummy_input=dummy_input)
+model = QuantTrainModule(model, dummy_input=dummy_input)
 
 # load your pretrained weights here into model.module
 pretrained_data = torch.load(pretrained_path)
@@ -141,11 +59,9 @@ torch.save(model.module.state_dict(), os.path.join(save_path,'model.pth'))
 torch.onnx.export(model.module, dummy_input, os.path.join(save_path,'model.onnx'), export_params=True, verbose=False)
 ```
 
-As can be seen, it is very easy to incorporate QuantTrainModule in your existing training code as the only thing required is to wrap your original model in QuantTrainModule. Careful attention needs to be given to how the pretrained model is loaded and trained model is saved as shown in the above code snippet.
+As can be seen, it is easy to incorporate QuantTrainModule in your existing training code as the only thing required is to wrap your original model in QuantTrainModule. Careful attention needs to be given to how the parameters of the pretrained model is loaded and trained model is saved as shown in the above code snippet.
 
-The Model Preparation steps used for Calibration applies for Trained Quantization as well. - Note that this Model preparation is automatically done by QuantTrainModule. The resultant model can then be used for training as usual and it will take care of quantization constraints during the training forward and backward passes.
-
-One word of caution is that our current implementation of Trained Quantization is a bit slow. The reason for this slowdown is that our implementation is using the top-level python layer of PyTorch and not the underlying C++ layer. But with PyTorch natively supporting the functionality required for quantization under the hood - we hope that this speed issue can be resolved in a future update. 
+Optional: We have provided a utility function called xnn.utils.load_weights() that prints which parameters are loaded correctly and which are not - you can use this load function if needed to ensure that your parameters are loaded correctly.
 
 Example commands for trained quantization: 
 ```
@@ -156,12 +72,16 @@ python ./scripts/train_classification_main.py --dataset_name image_folder_classi
 python ./scripts/train_segmentation_main.py --dataset_name cityscapes_segmentation --model_name deeplabv3lite_mobilenetv2_tv --data_path ./data/datasets/cityscapes/data --img_resize 384 768 --output_size 1024 2048 --gpus 0 1 --pretrained ./data/modelzoo/semantic_segmentation/cityscapes/deeplabv3lite-mobilenetv2/cityscapes_segmentation_deeplabv3lite-mobilenetv2_2019-06-26-08-59-32.pth --batch_size 8 --quantize True --epochs 150 --lr 5e-5 --evaluate_start False
 ```
 
+## Calibration
+We also have another method called Calibration to reduce the accuracy loss with quantization. If you are interested, you can take a look at the [documentation of Calibration here](Calibration.md) - but that is not our recommend method now.
+
+
 ## Important Notes - read carefully
-**Multi-GPU training/calibration/validation with DataParallel is not yet working with our quantization modules** QuantCalibrateModule, QuantTrainModule.<br>
-- **We recommend not to wrap the modules in DataParallel if you are training/calibrating/testing with quantization - i.e. if your model is wrapped in QuantCalibrateModule/QuantTrainModule/QuantTestModule.** If you get an error during training related to weights and input not being in the same GPU, please check and ensure that you are not using DataParallel with QuantCalibrateModule/QuantTrainModule/QuantTestModule. This may not be such a problem as calibration and quantization may not take as much time as the original training. If your calibration/training crashes with insufficient GPU memory, reduce the batch size and try again. The original training (without quantization) can use Multi-GPU as usual and we do not have any restrictions on that.<br>
+- **Multi-GPU training/calibration/validation with DataParallel is not yet working with our quantization modules** QuantTrainModule/QuantCalibrateModule/QuantTestModule. We recommend not to wrap the modules in DataParallel if you are training/calibrating/testing with quantization - i.e. if your model is wrapped in QuantTrainModule/QuantCalibrateModule/QuantTestModule.<br>
+- If you get an error during training related to weights and input not being in the same GPU, please check and ensure that you are not using DataParallel with QuantTrainModule/QuantCalibrateModule/QuantTestModule. This may not be such a problem as calibration and quantization may not take as much time as the original floating point training. The original floating point training (without quantization) can use Multi-GPU as usual and we do not have any restrictions on that.<br>
+- If your calibration/training crashes with insufficient GPU memory, reduce the batch size and try again.
 - **The same module should not be re-used multiple times within the module** in order that the activation range estimation is correct. Unfortunately, in the torchvision ResNet models, the ReLU module in the BasicBlock and BottleneckBlock are re-used multiple times. We have corrected this by defining separate ReLU modules. This change is minor and **does not** affect the loading of existing pretrained weights. See the [our modified ResNet model definition here](./modules/pytorch_jacinto_ai/vision/models/resnet.py).<br>
 - **Use Modules instead of functions** (we make use of modules to decide whether to do activation range clipping or not). For example use torch.nn.reLU instead of torch.nn.functional.relu(), torch.nn.AdaptiveAvgPool2d() instead of torch.nn.functional.adaptive_avg_pool2d(), torch.nn.Flatten() instead of torch.nn.functional.flatten() etc. If you are using functions in your model and is giving poor quantized accuracy, then consider replacing those functions by the corresponding modules.<br>
-- Tools for Calibration and Trained Quantization have started appearing in mainstream Deep Learning training frameworks [[7,8]]. Using the tools natively provided by these frameworks may be faster compared to an implementation in the Python layer of these frameworks (like we have done) - but they may not be mature currently.<br>
 
 
 ## Results
diff --git a/examples/quantization_example.py b/examples/quantization_example.py
new file mode 100644 (file)
index 0000000..7d5e548
--- /dev/null
@@ -0,0 +1,464 @@
+# this code is modified from: https://github.com/pytorch/examples/blob/master/imagenet/main.py
+# the changes required for quantizing the model are under the flag args.quantize
+import argparse
+import os
+import random
+import shutil
+import time
+import warnings
+
+import torch
+import torch.nn as nn
+import torch.nn.parallel
+import torch.backends.cudnn as cudnn
+import torch.distributed as dist
+import torch.optim
+import torch.multiprocessing as mp
+import torch.utils.data
+import torch.utils.data.distributed
+import torchvision.transforms as transforms
+import torchvision.datasets as datasets
+
+# some of the default torchvision models need some minor tweaks to be friendly for
+# quantization aware training. so use models from pytorch_jacinto_ai.vision insead
+#import torchvision.models as models
+
+from pytorch_jacinto_ai import xnn
+from pytorch_jacinto_ai import vision as xvision
+from pytorch_jacinto_ai.vision import models as models
+
+model_names = sorted(name for name in models.__dict__
+    if name.islower() and not name.startswith("__")
+    and callable(models.__dict__[name]))
+
+parser = argparse.ArgumentParser(description='PyTorch ImageNet Training')
+parser.add_argument('data', metavar='DIR',
+                    help='path to dataset')
+parser.add_argument('-a', '--arch', metavar='ARCH', default='resnet18',
+                    choices=model_names,
+                    help='model architecture: ' +
+                        ' | '.join(model_names) +
+                        ' (default: resnet18)')
+parser.add_argument('-j', '--workers', default=4, type=int, metavar='N',
+                    help='number of data loading workers (default: 4)')
+parser.add_argument('--epochs', default=90, type=int, metavar='N',
+                    help='number of total epochs to run')
+parser.add_argument('--epoch-size', default=0, type=int, metavar='N',
+                    help='number of iterations in one training epoch. 0 (default) means full training epoch')
+parser.add_argument('--start-epoch', default=0, type=int, metavar='N',
+                    help='manual epoch number (useful on restarts)')
+parser.add_argument('-b', '--batch-size', default=256, type=int,
+                    metavar='N',
+                    help='mini-batch size (default: 256), this is the total '
+                         'batch size of all GPUs on the current node when '
+                         'using Data Parallel or Distributed Data Parallel')
+parser.add_argument('--lr', '--learning-rate', default=0.1, type=float,
+                    metavar='LR', help='initial learning rate', dest='lr')
+parser.add_argument('--momentum', default=0.9, type=float, metavar='M',
+                    help='momentum')
+parser.add_argument('--wd', '--weight-decay', default=1e-4, type=float,
+                    metavar='W', help='weight decay (default: 1e-4)',
+                    dest='weight_decay')
+parser.add_argument('-p', '--print-freq', default=100, type=int,
+                    metavar='N', help='print frequency (default: 10)')
+parser.add_argument('--resume', default='', type=str, metavar='PATH',
+                    help='path to latest checkpoint (default: none)')
+parser.add_argument('-e', '--evaluate', dest='evaluate', action='store_true',
+                    help='evaluate model on validation set')
+parser.add_argument('--pretrained', type=str, default=None,
+                    help='use pre-trained model')
+parser.add_argument('--world-size', default=-1, type=int,
+                    help='number of nodes for distributed training')
+parser.add_argument('--rank', default=-1, type=int,
+                    help='node rank for distributed training')
+parser.add_argument('--dist-url', default='tcp://224.66.41.62:23456', type=str,
+                    help='url used to set up distributed training')
+parser.add_argument('--dist-backend', default='nccl', type=str,
+                    help='distributed backend')
+parser.add_argument('--seed', default=None, type=int,
+                    help='seed for initializing training. ')
+parser.add_argument('--gpu', default=None, type=int,
+                    help='GPU id to use.')
+parser.add_argument('--multiprocessing-distributed', action='store_true',
+                    help='Use multi-processing distributed training to launch '
+                         'N processes per node, which has N GPUs. This is the '
+                         'fastest way to use PyTorch for either single node or '
+                         'multi node data parallel training')
+parser.add_argument('--quantize', action='store_true',
+                    help='Enable Quantization')
+best_acc1 = 0
+
+
+def main():
+    args = parser.parse_args()
+
+    if args.seed is not None:
+        random.seed(args.seed)
+        torch.manual_seed(args.seed)
+        cudnn.deterministic = True
+        warnings.warn('You have chosen to seed training. '
+                      'This will turn on the CUDNN deterministic setting, '
+                      'which can slow down your training considerably! '
+                      'You may see unexpected behavior when restarting '
+                      'from checkpoints.')
+
+    if args.gpu is not None:
+        warnings.warn('You have chosen a specific GPU. This will completely '
+                      'disable data parallelism.')
+
+    if args.dist_url == "env://" and args.world_size == -1:
+        args.world_size = int(os.environ["WORLD_SIZE"])
+
+    args.distributed = args.world_size > 1 or args.multiprocessing_distributed
+
+    ngpus_per_node = torch.cuda.device_count()
+    if args.multiprocessing_distributed:
+        # Since we have ngpus_per_node processes per node, the total world_size
+        # needs to be adjusted accordingly
+        args.world_size = ngpus_per_node * args.world_size
+        # Use torch.multiprocessing.spawn to launch distributed processes: the
+        # main_worker process function
+        mp.spawn(main_worker, nprocs=ngpus_per_node, args=(ngpus_per_node, args))
+    else:
+        # Simply call main_worker function
+        main_worker(args.gpu, ngpus_per_node, args)
+
+
+def main_worker(gpu, ngpus_per_node, args):
+    global best_acc1
+    args.gpu = gpu
+
+    if args.gpu is not None:
+        print("Use GPU: {} for training".format(args.gpu))
+
+    if args.distributed:
+        if args.dist_url == "env://" and args.rank == -1:
+            args.rank = int(os.environ["RANK"])
+        if args.multiprocessing_distributed:
+            # For multiprocessing distributed training, rank needs to be the
+            # global rank among all the processes
+            args.rank = args.rank * ngpus_per_node + gpu
+        dist.init_process_group(backend=args.dist_backend, init_method=args.dist_url,
+                                world_size=args.world_size, rank=args.rank)
+    # create model
+    print("=> creating model '{}'".format(args.arch))
+    model = models.__dict__[args.arch]()
+
+    if args.quantize:
+        # DistributedDataParallel / DataParallel are not supported with quantization
+        dummy_input = torch.rand((1, 3, 224, 224))
+        if args.evaluate:
+            # for validation accuracy check with quantization - can be used to estimate approximate accuracy achieved with quantization
+            model = xnn.quantize.QuantTestModule(model, dummy_input=dummy_input).cuda(args.gpu)
+        else:
+            # for quantization aware training
+            model = xnn.quantize.QuantTrainModule(model, dummy_input=dummy_input).cuda(args.gpu)
+        #
+    else:
+        if args.distributed:
+            # For multiprocessing distributed, DistributedDataParallel constructor
+            # should always set the single device scope, otherwise,
+            # DistributedDataParallel will use all available devices.
+            if args.gpu is not None:
+                torch.cuda.set_device(args.gpu)
+                model.cuda(args.gpu)
+                # When using a single GPU per process and per
+                # DistributedDataParallel, we need to divide the batch size
+                # ourselves based on the total number of GPUs we have
+                args.batch_size = int(args.batch_size / ngpus_per_node)
+                args.workers = int((args.workers + ngpus_per_node - 1) / ngpus_per_node)
+                model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu])
+            else:
+                model.cuda()
+                # DistributedDataParallel will divide and allocate batch_size to all
+                # available GPUs if device_ids are not set
+                model = torch.nn.parallel.DistributedDataParallel(model)
+        elif args.gpu is not None:
+            torch.cuda.set_device(args.gpu)
+            model = model.cuda(args.gpu)
+        else:
+            # DataParallel will divide and allocate batch_size to all available GPUs
+            if args.arch.startswith('alexnet') or args.arch.startswith('vgg'):
+                model.features = torch.nn.DataParallel(model.features)
+                model.cuda()
+            else:
+                model = torch.nn.DataParallel(model).cuda()
+
+    if args.pretrained is not None:
+        model_orig = model.module if isinstance(model, (torch.nn.parallel.DistributedDataParallel, torch.nn.parallel.DataParallel)) else model
+        model_orig = model_orig.module if args.quantize else model_orig
+        print("=> using pre-trained model for {} from {}".format(args.arch, args.pretrained))
+        if hasattr(model_orig, 'load_weights'):
+            model_orig.load_weights(args.pretrained, download_root='./data/downloads')
+        else:
+            xnn.utils.load_weights(model_orig, args.pretrained, download_root='./data/downloads')
+        #
+
+    # define loss function (criterion) and optimizer
+    criterion = nn.CrossEntropyLoss().cuda(args.gpu)
+
+    optimizer = torch.optim.SGD(model.parameters(), args.lr,
+                                momentum=args.momentum,
+                                weight_decay=args.weight_decay)
+
+    # optionally resume from a checkpoint
+    if args.resume:
+        if os.path.isfile(args.resume):
+            print("=> loading checkpoint '{}'".format(args.resume))
+            if args.gpu is None:
+                checkpoint = torch.load(args.resume)
+            else:
+                # Map model to be loaded to specified single gpu.
+                loc = 'cuda:{}'.format(args.gpu)
+                checkpoint = torch.load(args.resume, map_location=loc)
+            args.start_epoch = checkpoint['epoch']
+            best_acc1 = checkpoint['best_acc1']
+            if args.gpu is not None:
+                # best_acc1 may be from a checkpoint from a different GPU
+                best_acc1 = best_acc1.to(args.gpu)
+            model.load_state_dict(checkpoint['state_dict'])
+            optimizer.load_state_dict(checkpoint['optimizer'])
+            print("=> loaded checkpoint '{}' (epoch {})"
+                  .format(args.resume, checkpoint['epoch']))
+        else:
+            print("=> no checkpoint found at '{}'".format(args.resume))
+
+    cudnn.benchmark = True
+
+    # Data loading code
+    traindir = os.path.join(args.data, 'train')
+    valdir = os.path.join(args.data, 'val')
+    normalize = xvision.transforms.NormalizeMeanScale(mean=[123.675, 116.28, 103.53], scale=[0.017125, 0.017507, 0.017429])
+
+    train_dataset = datasets.ImageFolder(
+        traindir,
+        transforms.Compose([
+            transforms.RandomResizedCrop(224),
+            transforms.RandomHorizontalFlip(),
+            xvision.transforms.ToFloat(),
+            transforms.ToTensor(),
+            normalize,
+        ]))
+
+    if args.distributed:
+        train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset)
+    else:
+        train_sampler = None
+
+    train_loader = torch.utils.data.DataLoader(
+        train_dataset, batch_size=args.batch_size, shuffle=(train_sampler is None),
+        num_workers=args.workers, pin_memory=True, sampler=train_sampler)
+
+    val_loader = torch.utils.data.DataLoader(
+        datasets.ImageFolder(valdir, transforms.Compose([
+            transforms.Resize(256),
+            transforms.CenterCrop(224),
+            xvision.transforms.ToFloat(),
+            transforms.ToTensor(),
+            normalize,
+        ])),
+        batch_size=args.batch_size, shuffle=False,
+        num_workers=args.workers, pin_memory=True)
+
+    validate(val_loader, model, criterion, args)
+
+    if args.evaluate:
+        return
+
+    for epoch in range(args.start_epoch, args.epochs):
+        if args.distributed:
+            train_sampler.set_epoch(epoch)
+        adjust_learning_rate(optimizer, epoch, args)
+
+        # train for one epoch
+        train(train_loader, model, criterion, optimizer, epoch, args)
+
+        # evaluate on validation set
+        acc1 = validate(val_loader, model, criterion, args)
+
+        # remember best acc@1 and save checkpoint
+        is_best = acc1 > best_acc1
+        best_acc1 = max(acc1, best_acc1)
+
+        model_orig = model.module if isinstance(model, (torch.nn.parallel.DistributedDataParallel, torch.nn.parallel.DataParallel)) else model
+        model_orig = model_orig.module if args.quantize else model_orig
+        if not args.multiprocessing_distributed or (args.multiprocessing_distributed
+                and args.rank % ngpus_per_node == 0):
+            out_basename = args.arch
+            out_basename += ('_quantized_checkpoint.pth.tar' if args.quantize else '_checkpoint.pth.tar')
+            save_filename = os.path.join('./data/checkpoints/quantization', out_basename)
+            save_checkpoint({
+                'epoch': epoch + 1,
+                'arch': args.arch,
+                'state_dict': model_orig.state_dict(),
+                'best_acc1': best_acc1,
+                'optimizer' : optimizer.state_dict(),
+            }, is_best, filename=save_filename)
+
+
+def train(train_loader, model, criterion, optimizer, epoch, args):
+    batch_time = AverageMeter('Time', ':6.3f')
+    data_time = AverageMeter('Data', ':6.3f')
+    losses = AverageMeter('Loss', ':.4e')
+    top1 = AverageMeter('Acc@1', ':6.2f')
+    top5 = AverageMeter('Acc@5', ':6.2f')
+    progress = ProgressMeter(
+        len(train_loader),
+        [batch_time, data_time, losses, top1, top5],
+        prefix="Epoch: [{}]".format(epoch))
+
+    # switch to train mode
+    model.train()
+
+    end = time.time()
+    for i, (images, target) in enumerate(train_loader):
+        # break the epoch at at the iteration epoch_size
+        if args.epoch_size != 0 and i >= args.epoch_size:
+            break
+        # measure data loading time
+        data_time.update(time.time() - end)
+
+        images = images.cuda(args.gpu, non_blocking=True)
+        target = target.cuda(args.gpu, non_blocking=True)
+
+        # compute output
+        output = model(images)
+        loss = criterion(output, target)
+
+        # measure accuracy and record loss
+        acc1, acc5 = accuracy(output, target, topk=(1, 5))
+        losses.update(loss.item(), images.size(0))
+        top1.update(acc1[0], images.size(0))
+        top5.update(acc5[0], images.size(0))
+
+        # compute gradient and do SGD step
+        optimizer.zero_grad()
+        loss.backward()
+        optimizer.step()
+
+        # measure elapsed time
+        batch_time.update(time.time() - end)
+        end = time.time()
+
+        if i % args.print_freq == 0:
+            progress.display(i)
+
+
+def validate(val_loader, model, criterion, args):
+    batch_time = AverageMeter('Time', ':6.3f')
+    losses = AverageMeter('Loss', ':.4e')
+    top1 = AverageMeter('Acc@1', ':6.2f')
+    top5 = AverageMeter('Acc@5', ':6.2f')
+    progress = ProgressMeter(
+        len(val_loader),
+        [batch_time, losses, top1, top5],
+        prefix='Test: ')
+
+    # switch to evaluate mode
+    model.eval()
+
+    with torch.no_grad():
+        end = time.time()
+        for i, (images, target) in enumerate(val_loader):
+            images = images.cuda(args.gpu, non_blocking=True)
+            target = target.cuda(args.gpu, non_blocking=True)
+
+            # compute output
+            output = model(images)
+            loss = criterion(output, target)
+
+            # measure accuracy and record loss
+            acc1, acc5 = accuracy(output, target, topk=(1, 5))
+            losses.update(loss.item(), images.size(0))
+            top1.update(acc1[0], images.size(0))
+            top5.update(acc5[0], images.size(0))
+
+            # measure elapsed time
+            batch_time.update(time.time() - end)
+            end = time.time()
+
+            if i % args.print_freq == 0:
+                progress.display(i)
+
+        # TODO: this should also be done with the ProgressMeter
+        print(' * Acc@1 {top1.avg:.3f} Acc@5 {top5.avg:.3f}'
+              .format(top1=top1, top5=top5))
+
+    return top1.avg
+
+
+def save_checkpoint(state, is_best, filename='checkpoint.pth.tar'):
+    dirname = os.path.dirname(filename)
+    xnn.utils.makedir_exist_ok(dirname)
+    torch.save(state, filename)
+    if is_best:
+        shutil.copyfile(filename, os.path.splitext(filename)[0]+'_best.pth.tar')
+
+
+class AverageMeter(object):
+    """Computes and stores the average and current value"""
+    def __init__(self, name, fmt=':f'):
+        self.name = name
+        self.fmt = fmt
+        self.reset()
+
+    def reset(self):
+        self.val = 0
+        self.avg = 0
+        self.sum = 0
+        self.count = 0
+
+    def update(self, val, n=1):
+        self.val = val
+        self.sum += val * n
+        self.count += n
+        self.avg = self.sum / self.count
+
+    def __str__(self):
+        fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})'
+        return fmtstr.format(**self.__dict__)
+
+
+class ProgressMeter(object):
+    def __init__(self, num_batches, meters, prefix=""):
+        self.batch_fmtstr = self._get_batch_fmtstr(num_batches)
+        self.meters = meters
+        self.prefix = prefix
+
+    def display(self, batch):
+        entries = [self.prefix + self.batch_fmtstr.format(batch)]
+        entries += [str(meter) for meter in self.meters]
+        print('\t'.join(entries))
+
+    def _get_batch_fmtstr(self, num_batches):
+        num_digits = len(str(num_batches // 1))
+        fmt = '{:' + str(num_digits) + 'd}'
+        return '[' + fmt + '/' + fmt.format(num_batches) + ']'
+
+
+def adjust_learning_rate(optimizer, epoch, args):
+    """Sets the learning rate to the initial LR decayed by 10 every 30 epochs"""
+    lr = args.lr * (0.1 ** (epoch // 30))
+    for param_group in optimizer.param_groups:
+        param_group['lr'] = lr
+
+
+def accuracy(output, target, topk=(1,)):
+    """Computes the accuracy over the k top predictions for the specified values of k"""
+    with torch.no_grad():
+        maxk = max(topk)
+        batch_size = target.size(0)
+
+        _, pred = output.topk(maxk, 1, True, True)
+        pred = pred.t()
+        correct = pred.eq(target.view(1, -1).expand_as(pred))
+
+        res = []
+        for k in topk:
+            correct_k = correct[:k].view(-1).float().sum(0, keepdim=True)
+            res.append(correct_k.mul_(100.0 / batch_size))
+        return res
+
+
+if __name__ == '__main__':
+    main()
\ No newline at end of file
index e05ebf1017830d8189617a76b818c68244327f4b..d17decf8a110d177614af1ec0e749b0fbdde469d 100644 (file)
@@ -381,7 +381,7 @@ def main(args):
                         dummy_input=dummy_input)
 
     # load pretrained weights
-    xnn.utils.load_weights_check(get_model_orig(model), pretrained=pretrained_data, change_names_dict=change_names_dict)
+    xnn.utils.load_weights(get_model_orig(model), pretrained=pretrained_data, change_names_dict=change_names_dict)
 
     if args.generate_onnx:
         write_onnx_model(args, model, save_path, name='model_best.onnx')
index 909f0f5411f786eb6391ab82ad21c02c79666e61..534c1110678945e487ba29f903e2a5dd4d4e454f 100644 (file)
@@ -166,7 +166,7 @@ def main(args):
     #
 
     # load pretrained
-    xnn.utils.load_weights_check(get_model_orig(model), pretrained=pretrained_data, change_names_dict=change_names_dict)
+    xnn.utils.load_weights(get_model_orig(model), pretrained=pretrained_data, change_names_dict=change_names_dict)
 
     #################################################
     if args.count_flops:
index 7ba0f2aea453e471135745f9fa89ed7cb970f60d..d1d60995f9fb8f94ab19bd3f8b6225d098995563 100644 (file)
@@ -229,7 +229,7 @@ def main(args):
 
     # load pretrained
     if pretrained_data is not None:
-        xnn.utils.load_weights_check(get_model_orig(model), pretrained=pretrained_data, change_names_dict=change_names_dict)
+        xnn.utils.load_weights(get_model_orig(model), pretrained=pretrained_data, change_names_dict=change_names_dict)
     #
     
     #################################################
@@ -315,7 +315,7 @@ def main(args):
                 args.start_epoch = checkpoint['epoch'] + 1
                 
             args.best_prec1 = checkpoint['best_prec1']
-            model = xnn.utils.load_weights_check(model, checkpoint)
+            model = xnn.utils.load_weights(model, checkpoint)
             optimizer.load_state_dict(checkpoint['optimizer'])
             print("=> resuming from checkpoint '{}' (epoch {})".format(args.resume, checkpoint['epoch']))
         else:
index c5a000708550b80384a6db5082314f5837f1ab60..237ac7520118e3c4fb63336c2a6c5c4e2e7530a2 100644 (file)
@@ -342,7 +342,7 @@ def main(args):
 
     # load pretrained model
     if pretrained_data is not None:
-        xnn.utils.load_weights_check(get_model_orig(model), pretrained=pretrained_data, change_names_dict=change_names_dict)
+        xnn.utils.load_weights(get_model_orig(model), pretrained=pretrained_data, change_names_dict=change_names_dict)
     #
 
     #################################################
@@ -478,7 +478,7 @@ def main(args):
             print("=> loading checkpoint '{}'".format(args.resume))
 
         checkpoint = torch.load(args.resume)
-        model = xnn.utils.load_weights_check(model, checkpoint)
+        model = xnn.utils.load_weights(model, checkpoint)
             
         if args.start_epoch == 0:
             args.start_epoch = checkpoint['epoch']
index 9ad8680029e0fdcb7e8c361ceb27409470ed0c2c..d9397105a18ba886c9720a670d308fd9daddf635 100644 (file)
@@ -1,4 +1,4 @@
-# pytorch vision models
+# modified from pytorch vision models
 from .alexnet import *
 from .resnet import *
 from .vgg import *
@@ -34,3 +34,7 @@ except: pass
 try: from .flownetbase_internal import *
 except: pass
 
+@property
+def name():
+    return 'pytorch_jacinto_ai.vision.models'
+#
\ No newline at end of file
index 5ba9ff778bb598bce670f7f7d23dbabf886a3a17..0950ce29d1727e06060713d2d4bb23b79428e62a 100644 (file)
@@ -34,7 +34,7 @@ def resnet50_x1(model_config, pretrained=None, width_mult=1.0):
                          '^relu.': 'features.relu.', '^maxpool.': 'features.maxpool.',
                          '^layer': 'features.layer', '^fc.': 'classifier.'}
     if pretrained:
-        model = xnn.utils.load_weights_check(model, pretrained, change_names_dict=change_names_dict)
+        model = xnn.utils.load_weights(model, pretrained, change_names_dict=change_names_dict)
     return model, change_names_dict
 
 
@@ -47,7 +47,7 @@ def mobilenetv1_x1(model_config, pretrained=None):
     model_config = mobilenetv1.get_config().merge_from(model_config)
     model = mobilenetv1.MobileNetV1(model_config=model_config)
     if pretrained:
-        model = xnn.utils.load_weights_check(model, pretrained)
+        model = xnn.utils.load_weights(model, pretrained)
     return model
 
 
@@ -56,7 +56,7 @@ def mobilenetv2_tv_x1(model_config, pretrained=None):
     model_config = mobilenetv2.get_config().merge_from(model_config)
     model = mobilenetv2.MobileNetV2TV(model_config=model_config)
     if pretrained:
-        model = xnn.utils.load_weights_check(model, pretrained)
+        model = xnn.utils.load_weights(model, pretrained)
     return model
 #
 #alias
@@ -69,7 +69,7 @@ def mobilenetv2_tv_x2_t2(model_config, pretrained=None):
     model_config.expand_ratio = 2.0
     model = mobilenetv2.MobileNetV2TV(model_config=model_config)
     if pretrained:
-        model = xnn.utils.load_weights_check(model, pretrained)
+        model = xnn.utils.load_weights(model, pretrained)
     return model
 
 
@@ -78,7 +78,7 @@ def mobilenetv2_tv_gws_x1(model_config, pretrained=None):
     model_config = mobilenetv2_gws_internal.get_config().merge_from(model_config)
     model = mobilenetv2_gws_internal.MobileNetV2TVGWS(model_config=model_config)
     if pretrained:
-        model = xnn.utils.load_weights_check(model, pretrained)
+        model = xnn.utils.load_weights(model, pretrained)
     return model
 
 
@@ -87,7 +87,7 @@ def mobilenetv2_ericsun_x1(model_config, pretrained=None):
     model_config = mobilenetv2_ericsun_internal.get_config().merge_from(model_config)
     model = mobilenetv2_ericsun_internal.MobileNetV2Ericsun(model_config=model_config)
     if pretrained:
-        model = xnn.utils.load_weights_check(model, pretrained)
+        model = xnn.utils.load_weights(model, pretrained)
     return model
 
 
@@ -95,7 +95,7 @@ def mobilenetv2_shicai_x1(model_config, pretrained=None):
     model_config = mobilenetv2_shicai_internal.get_config().merge_from(model_config)
     model = mobilenetv2_shicai_internal.mobilenetv2_shicai(model_config=model_config)
     if pretrained:
-        model = xnn.utils.load_weights_check(model, pretrained)
+        model = xnn.utils.load_weights(model, pretrained)
     return model
 
 
@@ -103,5 +103,5 @@ def flownetslite_base_x1(model_config, pretrained=None):
     model_config = flownetbase_internal.get_config().merge_from(model_config)
     model = flownetbase_internal.flownetslite_base(model_config, pretrained=pretrained)
     if pretrained:
-        model = xnn.utils.load_weights_check(model, pretrained)
+        model = xnn.utils.load_weights(model, pretrained)
     return model
\ No newline at end of file
index 2ad53b277053f188164497e5a613afb8046e469e..1f3ba014a7b7466eeeb1041a9e2b6142fe154749 100644 (file)
@@ -145,7 +145,7 @@ def _load_pretrained(model_name, model, progress):
     # the pretrained model provided by torchvision and what is defined here differs slightly
     # note: that this change_names_dict  will take effect only if the direct load fails
     change_names_dict = {'^layers.': 'features.'}
-    model = xnn.utils.load_weights_check(model, state_dict, change_names_dict=change_names_dict)
+    model = xnn.utils.load_weights(model, state_dict, change_names_dict=change_names_dict)
     return model
 
 
index a47d4b135fd170801195a55cb6dc92d7590c418a..89b59e2331ce24bbf64053394891223a0b50db81 100644 (file)
@@ -3,7 +3,7 @@ from .utils import *
 from ... import xnn
 
 ###################################################
-__all__ = ['MobileNetV2TVBase', 'MobileNetV2TV', 'mobilenet_v2_tv', 'get_config']
+__all__ = ['MobileNetV2TVBase', 'MobileNetV2TV', 'mobilenet_v2', 'mobilenet_v2_tv', 'get_config']
 
 
 ###################################################
@@ -187,3 +187,5 @@ def mobilenet_v2_tv(pretrained=False, progress=True, **kwargs):
         state_dict = load_state_dict_from_url(model_urls['mobilenet_v2_tv'], progress=progress)
         model.load_state_dict(state_dict)
     return model
+
+mobilenet_v2 = mobilenet_v2_tv
\ No newline at end of file
index 16d6d27d0f4dce67f7d532a4d48e6d21c2ecf9ea..4e760586f6ebdc6fa8473b6490940d400d58e416 100644 (file)
@@ -62,9 +62,9 @@ class MultiInputNet(torch.nn.Module):
 
         if model_config.num_inputs>1 and pretrained:
             change_names_dict = {'^features.': ['features.stream{}.'.format(stream) for stream in range(model_config.num_inputs)]}
-            xnn.utils.load_weights_check(self, model_config.pretrained, change_names_dict, ignore_size=True, verbose=True)
+            xnn.utils.load_weights(self, model_config.pretrained, change_names_dict, ignore_size=True, verbose=True)
         elif pretrained:
-            xnn.utils.load_weights_check(self, model_config.pretrained, change_names_dict=None, ignore_size=True, verbose=True)
+            xnn.utils.load_weights(self, model_config.pretrained, change_names_dict=None, ignore_size=True, verbose=True)
 
 
     def _initialize_weights(self):
index b75366ea3cd213ece922141c7ff2be23193db6c8..2252e385ade719a7c839b8a3d3db75dbc2cb62f1 100644 (file)
@@ -166,7 +166,7 @@ def deeplabv3lite_mobilenetv2_tv(model_config, pretrained=None):
     #
 
     if pretrained:
-        model = xnn.utils.load_weights_check(model, pretrained, change_names_dict, ignore_size=True, verbose=True)
+        model = xnn.utils.load_weights(model, pretrained, change_names_dict, ignore_size=True, verbose=True)
 
     return model, change_names_dict
 
@@ -192,7 +192,7 @@ def deeplabv3lite_mobilenetv2_ericsun(model_config, pretrained=None):
     #
 
     if pretrained:
-        model = xnn.utils.load_weights_check(model, pretrained, change_names_dict, ignore_size=True, verbose=True)
+        model = xnn.utils.load_weights(model, pretrained, change_names_dict, ignore_size=True, verbose=True)
 
     return model, change_names_dict
 
@@ -240,6 +240,6 @@ def deeplabv3lite_resnet50(model_config, pretrained=None):
     #
 
     if pretrained:
-        model = xnn.utils.load_weights_check(model, pretrained, change_names_dict, ignore_size=True, verbose=True)
+        model = xnn.utils.load_weights(model, pretrained, change_names_dict, ignore_size=True, verbose=True)
 
     return model, change_names_dict
index 571e70feaa40c81325db5b46c2f3c5843ad4787b..b24a204e07700fec2b093520b2fc2b4a32245f9b 100644 (file)
@@ -33,7 +33,7 @@ def get_config_fpnp2p_mnv2():
     model_config.aspp_chan = 256
     model_config.aspp_dil = (6,12,18)
 
-    model_config.inloop_fpn = False # inloop_fpn means the smooth convs are in the loop, after upsample
+    model_config.inloop_fpn = True #False # inloop_fpn means the smooth convs are in the loop, after upsample
 
     model_config.kernel_size_smooth = 3
     model_config.interpolation_type = 'upsample'
@@ -62,19 +62,22 @@ class FPNPyramid(torch.nn.Module):
         self.shortcut_strides = shortcut_strides
         self.shortcut_channels = shortcut_channels
         self.smooth_convs = torch.nn.ModuleList()
-        self.shortcuts = torch.nn.ModuleList([self.create_shortcut(current_channels, decoder_channels, activation)])
+        self.shortcuts = torch.nn.ModuleList()
         self.upsamples = torch.nn.ModuleList()
 
+        shortcut0 = self.create_shortcut(current_channels, decoder_channels, activation) if (current_channels != decoder_channels) else None
+        self.shortcuts.append(shortcut0)
+
+        smooth_conv0 = None #xnn.layers.ConvDWSepNormAct2d(decoder_channels, decoder_channels, kernel_size=kernel_size_smooth, activation=(activation, activation)) if all_outputs else None
+        self.smooth_convs.append(smooth_conv0)
+
         upstride = 2
         for idx, (s_stride, feat_chan) in enumerate(zip(shortcut_strides, shortcut_channels)):
             shortcut = self.create_shortcut(feat_chan, decoder_channels, activation)
             self.shortcuts.append(shortcut)
             is_last = (idx == len(shortcut_channels)-1)
-            if inloop_fpn or (all_outputs or is_last):
-                smooth_conv = xnn.layers.ConvDWSepNormAct2d(decoder_channels, decoder_channels, kernel_size=kernel_size_smooth, activation=(activation,activation))
-            else:
-                smooth_conv = None
-            #
+            smooth_conv = xnn.layers.ConvDWSepNormAct2d(decoder_channels, decoder_channels, kernel_size=kernel_size_smooth, activation=(activation,activation)) \
+                        if (inloop_fpn or all_outputs or is_last) else None
             self.smooth_convs.append(smooth_conv)
             upsample = xnn.layers.UpsampleTo(decoder_channels, decoder_channels, upstride, interpolation_type, interpolation_mode)
             self.upsamples.append(upsample)
@@ -86,24 +89,31 @@ class FPNPyramid(torch.nn.Module):
         return shortcut
     #
 
-    def forward(self, x_list, in_shape):
+    def forward(self, x_input, x_list):
+        in_shape = x_input.shape
         x = x_list[-1]
-        x = self.shortcuts[0](x)
-        outputs = [x]
-        for idx, (shortcut, smooth_conv, s_stride, short_chan, upsample) in enumerate(zip(self.shortcuts[1:], self.smooth_convs, self.shortcut_strides, self.shortcut_channels, self.upsamples)):
+
+        outputs = []
+        x = self.shortcuts[0](x) if (self.shortcuts[0] is not None) else x
+        y = self.smooth_convs[0](x) if (self.smooth_convs[0] is not None) else x
+        x = y if self.inloop_fpn else x
+        outputs.append(y)
+
+        for idx, (shortcut, smooth_conv, s_stride, short_chan, upsample) in enumerate(zip(self.shortcuts[1:], self.smooth_convs[1:], self.shortcut_strides, self.shortcut_channels, self.upsamples)):
+            # get the feature of lower stride
             shape_s = xnn.utils.get_shape_with_stride(in_shape, s_stride)
             shape_s[1] = short_chan
             x_s = xnn.utils.get_blob_from_list(x_list, shape_s)
             x_s = shortcut(x_s)
+            # updample current output and add to that
             x = upsample((x,x_s))
             x = x + x_s
-            if self.inloop_fpn:
-                x = smooth_conv(x)
-                outputs.append(x)
-            elif (smooth_conv is not None):
-                y = smooth_conv(x)
-                outputs.append(y)
-            #
+            # smooth conv
+            y = smooth_conv(x) if (smooth_conv is not None) else x
+            # use smooth output for next level in inloop_fpn
+            x = y if self.inloop_fpn else x
+            # output
+            outputs.append(y)
         #
         return outputs[::-1]
 
@@ -170,7 +180,7 @@ class FPNPixel2PixelDecoder(torch.nn.Module):
             x_list[-1] = x
         #
 
-        x_list = self.fpn(x_list, in_shape)
+        x_list = self.fpn(x_input, x_list)
         x = x_list[0]
 
         if self.model_config.final_prediction:
@@ -218,7 +228,7 @@ def fpn_pixel2pixel_aspp_mobilenetv2_tv(model_config, pretrained=None):
     #
 
     if pretrained:
-        model = xnn.utils.load_weights_check(model, pretrained, change_names_dict, ignore_size=True, verbose=True)
+        model = xnn.utils.load_weights(model, pretrained, change_names_dict, ignore_size=True, verbose=True)
     #
     return model, change_names_dict
 
@@ -313,7 +323,7 @@ def fpn_pixel2pixel_aspp_resnet50(model_config, pretrained=None):
     #
 
     if pretrained:
-        model = xnn.utils.load_weights_check(model, pretrained, change_names_dict, ignore_size=True, verbose=True)
+        model = xnn.utils.load_weights(model, pretrained, change_names_dict, ignore_size=True, verbose=True)
 
     return model, change_names_dict
 
index 0f8e3ee5399ef17bd86ecf37c71c0e4bfbce8a8d..18d0f8edd729098bc01b3c32333b760782b079a8 100644 (file)
@@ -220,16 +220,27 @@ class ResNet(nn.Module):
         return x
 
 
+    # define a load weights fuinction in the module since the module is changed w.r.t. to torchvision
+    # since we want to be able to laod the existing torchvision pretrained weights
+    def load_weights(self, pretrained, change_names_dict=None, download_root=None):
+        if change_names_dict is None:
+            # the pretrained model provided by torchvision and what is defined here differs slightly
+            # note: that this change_names_dict  will take effect only if the direct load fails
+            change_names_dict = {'^conv1.':'features.conv1.', '^bn1.':'features.bn1.',
+                                 '^relu.':'features.relu.', '^maxpool.':'features.maxpool.',
+                                 '^layer':'features.layer' , '^fc.':'classifier.'}
+        #
+        if pretrained is not None:
+            xnn.utils.load_weights(self, pretrained, change_names_dict=change_names_dict, download_root=download_root)
+        return self, change_names_dict
+
+
 def _resnet(arch, block, layers, pretrained, progress, **kwargs):
     model = ResNet(block, layers, **kwargs)
     if pretrained:
-        state_dict = load_state_dict_from_url(model_urls[arch], progress=progress)
-        # the pretrained model provided by torchvision and what is defined here differs slightly
-        # note: that this change_names_dict  will take effect only if the direct load fails
-        change_names_dict = {'^conv1.':'features.conv1.', '^bn1.':'features.bn1.',
-                             '^relu.':'features.relu.', '^maxpool.':'features.maxpool.',
-                             '^layer':'features.layer' , '^fc.':'classifier.'}
-        model = xnn.utils.load_weights_check(model, state_dict, change_names_dict=change_names_dict)
+        change_names_dict = kwargs.get('change_names_dict', None)
+        download_root = kwargs.get('download_root', None)
+        model.load_weights(pretrained, change_names_dict=change_names_dict, download_root=download_root)
     return model
 
 
index 601a03b387c70b34944cb06bb1fadef6bdffc42c..806528bea8ef82f0472b5ecedd6c6b00567d25db 100644 (file)
@@ -150,7 +150,7 @@ def _shufflenetv2(arch, pretrained, progress, *args, **kwargs):
             # note: that this change_names_dict  will take effect only if the direct load fails
             change_names_dict = {'^conv': 'features.conv', '^maxpool.': 'features.maxpool.',
                                  '^stage': 'features.stage', '^fc.': 'classifier.'}
-            model = xnn.utils.load_weights_check(model, state_dict, change_names_dict=change_names_dict)
+            model = xnn.utils.load_weights(model, state_dict, change_names_dict=change_names_dict)
 
     return model
 
index 66b6006e8eb2d5d348d14f66109a4c945265ec12..e491a57a6aee6d03be53f170fd04e7d27baf4015 100644 (file)
@@ -18,6 +18,16 @@ class QuantGraphModule(HookedModule):
         self.register_buffer('iter_in_epoch', torch.tensor(-1.0))
         self.register_buffer('epoch', torch.tensor(-1.0))
 
+        # TBD: is this required
+        # # if the original module has load_weights, add it to the quant module also
+        # if hasattr(module, 'load_weights'):
+        #     def load_weights(m, state_dict, change_names_dict=None):
+        #         utils.load_weights(m.module, state_dict, change_names_dict=change_names_dict)
+        #     #
+        #     self.load_weights = types.MethodType(load_weights, self)
+        # #
+
+
     # create the state object required to keep some quantization parameters that need to be preserved
     # a cuda() has been called on the module - copy the states from that was created for cpu
     def get_state(self):
index c8143cfce10b3ec3454daa578333cbefe85df843..f3a0197c88b0aac25c1ebab27fdd84f74b067132 100644 (file)
@@ -1,5 +1,6 @@
 from .print_utils import *
 from .util_functions import *
+from .utils_data import *
 from .load_weights import *
 from .tensor_utils import *
 from .logger import *
@@ -9,4 +10,4 @@ from .weights_init import *
 from .image_utils import *
 from .module_utils import *
 from .count_flops import forward_count_flops
-from .bn_utils import *
\ No newline at end of file
+from .bn_utils import *
index 2609710a8e9b80e30ca7f49e36b61d6ac6a01ce4..acd39115cfb6dbbaec49bf3cc3c56bca18245e65 100644 (file)
@@ -2,23 +2,27 @@
 # Copyright (c) Texas Instruments 2018
 # ALL RIGHTS RESERVED
 ######################################################
-
 import re
 import torch
 import copy
 from collections import OrderedDict
-from .print_utils import *
-
+from . import print_utils
+from . import utils_data
 
 ######################################################
-def load_weights_check(model, pretrained, change_names_dict=None, keep_original_names=False, width_mult=1.0,
-                       ignore_size=True, verbose=False, num_batches_tracked = None):
+def load_weights(model, pretrained, change_names_dict=None, keep_original_names=False, width_mult=1.0,
+                       ignore_size=True, verbose=False, num_batches_tracked = None, download_root=None, **kwargs):
     if pretrained is None:
-        print_yellow('=> weights could not be loaded. pretrained data given is None')
+        print_utils.print_yellow('=> weights could not be loaded. pretrained data given is None')
         return model
 
     if isinstance(pretrained, str):
-        data = torch.load(pretrained)
+        if pretrained.startswith('http://') or pretrained.startswith('https://'):
+            pretrained_file = utils_data.download_url(pretrained, root=download_root)
+        else:
+            pretrained_file = pretrained
+        #
+        data = torch.load(pretrained_file)
     else:
         data = pretrained
     #
@@ -77,7 +81,7 @@ def load_weights_check(model, pretrained, change_names_dict=None, keep_original_
             try:
                 model.load_state_dict(data, strict=False)
             except:
-                print_yellow('=> WARNING: weights could not be loaded completely.')
+                print_utils.print_yellow('=> WARNING: weights could not be loaded completely.')
         else:
             model.load_state_dict(data, strict=False)
         #
@@ -95,11 +99,11 @@ def check_model_data(model, data, verbose=False, ignore_names=('num_batches_trac
     not_matching_sizes = [k for k in model_dict.keys() if ((k in data.keys()) and (data[k].size() != model_dict[k].size()))]
 
     if missing_weights:
-        print_yellow("=> The following layers in the model could not be loaded from pre-trained: ", missing_weights)
+        print_utils.print_yellow("=> The following layers in the model could not be loaded from pre-trained: ", missing_weights)
     if not_matching_sizes:
-        print_yellow("=> The shape of the following weights did not match: ", not_matching_sizes)
+        print_utils.print_yellow("=> The shape of the following weights did not match: ", not_matching_sizes)
     if extra_weights:
-        print_yellow("=> The following weights in pre-trained were not used: ", extra_weights)
+        print_utils.print_yellow("=> The following weights in pre-trained were not used: ", extra_weights)
                 
     return missing_weights, extra_weights, not_matching_sizes
 
@@ -162,4 +166,3 @@ def widen_model_data(data, factor, verbose = True):
 
     return data
     
-    
index 0e0f7b0e955420a08abdcab15d3400d77bb3fb14..63137a1f8af0a7fb420059c404693a64cbb9ac95 100644 (file)
@@ -86,7 +86,7 @@ def is_fixed_range(op):
 def get_range(op):
     if isinstance(op, layers.PAct2):
         return op.get_clips_act()
-    elif isinstance(op, torch.nn.ReLUN):
+    elif isinstance(op, layers.ReLUN):
         return op.get_clips_act()
     elif isinstance(op, torch.nn.ReLU6):
         return 0.0, 6.0
diff --git a/modules/pytorch_jacinto_ai/xnn/utils/utils_data.py b/modules/pytorch_jacinto_ai/xnn/utils/utils_data.py
new file mode 100644 (file)
index 0000000..f0313f2
--- /dev/null
@@ -0,0 +1,284 @@
+# from torchvision.datasets
+import os
+import os.path
+import hashlib
+import gzip
+import errno
+import tarfile
+import zipfile
+
+import torch
+from torch.utils.model_zoo import tqdm
+
+
+def gen_bar_updater():
+    pbar = tqdm(total=None)
+
+    def bar_update(count, block_size, total_size):
+        if pbar.total is None and total_size:
+            pbar.total = total_size
+        progress_bytes = count * block_size
+        pbar.update(progress_bytes - pbar.n)
+
+    return bar_update
+
+
+def calculate_md5(fpath, chunk_size=1024 * 1024):
+    md5 = hashlib.md5()
+    with open(fpath, 'rb') as f:
+        for chunk in iter(lambda: f.read(chunk_size), b''):
+            md5.update(chunk)
+    return md5.hexdigest()
+
+
+def check_md5(fpath, md5, **kwargs):
+    return md5 == calculate_md5(fpath, **kwargs)
+
+
+def check_integrity(fpath, md5=None):
+    if not os.path.isfile(fpath):
+        return False
+    if md5 is None:
+        return True
+    return check_md5(fpath, md5)
+
+
+def makedir_exist_ok(dirpath):
+    """
+    Python2 support for os.makedirs(.., exist_ok=True)
+    """
+    try:
+        os.makedirs(dirpath)
+    except OSError as e:
+        if e.errno == errno.EEXIST:
+            pass
+        else:
+            raise
+
+
+def download_url(url, root, filename=None, md5=None):
+    """Download a file from a url and place it in root.
+
+    Args:
+        url (str): URL to download file from
+        root (str): Directory to place downloaded file in
+        filename (str, optional): Name to save the file under. If None, use the basename of the URL
+        md5 (str, optional): MD5 checksum of the download. If None, do not check
+    """
+    from six.moves import urllib
+
+    root = os.path.expanduser(root)
+    if not filename:
+        filename = os.path.basename(url)
+    fpath = os.path.join(root, filename)
+
+    makedir_exist_ok(root)
+
+    # downloads file
+    if check_integrity(fpath, md5):
+        print('Using downloaded and verified file: ' + fpath)
+    else:
+        try:
+            print('Downloading ' + url + ' to ' + fpath)
+            urllib.request.urlretrieve(
+                url, fpath,
+                reporthook=gen_bar_updater()
+            )
+        except (urllib.error.URLError, IOError) as e:
+            if url[:5] == 'https':
+                url = url.replace('https:', 'http:')
+                print('Failed download. Trying https -> http instead.'
+                      ' Downloading ' + url + ' to ' + fpath)
+                urllib.request.urlretrieve(
+                    url, fpath,
+                    reporthook=gen_bar_updater()
+                )
+            else:
+                raise e
+
+    return fpath
+
+
+def list_dir(root, prefix=False):
+    """List all directories at a given root
+
+    Args:
+        root (str): Path to directory whose folders need to be listed
+        prefix (bool, optional): If true, prepends the path to each result, otherwise
+            only returns the name of the directories found
+    """
+    root = os.path.expanduser(root)
+    directories = list(
+        filter(
+            lambda p: os.path.isdir(os.path.join(root, p)),
+            os.listdir(root)
+        )
+    )
+
+    if prefix is True:
+        directories = [os.path.join(root, d) for d in directories]
+
+    return directories
+
+
+def list_files(root, suffix, prefix=False):
+    """List all files ending with a suffix at a given root
+
+    Args:
+        root (str): Path to directory whose folders need to be listed
+        suffix (str or tuple): Suffix of the files to match, e.g. '.png' or ('.jpg', '.png').
+            It uses the Python "str.endswith" method and is passed directly
+        prefix (bool, optional): If true, prepends the path to each result, otherwise
+            only returns the name of the files found
+    """
+    root = os.path.expanduser(root)
+    files = list(
+        filter(
+            lambda p: os.path.isfile(os.path.join(root, p)) and p.endswith(suffix),
+            os.listdir(root)
+        )
+    )
+
+    if prefix is True:
+        files = [os.path.join(root, d) for d in files]
+
+    return files
+
+
+def download_file_from_google_drive(file_id, root, filename=None, md5=None):
+    """Download a Google Drive file from  and place it in root.
+
+    Args:
+        file_id (str): id of file to be downloaded
+        root (str): Directory to place downloaded file in
+        filename (str, optional): Name to save the file under. If None, use the id of the file.
+        md5 (str, optional): MD5 checksum of the download. If None, do not check
+    """
+    # Based on https://stackoverflow.com/questions/38511444/python-download-files-from-google-drive-using-url
+    import requests
+    url = "https://docs.google.com/uc?export=download"
+
+    root = os.path.expanduser(root)
+    if not filename:
+        filename = file_id
+    fpath = os.path.join(root, filename)
+
+    makedir_exist_ok(root)
+
+    if os.path.isfile(fpath) and check_integrity(fpath, md5):
+        print('Using downloaded and verified file: ' + fpath)
+    else:
+        session = requests.Session()
+
+        response = session.get(url, params={'id': file_id}, stream=True)
+        token = _get_confirm_token(response)
+
+        if token:
+            params = {'id': file_id, 'confirm': token}
+            response = session.get(url, params=params, stream=True)
+
+        _save_response_content(response, fpath)
+
+
+def _get_confirm_token(response):
+    for key, value in response.cookies.items():
+        if key.startswith('download_warning'):
+            return value
+
+    return None
+
+
+def _save_response_content(response, destination, chunk_size=32768):
+    with open(destination, "wb") as f:
+        pbar = tqdm(total=None)
+        progress = 0
+        for chunk in response.iter_content(chunk_size):
+            if chunk:  # filter out keep-alive new chunks
+                f.write(chunk)
+                progress += len(chunk)
+                pbar.update(progress - pbar.n)
+        pbar.close()
+
+
+def _is_tar(filename):
+    return filename.endswith(".tar")
+
+
+def _is_targz(filename):
+    return filename.endswith(".tar.gz")
+
+
+def _is_gzip(filename):
+    return filename.endswith(".gz") and not filename.endswith(".tar.gz")
+
+
+def _is_zip(filename):
+    return filename.endswith(".zip")
+
+
+def extract_archive(from_path, to_path=None, remove_finished=False):
+    if to_path is None:
+        to_path = os.path.dirname(from_path)
+
+    if _is_tar(from_path):
+        with tarfile.open(from_path, 'r') as tar:
+            tar.extractall(path=to_path)
+    elif _is_targz(from_path):
+        with tarfile.open(from_path, 'r:gz') as tar:
+            tar.extractall(path=to_path)
+    elif _is_gzip(from_path):
+        to_path = os.path.join(to_path, os.path.splitext(os.path.basename(from_path))[0])
+        with open(to_path, "wb") as out_f, gzip.GzipFile(from_path) as zip_f:
+            out_f.write(zip_f.read())
+    elif _is_zip(from_path):
+        with zipfile.ZipFile(from_path, 'r') as z:
+            z.extractall(to_path)
+    else:
+        raise ValueError("Extraction of {} not supported".format(from_path))
+
+    if remove_finished:
+        os.remove(from_path)
+
+
+def download_and_extract_archive(url, download_root, extract_root=None, filename=None,
+                                 md5=None, remove_finished=False):
+    download_root = os.path.expanduser(download_root)
+    if extract_root is None:
+        extract_root = download_root
+    if not filename:
+        filename = os.path.basename(url)
+
+    download_url(url, download_root, filename, md5)
+
+    archive = os.path.join(download_root, filename)
+    print("Extracting {} to {}".format(archive, extract_root))
+    extract_archive(archive, extract_root, remove_finished)
+
+
+def iterable_to_str(iterable):
+    return "'" + "', '".join([str(item) for item in iterable]) + "'"
+
+
+def verify_str_arg(value, arg=None, valid_values=None, custom_msg=None):
+    if not isinstance(value, torch._six.string_classes):
+        if arg is None:
+            msg = "Expected type str, but got type {type}."
+        else:
+            msg = "Expected type str for argument {arg}, but got type {type}."
+        msg = msg.format(type=type(value), arg=arg)
+        raise ValueError(msg)
+
+    if valid_values is None:
+        return value
+
+    if value not in valid_values:
+        if custom_msg is not None:
+            msg = custom_msg
+        else:
+            msg = ("Unknown value '{value}' for argument {arg}. "
+                   "Valid values are {{{valid_values}}}.")
+            msg = msg.format(value=value, arg=arg,
+                             valid_values=iterable_to_str(valid_values))
+        raise ValueError(msg)
+
+    return value
diff --git a/run_quantization_example.sh b/run_quantization_example.sh
new file mode 100755 (executable)
index 0000000..bc61f04
--- /dev/null
@@ -0,0 +1,39 @@
+#!/bin/bash
+
+# ----------------------------------
+base_dir="./data/checkpoints/quantization"
+date_var=`date '+%Y-%m-%d_%H-%M-%S'`
+logdir=$base_dir/"$date_var"_quantization
+logfile=$logdir/run.log
+echo Logging the output to: $logfile
+
+mkdir $base_dir
+mkdir $logdir
+exec &> >(tee -a "$logfile")
+# ----------------------------------
+
+# model names and pretrained paths from torchvision - add more as required
+declare -A model_pretrained=(
+  [mobilenet_v2]=https://download.pytorch.org/models/mobilenet_v2-b0353104.pth
+  [shufflenetv2_x1.0]=https://download.pytorch.org/models/shufflenetv2_x1-5666bf0f80.pth
+  [resnet50]=https://download.pytorch.org/models/resnet50-19c8e357.pth
+)
+
+for model in "${!model_pretrained[@]}"; do
+  echo ==========================================================
+  pretrained="${model_pretrained[$model]}"
+
+  echo ----------------------------------------------------------
+  echo Estimating evaluation accuracy without quantization for $model
+  python -u ./examples/quantization_example.py ./data/datasets/image_folder_classification --arch $model --batch-size 64 --evaluate --pretrained $pretrained
+
+  echo ----------------------------------------------------------
+  echo Estimating evaluation accuracy with quantization for $model
+  python -u ./examples/quantization_example.py ./data/datasets/image_folder_classification --arch $model --batch-size 64 --evaluate --quantize --pretrained $pretrained
+
+  echo ----------------------------------------------------------
+  echo Quantization Aware Training for $model
+  # note: this example uses only a part of the training epoch and only 10 such (partial) epochs during quantized training to save time,
+  # but it may necessary to use the full training epoch if the accuracy is not satisfactory.
+  python -u ./examples/quantization_example.py ./data/datasets/image_folder_classification --arch $model --batch-size 64 --lr 5e-5 --epoch-size 1000 --epochs 10 --quantize --pretrained $pretrained
+done
\ No newline at end of file