added mobilenetv3 from torchvision and also mobilenetv3_lite models, updated docs master
authorManu Mathew <mathew.manu@ti.com>
Wed, 28 Apr 2021 05:43:23 +0000 (11:13 +0530)
committerManu Mathew <mathew.manu@ti.com>
Wed, 28 Apr 2021 05:45:00 +0000 (11:15 +0530)
229 files changed:
LICENSE
docs/Image_Classification.md
docs/Semantic_Segmentation.md
examples/quantize_example.py [new file with mode: 0644]
examples/write_onnx_model_example.py
modules/pytorch_jacinto_ai/engine/__init__.py
modules/pytorch_jacinto_ai/engine/engine_utils.py
modules/pytorch_jacinto_ai/engine/evaluate_pixel2pixel.py
modules/pytorch_jacinto_ai/engine/infer_classification_onnx_rt.py
modules/pytorch_jacinto_ai/engine/infer_pixel2pixel.py
modules/pytorch_jacinto_ai/engine/infer_pixel2pixel_onnx_rt.py
modules/pytorch_jacinto_ai/engine/test_classification.py
modules/pytorch_jacinto_ai/engine/test_pixel2pixel_onnx.py
modules/pytorch_jacinto_ai/engine/train_classification.py
modules/pytorch_jacinto_ai/engine/train_pixel2pixel.py
modules/pytorch_jacinto_ai/xnn/__init__.py
modules/pytorch_jacinto_ai/xnn/layers/__init__.py
modules/pytorch_jacinto_ai/xnn/layers/activation.py
modules/pytorch_jacinto_ai/xnn/layers/blocks_internal.py [new file with mode: 0644]
modules/pytorch_jacinto_ai/xnn/layers/common_blocks.py
modules/pytorch_jacinto_ai/xnn/layers/conv_blocks.py
modules/pytorch_jacinto_ai/xnn/layers/conv_ws_internal.py [new file with mode: 0644]
modules/pytorch_jacinto_ai/xnn/layers/deconv_blocks.py
modules/pytorch_jacinto_ai/xnn/layers/function.py
modules/pytorch_jacinto_ai/xnn/layers/functional.py
modules/pytorch_jacinto_ai/xnn/layers/functional_blocks.py [new file with mode: 0644]
modules/pytorch_jacinto_ai/xnn/layers/layer_config.py
modules/pytorch_jacinto_ai/xnn/layers/model_utils.py
modules/pytorch_jacinto_ai/xnn/layers/multi_task.py
modules/pytorch_jacinto_ai/xnn/layers/normalization.py
modules/pytorch_jacinto_ai/xnn/layers/quant_ste.py
modules/pytorch_jacinto_ai/xnn/layers/resize_blocks.py
modules/pytorch_jacinto_ai/xnn/layers/rf_blocks.py
modules/pytorch_jacinto_ai/xnn/onnx/__init__.py
modules/pytorch_jacinto_ai/xnn/onnx/onnx2pytorch_internal.py [new file with mode: 0644]
modules/pytorch_jacinto_ai/xnn/optim/__init__.py
modules/pytorch_jacinto_ai/xnn/optim/lr_scheduler.py
modules/pytorch_jacinto_ai/xnn/quantize/__init__.py
modules/pytorch_jacinto_ai/xnn/quantize/hooked_module.py
modules/pytorch_jacinto_ai/xnn/quantize/quant_base_module.py
modules/pytorch_jacinto_ai/xnn/quantize/quant_calib_module.py
modules/pytorch_jacinto_ai/xnn/quantize/quant_graph_module.py
modules/pytorch_jacinto_ai/xnn/quantize/quant_test_module.py
modules/pytorch_jacinto_ai/xnn/quantize/quant_train_module.py
modules/pytorch_jacinto_ai/xnn/quantize_torch_internal/__init__.py [new file with mode: 0644]
modules/pytorch_jacinto_ai/xnn/quantize_torch_internal/quant_torch_base_module.py [new file with mode: 0644]
modules/pytorch_jacinto_ai/xnn/quantize_torch_internal/quant_torch_eagercalib_module.py [new file with mode: 0644]
modules/pytorch_jacinto_ai/xnn/quantize_torch_internal/quant_torch_eagerdistill_module.py [new file with mode: 0644]
modules/pytorch_jacinto_ai/xnn/quantize_torch_internal/quant_torch_eagertrain_module.py [new file with mode: 0644]
modules/pytorch_jacinto_ai/xnn/quantize_torch_internal/quant_torch_qconfig.py [new file with mode: 0644]
modules/pytorch_jacinto_ai/xnn/quantize_torch_internal/quant_torch_qconfig_qat.py [new file with mode: 0644]
modules/pytorch_jacinto_ai/xnn/quantize_torch_internal/quant_torch_scriptcalib_module.py [new file with mode: 0644]
modules/pytorch_jacinto_ai/xnn/utils/__init__.py
modules/pytorch_jacinto_ai/xnn/utils/amp.py [new file with mode: 0644]
modules/pytorch_jacinto_ai/xnn/utils/attr_dict.py
modules/pytorch_jacinto_ai/xnn/utils/bn_utils.py
modules/pytorch_jacinto_ai/xnn/utils/count_flops.py
modules/pytorch_jacinto_ai/xnn/utils/data_utils.py
modules/pytorch_jacinto_ai/xnn/utils/depth_utils.py
modules/pytorch_jacinto_ai/xnn/utils/export_utils_internal.py [new file with mode: 0644]
modules/pytorch_jacinto_ai/xnn/utils/function_utils.py
modules/pytorch_jacinto_ai/xnn/utils/hist_utils.py
modules/pytorch_jacinto_ai/xnn/utils/image_utils.py
modules/pytorch_jacinto_ai/xnn/utils/load_weights.py
modules/pytorch_jacinto_ai/xnn/utils/logger.py
modules/pytorch_jacinto_ai/xnn/utils/module_utils.py
modules/pytorch_jacinto_ai/xnn/utils/print_utils.py
modules/pytorch_jacinto_ai/xnn/utils/quant_utils.py [new file with mode: 0644]
modules/pytorch_jacinto_ai/xnn/utils/range_estimator_internal.py [new file with mode: 0644]
modules/pytorch_jacinto_ai/xnn/utils/range_utils.py
modules/pytorch_jacinto_ai/xnn/utils/tensor_utils.py
modules/pytorch_jacinto_ai/xnn/utils/tensor_utils_internal.py [new file with mode: 0644]
modules/pytorch_jacinto_ai/xnn/utils/weights_utils.py
modules/pytorch_jacinto_ai/xreferences/classification/README.md [new file with mode: 0644]
modules/pytorch_jacinto_ai/xreferences/classification/train.py [new file with mode: 0644]
modules/pytorch_jacinto_ai/xreferences/classification/utils.py [new file with mode: 0644]
modules/pytorch_jacinto_ai/xreferences/segmentation/README.md [new file with mode: 0644]
modules/pytorch_jacinto_ai/xreferences/segmentation/coco_utils.py [new file with mode: 0644]
modules/pytorch_jacinto_ai/xreferences/segmentation/run_segmentation.sh [new file with mode: 0755]
modules/pytorch_jacinto_ai/xreferences/segmentation/train.py [new file with mode: 0644]
modules/pytorch_jacinto_ai/xreferences/segmentation/transforms.py [new file with mode: 0644]
modules/pytorch_jacinto_ai/xreferences/segmentation/utils.py [new file with mode: 0644]
modules/pytorch_jacinto_ai/xvision/LICENSE [new file with mode: 0644]
modules/pytorch_jacinto_ai/xvision/README.rst [new file with mode: 0644]
modules/pytorch_jacinto_ai/xvision/__init__.py
modules/pytorch_jacinto_ai/xvision/datasets/caltech.py
modules/pytorch_jacinto_ai/xvision/datasets/celeba.py
modules/pytorch_jacinto_ai/xvision/datasets/cifar.py
modules/pytorch_jacinto_ai/xvision/datasets/cityscapes.py
modules/pytorch_jacinto_ai/xvision/datasets/classification/__init__.py
modules/pytorch_jacinto_ai/xvision/datasets/coco.py
modules/pytorch_jacinto_ai/xvision/datasets/fakedata.py
modules/pytorch_jacinto_ai/xvision/datasets/flickr.py
modules/pytorch_jacinto_ai/xvision/datasets/folder.py
modules/pytorch_jacinto_ai/xvision/datasets/hmdb51.py
modules/pytorch_jacinto_ai/xvision/datasets/imagenet.py
modules/pytorch_jacinto_ai/xvision/datasets/kinetics.py
modules/pytorch_jacinto_ai/xvision/datasets/lsun.py
modules/pytorch_jacinto_ai/xvision/datasets/mnist.py
modules/pytorch_jacinto_ai/xvision/datasets/omniglot.py
modules/pytorch_jacinto_ai/xvision/datasets/phototour.py
modules/pytorch_jacinto_ai/xvision/datasets/pixel2pixel/__init__.py
modules/pytorch_jacinto_ai/xvision/datasets/pixel2pixel/a2d2.py
modules/pytorch_jacinto_ai/xvision/datasets/pixel2pixel/ade20k.py [new file with mode: 0644]
modules/pytorch_jacinto_ai/xvision/datasets/pixel2pixel/calculate_class_weights.py
modules/pytorch_jacinto_ai/xvision/datasets/pixel2pixel/calculate_class_weights_internal.py [new file with mode: 0644]
modules/pytorch_jacinto_ai/xvision/datasets/pixel2pixel/cityscapes_plus.py
modules/pytorch_jacinto_ai/xvision/datasets/pixel2pixel/coco_plus.py [new file with mode: 0644]
modules/pytorch_jacinto_ai/xvision/datasets/pixel2pixel/dataset_utils.py
modules/pytorch_jacinto_ai/xvision/datasets/pixel2pixel/flyingchairs.py
modules/pytorch_jacinto_ai/xvision/datasets/pixel2pixel/kitti_depth.py
modules/pytorch_jacinto_ai/xvision/datasets/pixel2pixel/kitti_sceneflow.py
modules/pytorch_jacinto_ai/xvision/datasets/pixel2pixel/mpisintel.py
modules/pytorch_jacinto_ai/xvision/datasets/pixel2pixel/multi_dataset_internal.py [new file with mode: 0644]
modules/pytorch_jacinto_ai/xvision/datasets/pixel2pixel/segmentation.py
modules/pytorch_jacinto_ai/xvision/datasets/pixel2pixel/tiad_dataset_internal.py [new file with mode: 0755]
modules/pytorch_jacinto_ai/xvision/datasets/sbd.py
modules/pytorch_jacinto_ai/xvision/datasets/sbu.py
modules/pytorch_jacinto_ai/xvision/datasets/semeion.py
modules/pytorch_jacinto_ai/xvision/datasets/stl10.py
modules/pytorch_jacinto_ai/xvision/datasets/svhn.py
modules/pytorch_jacinto_ai/xvision/datasets/ucf101.py
modules/pytorch_jacinto_ai/xvision/datasets/usps.py
modules/pytorch_jacinto_ai/xvision/datasets/utils.py
modules/pytorch_jacinto_ai/xvision/datasets/video_utils.py
modules/pytorch_jacinto_ai/xvision/datasets/vision.py
modules/pytorch_jacinto_ai/xvision/datasets/voc.py
modules/pytorch_jacinto_ai/xvision/extension.py
modules/pytorch_jacinto_ai/xvision/io/__init__.py
modules/pytorch_jacinto_ai/xvision/io/_video_opt.py [new file with mode: 0644]
modules/pytorch_jacinto_ai/xvision/io/image.py [new file with mode: 0644]
modules/pytorch_jacinto_ai/xvision/io/video.py
modules/pytorch_jacinto_ai/xvision/losses/__init__.py
modules/pytorch_jacinto_ai/xvision/losses/basic_loss.py
modules/pytorch_jacinto_ai/xvision/losses/flow_loss.py
modules/pytorch_jacinto_ai/xvision/losses/interest_pt_loss.py
modules/pytorch_jacinto_ai/xvision/losses/loss_utils.py
modules/pytorch_jacinto_ai/xvision/losses/norm_loss.py
modules/pytorch_jacinto_ai/xvision/losses/scale_loss.py
modules/pytorch_jacinto_ai/xvision/losses/segmentation_loss.py
modules/pytorch_jacinto_ai/xvision/losses/segmentation_loss_internal.py [new file with mode: 0644]
modules/pytorch_jacinto_ai/xvision/losses/unflow_loss_internal.py [new file with mode: 0644]
modules/pytorch_jacinto_ai/xvision/models/__init__.py
modules/pytorch_jacinto_ai/xvision/models/_utils.py
modules/pytorch_jacinto_ai/xvision/models/classification/__init__.py
modules/pytorch_jacinto_ai/xvision/models/densenet.py
modules/pytorch_jacinto_ai/xvision/models/detection/_utils.py
modules/pytorch_jacinto_ai/xvision/models/detection/backbone_utils.py
modules/pytorch_jacinto_ai/xvision/models/detection/faster_rcnn.py
modules/pytorch_jacinto_ai/xvision/models/detection/generalized_rcnn.py
modules/pytorch_jacinto_ai/xvision/models/detection/image_list.py
modules/pytorch_jacinto_ai/xvision/models/detection/keypoint_rcnn.py
modules/pytorch_jacinto_ai/xvision/models/detection/mask_rcnn.py
modules/pytorch_jacinto_ai/xvision/models/detection/roi_heads.py
modules/pytorch_jacinto_ai/xvision/models/detection/rpn.py
modules/pytorch_jacinto_ai/xvision/models/detection/transform.py
modules/pytorch_jacinto_ai/xvision/models/mnasnet.py
modules/pytorch_jacinto_ai/xvision/models/mobilenetv1.py
modules/pytorch_jacinto_ai/xvision/models/mobilenetv1_internal.py [new file with mode: 0644]
modules/pytorch_jacinto_ai/xvision/models/mobilenetv2.py
modules/pytorch_jacinto_ai/xvision/models/mobilenetv2_densenas_internal.py [new file with mode: 0644]
modules/pytorch_jacinto_ai/xvision/models/mobilenetv2_ericsun_internal.py [new file with mode: 0644]
modules/pytorch_jacinto_ai/xvision/models/mobilenetv2_internal.py [new file with mode: 0644]
modules/pytorch_jacinto_ai/xvision/models/mobilenetv2_shicai_internal.py [new file with mode: 0644]
modules/pytorch_jacinto_ai/xvision/models/mobilenetv3.py [new file with mode: 0644]
modules/pytorch_jacinto_ai/xvision/models/mobilenetv3_lite.py [new file with mode: 0644]
modules/pytorch_jacinto_ai/xvision/models/multi_input_net.py
modules/pytorch_jacinto_ai/xvision/models/pixel2pixel/__init__.py
modules/pytorch_jacinto_ai/xvision/models/pixel2pixel/bifpnlite_pixel2pixel.py [new file with mode: 0644]
modules/pytorch_jacinto_ai/xvision/models/pixel2pixel/deeplabv3lite.py
modules/pytorch_jacinto_ai/xvision/models/pixel2pixel/deeplabv3lite_internal.py [new file with mode: 0644]
modules/pytorch_jacinto_ai/xvision/models/pixel2pixel/flownet_internal.py [new file with mode: 0644]
modules/pytorch_jacinto_ai/xvision/models/pixel2pixel/fpnlite_pixel2pixel.py
modules/pytorch_jacinto_ai/xvision/models/pixel2pixel/pixel2pixelnet.py
modules/pytorch_jacinto_ai/xvision/models/pixel2pixel/pixel2pixelnet_internal.py [new file with mode: 0644]
modules/pytorch_jacinto_ai/xvision/models/pixel2pixel/pixel2pixelnet_utils.py
modules/pytorch_jacinto_ai/xvision/models/pixel2pixel/unetlite_pixel2pixel.py
modules/pytorch_jacinto_ai/xvision/models/regnet.py
modules/pytorch_jacinto_ai/xvision/models/resnet.py
modules/pytorch_jacinto_ai/xvision/models/segmentation/_utils.py
modules/pytorch_jacinto_ai/xvision/models/segmentation/deeplabv3.py
modules/pytorch_jacinto_ai/xvision/models/segmentation/segmentation.py
modules/pytorch_jacinto_ai/xvision/models/shufflenetv2.py
modules/pytorch_jacinto_ai/xvision/models/utils.py
modules/pytorch_jacinto_ai/xvision/ops/__init__.py
modules/pytorch_jacinto_ai/xvision/ops/_register_onnx_ops.py [new file with mode: 0644]
modules/pytorch_jacinto_ai/xvision/ops/_utils.py
modules/pytorch_jacinto_ai/xvision/ops/boxes.py
modules/pytorch_jacinto_ai/xvision/ops/deform_conv.py [new file with mode: 0644]
modules/pytorch_jacinto_ai/xvision/ops/feature_pyramid_network.py
modules/pytorch_jacinto_ai/xvision/ops/misc.py
modules/pytorch_jacinto_ai/xvision/ops/new_empty_tensor.py [new file with mode: 0644]
modules/pytorch_jacinto_ai/xvision/ops/poolers.py
modules/pytorch_jacinto_ai/xvision/ops/ps_roi_align.py [new file with mode: 0644]
modules/pytorch_jacinto_ai/xvision/ops/ps_roi_pool.py [new file with mode: 0644]
modules/pytorch_jacinto_ai/xvision/ops/roi_align.py
modules/pytorch_jacinto_ai/xvision/ops/roi_pool.py
modules/pytorch_jacinto_ai/xvision/transforms/__init__.py
modules/pytorch_jacinto_ai/xvision/transforms/_functional_video.py [new file with mode: 0644]
modules/pytorch_jacinto_ai/xvision/transforms/_transforms_video.py [new file with mode: 0644]
modules/pytorch_jacinto_ai/xvision/transforms/functional.py
modules/pytorch_jacinto_ai/xvision/transforms/functional_pil.py [new file with mode: 0644]
modules/pytorch_jacinto_ai/xvision/transforms/functional_tensor.py [new file with mode: 0644]
modules/pytorch_jacinto_ai/xvision/transforms/image_transform_utils.py
modules/pytorch_jacinto_ai/xvision/transforms/image_transforms.py
modules/pytorch_jacinto_ai/xvision/transforms/image_transforms_xv12.py
modules/pytorch_jacinto_ai/xvision/transforms/transforms.py
modules/pytorch_jacinto_ai/xvision/utils.py
requirements.txt
requirements_conda.txt
run_classification.sh
run_depth.sh
run_onnxexport.sh [new file with mode: 0755]
run_quantization.sh
run_quantization_example.sh
run_segmentation.sh
scripts/evaluate_segmentation_main.py
scripts/infer_classification_onnx_rt_main.py
scripts/infer_segmentation_main.py
scripts/infer_segmentation_onnx_main.py
scripts/test_classification_main.py
scripts/train_classification_main.py
scripts/train_depth_main.py
scripts/train_motion_segmentation_main.py
scripts/train_pixel2pixel_multitask_main.py
scripts/train_segmentation_main.py
setup.py
setup.sh
version.py

diff --git a/LICENSE b/LICENSE
index 455a0f8966410037aff7174c0f4def013e5bedd6..4cab812f190818e6c82868c9533a1e10f216fcf0 100644 (file)
--- a/LICENSE
+++ b/LICENSE
@@ -1,5 +1,5 @@
-Texas Instruments (C) 2018-2019 
-All Rights Reserved
+Copyright (c) 2018-2021, Texas Instruments Incorporated - http://www.ti.com
+All Rights Reserved.
 
 Redistribution and use in source and binary forms, with or without
 modification, are permitted provided that the following conditions are met:
index c972a907cdbdab9aca019fbe7f13ceb5f0fe5414..99170adea6ce11de5604e52dd485191cc0321169 100644 (file)
@@ -121,6 +121,8 @@ ImageNet classification results are as follows:
 |ImageNet |ResNet50-0.5       |256x256          |224x224        |1.051                |**72.05**      |resnet50_xp5            |
 |ImageNet |**RegNetX800MF**   |256x256          |224x224        |0.800                |               |regnetx800mf_x1         |
 |.
+|ImageNet |MobileNetV2-QAT*   |256x256          |224x224        |0.296                |71.76          |                        |
+|.
 |ImageNet |MobileNetV1[1]     |256x256          |224x224        |0.569                |70.60          |                        |
 |ImageNet |MobileNetV2[2]     |256x256          |224x224        |0.300                |72.00          |                        |
 |ImageNet |ResNet50[3]        |256x256          |224x224        |4.087                |76.15          |                        |
@@ -128,6 +130,8 @@ ImageNet classification results are as follows:
 |ImageNet |**RegNetX800MF**[4]|256x256          |224x224        |0.800                |**75.2**       |                        |
 |ImageNet |RegNetX1.6F[4]     |256x256          |224x224        |1.6                  |**77.0**       |                        |
 
+*- Quantization Aware Training using 8b precision
+
 
 #### Notes
 - As can be seen from the table, the models included in this repository provide a good Accuracy/Complexity tradeoff. 
index 7203feba3f7d45cd41dca694d74f057689bc536b..31498c242811a7b7ae2c37660587fdaae419ec1b 100644 (file)
@@ -111,37 +111,57 @@ python ./scripts/infer_segmentation_main.py --phase validation --model_name deep
  
 ## Results
 
-Dataset    |Mode Name                     |Input Size |GigaMACs  |Accuracy%      |Model Name       |
-|----------|------------------------------|-----------|----------|---------------|-----------------|
-|Cityscapes|RegNetX800MF+FPNLite          |768x384    |**8.84**  |**72.01**      |fpnlite_pixel2pixel_aspp_regnetx800mf  |
-|Cityscapes|RegNetX1.6GF+FPNLite          |1024x512   |**24.29** |**75.84**      |fpnlite_pixel2pixel_aspp_regnetx1p6gf  |
-|Cityscapes|RegNetX3.2GF+FPNLite          |1024x512   |**49.40** |**77.24**      |fpnlite_pixel2pixel_aspp_regnetx3p2gf  |
-|Cityscapes|RegNetX3.2FF+FPNLite          |1536x768   |**111.16**|**78.90**      |fpnlite_pixel2pixel_aspp_regnetx3p2gf  |
+|Dataset   |Mode Name                     |Input Size |GigaMACs  |mIoU Accuracy% |Model Name       |Notes |
+|----------|------------------------------|-----------|----------|---------------|-----------------|------|
+|           |**ADE20K32 dataset models**
+|ADE20K32   |MobileNetV2S16+DeepLabV3Lite |512x512    |3.28      |51.01          |                |      | 
+|ADE20K32   |MobileNetV2+UNetLite         |512x512    |2.427     |49.95          |                |      | 
+|ADE20K32   |MobileNetV2+FPNLite          |512x512    |3.481     |50.72          |                |      | 
+|ADE20K32   |MobileNetV2-1.4+FPNLite      |512x512    |6.646     |52.93          |                |      | 
+|- 
+|ADE20K32   |RegNetX400MF+FPNLite         |384x384    |3.1526    |51.03          |                |      | 
+|ADE20K32   |RegNetX800MF+FPNLite         |512x512    |8.0683    |53.29          |                |      | 
+|           |**COCOSeg21 dataset models**
+|COCOSeg21  |MobileNetV2S16+DeepLabV3Lite |512x512    |3.161     |57.77          |                |      | 
+|COCOSeg21  |MobileNetV2+UNetLite         |512x512    |2.009     |57.01          |                |      | 
+|COCOSeg21  |MobileNetV2+FPNLite          |512x512    |3.357     |               |                |      | 
+|COCOSeg21  |RegNetX800MF+FPNLite         |512x512    |7.864     |61.15          |                |      | 
+|           |**Cityscapes dataset models**
+|Cityscapes|RegNetX800MF+FPNLite          |768x384    |**8.84**  |**72.01**      |fpnlite_pixel2pixel_aspp_regnetx800mf  |      | 
+|Cityscapes|RegNetX1.6GF+FPNLite          |1024x512   |**24.29** |**75.84**      |fpnlite_pixel2pixel_aspp_regnetx1p6gf  |      | 
+|Cityscapes|RegNetX3.2GF+FPNLite          |1024x512   |**49.40** |**77.24**      |fpnlite_pixel2pixel_aspp_regnetx3p2gf  |      | 
+|Cityscapes|RegNetX3.2FF+FPNLite          |1536x768   |**111.16**|**78.90**      |fpnlite_pixel2pixel_aspp_regnetx3p2gf  |      | 
 |-
-|Cityscapes|RegNetX400MF+FPNLite          |768x384    |**6.09**  |**68.03**      |fpnlite_pixel2pixel_aspp_regnetx400mf  |
-|Cityscapes|RegNetX400MF+FPNLite          |1536x768   |**24.37** |**73.96**      |fpnlite_pixel2pixel_aspp_regnetx400mf  |
+|Cityscapes|RegNetX400MF+FPNLite          |768x384    |**6.09**  |**68.03**      |fpnlite_pixel2pixel_aspp_regnetx400mf  |      | 
+|Cityscapes|RegNetX400MF+FPNLite          |1536x768   |**24.37** |**73.96**      |fpnlite_pixel2pixel_aspp_regnetx400mf  |      | 
 |-
-|Cityscapes|MobileNetV2S16+DeepLabV3Lite  |768x384    |**3.54**  |**69.13**      |deeplabv3lite_mobilenetv2_tv           |
-|Cityscapes|MobileNetV2+UNetLite          |768x384    |**2.20**  |**68.94**      |unetlite_pixel2pixel_aspp_mobilenetv2_tv |
-|Cityscapes|MobileNetV2+FPNLite           |768x384    |**3.84**  |**70.39**      |fpnlite_pixel2pixel_aspp_mobilenetv2_tv |
-|Cityscapes|MobileNetV2+FPNLite           |1536x768   |**15.07** |**74.61**      |fpnlite_pixel2pixel_aspp_mobilenetv2_tv |
+|Cityscapes|MobileNetV2S16+DeepLabV3Lite  |768x384    |**3.54**  |**69.13**      |deeplabv3lite_mobilenetv2_tv           |      | 
+|Cityscapes|MobileNetV2+UNetLite          |768x384    |**2.20**  |**68.94**      |unetlite_pixel2pixel_aspp_mobilenetv2_tv |      | 
+|Cityscapes|MobileNetV2+FPNLite           |768x384    |**3.84**  |**70.39**      |fpnlite_pixel2pixel_aspp_mobilenetv2_tv |      | 
+|Cityscapes|MobileNetV2+FPNLite           |1536x768   |**15.07** |**74.61**      |fpnlite_pixel2pixel_aspp_mobilenetv2_tv |      | 
 |-
-|Cityscapes|FD-MobileNetV2+FPNLite        |1536x768   |**3.96**  |**71.28**      |fpnlite_pixel2pixel_aspp_mobilenetv2_tv_fd |
-|Cityscapes|FD-MobileNetV2+FPNLite        |2048x1024  |**7.03**  |**72.67**      |fpnlite_pixel2pixel_aspp_mobilenetv2_tv_fd |
+|Cityscapes|FD-MobileNetV2+FPNLite        |1536x768   |**3.96**  |**71.28**      |fpnlite_pixel2pixel_aspp_mobilenetv2_tv_fd |      | 
+|Cityscapes|FD-MobileNetV2+FPNLite        |2048x1024  |**7.03**  |**72.67**      |fpnlite_pixel2pixel_aspp_mobilenetv2_tv_fd |      | 
 |-
-|Cityscapes|MobileNetV2+DeepLab[10]       |769x769    |21.27     |70.71          |                 |
-|Cityscapes|MobileNetV3+DeepLab[10]       |769x769    |15.95     |72.41          |                 |
-|Cityscapes|Xception65+DeepLab[10]        |769x769    |418.64    |78.79          |                 |
-|Cityscapes|ERFNet[8]                     |1024x512   |27.705    |69.7           |                 |
-|Cityscapes|MobileNetV2+SwiftNet[9]       |2048x1024  |41.0      |75.3           |                 |
-|Cityscapes|ResNet50+FCN[3][11]           |1040x520   |285.4     |71.6           |                 |
-|Cityscapes|ResNet50+DeepLabV3[5][11]     |1040x520   |337.5     |73.5           |                 |
+|Cityscapes |MobileNetV2S16+DeepLabV3Lite-QAT* |768x384  |**3.54**  |**68.77**   |                |      |
+|Cityscapes |MobileNetV2+UNetLite-QAT*         |768x384  |**2.20**  |**68.18**   |                |      |
+|Cityscapes |MobileNetV2+FPNLite-QAT*          |768x384  |**3.84**  |**69.88**   |                |      |
+|-
+|Cityscapes|MobileNetV2+DeepLab[10]       |769x769    |21.27     |70.71          |                |      | 
+|Cityscapes|MobileNetV3+DeepLab[10]       |769x769    |15.95     |72.41          |                |      | 
+|Cityscapes|Xception65+DeepLab[10]        |769x769    |418.64    |78.79          |                |      | 
+|Cityscapes|ERFNet[8]                     |1024x512   |27.705    |69.7           |                |      | 
+|Cityscapes|MobileNetV2+SwiftNet[9]       |2048x1024  |41.0      |75.3           |                |      | 
+|Cityscapes|ResNet50+FCN[3][11]           |1040x520   |285.4     |71.6           |                |      | 
+|Cityscapes|ResNet50+DeepLabV3[5][11]     |1040x520   |337.5     |73.5           |                |      | 
+
+*- Quantization Aware Training using 8b precision
 
 
 #### Notes
 - The suffix 'Lite' in the name of models such as DeepLabV3Lite, FPNLite & UNetLite indicates the use of Depthwise convolutions or Grouped convolutions. If the feature extractor (encoder) uses Depthwise Convolutions, then Depthwise convolutions are used throughout such models, even in the neck and decoder. If the feature extractor (encoder) uses grouped convolutions as in the case of RegNetX, then grouped convolutions (with the same group size as that of the feature extractor) are used even in the neck and decoder.<br>
 - GigaMACS: Complexity in Giga Multiply-Accumulations (lower is better). This is an important metric to watch out for when selecting models for embedded inference.<br>
-- Accuracy%: Original Floating Point Validation Accuracy% obtained after training.<br>
+- mIoU Accuracy%: Original Floating Point Validation Mean IoU Accuracy% obtained after training.<br>
 - Overall, RegNetX based models are highly recommend as they strike a good balance between Complexity (GigaMACS), speed of inference on device and easiness of Quantization and we recommend them.<br>
 
 
diff --git a/examples/quantize_example.py b/examples/quantize_example.py
new file mode 100644 (file)
index 0000000..1aa396c
--- /dev/null
@@ -0,0 +1,685 @@
+# Copyright (c) 2018-2021, Texas Instruments
+# All Rights Reserved.
+#
+# Redistribution and use in source and binary forms, with or without
+# modification, are permitted provided that the following conditions are met:
+#
+# * Redistributions of source code must retain the above copyright notice, this
+#   list of conditions and the following disclaimer.
+#
+# * Redistributions in binary form must reproduce the above copyright notice,
+#   this list of conditions and the following disclaimer in the documentation
+#   and/or other materials provided with the distribution.
+#
+# * Neither the name of the copyright holder nor the names of its
+#   contributors may be used to endorse or promote products derived from
+#   this software without specific prior written permission.
+#
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
+# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
+# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
+# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
+# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
+# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
+# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+
+#####################################################################################
+
+# Post Training Quantization (PTQ) / Quantization Aware Training (QAT) Example
+# this original code is from: https://github.com/pytorch/examples/blob/master/imagenet/main.py
+# the changes required for quantizing the model are under the flag args.quantize
+#
+# BSD 3-Clause License
+#
+# Copyright (c) 2017,
+# All rights reserved.
+#
+# Redistribution and use in source and binary forms, with or without
+# modification, are permitted provided that the following conditions are met:
+#
+# * Redistributions of source code must retain the above copyright notice, this
+#   list of conditions and the following disclaimer.
+#
+# * Redistributions in binary form must reproduce the above copyright notice,
+#   this list of conditions and the following disclaimer in the documentation
+#   and/or other materials provided with the distribution.
+#
+# * Neither the name of the copyright holder nor the names of its
+#   contributors may be used to endorse or promote products derived from
+#   this software without specific prior written permission.
+#
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
+# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
+# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
+# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
+# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
+# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
+# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+
+import argparse
+import os
+import random
+import shutil
+import time
+import warnings
+
+import torch
+import torch.nn as nn
+import torch.nn.parallel
+import torch.backends.cudnn as cudnn
+import torch.distributed as dist
+import torch.optim
+import torch.multiprocessing as mp
+import torch.utils.data
+import torch.utils.data.distributed
+import torchvision.transforms as transforms
+import torchvision.datasets as datasets
+
+# some of the default torchvision models need some minor tweaks to be friendly for
+# quantization aware training. so use models from pytorch_jacinto_ai.xvision insead
+#import torchvision.models as models
+
+from pytorch_jacinto_ai import xnn
+from pytorch_jacinto_ai import xvision as xvision
+from pytorch_jacinto_ai.xvision import models
+
+model_names = sorted(name for name in models.__dict__
+    if name.islower() and not name.startswith("__")
+    and callable(models.__dict__[name]))
+
+def str2bool(v):
+  if isinstance(v, (str)):
+      if v.lower() in ("yes", "true", "t", "1"):
+          return True
+      elif v.lower() in ("no", "false", "f", "0"):
+          return False
+      else:
+          return v
+      #
+  else:
+      return v
+
+parser = argparse.ArgumentParser(description='PyTorch ImageNet Training')
+parser.add_argument('data', metavar='DIR',
+                    help='path to dataset')
+parser.add_argument('-a', '--arch', metavar='ARCH', default='resnet18',
+                    choices=model_names,
+                    help='model architecture: ' +
+                        ' | '.join(model_names) +
+                        ' (default: resnet18)')
+parser.add_argument('-j', '--workers', default=4, type=int, metavar='N',
+                    help='number of data loading workers (default: 4)')
+parser.add_argument('--epochs', default=90, type=int, metavar='N',
+                    help='number of total epochs to run')
+parser.add_argument('--epoch_size', default=0, type=float, metavar='N',
+                    help='fraction of training epoch to use. 0 (default) means full training epoch')
+parser.add_argument('--epoch_size_val', default=0, type=float, metavar='N',
+                    help='fraction of validation epoch to use. 0 (default) means full validation epoch')
+parser.add_argument('--start_epoch', default=0, type=int, metavar='N',
+                    help='manual epoch number (useful on restarts)')
+parser.add_argument('-b', '--batch_size', default=256, type=int,
+                    metavar='N',
+                    help='mini-batch size (default: 256), this is the total '
+                         'batch size of all GPUs on the current node when '
+                         'using Data Parallel or Distributed Data Parallel')
+parser.add_argument('--lr', '--learning_rate', default=0.1, type=float,
+                    metavar='LR', help='initial learning rate', dest='lr')
+parser.add_argument('--lr_step_size', default=30, type=int,
+                    metavar='N', help='number of steps before learning rate is reduced')
+parser.add_argument('--momentum', default=0.9, type=float, metavar='M',
+                    help='momentum')
+parser.add_argument('--wd', '--weight_decay', default=1e-4, type=float,
+                    metavar='W', help='weight decay (default: 1e-4)',
+                    dest='weight_decay')
+parser.add_argument('-p', '--print_freq', default=100, type=int,
+                    metavar='N', help='print frequency (default: 10)')
+parser.add_argument('--resume', default='', type=str, metavar='PATH',
+                    help='path to latest checkpoint (default: none)')
+parser.add_argument('-e', '--evaluate', dest='evaluate', action='store_true',
+                    help='evaluate model on validation set')
+parser.add_argument('--pretrained', type=str, default=None,
+                    help='use pre-trained model')
+parser.add_argument('--world_size', default=-1, type=int,
+                    help='number of nodes for distributed training')
+parser.add_argument('--rank', default=-1, type=int,
+                    help='node rank for distributed training')
+parser.add_argument('--dist_url', default='tcp://224.66.41.62:23456', type=str,
+                    help='url used to set up distributed training')
+parser.add_argument('--dist_backend', default='nccl', type=str,
+                    help='distributed backend')
+parser.add_argument('--seed', default=None, type=int,
+                    help='seed for initializing training. ')
+parser.add_argument('--use_gpu', default=True, type=str2bool,
+                    help='whether to use gpu or not')
+parser.add_argument('--gpu', default=None, type=int,
+                    help='GPU id to use.')
+parser.add_argument('--multiprocessing_distributed', action='store_true',
+                    help='Use multi-processing distributed training to launch '
+                         'N processes per node, which has N GPUs. This is the '
+                         'fastest way to use PyTorch for either single node or '
+                         'multi node data parallel training')
+parser.add_argument('--save_path', type=str, default='./data/checkpoints/quantization',
+                    help='path to save the logs and models')
+parser.add_argument('--quantize', default=False, choices=[False, 'ptq', 'distill', 'qat', True],
+                    help='Enable Quantization')
+parser.add_argument('--quantize_torch', default=False, type=str2bool,
+                    help='Enable PyTorch Quantization')
+parser.add_argument('--opset_version', default=11, type=int,
+                    help='opset version for onnx export')
+
+best_acc1 = 0
+
+
+def main():
+    args = parser.parse_args()
+
+    args.cur_lr = args.lr
+
+    if args.use_gpu is None:
+        args.use_gpu = (not args.quantize)
+
+    args.do_ptq = (args.quantize is not False) and \
+        (args.quantize == 'ptq' or args.quantize == 'distill')
+
+    if args.seed is not None:
+        random.seed(args.seed)
+        torch.manual_seed(args.seed)
+        cudnn.deterministic = True
+        warnings.warn('You have chosen to seed training. '
+                      'This will turn on the CUDNN deterministic setting, '
+                      'which can slow down your training considerably! '
+                      'You may see unexpected behavior when restarting '
+                      'from checkpoints.')
+
+    if args.gpu is not None:
+        warnings.warn('You have chosen a specific GPU. This will completely '
+                      'disable data parallelism.')
+
+    if args.dist_url == "env://" and args.world_size == -1:
+        args.world_size = int(os.environ["WORLD_SIZE"])
+
+    args.distributed = args.world_size > 1 or args.multiprocessing_distributed
+
+    ngpus_per_node = torch.cuda.device_count()
+    if args.multiprocessing_distributed:
+        # Since we have ngpus_per_node processes per node, the total world_size
+        # needs to be adjusted accordingly
+        args.world_size = ngpus_per_node * args.world_size
+        # Use torch.multiprocessing.spawn to launch distributed processes: the
+        # main_worker process function
+        mp.spawn(main_worker, nprocs=ngpus_per_node, args=(ngpus_per_node, args))
+    else:
+        # Simply call main_worker function
+        main_worker(args.gpu, ngpus_per_node, args)
+
+
+def main_worker(gpu, ngpus_per_node, args):
+    global best_acc1
+    args.gpu = gpu
+
+    if args.gpu is not None:
+        print("Use GPU: {} for training".format(args.gpu))
+
+    if args.distributed:
+        if args.dist_url == "env://" and args.rank == -1:
+            args.rank = int(os.environ["RANK"])
+        if args.multiprocessing_distributed:
+            # For multiprocessing distributed training, rank needs to be the
+            # global rank among all the processes
+            args.rank = args.rank * ngpus_per_node + gpu
+        dist.init_process_group(backend=args.dist_backend, init_method=args.dist_url,
+                                world_size=args.world_size, rank=args.rank)
+    # create model
+    print("=> creating model '{}'".format(args.arch))
+    model = models.__dict__[args.arch]()
+
+    if args.quantize is not False:
+        if args.use_gpu:
+            warnings.warn('quantized inference/test may fail as it is not yet supported in gpu: '
+                          'use_gpu should not be set while quantizing')
+        #
+        # DistributedDataParallel / DataParallel are not supported with quantization
+        dummy_input = torch.rand((1, 3, 224, 224))
+        if args.quantize_torch:
+            # GPU/CUDA is not yet support for Torch quantization
+            if args.evaluate:
+                model = xnn.quantize_torch.QuantTorchEagerTestModule(model, dummy_input=dummy_input)
+            elif (args.do_ptq and args.quantize == 'distill'):
+                model = xnn.quantize_torch.QuantTorchEagerDistillModule(model, dummy_input=dummy_input, learning_rate=args.lr)
+            elif args.do_ptq:
+                model = xnn.quantize_torch.QuantTorchEagerCalibrateModule(model, dummy_input=dummy_input)
+            else:
+                model = xnn.quantize_torch.QuantTorchEagerTrainModule(model, dummy_input=dummy_input)
+            #
+        else:
+            if args.evaluate:
+                model = xnn.quantize.QuantTestModule(model, dummy_input=dummy_input)
+            elif args.do_ptq:
+                model = xnn.quantize.QuantCalibrateModule(model, dummy_input=dummy_input)
+            else:
+                model = xnn.quantize.QuantTrainModule(model, dummy_input=dummy_input)
+            #
+        #
+        if args.use_gpu:
+            model = model.cuda(args.gpu)
+        #
+
+    if args.use_gpu:
+        if args.distributed and (not args.do_ptq):
+            # For multiprocessing distributed, DistributedDataParallel constructor
+            # should always set the single device scope, otherwise,
+            # DistributedDataParallel will use all available devices.
+            if args.gpu is not None:
+                torch.cuda.set_device(args.gpu)
+                model.cuda(args.gpu)
+                # When using a single GPU per process and per
+                # DistributedDataParallel, we need to divide the batch size
+                # ourselves based on the total number of GPUs we have
+                args.batch_size = int(args.batch_size / ngpus_per_node)
+                args.workers = int((args.workers + ngpus_per_node - 1) / ngpus_per_node)
+                model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu])
+            else:
+                model.cuda()
+                # DistributedDataParallel will divide and allocate batch_size to all
+                # available GPUs if device_ids are not set
+                model = torch.nn.parallel.DistributedDataParallel(model)
+        elif args.gpu is not None:
+            torch.cuda.set_device(args.gpu)
+            model = model.cuda(args.gpu)
+        elif (not args.do_ptq):
+            # DataParallel will divide and allocate batch_size to all available GPUs
+            if args.arch.startswith('alexnet') or args.arch.startswith('vgg'):
+                model.features = torch.nn.DataParallel(model.features)
+                model.cuda()
+            else:
+                model = torch.nn.DataParallel(model).cuda()
+
+    if args.pretrained is not None:
+        model_orig = model.module if isinstance(model, (torch.nn.parallel.DistributedDataParallel, torch.nn.parallel.DataParallel)) else model
+        model_orig = model_orig.module if args.quantize else model_orig
+        print("=> using pre-trained model for {} from {}".format(args.arch, args.pretrained))
+        if hasattr(model_orig, 'load_weights'):
+            model_orig.load_weights(args.pretrained, download_root='./data/downloads')
+        else:
+            xnn.utils.load_weights(model_orig, args.pretrained, download_root='./data/downloads')
+        #
+
+    if args.quantize:
+        model_quant = model.module if isinstance(model, (torch.nn.parallel.DistributedDataParallel, torch.nn.parallel.DataParallel)) else model
+        if hasattr(model_quant, 'fuse_model') and callable(model_quant.fuse_model):
+            model_quant.fuse_model()
+
+        if hasattr(model_quant, 'prepare') and callable(model_quant.prepare):
+            model_quant.prepare()
+
+    # define loss function (criterion) and optimizer
+    criterion = nn.CrossEntropyLoss().cuda(args.gpu)
+
+    if not args.do_ptq:
+        optimizer = torch.optim.SGD(model.parameters(), args.lr,
+                                    momentum=args.momentum,
+                                    weight_decay=args.weight_decay)
+    else:
+        optimizer = None
+
+    # optionally resume from a checkpoint
+    if args.resume:
+        if os.path.isfile(args.resume):
+            print("=> loading checkpoint '{}'".format(args.resume))
+            if args.gpu is None:
+                checkpoint = torch.load(args.resume)
+            else:
+                # Map model to be loaded to specified single gpu.
+                loc = 'cuda:{}'.format(args.gpu)
+                checkpoint = torch.load(args.resume, map_location=loc)
+            args.start_epoch = checkpoint['epoch']
+            best_acc1 = checkpoint['best_acc1']
+            if args.gpu is not None:
+                # best_acc1 may be from a checkpoint from a different GPU
+                best_acc1 = best_acc1.to(args.gpu)
+            model.load_state_dict(checkpoint['state_dict'])
+            optimizer.load_state_dict(checkpoint['optimizer'])
+            print("=> loaded checkpoint '{}' (epoch {})"
+                  .format(args.resume, checkpoint['epoch']))
+        else:
+            print("=> no checkpoint found at '{}'".format(args.resume))
+
+    cudnn.benchmark = True
+
+    # Data loading code
+    traindir = os.path.join(args.data, 'train')
+    valdir = os.path.join(args.data, 'val')
+    normalize = xvision.transforms.NormalizeMeanScale(mean=[123.675, 116.28, 103.53], scale=[0.017125, 0.017507, 0.017429])
+
+    train_dataset = datasets.ImageFolder(
+        traindir,
+        transforms.Compose([
+            transforms.RandomResizedCrop(224),
+            transforms.RandomHorizontalFlip(),
+            xvision.transforms.ToFloat(),   # converting to float avoids the division by 255 in ToTensor()
+            transforms.ToTensor(),
+            normalize,
+        ]))
+
+    val_dataset = datasets.ImageFolder(valdir, transforms.Compose([
+        transforms.Resize(256),
+        transforms.CenterCrop(224),
+        xvision.transforms.ToFloat(),  # converting to float avoids the division by 255 in ToTensor()
+        transforms.ToTensor(),
+        normalize,
+    ]))
+
+    if args.distributed:
+        train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset)
+        val_sampler = torch.utils.data.distributed.DistributedSampler(val_dataset)
+    else:
+        train_sampler = get_dataset_sampler(train_dataset, args.epoch_size) if args.epoch_size != 0 else None
+        val_sampler = get_dataset_sampler(val_dataset, args.epoch_size_val) if args.epoch_size_val != 0 else None
+
+    train_loader = torch.utils.data.DataLoader(
+        train_dataset, batch_size=args.batch_size, shuffle=(train_sampler is None),
+        num_workers=args.workers, pin_memory=True, sampler=train_sampler)
+
+    val_loader = torch.utils.data.DataLoader(
+        val_dataset, batch_size=args.batch_size, shuffle=(val_sampler is None),
+        num_workers=args.workers, pin_memory=True, sampler=val_sampler)
+
+    ##validate(val_loader, model, criterion, args)
+
+    if args.evaluate:
+        return
+
+    for epoch in range(args.start_epoch, args.epochs):
+        model_orig = model.module if isinstance(model,
+            (torch.nn.parallel.DistributedDataParallel, torch.nn.parallel.DataParallel)) else model
+
+        if args.distributed:
+            train_sampler.set_epoch(epoch)
+        adjust_learning_rate(optimizer, epoch, args)
+
+        # distill mode needs learning rate inside the model.
+        if hasattr(model_orig, 'set_learning_rate') and callable(model_orig.set_learning_rate):
+            model_orig.set_learning_rate(learning_rate=args.cur_lr)
+        #
+
+        # train for one epoch
+        train(train_loader, model, criterion, optimizer, epoch, args)
+
+        # evaluate on validation set
+        acc1 = validate(val_loader, model, criterion, args)
+
+        # remember best acc@1 and save checkpoint
+        is_best = acc1 > best_acc1
+        best_acc1 = max(acc1, best_acc1)
+
+        model_orig = model_orig.module if args.quantize else model_orig
+        if not args.multiprocessing_distributed or (args.multiprocessing_distributed
+                and args.rank % ngpus_per_node == 0):
+            out_basename = args.arch + ('_checkpoint_quantized.pth' if args.quantize else '_checkpoint.pth')
+            save_filename = os.path.join(args.save_path, out_basename)
+            checkpoint_dict = {
+                'epoch': epoch + 1,
+                'arch': args.arch,
+                'state_dict': model_orig.state_dict(),
+                'best_acc1': best_acc1,
+                'optimizer': (optimizer.state_dict() if optimizer is not None else None),
+            }
+            save_checkpoint(checkpoint_dict, is_best, filename=save_filename)
+            # onnx model cannot be exported for torch quantization mode
+            if not args.quantize_torch:
+                save_onnxname = os.path.splitext(save_filename)[0]+'.onnx'
+                write_onnx_model(args, model_orig, is_best, filename=save_onnxname)
+            #
+
+    if args.quantize and hasattr(model, 'convert') and callable(model.convert):
+        model_quant.eval()
+        model_quant.cpu()
+        model_quant.convert()
+        if hasattr(model_quant, 'export') and callable(model_quant.export):
+            model_quant = model_quant.export(dummy_input)
+        else:
+            model_quant = torch.jit.trace(model_quant, dummy_input)
+        #
+
+        print('########################################')
+        print(model_quant)
+
+        print('########################################')
+        print(model_quant.graph)
+
+        print('########################################')
+        print('saving the quantized model')
+        out_basename = args.arch + ('_model_quantized.pth' if args.quantize else '_model.pth')
+        save_filename = os.path.join(args.save_path, out_basename)
+        model_quant.save(save_filename)
+
+        print('########################################')
+        # evaluate on validation set
+        acc1 = validate(val_loader, model_quant, criterion, args)
+        print('########################################')
+
+
+def train(train_loader, model, criterion, optimizer, epoch, args):
+    batch_time = AverageMeter('Time', ':6.3f')
+    data_time = AverageMeter('Data', ':6.3f')
+    losses = AverageMeter('Loss', ':.4e')
+    top1 = AverageMeter('Acc@1', ':6.2f')
+    top5 = AverageMeter('Acc@5', ':6.2f')
+    progress = ProgressMeter(
+        len(train_loader),
+        [batch_time, data_time, losses, top1, top5],
+        prefix="Epoch: [{}]".format(epoch))
+
+    # switch to train mode
+    model.train()
+
+    # freeze the quantization params and bn
+    if args.quantize:
+        if epoch > 2 and epoch > ((args.epochs//2)-1):
+            xnn.utils.print_once('Freezing BN for subseq epochs')
+            # model.apply(torch.nn.intrinsic.qat.freeze_bn_stats)
+            # this is a more generic version
+            xnn.utils.freeze_bn(model)
+        #
+        if epoch > 4 and epoch >= ((args.epochs//2)+1):
+            xnn.utils.print_once('Freezing range observer for subseq epochs')
+            # model.apply(torch.quantization.disable_observer)
+            # this is a more generic version
+            xnn.layers.freeze_quant_range(model)
+        #
+    #
+
+    end = time.time()
+    for i, (images, target) in enumerate(train_loader):
+        # measure data loading time
+        data_time.update(time.time() - end)
+
+        if args.use_gpu:
+            # GPU/CUDA is not yet support for Torch quantization
+            images = images.cuda(args.gpu, non_blocking=True)
+            target = target.cuda(args.gpu, non_blocking=True)
+        #
+
+        # compute output
+        output = model(images)
+        loss = criterion(output, target)
+
+        # measure accuracy and record loss
+        acc1, acc5 = accuracy(output, target, topk=(1, 5))
+        losses.update(loss.item(), images.size(0))
+        top1.update(acc1[0], images.size(0))
+        top5.update(acc5[0], images.size(0))
+
+        if not args.do_ptq:
+            # compute gradient and do SGD step
+            optimizer.zero_grad()
+            loss.backward()
+            optimizer.step()
+        #
+
+        # measure elapsed time
+        batch_time.update(time.time() - end)
+        end = time.time()
+
+        if i % args.print_freq == 0:
+            progress.display(i, args.cur_lr)
+
+
+def validate(val_loader, model, criterion, args):
+    batch_time = AverageMeter('Time', ':6.3f')
+    losses = AverageMeter('Loss', ':.4e')
+    top1 = AverageMeter('Acc@1', ':6.2f')
+    top5 = AverageMeter('Acc@5', ':6.2f')
+    progress = ProgressMeter(
+        len(val_loader),
+        [batch_time, losses, top1, top5],
+        prefix='Test: ')
+
+    # switch to evaluate mode
+    model.eval()
+
+    with torch.no_grad():
+        end = time.time()
+        for i, (images, target) in enumerate(val_loader):
+            if args.use_gpu:
+                # GPU/CUDA is not yet support for Torch quantization
+                images = images.cuda(args.gpu, non_blocking=True)
+                target = target.cuda(args.gpu, non_blocking=True)
+            #
+
+            # compute output
+            output = model(images)
+            loss = criterion(output, target)
+
+            # measure accuracy and record loss
+            acc1, acc5 = accuracy(output, target, topk=(1, 5))
+            losses.update(loss.item(), images.size(0))
+            top1.update(acc1[0], images.size(0))
+            top5.update(acc5[0], images.size(0))
+
+            # measure elapsed time
+            batch_time.update(time.time() - end)
+            end = time.time()
+
+            if i % args.print_freq == 0:
+                progress.display(i, args.cur_lr)
+
+        # TODO: this should also be done with the ProgressMeter
+        print(' * Acc@1 {top1.avg:.3f} Acc@5 {top5.avg:.3f}'
+              .format(top1=top1, top5=top5))
+
+    return top1.avg
+
+
+def save_checkpoint(state, is_best, filename='checkpoint.pth'):
+    dirname = os.path.dirname(filename)
+    xnn.utils.makedir_exist_ok(dirname)
+    torch.save(state, filename)
+    if is_best:
+        shutil.copyfile(filename, os.path.splitext(filename)[0]+'_best.pth')
+
+
+def create_rand_inputs(is_cuda):
+    dummy_input = torch.rand((1, 3, 224, 224))
+    dummy_input = dummy_input.cuda() if is_cuda else dummy_input
+    return dummy_input
+
+
+def write_onnx_model(args, model, is_best, filename='checkpoint.onnx'):
+    model.eval()
+    is_cuda = next(model.parameters()).is_cuda
+    dummy_input = create_rand_inputs(is_cuda)
+    torch.onnx.export(model, dummy_input, filename, export_params=True, verbose=False,
+                      do_constant_folding=True, opset_version=args.opset_version)
+    if is_best:
+        shutil.copyfile(filename, os.path.splitext(filename)[0]+'_best.onnx')
+
+
+class AverageMeter(object):
+    """Computes and stores the average and current value"""
+    def __init__(self, name, fmt=':f'):
+        self.name = name
+        self.fmt = fmt
+        self.reset()
+
+    def reset(self):
+        self.val = 0
+        self.avg = 0
+        self.sum = 0
+        self.count = 0
+
+    def update(self, val, n=1):
+        self.val = val
+        self.sum += val * n
+        self.count += n
+        self.avg = self.sum / self.count
+
+    def __str__(self):
+        fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})'
+        return fmtstr.format(**self.__dict__)
+
+
+class ProgressMeter(object):
+    def __init__(self, num_batches, meters, prefix=""):
+        self.batch_fmtstr = self._get_batch_fmtstr(num_batches)
+        self.lr_fmtstr = self._get_lr_fmtstr()
+        self.meters = meters
+        self.prefix = prefix
+
+    def display(self, batch, cur_lr):
+        entries = [self.prefix + self.batch_fmtstr.format(batch), self.lr_fmtstr.format(cur_lr)]
+        entries += [str(meter) for meter in self.meters]
+        print('\t'.join(entries))
+
+    def _get_batch_fmtstr(self, num_batches):
+        num_digits = len(str(num_batches // 1))
+        fmt = '{:' + str(num_digits) + 'd}'
+        return '[' + fmt + '/' + fmt.format(num_batches) + ']'
+
+    def _get_lr_fmtstr(self):
+        fmt = 'LR {:g}'
+        return fmt
+
+def adjust_learning_rate(optimizer, epoch, args):
+    """Sets the learning rate to the initial LR decayed by 10 every 30 epochs"""
+    lr = args.lr * (0.1 ** (epoch // args.lr_step_size))
+    args.cur_lr = lr
+    if not args.do_ptq:
+        for param_group in optimizer.param_groups:
+            param_group['lr'] = lr
+
+
+def accuracy(output, target, topk=(1,)):
+    """Computes the accuracy over the k top predictions for the specified values of k"""
+    with torch.no_grad():
+        maxk = max(topk)
+        batch_size = target.size(0)
+
+        _, pred = output.topk(maxk, 1, True, True)
+        pred = pred.t()
+        correct = pred.eq(target.view(1, -1).expand_as(pred)).contiguous()
+
+        res = []
+        for k in topk:
+            correct_k = correct[:k].view(-1).float().sum(0, keepdim=True)
+            res.append(correct_k.mul_(100.0 / batch_size))
+        return res
+
+
+def get_dataset_sampler(dataset_object, epoch_size):
+    num_samples = len(dataset_object)
+    epoch_size = num_samples if (epoch_size == 0) else \
+        (int(epoch_size * num_samples) if epoch_size < 1.0 else int(epoch_size))
+    print('=> creating a random sampler as epoch_size is specified')
+    dataset_sampler = torch.utils.data.sampler.RandomSampler(data_source=dataset_object, replacement=True, num_samples=epoch_size)
+    return dataset_sampler
+
+
+if __name__ == '__main__':
+    main()
\ No newline at end of file
index 6c8362253ef907abb9e1925dc591e5a6526c7300..2b2f864a64f08ff428c0be80a6a7df32429892ac 100644 (file)
@@ -1,3 +1,31 @@
+# Copyright (c) 2018-2021, Texas Instruments
+# All Rights Reserved.
+#
+# Redistribution and use in source and binary forms, with or without
+# modification, are permitted provided that the following conditions are met:
+#
+# * Redistributions of source code must retain the above copyright notice, this
+#   list of conditions and the following disclaimer.
+#
+# * Redistributions in binary form must reproduce the above copyright notice,
+#   this list of conditions and the following disclaimer in the documentation
+#   and/or other materials provided with the distribution.
+#
+# * Neither the name of the copyright holder nor the names of its
+#   contributors may be used to endorse or promote products derived from
+#   this software without specific prior written permission.
+#
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
+# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
+# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
+# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
+# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
+# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
+# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+
 import os
 import torch
 import datetime
@@ -5,7 +33,7 @@ import torchvision as xvision
 # from pytorch_jacinto_ai import xvision
 
 # dependencies
-# Anaconda Python 3.7 for Linux - download and install from: https://www.anaconda.com/distribution/
+# Python 3.7 (might work in other versions as well)
 # pytorch, torchvision - install using: 
 # conda install pytorch torchvision -c pytorch
 
index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..ddbc3696e7fe90db558297392e49e282888e4494 100644 (file)
@@ -0,0 +1,27 @@
+# Copyright (c) 2018-2021, Texas Instruments
+# All Rights Reserved.
+#
+# Redistribution and use in source and binary forms, with or without
+# modification, are permitted provided that the following conditions are met:
+#
+# * Redistributions of source code must retain the above copyright notice, this
+#   list of conditions and the following disclaimer.
+#
+# * Redistributions in binary form must reproduce the above copyright notice,
+#   this list of conditions and the following disclaimer in the documentation
+#   and/or other materials provided with the distribution.
+#
+# * Neither the name of the copyright holder nor the names of its
+#   contributors may be used to endorse or promote products derived from
+#   this software without specific prior written permission.
+#
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
+# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
+# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
+# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
+# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
+# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
+# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
index e9b727df5a2adc76e912d02ecdc48ccf3219c4de..3d0f773f515ff2cf08f044de59ede003f889fdc4 100644 (file)
@@ -1,3 +1,31 @@
+# Copyright (c) 2018-2021, Texas Instruments
+# All Rights Reserved.
+#
+# Redistribution and use in source and binary forms, with or without
+# modification, are permitted provided that the following conditions are met:
+#
+# * Redistributions of source code must retain the above copyright notice, this
+#   list of conditions and the following disclaimer.
+#
+# * Redistributions in binary form must reproduce the above copyright notice,
+#   this list of conditions and the following disclaimer in the documentation
+#   and/or other materials provided with the distribution.
+#
+# * Neither the name of the copyright holder nor the names of its
+#   contributors may be used to endorse or promote products derived from
+#   this software without specific prior written permission.
+#
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
+# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
+# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
+# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
+# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
+# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
+# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+
 import os
 import numpy as np
 import torch
@@ -13,7 +41,7 @@ def shape_as_string(shape=[]):
 
 
 def write_tensor_int(m=[], tensor=[], suffix='op', bitwidth=8, power2_scaling=True, file_format='bin',
-                     rnd_type='rnd_sym'):
+                     rnd_type='rnd_sym', force_data_type=None,save_path=None):
     mn = tensor.min()
     mx = tensor.max()
 
@@ -21,7 +49,8 @@ def write_tensor_int(m=[], tensor=[], suffix='op', bitwidth=8, power2_scaling=Tr
         '{:6}, {:32}, {:10}, {:7.2f}, {:7.2f}'.format(suffix, m.name, m.__class__.__name__, tensor.min(), tensor.max()),
         end=" ")
 
-    [tensor_scale, clamp_limits] = xnn.quantize.compute_tensor_scale(tensor, mn, mx, bitwidth, power2_scaling)
+    [tensor_scale, clamp_limits, tensor_signed] = xnn.utils.compute_tensor_scale(tensor, mn, mx, bitwidth, power2_scaling, force_data_type=force_data_type)
+    #tensor_signed = min(mn, mx) < 0
     print("{:30} : {:15} : {:8.2f}".format(str(tensor.shape), str(tensor.dtype), tensor_scale), end=" ")
 
     print_weight_bias = False
@@ -33,33 +62,41 @@ def write_tensor_int(m=[], tensor=[], suffix='op', bitwidth=8, power2_scaling=Tr
             print("tensor_scale: ", tensor_scale)
             print(tensor[no_idx])
         if tensor.dtype != torch.int64:
-            tensor = xnn.quantize.symmetric_round_tensor(tensor * tensor_scale)
+            tensor = xnn.utils.symmetric_round_tensor(tensor * tensor_scale)
         if suffix == 'weight' and print_weight_bias:
             print(tensor[no_idx])
     else:
         # for activation use HW friendly rounding
         if tensor.dtype != torch.int64:
-            tensor = xnn.quantize.upward_round_tensor(tensor * tensor_scale)
+            tensor = xnn.utils.upward_round_tensor(tensor * tensor_scale)
     tensor = tensor.clamp(clamp_limits[0], clamp_limits[1]).float()
 
-    if bitwidth == 8:
+    if bitwidth == 8 and tensor_signed:
         data_type = np.int8
-    elif bitwidth == 16:
+        str_data_type = 'int8'
+    elif bitwidth == 16 and tensor_signed:
         data_type = np.int16
-    elif bitwidth == 32:
+        str_data_type = 'int16'
+    elif bitwidth == 32 and tensor_signed:
         data_type = np.int32
+        str_data_type = 'int32'
+    elif bitwidth == 8 and not tensor_signed:
+        data_type = np.uint8
+        str_data_type = 'uint8'
+    elif bitwidth == 16 and not tensor_signed:
+        data_type = np.uint16
+        str_data_type = 'uint16'
+    elif bitwidth == 32 and not tensor_signed:
+        data_type = np.uint32
+        str_data_type = 'uint32'
     else:
         exit("Bit width other 8,16,32 not supported for writing layer level op")
 
     tensor = tensor.cpu().numpy().astype(data_type)
 
     print("{:7} : {:7d} : {:7d}".format(str(tensor.dtype), tensor.min(), tensor.max()))
-
     root = os.getcwd()
-    tensor_dir = root + '/checkpoints/debug/test_vecs/' + '{}_{}_{}_scale_{:010.4f}'.format(m.name,
-                                                                                            m.__class__.__name__,
-                                                                                            suffix, tensor_scale)
-
+    tensor_dir = os.path.join(root, save_path, '{}_{}_{}_{}_scale_{:010.4f}'.format(m.name, m.__class__.__name__, suffix, str_data_type, tensor_scale))
     if not os.path.exists(tensor_dir):
         os.makedirs(tensor_dir)
 
@@ -71,9 +108,10 @@ def write_tensor_int(m=[], tensor=[], suffix='op', bitwidth=8, power2_scaling=Tr
         np.save(tensor_name, tensor)
 
     # utils_hist.comp_hist_tensor3d(x=tensor, name=m.name, en=True, dir=m.name, log=True, ch_dim=0)
+    return tensor_scale
 
 
-def write_tensor_float(m=[], tensor=[], suffix='op'):
+def write_tensor_float(m=[], tensor=[], suffix='op',save_path=None):
     mn = tensor.min()
     mx = tensor.max()
 
@@ -90,40 +128,37 @@ def write_tensor_float(m=[], tensor=[], suffix='op'):
 
 
 def write_tensor(data_type='int', m=[], tensor=[], suffix='op', bitwidth=8, power2_scaling=True, file_format='bin',
-                 rnd_type='rnd_sym'):
+                 rnd_type='rnd_sym', force_data_type=None,save_path=None):
     if data_type == 'int':
-        write_tensor_int(m=m, tensor=tensor, suffix=suffix, rnd_type=rnd_type, file_format=file_format)
+        tensor_scale = write_tensor_int(m=m, tensor=tensor, suffix=suffix, rnd_type=rnd_type, bitwidth=bitwidth,
+            file_format=file_format,force_data_type=force_data_type,save_path=save_path)
     elif data_type == 'float':
-        write_tensor_float(m=m, tensor=tensor, suffix=suffix)
-
-
-enable_hook_function = True
-
+        write_tensor_float(m=m, tensor=tensor, suffix=suffix,save_path=save_path)
+    return tensor_scale
 
-def write_tensor_hook_function(m, inp, out, file_format='none'):
-    if not enable_hook_function:
-        return
+def write_tensor_hook_function(m, inp, out, save_path=None, file_format='bin'):
 
     # Output
     if isinstance(out, (torch.Tensor)):
-        write_tensor(m=m, tensor=out, suffix='op', rnd_type='rnd_up', file_format=file_format)
+        tensor_scale_op = write_tensor(m=m, tensor=out, suffix='op', rnd_type='rnd_up', file_format=file_format, save_path=save_path)
 
     # Input(s)
     if type(inp) is tuple:
         # if there are more than 1 inputs
         for index, sub_ip in enumerate(inp[0]):
             if isinstance(sub_ip, (torch.Tensor)):
-                write_tensor(m=m, tensor=sub_ip, suffix='ip_{}'.format(index), rnd_type='rnd_up',
-                             file_format=file_format)
+                tensor_scale_ip = write_tensor(m=m, tensor=sub_ip, suffix='ip_{}'.format(index), rnd_type='rnd_up',
+                             file_format=file_format, save_path=save_path)
     elif isinstance(inp, (torch.Tensor)):
-        write_tensor(m=m, tensor=inp, suffix='ip', rnd_type='rnd_up', file_format=file_format)
+        tensor_scale_ip = write_tensor(m=m, tensor=inp, suffix='ip', rnd_type='rnd_up', file_format=file_format, save_path=save_path)
 
     # weights
     if hasattr(m, 'weight'):
         if isinstance(m.weight, torch.Tensor):
-            write_tensor(m=m, tensor=m.weight, suffix='weight', rnd_type='rnd_sym', file_format=file_format)
+            tensor_scale_wt = write_tensor(m=m, tensor=m.weight, suffix='weight', rnd_type='rnd_sym', file_format=file_format, save_path=save_path)
 
     # bias
     if hasattr(m, 'bias'):
         if m.bias is not None:
-            write_tensor(m=m, tensor=m.bias, suffix='bias', rnd_type='rnd_sym', file_format=file_format)
+            tensor_scale_bias = write_tensor(m=m, tensor=m.bias, suffix='bias', rnd_type='rnd_sym', bitwidth=16,
+                force_data_type = 'signed', file_format=file_format, save_path=save_path)
index 5b10a02a390a402115d00ec987139417300e4c8d..78b5fc9fe3d4af081d0fe1cd19db55d917175e42 100644 (file)
@@ -1,3 +1,31 @@
+# Copyright (c) 2018-2021, Texas Instruments
+# All Rights Reserved.
+#
+# Redistribution and use in source and binary forms, with or without
+# modification, are permitted provided that the following conditions are met:
+#
+# * Redistributions of source code must retain the above copyright notice, this
+#   list of conditions and the following disclaimer.
+#
+# * Redistributions in binary form must reproduce the above copyright notice,
+#   this list of conditions and the following disclaimer in the documentation
+#   and/or other materials provided with the distribution.
+#
+# * Neither the name of the copyright holder nor the names of its
+#   contributors may be used to endorse or promote products derived from
+#   this software without specific prior written permission.
+#
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
+# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
+# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
+# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
+# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
+# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
+# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+
 import sys
 import torch
 
index 2249e25755e94be08e2f9623588428ca68857d03..7e6a675c148f03340cf9465e323a6965c5f1b346 100644 (file)
@@ -1,3 +1,31 @@
+# Copyright (c) 2018-2021, Texas Instruments
+# All Rights Reserved.
+#
+# Redistribution and use in source and binary forms, with or without
+# modification, are permitted provided that the following conditions are met:
+#
+# * Redistributions of source code must retain the above copyright notice, this
+#   list of conditions and the following disclaimer.
+#
+# * Redistributions in binary form must reproduce the above copyright notice,
+#   this list of conditions and the following disclaimer in the documentation
+#   and/or other materials provided with the distribution.
+#
+# * Neither the name of the copyright holder nor the names of its
+#   contributors may be used to endorse or promote products derived from
+#   this software without specific prior written permission.
+#
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
+# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
+# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
+# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
+# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
+# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
+# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+
 import os
 import sys
 import shutil
@@ -339,7 +367,7 @@ def write_tensor_int(m=[], tensor=[], suffix='op', bitwidth=8, power2_scaling=Tr
         '{:6}, {:32}, {:10}, {:7.2f}, {:7.2f}'.format(suffix, m.name, m.__class__.__name__, tensor.min(), tensor.max()),
         end=" ")
 
-    [tensor_scale, clamp_limits] = xnn.quantize.compute_tensor_scale(tensor, mn, mx, bitwidth, power2_scaling)
+    [tensor_scale, clamp_limits] = xnn.utils.compute_tensor_scale(tensor, mn, mx, bitwidth, power2_scaling)
     print("{:30} : {:15} : {:8.2f}".format(str(tensor.shape), str(tensor.dtype), tensor_scale), end=" ")
 
     print_weight_bias = False
@@ -351,13 +379,13 @@ def write_tensor_int(m=[], tensor=[], suffix='op', bitwidth=8, power2_scaling=Tr
             print("tensor_scale: ", tensor_scale)
             print(tensor[no_idx])
         if tensor.dtype != torch.int64:
-            tensor = xnn.quantize.symmetric_round_tensor(tensor * tensor_scale)
+            tensor = xnn.utils.symmetric_round_tensor(tensor * tensor_scale)
         if suffix == 'weight' and print_weight_bias:
             print(tensor[no_idx])
     else:
         # for activation use HW friendly rounding
         if tensor.dtype != torch.int64:
-            tensor = xnn.quantize.upward_round_tensor(tensor * tensor_scale)
+            tensor = xnn.utils.upward_round_tensor(tensor * tensor_scale)
     tensor = tensor.clamp(clamp_limits[0], clamp_limits[1]).float()
 
     if bitwidth == 8:
index e56f474f4d5bec64b537ffaf8d4803d2825ffc74..ccc8d0a21eeefdeef5822df59049672bd61c7b2d 100644 (file)
@@ -1,3 +1,31 @@
+# Copyright (c) 2018-2021, Texas Instruments
+# All Rights Reserved.
+#
+# Redistribution and use in source and binary forms, with or without
+# modification, are permitted provided that the following conditions are met:
+#
+# * Redistributions of source code must retain the above copyright notice, this
+#   list of conditions and the following disclaimer.
+#
+# * Redistributions in binary form must reproduce the above copyright notice,
+#   this list of conditions and the following disclaimer in the documentation
+#   and/or other materials provided with the distribution.
+#
+# * Neither the name of the copyright holder nor the names of its
+#   contributors may be used to endorse or promote products derived from
+#   this software without specific prior written permission.
+#
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
+# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
+# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
+# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
+# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
+# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
+# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+
 import os
 import time
 import sys
@@ -34,6 +62,8 @@ def get_config():
 
     args.model = None
     args.model_config = xnn.utils.ConfigNode()
+    args.model_config.enable_fp16 = False               # faster training/inference if the GPU supports fp16
+
     args.model_name = 'deeplabv2lite_mobilenetv2'       # model architecture, overwritten if pretrained is specified
     args.dataset_name = 'flying_chairs'              # dataset type
     args.transforms = None
@@ -44,6 +74,7 @@ def get_config():
 
     args.model_config.output_type = ['flow']                # the network is used to predict flow or depth or sceneflow')
     args.model_config.output_channels = None                 # number of output channels
+    args.model_config.prediction_channels = None        # intermediate number of channels before final output_channels
     args.model_config.input_channels = None                  # number of input channels
     args.model_config.num_classes = None                       # number of classes (for segmentation)
     args.model_config.output_range = None  # max range of output
@@ -116,7 +147,8 @@ def get_config():
     args.learn_scaled_values_interest_pt = True
     args.save_mod_files = False                 # saves modified files after last commit. Also  stores commit id.
     args.gpu_mode = True                        #False will make inference run on CPU
-    args.write_layer_ip_op= False               #True will make it tap inputs outputs for layers
+    args.write_layer_ip_op=False                #True will make it tap inputs outputs for layers
+    args.write_layer_ip_op_names=None           #name of the layers to write out 
     args.file_format = 'none'                   #Ip/Op tapped points for each layer: None : it will not be written but print will still appear
     args.save_onnx = True
     args.remove_ignore_lbls_in_pred = False     #True: if in the pred where GT has ignore label do not visualize for GT visualization
@@ -125,14 +157,13 @@ def get_config():
     args.visualize_gt = False                   #to vis pred or GT
     args.viz_depth_color_type = 'plasma'        #color type for dpeth visualization
     args.depth = [False]
-
     args.palette = None
     args.label_infer = False
     args.viz_op_type = None
     args.car_mask = None
     args.en_accuracy_measurement = True         #enabling accuracy measurement makes whole operation sequential and hence slows down inference significantly.
-
     args.opset_version = 9                      # onnx opset version
+    args.prob_color_to_gray = 0.0               #for color 2 gray augmentation during inference
     return args
 
 
@@ -302,16 +333,8 @@ def main(args):
     model = model.cuda()
 
     #################################################
-    if args.write_layer_ip_op:
-        # for dumping module outputs
-        for name, module in model.named_modules():
-            module.name = name
-            print(name)
-            #if 'module.encoder.features.0.' in name:
-            module.register_forward_hook(write_tensor_hook_function)
-        print('{:7} {:33} {:12} {:8} {:6} {:30} : {:17} : {:4} : {:11} : {:7} : {:7}'.format("type",  "name", "layer", "min", "max", "tensor_shape", "dtype", "scale", "dtype", "min", "max"))
-
-    #################################################
+    assign_write_layer_ip_op_hook(model=model, save_path=save_path, args=args, file_format='npy')
+    
     args.loss_modules = copy.deepcopy(args.losses)
     for task_dx, task_losses in enumerate(args.losses):
         for loss_idx, loss_fn in enumerate(task_losses):
@@ -469,7 +492,7 @@ def validate(args, val_dataset, val_loader, model, epoch, infer_path):
                     print('{}/{}'.format((args.batch_size * iter + index), len(val_dataset)))
                     output_name = os.path.join(infer_path[task_index], seq + '_' + input_path[-1][index].split('/')[-3] + '_' + base_file)
                     output_name_short = os.path.join(infer_path[task_index], os.path.basename(input_path[-1][index]))
-                    wrapper_write_desc(args=args, task_index=task_index, outputs=outputs, index=index, output_name=output_name, output_name_short=output_name_short)
+                    wrapper_write_desc(args=args, target_list=target_list, task_index=task_index, outputs=outputs, index=index, output_name=output_name, output_name_short=output_name_short)
                     
                 if args.model_config.output_type[task_index] is 'depth':
                     output_name = os.path.join(infer_path[task_index], seq + '_' + input_path[-1][index].split('/')[-3] + '_' + base_file)
@@ -480,6 +503,9 @@ def validate(args, val_dataset, val_loader, model, epoch, infer_path):
                     prediction_size = (prediction.shape[0], prediction.shape[1], 3)
                     output_image = args.palette[task_index-1][prediction.ravel()].reshape(prediction_size)
                     input_bgr = cv2.imread(input_path[-1][index]) #Read the actual RGB image
+                    if args.img_border_crop is not None:
+                        t, l, h, w = args.img_border_crop
+                        input_bgr = input_bgr[t:t+h, l:l+w]
                     input_bgr = cv2.resize(input_bgr, dsize=(prediction.shape[1],prediction.shape[0]))
                     output_image = xnn.utils.chroma_blend(input_bgr, output_image)
                     output_name = os.path.join(infer_path[task_index], input_path[-1][index].split('/')[-4] + '_' + input_path[-1][index].split('/')[-3] + '_' +os.path.basename(input_path[-1][index]))
@@ -620,6 +646,9 @@ def viz_depth(prediction = [], args=[], output_name=[], input_name=[]):
         print(output_image.max())
         #output_image[label == 1] = 0
         input_bgr = cv2.imread(input_name)  # Read the actual RGB image
+        if args.img_border_crop is not None:
+            t, l, h, w = args.img_border_crop
+            input_bgr = input_bgr[t:t+h, l:l+w]
         input_bgr = cv2.resize(input_bgr, dsize=(prediction.shape[1], prediction.shape[0]))
         if args.sky_dir:
             label_file = os.path.join(args.sky_dir, seq, seq + '_image_00_' + base_file)
@@ -654,7 +683,7 @@ def viz_depth(prediction = [], args=[], output_name=[], input_name=[]):
         cv2.imwrite(output_name, output_image)
 
 
-def wrapper_write_desc(args=[], task_index=0, outputs=[], index=0, output_name=[], output_name_short=[]):
+def wrapper_write_desc(args=[], target_list=None, task_index=0, outputs=[], index=0, output_name=[], output_name_short=[]):
     if args.write_desc_type == 'GT':
         # write GT desc
         tensor_to_write = target_list[task_index]
@@ -702,6 +731,7 @@ def get_transforms(args):
     args.image_scale = np.array(args.image_scale, dtype=np.float32)
     image_prenorm = xvision.transforms.image_transforms.NormalizeMeanScale(mean=args.image_mean, scale=args.image_scale) if args.image_prenorm else None
     image_postnorm = xvision.transforms.image_transforms.NormalizeMeanScale(mean=args.image_mean, scale=args.image_scale) if (not image_prenorm) else None
+    color_2_gray = xvision.transforms.image_transforms.RandomColor2Gray(is_flow=args.is_flow, random_threshold=  args.prob_color_to_gray[1]) if args.prob_color_to_gray != 0.0 else None
 
     #target size must be according to output_size. prediction will be resized to output_size before evaluation.
     test_transform = xvision.transforms.image_transforms.Compose([
@@ -710,6 +740,7 @@ def get_transforms(args):
         xvision.transforms.image_transforms.MaskTarget(args.target_mask, 0),
         xvision.transforms.image_transforms.CropRect(args.img_border_crop),
         xvision.transforms.image_transforms.Scale(args.img_resize, target_size=args.output_size, is_flow=args.is_flow),
+        color_2_gray,
         image_postnorm,
         xvision.transforms.image_transforms.ConvertToTensor()
         ])
@@ -872,6 +903,30 @@ def write_onnx_model(args, model, save_path, name='checkpoint.onnx'):
                       do_constant_folding=True, opset_version=args.opset_version)
     # torch onnx export does not update names. Do it using onnx.save
 
+def assign_write_layer_ip_op_hook(model=None, save_path=None, args=None, file_format='bin'):
+    if args.write_layer_ip_op:
+        def write_tensor_hook_function_save_path(m, inp, out):
+            write_tensor_hook_function(m, inp, out, save_path=save_path, file_format=file_format)
+
+        # for dumping module outputs
+        for name, module in model.named_modules():
+            module.name = name
+            print(name)
+            
+            en_write_layer = False
+            if args.write_layer_ip_op_names == None:
+                #write all layers
+                en_write_layer = True
+            else:
+                for layer_name_to_write in args.write_layer_ip_op_names:
+                    if layer_name_to_write in name:
+                        en_write_layer = True
+                        break
+            
+            if en_write_layer:
+                module.register_forward_hook(write_tensor_hook_function_save_path)
+                print('{:7} {:33} {:12} {:8} {:6} {:30} : {:17} : {:4} : {:11} : {:7} : {:7}'.format("type",  "name", "layer", "min", "max", "tensor_shape", "dtype", "scale", "dtype", "min", "max"))
+
 
 
 if __name__ == '__main__':
index 4532440c8417cd7b0e838d474a4cfe25a622a7b0..2df0077972dd494d98a558eb5e71c68e8a0b8855 100644 (file)
@@ -1,3 +1,31 @@
+# Copyright (c) 2018-2021, Texas Instruments
+# All Rights Reserved.
+#
+# Redistribution and use in source and binary forms, with or without
+# modification, are permitted provided that the following conditions are met:
+#
+# * Redistributions of source code must retain the above copyright notice, this
+#   list of conditions and the following disclaimer.
+#
+# * Redistributions in binary form must reproduce the above copyright notice,
+#   this list of conditions and the following disclaimer in the documentation
+#   and/or other materials provided with the distribution.
+#
+# * Neither the name of the copyright holder nor the names of its
+#   contributors may be used to endorse or promote products derived from
+#   this software without specific prior written permission.
+#
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
+# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
+# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
+# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
+# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
+# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
+# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+
 import os
 import time
 import sys
index 2a6e8069f8fe04c4dd02ca0c14a22b829dd3b1ec..43866ea1ef63b6fd92d7a6a81e88d4c2f61870c3 100644 (file)
@@ -1,3 +1,31 @@
+# Copyright (c) 2018-2021, Texas Instruments
+# All Rights Reserved.
+#
+# Redistribution and use in source and binary forms, with or without
+# modification, are permitted provided that the following conditions are met:
+#
+# * Redistributions of source code must retain the above copyright notice, this
+#   list of conditions and the following disclaimer.
+#
+# * Redistributions in binary form must reproduce the above copyright notice,
+#   this list of conditions and the following disclaimer in the documentation
+#   and/or other materials provided with the distribution.
+#
+# * Neither the name of the copyright holder nor the names of its
+#   contributors may be used to endorse or promote products derived from
+#   this software without specific prior written permission.
+#
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
+# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
+# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
+# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
+# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
+# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
+# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+
 import os
 import sys
 import shutil
@@ -65,12 +93,13 @@ def get_config():
     args.write_layer_ip_op = False
 
     args.quantize = False                               # apply quantized inference or not
-    #args.model_surgery = None                           # replace activations with PAct2 activation module. Helpful in quantized training.
+    #args.model_surgery = None                          # replace activations with PAct2 activation module. Helpful in quantized training.
     args.bitwidth_weights = 8                           # bitwidth for weights
     args.bitwidth_activations = 8                       # bitwidth for activations
     args.histogram_range = True                         # histogram range for calibration
     args.per_channel_q = False                          # apply separate quantizion factor for each channel in depthwise or not
-    args.bias_calibration = False                        # apply bias correction during quantized inference calibration
+    args.bias_calibration = False                       # apply bias correction during quantized inference calibration
+    args.constrain_bias = None                          # constrain bias according to the constraints of convolution engine
 
     args.opset_version = 9                              # onnx opset version
     return args
@@ -151,17 +180,18 @@ def main(args):
         if 'training' in args.phase:
             model = xnn.quantize.QuantTrainModule(model, per_channel_q=args.per_channel_q,
                         bitwidth_weights=args.bitwidth_weights, bitwidth_activations=args.bitwidth_activations,
-                        dummy_input=dummy_input)
+                        constrain_bias=args.constrain_bias, dummy_input=dummy_input)
         elif 'calibration' in args.phase:
             model = xnn.quantize.QuantCalibrateModule(model, per_channel_q=args.per_channel_q,
                         bitwidth_weights=args.bitwidth_weights, bitwidth_activations=args.bitwidth_activations,
-                        bias_calibration=args.bias_calibration, dummy_input=dummy_input, lr_calib=args.lr_calib)
+                        bias_calibration=args.bias_calibration, constrain_bias=args.constrain_bias,
+                        dummy_input=dummy_input, lr_calib=args.lr_calib)
         elif 'validation' in args.phase:
             # Note: bias_calibration is not enabled in test
             model = xnn.quantize.QuantTestModule(model, per_channel_q=args.per_channel_q,
                         bitwidth_weights=args.bitwidth_weights, bitwidth_activations=args.bitwidth_activations,
-                        histogram_range=args.histogram_range, dummy_input=dummy_input,
-                        model_surgery_quantize=model_surgery_quantize)
+                        histogram_range=args.histogram_range, constrain_bias=args.constrain_bias,
+                        dummy_input=dummy_input, model_surgery_quantize=model_surgery_quantize)
         else:
             assert False, f'invalid phase {args.phase}'
     #
@@ -419,7 +449,7 @@ def write_tensor_int(m=[], tensor=[], suffix='op', bitwidth=8, power2_scaling=Tr
         '{:6}, {:32}, {:10}, {:7.2f}, {:7.2f}'.format(suffix, m.name, m.__class__.__name__, tensor.min(), tensor.max()),
         end=" ")
 
-    [tensor_scale, clamp_limits] = xnn.quantize.compute_tensor_scale(tensor, mn, mx, bitwidth, power2_scaling)
+    [tensor_scale, clamp_limits] = xnn.utils.compute_tensor_scale(tensor, mn, mx, bitwidth, power2_scaling)
     print("{:30} : {:15} : {:8.2f}".format(str(tensor.shape), str(tensor.dtype), tensor_scale), end=" ")
 
     print_weight_bias = False
@@ -431,13 +461,13 @@ def write_tensor_int(m=[], tensor=[], suffix='op', bitwidth=8, power2_scaling=Tr
             print("tensor_scale: ", tensor_scale)
             print(tensor[no_idx])
         if tensor.dtype != torch.int64:
-            tensor = xnn.quantize.symmetric_round_tensor(tensor * tensor_scale)
+            tensor = xnn.utils.symmetric_round_tensor(tensor * tensor_scale)
         if suffix == 'weight' and print_weight_bias:
             print(tensor[no_idx])
     else:
         # for activation use HW friendly rounding
         if tensor.dtype != torch.int64:
-            tensor = xnn.quantize.upward_round_tensor(tensor * tensor_scale)
+            tensor = xnn.utils.upward_round_tensor(tensor * tensor_scale)
     tensor = tensor.clamp(clamp_limits[0], clamp_limits[1]).float()
 
     if bitwidth == 8:
@@ -498,7 +528,7 @@ def write_tensor(data_type='int', m=[], tensor=[], suffix='op', bitwidth=8, powe
 
 
 enable_hook_function = True
-def write_tensor_hook_function(m, inp, out, file_format='bin'):
+def write_tensor_hook_function(m, inp, out, file_format='npy'):
     if not enable_hook_function:
         return
 
index 5dde0194d62b5a9e1e6773e7b46a4408901bc2cd..dfe2f831c49431fd8df9ded335dfbcd22a65c5a3 100644 (file)
@@ -1,3 +1,31 @@
+# Copyright (c) 2018-2021, Texas Instruments
+# All Rights Reserved.
+#
+# Redistribution and use in source and binary forms, with or without
+# modification, are permitted provided that the following conditions are met:
+#
+# * Redistributions of source code must retain the above copyright notice, this
+#   list of conditions and the following disclaimer.
+#
+# * Redistributions in binary form must reproduce the above copyright notice,
+#   this list of conditions and the following disclaimer in the documentation
+#   and/or other materials provided with the distribution.
+#
+# * Neither the name of the copyright holder nor the names of its
+#   contributors may be used to endorse or promote products derived from
+#   this software without specific prior written permission.
+#
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
+# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
+# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
+# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
+# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
+# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
+# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+
 import os
 import time
 import sys
index bc208a4f6e4d0d88b69ed9cd589b898713d2d7bf..9bfde8e6b7bcfc440167b5d69862096917671d0d 100644 (file)
@@ -1,3 +1,31 @@
+# Copyright (c) 2018-2021, Texas Instruments
+# All Rights Reserved.
+#
+# Redistribution and use in source and binary forms, with or without
+# modification, are permitted provided that the following conditions are met:
+#
+# * Redistributions of source code must retain the above copyright notice, this
+#   list of conditions and the following disclaimer.
+#
+# * Redistributions in binary form must reproduce the above copyright notice,
+#   this list of conditions and the following disclaimer in the documentation
+#   and/or other materials provided with the distribution.
+#
+# * Neither the name of the copyright holder nor the names of its
+#   contributors may be used to endorse or promote products derived from
+#   this software without specific prior written permission.
+#
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
+# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
+# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
+# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
+# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
+# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
+# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+
 import os
 import shutil
 import time
@@ -32,31 +60,37 @@ def get_config():
     args = xnn.utils.ConfigNode()
     args.model_config = xnn.utils.ConfigNode()
     args.dataset_config = xnn.utils.ConfigNode()
+    args.model_config.input_channels = 3                # num input channels
+    args.model_config.output_type = 'classification'
+    args.model_config.output_channels = None
+    args.model_config.strides = None                    # (2,2,2,2,2)
     args.model_config.num_tiles_x = int(1)
     args.model_config.num_tiles_y = int(1)
     args.model_config.en_make_divisible_by8 = True
-    args.model_config.input_channels = 3                # num input channels
+    args.model_config.enable_fp16 = False               # FP16 half precision mode
 
     args.input_channel_reverse = False                  # rgb to bgr
     args.data_path = './data/datasets/ilsvrc'           # path to dataset
     args.model_name = 'mobilenetv2_tv_x1'     # model architecture'
     args.model = None                                   #if mdoel is crated externaly 
     args.dataset_name = 'imagenet_classification'       # image folder classification
+    args.transforms = None                              # the transforms itself can be given from outside
     args.save_path = None                               # checkpoints save path
     args.phase = 'training'                             # training/calibration/validation
     args.date = None                                    # date to add to save path. if this is None, current date will be added.
 
-    args.workers =                                    # number of data loading workers (default: 8)
+    args.workers = 12                                   # number of data loading workers (default: 8)
     args.logger = None                                  # logger stream to output into
 
-    args.epochs = 90                                    # number of total epochs to run
-    args.warmup_epochs = None                           # number of epochs to warm up by linearly increasing lr
+    args.epochs = 150                                   # number of total epochs to run: recommended 100 or 150
+    args.warmup_epochs = 5                              # number of epochs to warm up by linearly increasing lr
+    args.warmup_factor = 1e-3                           # max lr allowed for the first epoch during warmup (as a factor of initial lr)
 
     args.epoch_size = 0                                 # fraction of training epoch to use each time. 0 indicates full
     args.epoch_size_val = 0                             # manual epoch size (will match dataset size if not specified)
     args.start_epoch = 0                                # manual epoch number to start
     args.stop_epoch = None                              # manual epoch number to stop
-    args.batch_size = 256                               # mini_batch size (default: 256)
+    args.batch_size = 512                               # mini_batch size (default: 256)
     args.total_batch_size = None                        # accumulated batch size. total_batch_size = batch_size*iter_size
     args.iter_size = 1                                  # iteration size. total_batch_size = batch_size*iter_size
 
@@ -64,7 +98,7 @@ def get_config():
     args.lr_clips = None                                # use args.lr itself if it is None
     args.lr_calib = 0.05                                # lr for bias calibration
     args.momentum = 0.9                                 # momentum
-    args.weight_decay = 1e-4                            # weight decay (default: 1e-4)
+    args.weight_decay = 4e-5                            # weight decay (default: 1e-4)
     args.bias_decay = None                              # bias decay (default: 0.0)
 
     args.shuffle = True                                 # shuffle or not
@@ -78,18 +112,18 @@ def get_config():
     args.dist_url = 'tcp://224.66.41.62:23456'          # url used to set up distributed training
     args.dist_backend = 'gloo'                          # distributed backend
 
-    args.optimizer = 'sgd'                              # solver algorithms, choices=['adam','sgd','sgd_nesterov','rmsprop']
-    args.scheduler = 'step'                             # help='scheduler algorithms, choices=['step','poly','exponential', 'cosine']
+    args.optimizer = 'sgd'                              # optimizer algorithms, choices=['adam','sgd','sgd_nesterov','rmsprop']
+    args.scheduler = 'cosine'                           # help='scheduler algorithms, choices=['step','poly','exponential', 'cosine']
     args.milestones = (30, 60, 90)                      # epochs at which learning rate is divided
     args.multistep_gamma = 0.1                          # multi step gamma (default: 0.1)
     args.polystep_power = 1.0                           # poly step gamma (default: 1.0)
-    args.step_size = 1,                                 # step size for exp lr decay
+    args.step_size = 1                                  # step size for exp lr decay
 
     args.beta = 0.999                                   # beta parameter for adam
     args.pretrained = None                              # path to pre_trained model
     args.img_resize = 256                               # image resize
     args.img_crop = 224                                 # image crop
-    args.rand_scale = (0.08,1.0)                        # random scale range for training
+    args.rand_scale = (0.2,1.0)                         # random scale range for training
     args.data_augument = 'inception'                    # data augumentation method, choices=['inception','resize','adaptive_resize']
     args.count_flops = True                             # count flops and report
 
@@ -110,11 +144,12 @@ def get_config():
     args.histogram_range = True                         # histogram range for calibration
     args.bias_calibration = True                        # apply bias correction during quantized inference calibration
     args.per_channel_q = False                          # apply separate quantizion factor for each channel in depthwise or not
+    args.constrain_bias = None                          # constrain bias according to the constraints of convolution engine
 
     args.freeze_bn = False                              # freeze the statistics of bn
     args.save_mod_files = False                         # saves modified files after last commit. Also  stores commit id.
 
-    args.opset_version =                              # onnx opset_version
+    args.opset_version = 11                             # onnx opset_version
     return args
 
 
@@ -246,12 +281,13 @@ def main(args):
         if 'training' in args.phase:
             model = xnn.quantize.QuantTrainModule(model, per_channel_q=args.per_channel_q,
                         histogram_range=args.histogram_range, bitwidth_weights=args.bitwidth_weights,
-                        bitwidth_activations=args.bitwidth_activations,
+                        bitwidth_activations=args.bitwidth_activations, constrain_bias=args.constrain_bias,
                         dummy_input=dummy_input)
         elif 'calibration' in args.phase:
             model = xnn.quantize.QuantCalibrateModule(model, per_channel_q=args.per_channel_q,
                         bitwidth_weights=args.bitwidth_weights, bitwidth_activations=args.bitwidth_activations,
-                        histogram_range=args.histogram_range, bias_calibration=args.bias_calibration, dummy_input=dummy_input,
+                        histogram_range=args.histogram_range,  constrain_bias=args.constrain_bias,
+                        bias_calibration=args.bias_calibration, dummy_input=dummy_input,
                         lr_calib=args.lr_calib)
         elif 'validation' in args.phase:
             # Note: bias_calibration is not used in test
@@ -278,7 +314,7 @@ def main(args):
         count_flops(args, model)
 
     #################################################
-    if args.save_onnx and (any(p in args.phase for p in ('training','calibration')) or (args.run_soon == False)):
+    if args.save_onnx:
         write_onnx_model(args, get_model_orig(model), save_path)
     #
 
@@ -374,12 +410,14 @@ def main(args):
         close(args)
         return
 
+    grad_scaler = torch.cuda.amp.GradScaler() if args.model_config.enable_fp16 else None
+
     for epoch in range(args.start_epoch, args.stop_epoch):
         if args.distributed:
             train_loader.sampler.set_epoch(epoch)
 
         # train for one epoch
-        train(args, train_loader, model, criterion, optimizer, epoch)
+        train(args, train_loader, model, criterion, optimizer, epoch, grad_scaler)
 
         # evaluate on validation set
         prec1 = validate(args, val_loader, model, criterion, epoch)
@@ -414,6 +452,7 @@ def is_valid_phase(phase):
 
 def close(args):
     if args.logger is not None:
+        args.logger.close()
         del args.logger
         args.logger = None
     #
@@ -467,7 +506,7 @@ def write_onnx_model(args, model, save_path, name='checkpoint.onnx'):
         onnx.save(onnx.shape_inference.infer_shapes(onnx.load(path)), path)
 
 
-def train(args, train_loader, model, criterion, optimizer, epoch):
+def train(args, train_loader, model, criterion, optimizer, epoch, grad_scaler):
     # actual training code
     batch_time = AverageMeter()
     data_time = AverageMeter()
@@ -477,9 +516,16 @@ def train(args, train_loader, model, criterion, optimizer, epoch):
 
     # switch to train mode
     model.train()
-    if args.freeze_bn:
+
+    # freeze bn and range after some epochs during quantization
+    if args.freeze_bn or (args.quantize and epoch > 2 and epoch >= ((args.epochs//2)-1)):
+        xnn.utils.print_once('Freezing BN for subsequent epochs')
         xnn.utils.freeze_bn(model)
     #
+    if (args.quantize and epoch > 4 and epoch >= ((args.epochs//2)+1)):
+        xnn.utils.print_once('Freezing ranges for subsequent epochs')
+        xnn.layers.freeze_quant_range(model)
+    #
 
     num_iters = len(train_loader)
     progress_bar = progiter.ProgIter(np.arange(num_iters), chunksize=1)
@@ -525,14 +571,22 @@ def train(args, train_loader, model, criterion, optimizer, epoch):
         top5.update(prec5[0], input_size[0])
 
         if 'training' in args.phase:
-            # zero gradients so that we can accumulate gradients
-            if (iteration % args.iter_size) == 0:
-                optimizer.zero_grad()
-
-            loss.backward()
+            if args.model_config.enable_fp16:
+                grad_scaler.scale(loss).backward()
+            else:
+                loss.backward()
+            #
 
             if ((iteration+1) % args.iter_size) == 0:
-                optimizer.step()
+                if args.model_config.enable_fp16:
+                    grad_scaler.step(optimizer)
+                    grad_scaler.update()
+                else:
+                    optimizer.step()
+                #
+                # setting grad=None is a faster alternative instead of optimizer.zero_grad()
+                xnn.utils.clear_grad(model)
+            #
         #
 
         # measure elapsed time
@@ -677,17 +731,17 @@ class AverageMeter(object):
 
 def adjust_learning_rate(args, optimizer, epoch):
     """Sets the learning rate to the initial LR decayed by 10 every 30 epochs"""
-    cur_lr = args.cur_lr if hasattr(args, 'cur_lr') else args.lr
+    cur_lr = args.lr
 
-    if (args.warmup_epochs is not None) and (epoch < (args.warmup_epochs-1)):
-        cur_lr = (epoch + 1) * args.lr / args.warmup_epochs
+    if (args.warmup_epochs is not None) and (epoch <= args.warmup_epochs):
+        cur_lr = epoch * args.lr / args.warmup_epochs
+        if epoch == 0 and args.warmup_factor is not None:
+            cur_lr = max(cur_lr, args.lr * args.warmup_factor)
+        #
     elif args.scheduler == 'poly':
         epoch_frac = (args.epochs - epoch) / args.epochs
         epoch_frac = max(epoch_frac, 0)
         cur_lr = args.lr * (epoch_frac ** args.polystep_power)
-        for param_group in optimizer.param_groups:
-            param_group['lr'] = cur_lr
-        #
     elif args.scheduler == 'step':                                            # step
         num_milestones = 0
         for m in args.milestones:
@@ -701,7 +755,7 @@ def adjust_learning_rate(args, optimizer, epoch):
             cur_lr = args.lr
         else:
             lr_min = 0
-            cur_lr = (args.lr - lr_min) * (1 + math.cos(math.pi * epoch / args.epochs)) / 2.0  + lr_min
+            cur_lr = (args.lr - lr_min) * (1 + math.cos(math.pi * epoch / args.epochs)) / 2.0 + lr_min
         #
     else:
         ValueError('Unknown scheduler {}'.format(args.scheduler))
@@ -724,7 +778,7 @@ def accuracy(output, target, topk=(1,)):
 
         res = []
         for k in topk:
-            correct_k = correct[:k].view(-1).float().sum(0, keepdim=True)
+            correct_k = correct[:k].contiguous().view(-1).float().sum(0, keepdim=True)
             res.append(correct_k.mul_(100.0 / batch_size))
         return res
 
@@ -792,10 +846,10 @@ def get_transforms(args):
     always_use_val_transform = (args.rand_scale[0] == 0)
     train_transform = get_validation_transform(args) if always_use_val_transform else get_train_transform(args)
     val_transform = get_validation_transform(args)
-    return train_transform, val_transform
+    return (train_transform, val_transform)
 
 def get_data_loaders(args):
-    train_transform, val_transform = get_transforms(args)
+    train_transform, val_transform = get_transforms(args) if args.transforms is None else (args.transforms[0], args.transforms[1])
 
     train_dataset, val_dataset = xvision.datasets.classification.__dict__[args.dataset_name](args.dataset_config, args.data_path, transforms=(train_transform,val_transform))
 
index bd031238d63a9297672ed19c622517d9a05d6314..d9ee0e750e498195ea9fa74e075d63cdde098c21 100644 (file)
@@ -1,3 +1,31 @@
+# Copyright (c) 2018-2021, Texas Instruments
+# All Rights Reserved.
+#
+# Redistribution and use in source and binary forms, with or without
+# modification, are permitted provided that the following conditions are met:
+#
+# * Redistributions of source code must retain the above copyright notice, this
+#   list of conditions and the following disclaimer.
+#
+# * Redistributions in binary form must reproduce the above copyright notice,
+#   this list of conditions and the following disclaimer in the documentation
+#   and/or other materials provided with the distribution.
+#
+# * Neither the name of the copyright holder nor the names of its
+#   contributors may be used to endorse or promote products derived from
+#   this software without specific prior written permission.
+#
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
+# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
+# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
+# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
+# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
+# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
+# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+
 import os
 import shutil
 import time
@@ -45,14 +73,18 @@ def get_config():
     args.model_config = xnn.utils.ConfigNode()
     args.model_config.output_type = ['segmentation']   # the network is used to predict flow or depth or sceneflow
     args.model_config.output_channels = None            # number of output channels
+    args.model_config.prediction_channels = None        # intermediate number of channels before final output_channels
     args.model_config.input_channels = None             # number of input channels
+    args.model_config.final_upsample = True             # use final upsample to input resolution or not
     args.model_config.output_range = None               # max range of output
     args.model_config.num_decoders = None               # number of decoders to use. [options: 0, 1, None]
     args.model_config.freeze_encoder = False            # do not update encoder weights
     args.model_config.freeze_decoder = False            # do not update decoder weights
     args.model_config.multi_task_type = 'learned'       # find out loss multiplier by learning, choices=[None, 'learned', 'uncertainty', 'grad_norm', 'dwa_grad_norm']
     args.model_config.target_input_ratio = 1            # Keep target size same as input size
-    args.model_config.input_nv12 = False
+    args.model_config.input_nv12 = False                # convert input to nv12 format
+    args.model_config.enable_fp16 = False               # faster training if the GPU supports fp16
+
     args.model = None                                   # the model itself can be given from ouside
     args.model_name = 'deeplabv2lite_mobilenetv2'       # model architecture, overwritten if pretrained is specified
     args.dataset_name = 'cityscapes_segmentation'       # dataset type
@@ -71,7 +103,7 @@ def get_config():
     args.split_files = None                             # split list files. eg: train.txt val.txt
     args.split_value = None                             # test_val split proportion (between 0 (only test) and 1 (only train))
 
-    args.solver = 'adam'                                # solver algorithms, choices=['adam','sgd']
+    args.optimizer = 'adam'                                # optimizer algorithms, choices=['adam','sgd']
     args.scheduler = 'step'                             # scheduler algorithms, choices=['step','poly', 'cosine']
     args.workers = 8                                    # number of data loading workers
 
@@ -88,6 +120,7 @@ def get_config():
     args.lr_clips = None                                # use args.lr itself if it is None
     args.lr_calib = 0.05                                # lr for bias calibration
     args.warmup_epochs = 5                              # number of epochs to warmup
+    args.warmup_factor = 1e-3                           # max lr allowed for the first epoch during warmup (as a factor of initial lr)
 
     args.momentum = 0.9                                 # momentum for sgd, alpha parameter for adam
     args.beta = 0.999                                   # beta parameter for adam
@@ -159,6 +192,7 @@ def get_config():
     args.histogram_range = True                         # histogram range for calibration
     args.bias_calibration = True                        # apply bias correction during quantized inference calibration
     args.per_channel_q = False                          # apply separate quantizion factor for each channel in depthwise or not
+    args.constrain_bias = None                          # constrain bias according to the constraints of convolution engine
 
     args.save_mod_files = False                         # saves modified files after last commit. Also  stores commit id.
     args.make_score_zero_mean = False                   # make score zero mean while learning
@@ -171,9 +205,10 @@ def get_config():
     args.print_train_class_iou = False
     args.print_val_class_iou = False
     args.freeze_layers = None
+    args.opset_version = 11                             # onnx opset_version
+    args.prob_color_to_gray = (0.0,0.0)                 # this will be used for controlling color 2 gray augmentation
 
-    args.opset_version = 9                              # onnx opset_version
-
+    args.interpolation = None                           # interpolation method to be used for resize. one of cv2.INTER_
     return args
 
 
@@ -205,6 +240,11 @@ def main(args):
     # resume has higher priority
     args.pretrained = None if (args.resume is not None) else args.pretrained
 
+    # prob_color_to_gray will be used for controlling color 2 gray augmentation
+    if 'tiad' in args.dataset_name and args.prob_color_to_gray == (0.0, 0.0):
+        #override in case of 'tiad' if default values are used
+        args.prob_color_to_gray = (0.5, 0.0)
+
     if args.save_path is None:
         save_path = get_save_path(args)
     else:
@@ -359,18 +399,20 @@ def main(args):
         #
         if 'training' in args.phase:
             model = xnn.quantize.QuantTrainModule(model, per_channel_q=args.per_channel_q,
-                        histogram_range=args.histogram_range, bitwidth_weights=args.bitwidth_weights, bitwidth_activations=args.bitwidth_activations, dummy_input=dummy_input)
+                        histogram_range=args.histogram_range, bitwidth_weights=args.bitwidth_weights,
+                        bitwidth_activations=args.bitwidth_activations, constrain_bias=args.constrain_bias,
+                        dummy_input=dummy_input)
         elif 'calibration' in args.phase:
             model = xnn.quantize.QuantCalibrateModule(model, per_channel_q=args.per_channel_q,
                         bitwidth_weights=args.bitwidth_weights, bitwidth_activations=args.bitwidth_activations,
-                        histogram_range=args.histogram_range, bias_calibration=args.bias_calibration,
-                        dummy_input=dummy_input, lr_calib=args.lr_calib)
+                        histogram_range=args.histogram_range, constrain_bias=args.constrain_bias,
+                        bias_calibration=args.bias_calibration, dummy_input=dummy_input, lr_calib=args.lr_calib)
         elif 'validation' in args.phase:
             # Note: bias_calibration is not emabled
             model = xnn.quantize.QuantTestModule(model, per_channel_q=args.per_channel_q,
                         bitwidth_weights=args.bitwidth_weights, bitwidth_activations=args.bitwidth_activations,
-                        histogram_range=args.histogram_range, dummy_input=dummy_input,
-                        model_surgery_quantize=model_surgery_quantize)
+                        histogram_range=args.histogram_range, constrain_bias=args.constrain_bias,
+                        dummy_input=dummy_input, model_surgery_quantize=model_surgery_quantize)
         else:
             assert False, f'invalid phase {args.phase}'
     #
@@ -393,7 +435,7 @@ def main(args):
         count_flops(args, model)
 
     #################################################
-    if args.save_onnx and (any(args.phase in p for p in ('training','calibration')) or (args.run_soon == False)):
+    if args.save_onnx:
         write_onnx_model(args, get_model_orig(model), save_path, save_traced_model=False)
     #
 
@@ -437,6 +479,8 @@ def main(args):
                     kw_args.update({arg:args.model_config.output_channels[task_dx]})
                 elif arg == 'sparse':
                     kw_args.update({arg:args.sparse})
+                elif arg == 'enable_fp16':
+                    kw_args.update({arg:args.model_config.enable_fp16})
                 #
             #
             loss_fn_raw = xvision.losses.__dict__[loss_fn](**kw_args)
@@ -462,7 +506,10 @@ def main(args):
                     kw_args.update({arg:args.model_config.output_channels[task_dx]})
                 elif arg == 'sparse':
                     kw_args.update({arg:args.sparse})
-
+                elif arg == 'enable_fp16':
+                    kw_args.update({arg:args.model_config.enable_fp16})
+                #
+            #
             metric_fn_raw = xvision.losses.__dict__[metric_fn](**kw_args)
             if args.parallel_criterion:
                 metric_fn = torch.nn.DataParallel(metric_fn_raw).cuda()
@@ -483,8 +530,8 @@ def main(args):
         return
 
     #################################################
-    assert(args.solver in ['adam', 'sgd'])
-    print('=> setting {} solver'.format(args.solver))
+    assert(args.optimizer in ['adam', 'sgd'])
+    print('=> setting {} optimizer'.format(args.optimizer))
     if args.lr_clips is not None:
         learning_rate_clips = args.lr_clips if 'training' in args.phase else 0.0
         clips_decay = args.bias_decay if (args.bias_decay is not None and args.bias_decay != 0.0) else args.weight_decay
@@ -497,19 +544,21 @@ def main(args):
     #
 
     learning_rate = args.lr if ('training'in args.phase) else 0.0
-    if args.solver == 'adam':
+    if args.optimizer == 'adam':
         optimizer = torch.optim.Adam(param_groups, learning_rate, betas=(args.momentum, args.beta))
-    elif args.solver == 'sgd':
+    elif args.optimizer == 'sgd':
         optimizer = torch.optim.SGD(param_groups, learning_rate, momentum=args.momentum)
     else:
-        raise ValueError('Unknown optimizer type{}'.format(args.solver))
+        raise ValueError('Unknown optimizer type{}'.format(args.optimizer))
     #
 
     #################################################
     max_iter = args.epochs * len(train_loader)
-    scheduler = xnn.optim.lr_scheduler.SchedulerWrapper(args.scheduler, optimizer, args.epochs, args.start_epoch, \
-                                                            args.warmup_epochs, max_iter=max_iter, polystep_power=args.polystep_power,
-                                                            milestones=args.milestones, multistep_gamma=args.multistep_gamma)
+    scheduler = xnn.optim.lr_scheduler.SchedulerWrapper(scheduler_type=args.scheduler, optimizer=optimizer,
+                    epochs=args.epochs, start_epoch=args.start_epoch,
+                    warmup_epochs=args.warmup_epochs, warmup_factor=args.warmup_factor,
+                    max_iter=max_iter, polystep_power=args.polystep_power,
+                    milestones=args.milestones, multistep_gamma=args.multistep_gamma)
 
     # optionally resume from a checkpoint
     if args.resume:
@@ -543,6 +592,8 @@ def main(args):
         with torch.no_grad():
             validate(args, val_dataset, val_loader, model, args.start_epoch, val_writer)
 
+    grad_scaler = torch.cuda.amp.GradScaler() if args.model_config.enable_fp16 else None
+
     for epoch in range(args.start_epoch, args.epochs):
         # epoch is needed to seed shuffling in DistributedSampler, every epoch.
         # otherwise seed of 0 is used every epoch, which seems incorrect.
@@ -552,7 +603,7 @@ def main(args):
             val_sampler.set_epoch(epoch)
 
         # train for one epoch
-        train(args, train_dataset, train_loader, model, optimizer, epoch, train_writer, scheduler)
+        train(args, train_dataset, train_loader, model, optimizer, epoch, train_writer, scheduler, grad_scaler)
 
         # evaluate on validation set
         with torch.no_grad():
@@ -603,7 +654,7 @@ def is_valid_phase(phase):
 
 
 ###################################################################
-def train(args, train_dataset, train_loader, model, optimizer, epoch, train_writer, scheduler):
+def train(args, train_dataset, train_loader, model, optimizer, epoch, train_writer, scheduler, grad_scaler):
     batch_time = xnn.utils.AverageMeter()
     data_time = xnn.utils.AverageMeter()
     # if the loss/ metric is already an average, no need to further average
@@ -614,10 +665,17 @@ def train(args, train_dataset, train_loader, model, optimizer, epoch, train_writ
     ##########################
     # switch to train mode
     model.train()
-    if args.freeze_bn:
+
+    # freeze bn and range after some epochs during quantization
+    if args.freeze_bn or (args.quantize and epoch > 2 and epoch >= ((args.epochs//2)-1)):
+        xnn.utils.print_once('Freezing BN for subsequent epochs')
         xnn.utils.freeze_bn(model)
     #
-    
+    if (args.quantize and epoch > 4 and epoch >= ((args.epochs//2)+1)):
+        xnn.utils.print_once('Freezing ranges for subsequent epochs')
+        xnn.layers.freeze_quant_range(model)
+    #
+
     #freeze layers 
     if args.freeze_layers is not None:
         # 'freeze_layer_name' could be part of 'name', i.e. 'name' need not be exact same as 'freeze_layer_name'
@@ -695,15 +753,25 @@ def train(args, train_dataset, train_loader, model, optimizer, epoch, train_writ
             xnn.layers.set_losses(model, loss_list_orig)
 
         if 'training' in args.phase:
-            # zero gradients so that we can accumulate gradients
-            if (iter_id % args.iter_size) == 0:
-                optimizer.zero_grad()
-
             # accumulate gradients
-            loss_total.backward()
+            if args.model_config.enable_fp16:
+                grad_scaler.scale(loss_total).backward()
+            else:
+                loss_total.backward()
+            #
+
             # optimization step
             if ((iter_id+1) % args.iter_size) == 0:
-                optimizer.step()
+                if args.model_config.enable_fp16:
+                    grad_scaler.step(optimizer)
+                    grad_scaler.update()
+                else:
+                    optimizer.step()
+                #
+                # zero gradients so that we can accumulate gradients
+                # setting grad=None is a faster alternative instead of optimizer.zero_grad()
+                xnn.utils.clear_grad(model)
+            #
         #
 
         # record loss.
@@ -1137,6 +1205,7 @@ def get_train_transform(args):
     image_prenorm = xvision.transforms.image_transforms.NormalizeMeanScale(mean=image_mean, scale=image_scale) if args.image_prenorm else None
     image_postnorm = xvision.transforms.image_transforms.NormalizeMeanScale(mean=image_mean, scale=image_scale) if (not image_prenorm) else None
     reverse_channels = xvision.transforms.image_transforms.ReverseImageChannels() if args.input_channel_reverse else None
+    color_2_gray = xvision.transforms.image_transforms.RandomColor2Gray(is_flow=args.is_flow, random_threshold=args.prob_color_to_gray[0]) if args.prob_color_to_gray[0] != 0.0 else None
 
     # crop size used only for training
     image_train_output_scaling = xvision.transforms.image_transforms.Scale(args.rand_resize, target_size=args.rand_output_size, is_flow=args.is_flow) \
@@ -1144,14 +1213,14 @@ def get_train_transform(args):
     train_transform = xvision.transforms.image_transforms.Compose([
         reverse_channels,
         image_prenorm,
-        xvision.transforms.image_transforms.AlignImages(),
+        xvision.transforms.image_transforms.AlignImages(interpolation=args.interpolation),
         xvision.transforms.image_transforms.MaskTarget(args.target_mask, 0),
         xvision.transforms.image_transforms.CropRect(args.img_border_crop),
         xvision.transforms.image_transforms.RandomRotate(args.transform_rotation, is_flow=args.is_flow) if args.transform_rotation else None,
-        xvision.transforms.image_transforms.RandomScaleCrop(args.rand_resize, scale_range=args.rand_scale, is_flow=args.is_flow),
+        xvision.transforms.image_transforms.RandomScaleCrop(args.rand_resize, scale_range=args.rand_scale, is_flow=args.is_flow, interpolation=args.interpolation),
         xvision.transforms.image_transforms.RandomHorizontalFlip(is_flow=args.is_flow),
         xvision.transforms.image_transforms.RandomCrop(args.rand_crop),
-        xvision.transforms.image_transforms.RandomColor2Gray(is_flow=args.is_flow, random_threshold=0.5) if 'tiad' in args.dataset_name else None,
+        color_2_gray,
         image_train_output_scaling,
         image_postnorm,
         xvision.transforms.image_transforms.ConvertToTensor()
@@ -1166,15 +1235,17 @@ def get_validation_transform(args):
     image_prenorm = xvision.transforms.image_transforms.NormalizeMeanScale(mean=image_mean, scale=image_scale) if args.image_prenorm else None
     image_postnorm = xvision.transforms.image_transforms.NormalizeMeanScale(mean=image_mean, scale=image_scale) if (not image_prenorm) else None
     reverse_channels = xvision.transforms.image_transforms.ReverseImageChannels() if args.input_channel_reverse else None
+    color_2_gray = xvision.transforms.image_transforms.RandomColor2Gray(is_flow=args.is_flow, random_threshold=args.prob_color_to_gray[1]) if args.prob_color_to_gray[1] != 0.0 else None
 
     # prediction is resized to output_size before evaluation.
     val_transform = xvision.transforms.image_transforms.Compose([
         reverse_channels,
         image_prenorm,
-        xvision.transforms.image_transforms.AlignImages(),
+        xvision.transforms.image_transforms.AlignImages(interpolation=args.interpolation),
         xvision.transforms.image_transforms.MaskTarget(args.target_mask, 0),
         xvision.transforms.image_transforms.CropRect(args.img_border_crop),
-        xvision.transforms.image_transforms.Scale(args.img_resize, target_size=args.output_size, is_flow=args.is_flow),
+        xvision.transforms.image_transforms.Scale(args.img_resize, target_size=args.output_size, is_flow=args.is_flow, interpolation=args.interpolation),
+        color_2_gray,
         image_postnorm,
         xvision.transforms.image_transforms.ConvertToTensor()
         ])
index 6cf402db5d0d6743c9882aba4a8f8a972b1a1fa8..b28b4009f36b887078ff18dee6593afb634d97e7 100644 (file)
@@ -1,7 +1,39 @@
+#################################################################################
+# Copyright (c) 2018-2021, Texas Instruments Incorporated - http://www.ti.com
+# All Rights Reserved.
+#
+# Redistribution and use in source and binary forms, with or without
+# modification, are permitted provided that the following conditions are met:
+#
+# * Redistributions of source code must retain the above copyright notice, this
+#   list of conditions and the following disclaimer.
+#
+# * Redistributions in binary form must reproduce the above copyright notice,
+#   this list of conditions and the following disclaimer in the documentation
+#   and/or other materials provided with the distribution.
+#
+# * Neither the name of the copyright holder nor the names of its
+#   contributors may be used to endorse or promote products derived from
+#   this software without specific prior written permission.
+#
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
+# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
+# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
+# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
+# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
+# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
+# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+#
+#################################################################################
+
 from . import layers
 from . import optim
 from . import utils
 from . import quantize
 from . import onnx
 
-
+try: from . import quantize_torch_internal as quantize_torch
+except: pass
index 464dbb0e46f7001c749f4bd6787ac4062e945fa6..a4fb415f7960dbf70f4232a1c4e2e889ccca4a22 100644 (file)
@@ -1,3 +1,35 @@
+#################################################################################
+# Copyright (c) 2018-2021, Texas Instruments Incorporated - http://www.ti.com
+# All Rights Reserved.
+#
+# Redistribution and use in source and binary forms, with or without
+# modification, are permitted provided that the following conditions are met:
+#
+# * Redistributions of source code must retain the above copyright notice, this
+#   list of conditions and the following disclaimer.
+#
+# * Redistributions in binary form must reproduce the above copyright notice,
+#   this list of conditions and the following disclaimer in the documentation
+#   and/or other materials provided with the distribution.
+#
+# * Neither the name of the copyright holder nor the names of its
+#   contributors may be used to endorse or promote products derived from
+#   this software without specific prior written permission.
+#
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
+# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
+# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
+# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
+# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
+# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
+# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+#
+#################################################################################
+
+
 from .import functional
 from .layer_config import *
 from .normalization import *
@@ -7,6 +39,7 @@ from .common_blocks import *
 from .conv_blocks import *
 from .deconv_blocks import *
 from .resize_blocks import *
+from .functional_blocks import *
 
 from .multi_task import *
 from .rf_blocks import *
@@ -19,4 +52,4 @@ from .quant_ste import *
 try:
     from .blocks_internal import *
 except:
-    pass
\ No newline at end of file
+    pass
index d89e6eea496517cbc2c320905e891abf2ba20b15..5e6a7c6c4f4419d15846e4bfeb725cb038c3fa14 100644 (file)
@@ -1,3 +1,34 @@
+#################################################################################
+# Copyright (c) 2018-2021, Texas Instruments Incorporated - http://www.ti.com
+# All Rights Reserved.
+#
+# Redistribution and use in source and binary forms, with or without
+# modification, are permitted provided that the following conditions are met:
+#
+# * Redistributions of source code must retain the above copyright notice, this
+#   list of conditions and the following disclaimer.
+#
+# * Redistributions in binary form must reproduce the above copyright notice,
+#   this list of conditions and the following disclaimer in the documentation
+#   and/or other materials provided with the distribution.
+#
+# * Neither the name of the copyright holder nor the names of its
+#   contributors may be used to endorse or promote products derived from
+#   this software without specific prior written permission.
+#
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
+# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
+# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
+# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
+# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
+# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
+# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+#
+#################################################################################
+
 import numpy as np
 import torch
 
@@ -5,17 +36,16 @@ from .functional import *
 from .. import utils
 
 
-###############################################################
 # Parametric Activation (PACT) with clip values being power of 2
 # Supports learned mode, estimated mode or fixed range
 class PAct2(torch.nn.Module):
-    # constants
-    PACT2_RANGE_LEARN = False   # False : Running Avg, True  : Backprop
-    PACT2_RANGE_SHRINK = 0.01   # 0.01
-    PACT2_RANGE_INIT = 8.0      # this is the starting range
-    PACT2_RANGE_EXPANSION = 1.0 # expand the calculated range for margin
+    # constants - default/init values
+    PACT2_RANGE_LEARN_MODE = False      # False : Running Avg, True  : Backprop
+    PACT2_RANGE_SHRINK_DEFAULT = 0.01   # 0.01
+    PACT2_RANGE_INIT = 8.0              # this is the starting range
+    PACT2_RANGE_EXPANSION_FACTOR = 1.0  # expand the calculated range for margin
 
-    def __init__(self, inplace=False, signed=None, range_shrink_percentile=PACT2_RANGE_SHRINK, clip_range=None,
+    def __init__(self, inplace=False, signed=None, range_shrink_activations=PACT2_RANGE_SHRINK_DEFAULT, clip_range=None,
                  power2_activation_range=True, **kwargs):
         super().__init__()
         if (clip_range is not None) and (signed is not None):
@@ -24,9 +54,9 @@ class PAct2(torch.nn.Module):
         self.inplace = inplace
         self.clip_range = clip_range
         self.signed = signed if (clip_range is None) else (clip_range[0]<0.0)
-        self.range_shrink_percentile = range_shrink_percentile # range shrinking factor
+        self.range_shrink_activations = range_shrink_activations # range shrinking factor
         self.fixed_range = (clip_range is not None)
-        self.learn_range = (self.PACT2_RANGE_LEARN and (not self.fixed_range))
+        self.learn_range = (self.PACT2_RANGE_LEARN_MODE and (not self.fixed_range))
         self.eps = np.power(2.0, -16.0)
         self.power2_activation_range = power2_activation_range   # power of 2 ranges
         self.log_base = None # 2.0  # log is used only in learned mode if log_base is not None
@@ -55,7 +85,7 @@ class PAct2(torch.nn.Module):
             self.register_buffer('num_batches_tracked', torch.tensor(-1.0, dtype=torch.float32))
 
             if utils.has_range_estimator:
-                self.range_estimator = utils.RangeEstimator(range_shrink_percentile=range_shrink_percentile,
+                self.range_estimator = utils.RangeEstimator(range_shrink_percentile=range_shrink_activations,
                                                             range_update_factor_min=self.range_update_factor_min)
             #
         #
@@ -86,6 +116,10 @@ class PAct2(torch.nn.Module):
         return '{}(inplace={}, signed={}, clips={})'.format(self.__class__.__name__, self.inplace, self.signed, clips)
 
 
+    def freeze_range(self):
+        self.fixed_range = True
+
+
     def convert_to_log(self, x):
         if (not self.learn_range) or (self.log_base is None):
             return x
@@ -102,8 +136,8 @@ class PAct2(torch.nn.Module):
 
     def update_clips_act(self, x):
         if self.learn_range or (self.range_estimator is None):
-            x_min, x_max = utils.extrema_fast(x, range_shrink_percentile=self.range_shrink_percentile)
-            x_min, x_max = self.convert_to_log(x_min*self.PACT2_RANGE_EXPANSION), self.convert_to_log(x_max*self.PACT2_RANGE_EXPANSION)
+            x_min, x_max = utils.extrema_fast(x, range_shrink_percentile=self.range_shrink_activations)
+            x_min, x_max = self.convert_to_log(x_min*self.PACT2_RANGE_EXPANSION_FACTOR), self.convert_to_log(x_max*self.PACT2_RANGE_EXPANSION_FACTOR)
             # exponential update factor
             update_factor = 1.0 / float(self.num_batches_tracked if self.num_batches_tracked else 1.0)
             update_factor = max(update_factor, self.range_update_factor_min)
@@ -124,7 +158,7 @@ class PAct2(torch.nn.Module):
         clip_max = torch.clamp(clip_max, min=self.eps)
         clip_max = self.convert_to_linear(clip_max)
         # in range learning mode + training - this power2_activation_range is taken care in the quantize function
-        is_learning_range = (self.PACT2_RANGE_LEARN and self.training)
+        is_learning_range = (self.PACT2_RANGE_LEARN_MODE and self.training)
         use_power2_activation_range = (self.power2_activation_range and (not is_learning_range))
         clip_max2 = ceil2_g(clip_max) if use_power2_activation_range else clip_max
         clip_min2 = (-clip_max2 if signed else clip_max2*0.0)
@@ -174,6 +208,7 @@ class QAct(torch.nn.Module):
         return x
 
 
+###############################################################
 # Never quantized activation function.
 # Also if the next block is this, the previous block output is also not quantized.
 # Inserting this activation function is a simple way to avoid quantization at certain places.
@@ -184,4 +219,15 @@ class NoQAct(torch.nn.Module):
         self.signed = signed
 
     def forward(self, x):
-        return x
\ No newline at end of file
+        return x
+
+
+###############################################################
+def freeze_quant_range(module):
+    def _freeze_range_op(op):
+        if isinstance(op, PAct2):
+            op.freeze_range()
+        #
+    #
+    module.apply(_freeze_range_op)
+    module.apply(torch.quantization.disable_observer)
diff --git a/modules/pytorch_jacinto_ai/xnn/layers/blocks_internal.py b/modules/pytorch_jacinto_ai/xnn/layers/blocks_internal.py
new file mode 100644 (file)
index 0000000..9b12ac3
--- /dev/null
@@ -0,0 +1,109 @@
+#################################################################################
+# Copyright (c) 2018-2021, Texas Instruments Incorporated - http://www.ti.com
+# All Rights Reserved.
+#
+# Redistribution and use in source and binary forms, with or without
+# modification, are permitted provided that the following conditions are met:
+#
+# * Redistributions of source code must retain the above copyright notice, this
+#   list of conditions and the following disclaimer.
+#
+# * Redistributions in binary form must reproduce the above copyright notice,
+#   this list of conditions and the following disclaimer in the documentation
+#   and/or other materials provided with the distribution.
+#
+# * Neither the name of the copyright holder nor the names of its
+#   contributors may be used to endorse or promote products derived from
+#   this software without specific prior written permission.
+#
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
+# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
+# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
+# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
+# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
+# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
+# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+#
+#################################################################################
+
+from .layer_config import *
+from . import conv_blocks
+from . import common_blocks
+
+
+###########################################################
+def ConvGWSepNormAct2d(in_planes, out_planes, stride=1, kernel_size=None, groups=1, dilation=1, bias=False, \
+                   first_1x1=False, normalization=(DefaultNorm2d,DefaultNorm2d), activation=(DefaultAct2d,DefaultAct2d),
+                   shuffle=True, **kwargs):
+    num_ch_in_dws = 4
+    groups_dws = in_planes//num_ch_in_dws
+    layers = [conv_blocks.ConvNormAct2d(in_planes, in_planes, groups=groups_dws, kernel_size=kernel_size, bias=bias, dilation=dilation,
+                normalization=normalization[0], activation=activation[0])]
+    ###########################################################################            
+    #shuffle between first 3x3 and 1x1 with group = Ni/4
+    #Example Ni = 64, No= 64, G = 4
+    #3x3 with G = 16, will be on [0,1,2,3], [4,5,6,7]……[60,61,62,63]
+    #After shuffle with G=16
+    #Channel order will be [0,4,8,12,16,20,24,….60] [1,5,9,13,……61], [2,6,10,14,……62], [3,7,11,15,……63]
+    #Then 1x1 with G = 4 they will be rightly mixed.
+    ###########################################################################            
+    if shuffle and (groups != 1):
+        layers += [common_blocks.ShuffleBlock(groups=groups_dws)]
+
+    layers += [conv_blocks.ConvNormAct2d(in_planes, out_planes, groups=groups, kernel_size=1, bias=bias,
+                normalization=normalization[1], activation=activation[1])]
+
+    layers = torch.nn.Sequential(*layers)
+    return layers
+
+
+######################################################
+# this is called a lite block because the dilated convolutions use
+# ConvDWNormAct2d instead of ConvDWSepNormAct2d
+class GWASPPLiteBlock(torch.nn.Module):
+    def __init__(self, in_chs, aspp_chs, out_chs, dilation=(6, 12, 18), groups=1, avg_pool=False, activation=DefaultAct2d, linear_dw=False):
+        super().__init__()
+
+        self.aspp_chs = aspp_chs
+        self.avg_pool = avg_pool
+        self.last_chns = aspp_chs * (4 + (1 if self.avg_pool else 0))
+
+        if self.avg_pool:
+            self.gave_pool = torch.nn.Sequential(activation(inplace=False), torch.nn.AdaptiveAvgPool2d((1, 1)),
+                                           torch.nn.Conv2d(in_chs, aspp_chs, kernel_size=1), activation(inplace=True))
+        #
+
+        self.conv1x1 = conv_blocks.ConvNormAct2d(in_chs, aspp_chs, kernel_size=1, activation=activation)
+        normalizations_dw = ((not linear_dw), True)
+        activations_dw = (False if linear_dw else activation, activation)
+        self.aspp_bra1 = ConvGWSepNormAct2d(in_chs, aspp_chs, kernel_size=3, dilation=dilation[0], normalization=normalizations_dw, activation=activations_dw, groups = groups)
+        self.aspp_bra2 = ConvGWSepNormAct2d(in_chs, aspp_chs, kernel_size=3, dilation=dilation[1], normalization=normalizations_dw, activation=activations_dw, groups = groups)
+        self.aspp_bra3 = ConvGWSepNormAct2d(in_chs, aspp_chs, kernel_size=3, dilation=dilation[2], normalization=normalizations_dw, activation=activations_dw, groups = groups)
+
+        self.dropout = torch.nn.Dropout2d(p=0.2, inplace=True)
+        self.aspp_out = conv_blocks.ConvNormAct2d(self.last_chns, out_chs, kernel_size=1, groups=1, activation=activation)
+        self.cat = common_blocks.CatBlock()
+
+    def forward(self, x):
+        x1 = self.conv1x1(x)
+        b1 = self.aspp_bra1(x)
+        b2 = self.aspp_bra2(x)
+        b3 = self.aspp_bra3(x)
+
+        if self.avg_pool:
+            xavg = F.interpolate(self.gave_pool(self.aspp_in(x)), size=x.shape[2:], mode='bilinear')
+            branches = [xavg, x1, b1, b2, b3]
+        else:
+            branches = [x1, b1, b2, b3]
+        #
+
+        cat = self.cat(branches)
+        cat = self.dropout(cat)
+        out = self.aspp_out(cat)
+        return out
+#
+
+
index f7e3b9e294583cd63a0b0e6b93ea80df2dd30d03..1c091452c934bc315e10de66693b8746290e83d8 100644 (file)
@@ -1,3 +1,34 @@
+#################################################################################
+# Copyright (c) 2018-2021, Texas Instruments Incorporated - http://www.ti.com
+# All Rights Reserved.
+#
+# Redistribution and use in source and binary forms, with or without
+# modification, are permitted provided that the following conditions are met:
+#
+# * Redistributions of source code must retain the above copyright notice, this
+#   list of conditions and the following disclaimer.
+#
+# * Redistributions in binary form must reproduce the above copyright notice,
+#   this list of conditions and the following disclaimer in the documentation
+#   and/or other materials provided with the distribution.
+#
+# * Neither the name of the copyright holder nor the names of its
+#   contributors may be used to endorse or promote products derived from
+#   this software without specific prior written permission.
+#
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
+# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
+# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
+# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
+# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
+# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
+# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+#
+#################################################################################
+
 import torch
 from . import functional
 
index 28077be6ce65ff477d2662c6b558a2014fea2291..4039c0645ec4abca8855f7e2df15e316c883907b 100644 (file)
@@ -1,3 +1,34 @@
+#################################################################################
+# Copyright (c) 2018-2021, Texas Instruments Incorporated - http://www.ti.com
+# All Rights Reserved.
+#
+# Redistribution and use in source and binary forms, with or without
+# modification, are permitted provided that the following conditions are met:
+#
+# * Redistributions of source code must retain the above copyright notice, this
+#   list of conditions and the following disclaimer.
+#
+# * Redistributions in binary form must reproduce the above copyright notice,
+#   this list of conditions and the following disclaimer in the documentation
+#   and/or other materials provided with the distribution.
+#
+# * Neither the name of the copyright holder nor the names of its
+#   contributors may be used to endorse or promote products derived from
+#   this software without specific prior written permission.
+#
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
+# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
+# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
+# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
+# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
+# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
+# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+#
+#################################################################################
+
 import torch
 from .layer_config import *
 from . import functional
diff --git a/modules/pytorch_jacinto_ai/xnn/layers/conv_ws_internal.py b/modules/pytorch_jacinto_ai/xnn/layers/conv_ws_internal.py
new file mode 100644 (file)
index 0000000..ef2a6c6
--- /dev/null
@@ -0,0 +1,74 @@
+#################################################################################
+# Copyright (c) 2018-2021, Texas Instruments Incorporated - http://www.ti.com
+# All Rights Reserved.
+#
+# Redistribution and use in source and binary forms, with or without
+# modification, are permitted provided that the following conditions are met:
+#
+# * Redistributions of source code must retain the above copyright notice, this
+#   list of conditions and the following disclaimer.
+#
+# * Redistributions in binary form must reproduce the above copyright notice,
+#   this list of conditions and the following disclaimer in the documentation
+#   and/or other materials provided with the distribution.
+#
+# * Neither the name of the copyright holder nor the names of its
+#   contributors may be used to endorse or promote products derived from
+#   this software without specific prior written permission.
+#
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
+# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
+# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
+# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
+# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
+# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
+# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+#
+#################################################################################
+
+import torch
+
+# Similar to Weight Normalization: https://arxiv.org/abs/1602.07868, but there are differences.
+# Also similar to Weight Standardization: https://arxiv.org/abs/1903.10520, but there are differences.
+# In this implementation:
+# (1) Snmall std value of weights are clamped to eps, istead of blindly adding eps to std
+# (2) The whole tensor can be jointly standardized (optional), instead of each output channel separately
+# (3) The standardized weights are saved into the parameter in eval pass so that the stored weights can work with regular convolution as well.
+# (4) ONNX export does not export the standardization operations, but only the standardized weights with regular convolution
+# Make sure that the model state_dict saving and the ONNX export are done in eval model.
+# Also make sure that your training does an eval pass at the end, so that the standardized weights are available in params.
+class ConvWS2d(torch.nn.Conv2d):
+    def __init__(self, *args, per_channel=True, **kwargs):
+        super().__init__(*args, **kwargs)
+        self.per_channel = per_channel
+
+    def forward(self, input):
+        weight = standardize_weight(self.weight)
+        if not self.training:
+            # store the standardized weight in the parameter
+            # so that the stored weights can work with regular convolution as well.
+            # note: this storing is done only in eval mode to save time
+            # make sure that an eval mode run is done before saving the weights
+            self.weight.data.copy_(weight.data)
+            # detach the weight in eval mode to make sure that
+            # the onnx graph does not have the above operations
+            weight = weight.data
+        #
+        return self.conv2d_forward(input, weight)
+
+
+def standardize_weight(weight, per_channel=True):
+    if per_channel:
+        wsz0 = weight.size(0)
+        weight_mean = weight.view(wsz0, -1).mean(dim=1).view(wsz0, 1, 1, 1)
+        weight_std = weight.view(wsz0, -1).std(dim=1).view(wsz0, 1, 1, 1)
+    else:
+        weight_mean = weight.mean()
+        weight_std = weight.std()
+    #
+    weight_std = torch.clamp(weight_std, min=1e-5)
+    weight = (weight - weight_mean) / weight_std
+    return weight
\ No newline at end of file
index f5faa9b90d5f66149ebf501d7ed44e2605bdee63..321725df81b835798e2284d71340478bb98c3511 100644 (file)
@@ -1,3 +1,34 @@
+#################################################################################
+# Copyright (c) 2018-2021, Texas Instruments Incorporated - http://www.ti.com
+# All Rights Reserved.
+#
+# Redistribution and use in source and binary forms, with or without
+# modification, are permitted provided that the following conditions are met:
+#
+# * Redistributions of source code must retain the above copyright notice, this
+#   list of conditions and the following disclaimer.
+#
+# * Redistributions in binary form must reproduce the above copyright notice,
+#   this list of conditions and the following disclaimer in the documentation
+#   and/or other materials provided with the distribution.
+#
+# * Neither the name of the copyright holder nor the names of its
+#   contributors may be used to endorse or promote products derived from
+#   this software without specific prior written permission.
+#
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
+# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
+# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
+# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
+# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
+# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
+# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+#
+#################################################################################
+
 import torch
 from .conv_blocks import *
 from .layer_config import *
index 97095ee4afffc5e5e8050ebfa3a684cddb725aaa..7665585f527e9c3ea1b16704c19e77af4c5f8ffb 100644 (file)
@@ -1,3 +1,34 @@
+#################################################################################
+# Copyright (c) 2018-2021, Texas Instruments Incorporated - http://www.ti.com
+# All Rights Reserved.
+#
+# Redistribution and use in source and binary forms, with or without
+# modification, are permitted provided that the following conditions are met:
+#
+# * Redistributions of source code must retain the above copyright notice, this
+#   list of conditions and the following disclaimer.
+#
+# * Redistributions in binary form must reproduce the above copyright notice,
+#   this list of conditions and the following disclaimer in the documentation
+#   and/or other materials provided with the distribution.
+#
+# * Neither the name of the copyright holder nor the names of its
+#   contributors may be used to endorse or promote products derived from
+#   this software without specific prior written permission.
+#
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
+# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
+# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
+# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
+# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
+# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
+# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+#
+#################################################################################
+
 import torch
 import numpy as np
 
@@ -127,7 +158,7 @@ class Floor2G(torch.autograd.Function):
 
 class QuantizeDequantizeG(torch.autograd.Function):
     @staticmethod
-    def forward(ctx, x, scale_tensor, width_min, width_max, power2, round_type='round_up'):
+    def forward(ctx, x, scale_tensor, width_min, width_max, power2, axis, round_type='round_up'):
         # apply quantization
         y, x_scaled_round = QuantizeDequantizeG.quant_dequant(ctx, x, scale_tensor, width_min, width_max, power2, round_type)
         # save for backward
@@ -186,9 +217,32 @@ class QuantizeDequantizeG(torch.autograd.Function):
         ds = dy * ds_local
 
         # return
-        return dx, ds, None, None, None, None
+        return dx, ds, None, None, None, None, None
+
+
+    @staticmethod
+    def symbolic(g,  x, scale_tensor, width_min, width_max, power2, round_type='round_up'):
+        return g.op("QuantizeDequantizeG",  x, scale_tensor)
+
+
+class TorchQuantizeDequantizeG(torch.autograd.Function):
+    @staticmethod
+    def forward(ctx, x, scale_tensor, width_min, width_max, power2, axis, round_type='round_up'):
+        # apply quantization
+        if scale_tensor.dim()>0:
+            device = x.device
+            axis_size = int(x.size(axis))
+            scale_tensor = scale_tensor.reshape(axis_size)
+            zero_point = torch.zeros(axis_size).to(device=device, dtype=torch.long)
+            y = torch.fake_quantize_per_channel_affine(x, scale=scale_tensor, zero_point=zero_point, axis=axis,
+                    quant_min=int(width_min), quant_max=int(width_max))
+        else:
+            y = torch.fake_quantize_per_tensor_affine(x, scale=float(scale_tensor), zero_point=0,
+                    quant_min=int(width_min), quant_max=int(width_max))
+        #
+        return y
 
 
     @staticmethod
     def symbolic(g,  x, scale_tensor, width_min, width_max, power2, round_type='round_up'):
-        return g.op("QuantizeDequantizeG",  x, scale_tensor)
\ No newline at end of file
+        return g.op("TorchQuantizeDequantizeG",  x, scale_tensor)
\ No newline at end of file
index 16f1f3da0e1eea1e4c60b8f9c388bef4536fdbe2..31bdc14f6197842872abf145060b75de7af50909 100644 (file)
@@ -1,3 +1,33 @@
+#################################################################################
+# Copyright (c) 2018-2021, Texas Instruments Incorporated - http://www.ti.com
+# All Rights Reserved.
+#
+# Redistribution and use in source and binary forms, with or without
+# modification, are permitted provided that the following conditions are met:
+#
+# * Redistributions of source code must retain the above copyright notice, this
+#   list of conditions and the following disclaimer.
+#
+# * Redistributions in binary form must reproduce the above copyright notice,
+#   this list of conditions and the following disclaimer in the documentation
+#   and/or other materials provided with the distribution.
+#
+# * Neither the name of the copyright holder nor the names of its
+#   contributors may be used to endorse or promote products derived from
+#   this software without specific prior written permission.
+#
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
+# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
+# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
+# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
+# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
+# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
+# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+#
+#################################################################################
 import functools
 import torch
 from . import function
@@ -19,6 +49,7 @@ ceil2_g = quant_ste.PropagateQuantTensorSTE(function.Ceil2G.apply)
 # by using quantize_backward_type = 'qte' in QuantTrainPAct2
 # TODO: when using QTE here, we need to register this OP for ONNX export to work
 # and even then the exported model may not be clean.
+#quantize_dequantize_g = quant_ste.PropagateQuantTensorSTE(function.TorchQuantizeDequantizeG.apply)
 quantize_dequantize_g = quant_ste.PropagateQuantTensorSTE(function.QuantizeDequantizeG.apply)
 
 
diff --git a/modules/pytorch_jacinto_ai/xnn/layers/functional_blocks.py b/modules/pytorch_jacinto_ai/xnn/layers/functional_blocks.py
new file mode 100644 (file)
index 0000000..cc994b1
--- /dev/null
@@ -0,0 +1,46 @@
+#################################################################################
+# Copyright (c) 2018-2021, Texas Instruments Incorporated - http://www.ti.com
+# All Rights Reserved.
+#
+# Redistribution and use in source and binary forms, with or without
+# modification, are permitted provided that the following conditions are met:
+#
+# * Redistributions of source code must retain the above copyright notice, this
+#   list of conditions and the following disclaimer.
+#
+# * Redistributions in binary form must reproduce the above copyright notice,
+#   this list of conditions and the following disclaimer in the documentation
+#   and/or other materials provided with the distribution.
+#
+# * Neither the name of the copyright holder nor the names of its
+#   contributors may be used to endorse or promote products derived from
+#   this software without specific prior written permission.
+#
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
+# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
+# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
+# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
+# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
+# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
+# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+#
+#################################################################################
+
+import torch
+
+class FloatFunctionalBlock(torch.nn.Module):
+    def __init__(self, func_name):
+        super().__init__()
+        self.func_name = func_name
+        self.func = torch.nn.quantized.FloatFunctional()
+
+    def forward(self, *inputs):
+        if isinstance(inputs, (list,tuple)) and len(inputs) == 1:
+            inputs = inputs[0]
+        #
+        func = getattr(self.func, self.func_name)
+        return func(*inputs)
+
index 1108af0f1cbeca308ac2cfdbc1495a6a76a1858b..c107ad5e37bce507928417443c4a5448d8b95ee9 100644 (file)
@@ -1,3 +1,34 @@
+#################################################################################
+# Copyright (c) 2018-2021, Texas Instruments Incorporated - http://www.ti.com
+# All Rights Reserved.
+#
+# Redistribution and use in source and binary forms, with or without
+# modification, are permitted provided that the following conditions are met:
+#
+# * Redistributions of source code must retain the above copyright notice, this
+#   list of conditions and the following disclaimer.
+#
+# * Redistributions in binary form must reproduce the above copyright notice,
+#   this list of conditions and the following disclaimer in the documentation
+#   and/or other materials provided with the distribution.
+#
+# * Neither the name of the copyright holder nor the names of its
+#   contributors may be used to endorse or promote products derived from
+#   this software without specific prior written permission.
+#
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
+# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
+# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
+# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
+# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
+# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
+# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+#
+#################################################################################
+
 import torch
 from .normalization import *
 
index bdffbbdcebc1e7fcf416bedacce48a2723886950..8902f2fac3379bd6dc4a2a600ffc1eb16b9a9585 100644 (file)
@@ -1,3 +1,34 @@
+#################################################################################
+# Copyright (c) 2018-2021, Texas Instruments Incorporated - http://www.ti.com
+# All Rights Reserved.
+#
+# Redistribution and use in source and binary forms, with or without
+# modification, are permitted provided that the following conditions are met:
+#
+# * Redistributions of source code must retain the above copyright notice, this
+#   list of conditions and the following disclaimer.
+#
+# * Redistributions in binary form must reproduce the above copyright notice,
+#   this list of conditions and the following disclaimer in the documentation
+#   and/or other materials provided with the distribution.
+#
+# * Neither the name of the copyright holder nor the names of its
+#   contributors may be used to endorse or promote products derived from
+#   this software without specific prior written permission.
+#
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
+# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
+# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
+# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
+# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
+# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
+# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+#
+#################################################################################
+
 import torch
 
 def _get_last_bias_module_sequential(module):
index 8dd2afb4cf58009223d7ba1b4d36187afee157c1..934dee2dce851793eb47d1e80899067b84126374 100644 (file)
@@ -1,3 +1,34 @@
+#################################################################################
+# Copyright (c) 2018-2021, Texas Instruments Incorporated - http://www.ti.com
+# All Rights Reserved.
+#
+# Redistribution and use in source and binary forms, with or without
+# modification, are permitted provided that the following conditions are met:
+#
+# * Redistributions of source code must retain the above copyright notice, this
+#   list of conditions and the following disclaimer.
+#
+# * Redistributions in binary form must reproduce the above copyright notice,
+#   this list of conditions and the following disclaimer in the documentation
+#   and/or other materials provided with the distribution.
+#
+# * Neither the name of the copyright holder nor the names of its
+#   contributors may be used to endorse or promote products derived from
+#   this software without specific prior written permission.
+#
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
+# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
+# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
+# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
+# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
+# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
+# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+#
+#################################################################################
+
 import torch
 import torch.optim as optim
 import numpy as np
@@ -53,10 +84,10 @@ class MultiTask(torch.nn.Module):
                 param_groups = [{'params':self.loss_scales}]
             elif self.multi_task_type == 'uncertainty':
                 param_groups = [{'params': self.uncertainty_factors}]
-            self.gradnorm_solver = 'sgd' #'adam' #'sgd'
-            if self.gradnorm_solver == 'adam':
+            self.gradnorm_optimizer = 'sgd' #'adam' #'sgd'
+            if self.gradnorm_optimizer == 'adam':
                 self.optimizer = torch.optim.Adam(param_groups, self.lr, betas=(self.momentum, self.beta))
-            elif self.gradnorm_solver == 'sgd':
+            elif self.gradnorm_optimizer == 'sgd':
                 self.optimizer = torch.optim.SGD(param_groups, self.lr, momentum=self.momentum)
 
     def forward(self, x):
index 63aa9045fcfbbea4a50540b7147a9242ab75b202..c81991c5ced20f9f26e5b569773f7658d31d9c60 100644 (file)
@@ -1,3 +1,34 @@
+#################################################################################
+# Copyright (c) 2018-2021, Texas Instruments Incorporated - http://www.ti.com
+# All Rights Reserved.
+#
+# Redistribution and use in source and binary forms, with or without
+# modification, are permitted provided that the following conditions are met:
+#
+# * Redistributions of source code must retain the above copyright notice, this
+#   list of conditions and the following disclaimer.
+#
+# * Redistributions in binary form must reproduce the above copyright notice,
+#   this list of conditions and the following disclaimer in the documentation
+#   and/or other materials provided with the distribution.
+#
+# * Neither the name of the copyright holder nor the names of its
+#   contributors may be used to endorse or promote products derived from
+#   this software without specific prior written permission.
+#
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
+# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
+# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
+# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
+# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
+# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
+# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+#
+#################################################################################
+
 import torch
 
 ###############################################################
index d3ab4b1fb1961ddb8d228f5834e71b940cbd3f58..da1d1287e4564200822c11bd377490f5f0f16078 100644 (file)
@@ -1,3 +1,34 @@
+#################################################################################
+# Copyright (c) 2018-2021, Texas Instruments Incorporated - http://www.ti.com
+# All Rights Reserved.
+#
+# Redistribution and use in source and binary forms, with or without
+# modification, are permitted provided that the following conditions are met:
+#
+# * Redistributions of source code must retain the above copyright notice, this
+#   list of conditions and the following disclaimer.
+#
+# * Redistributions in binary form must reproduce the above copyright notice,
+#   this list of conditions and the following disclaimer in the documentation
+#   and/or other materials provided with the distribution.
+#
+# * Neither the name of the copyright holder nor the names of its
+#   contributors may be used to endorse or promote products derived from
+#   this software without specific prior written permission.
+#
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
+# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
+# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
+# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
+# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
+# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
+# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+#
+#################################################################################
+
 import torch
 
 ###################################################
index 8a02a21634e34297db3dbe7c94d065d224ddc070..3c187820acf36374e69f1fa28e85852d21c49ff2 100644 (file)
@@ -1,3 +1,34 @@
+#################################################################################
+# Copyright (c) 2018-2021, Texas Instruments Incorporated - http://www.ti.com
+# All Rights Reserved.
+#
+# Redistribution and use in source and binary forms, with or without
+# modification, are permitted provided that the following conditions are met:
+#
+# * Redistributions of source code must retain the above copyright notice, this
+#   list of conditions and the following disclaimer.
+#
+# * Redistributions in binary form must reproduce the above copyright notice,
+#   this list of conditions and the following disclaimer in the documentation
+#   and/or other materials provided with the distribution.
+#
+# * Neither the name of the copyright holder nor the names of its
+#   contributors may be used to endorse or promote products derived from
+#   this software without specific prior written permission.
+#
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
+# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
+# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
+# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
+# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
+# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
+# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+#
+#################################################################################
+
 import warnings
 import torch
 from .deconv_blocks import *
@@ -10,13 +41,19 @@ from .deconv_blocks import *
 ##############################################################################################
 
 # resize with output size or scale factor
-def resize_with(x, size=None, scale_factor=None, mode='nearest', align_corners=None):
+def resize_with(x, size=None, scale_factor=None, mode='nearest', align_corners=None, force_scale_factor=False):
     assert size is None or scale_factor is None, 'both size and scale_factor must not be specified'
     assert size is not None or scale_factor is not None, 'at least one of size or scale factor must be specified'
     assert isinstance(x, torch.Tensor), 'must provide a single tensor as input'
     try:
         # Newer PyTorch versions support recompute_scale_factor = False, that exports a clean onnx graph
         # Attempt it first. Works with onnx opset_version=9 & opset_version=11
+        if scale_factor is None and force_scale_factor:
+            size = size[-2:] if len(size) > 2 else size
+            x_size = x.data.size()[-2:]
+            scale_factor = [float(torch.true_divide(size[0],x_size[0])), float(torch.true_divide(size[1],x_size[1]))]
+            size = None
+        #
         y = torch.nn.functional.interpolate(x, size=size, scale_factor=scale_factor, mode=mode, align_corners=align_corners, recompute_scale_factor=False)
     except:
         if torch.onnx.is_in_onnx_export():
@@ -35,22 +72,8 @@ def resize_with(x, size=None, scale_factor=None, mode='nearest', align_corners=N
 
 # always use scale factor to do the rescaling. if scale factor is not provided, generate it from the size.
 def resize_with_scale_factor(x, size=None, scale_factor=None, mode='nearest', align_corners=None):
-    assert size is None or scale_factor is None, 'both size and scale_factor must not be specified'
-    assert size is not None or scale_factor is not None, 'at least one of size or scale factor must be specified'
-    assert isinstance(x, torch.Tensor), 'must provide a single tensor as input'
-    if scale_factor is None:
-        if isinstance(size, torch.Tensor):
-            size = [float(s) for s in size]
-        elif isinstance(size, (int,float)):
-            size = [size,size]
-        #
-        if isinstance(size, (list,tuple)) and len(size) > 2:
-            size = size[-2:]
-        #
-        x_size = [float(s) for s in x.size()][-2:]
-        scale_factor = [float(s)/float(x_s) for (s,x_s) in zip(size,x_size)]
-    #
-    y = resize_with(x, scale_factor=scale_factor, mode=mode, align_corners=align_corners)
+    y = resize_with(x, size=size, scale_factor=scale_factor, mode=mode, align_corners=align_corners, force_scale_factor=True)
+    return y
 
 
 class ResizeWith(torch.nn.Module):
@@ -81,6 +104,10 @@ def UpsampleWith(input_channels=None, output_channels=None, upstride=None, inter
         final_norm = (False if is_final_layer else True)
         normalization = (True, final_norm)
         activation = (False, final_activation)
+        #this should be removed
+        if interpolation_type == 'upsample_dwconv' and not is_final_layer:
+            interpolation_type = 'upsample_dwconv3_dil3'
+
         if interpolation_type == 'deconv':
             upsample = [DeConvDWSepNormAct2d(input_channels, output_channels, kernel_size=upstride * 2, stride=upstride,
                                       normalization=normalization, activation=activation)]
@@ -88,6 +115,14 @@ def UpsampleWith(input_channels=None, output_channels=None, upstride=None, inter
             upsample = [ResizeWith(scale_factor=upstride, mode=interpolation_mode),
                         ConvDWSepNormAct2d(input_channels, output_channels, kernel_size=int(upstride * 1.5 + 1),
                                       normalization=normalization, activation=activation)]
+        elif interpolation_type == 'upsample_dwconv':
+            upsample = [ResizeWith(scale_factor=upstride, mode=interpolation_mode),
+                        ConvDWNormAct2d(input_channels, output_channels, kernel_size=int(upstride * 1.5 + 1),
+                                           normalization=False, activation=final_activation, bias=True)]
+        elif interpolation_type == 'upsample_dwconv3_dil3':
+            upsample = [ResizeWith(scale_factor=upstride, mode=interpolation_mode),
+                        ConvDWNormAct2d(input_channels, output_channels, kernel_size=3, dilation=3,
+                                           normalization=False, activation=final_activation, bias=True)]
         elif (interpolation_type == 'subpixel_conv' or interpolation_type == 'pixel_shuffle'):
             upsample = [ConvDWSepNormAct2d(input_channels, output_channels*upstride*upstride, kernel_size=int(upstride + 1),
                                       normalization=normalization, activation=activation),
index 368a077108d6b3181df3f9a3e5fafb5068dca177..0efeee80c9d7174eba6a3ed06ad159ef5a6bad4d 100644 (file)
@@ -1,3 +1,34 @@
+#################################################################################
+# Copyright (c) 2018-2021, Texas Instruments Incorporated - http://www.ti.com
+# All Rights Reserved.
+#
+# Redistribution and use in source and binary forms, with or without
+# modification, are permitted provided that the following conditions are met:
+#
+# * Redistributions of source code must retain the above copyright notice, this
+#   list of conditions and the following disclaimer.
+#
+# * Redistributions in binary form must reproduce the above copyright notice,
+#   this list of conditions and the following disclaimer in the documentation
+#   and/or other materials provided with the distribution.
+#
+# * Neither the name of the copyright holder nor the names of its
+#   contributors may be used to endorse or promote products derived from
+#   this software without specific prior written permission.
+#
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
+# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
+# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
+# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
+# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
+# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
+# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+#
+#################################################################################
+
 # blocks to improve the receptive field. eg. ASPP, LargeHead
 
 import torch
index ba04b03e5ad9945bda5262a6f9b94c9009b2d6b4..e08087cc7e45e8919c83be51ade5da161bd73239 100644 (file)
@@ -1,3 +1,34 @@
+#################################################################################
+# Copyright (c) 2018-2021, Texas Instruments Incorporated - http://www.ti.com
+# All Rights Reserved.
+#
+# Redistribution and use in source and binary forms, with or without
+# modification, are permitted provided that the following conditions are met:
+#
+# * Redistributions of source code must retain the above copyright notice, this
+#   list of conditions and the following disclaimer.
+#
+# * Redistributions in binary form must reproduce the above copyright notice,
+#   this list of conditions and the following disclaimer in the documentation
+#   and/or other materials provided with the distribution.
+#
+# * Neither the name of the copyright holder nor the names of its
+#   contributors may be used to endorse or promote products derived from
+#   this software without specific prior written permission.
+#
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
+# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
+# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
+# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
+# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
+# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
+# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+#
+#################################################################################
+
 try:
     from .onnx2pytorch_internal import import_onnx
 except:
diff --git a/modules/pytorch_jacinto_ai/xnn/onnx/onnx2pytorch_internal.py b/modules/pytorch_jacinto_ai/xnn/onnx/onnx2pytorch_internal.py
new file mode 100644 (file)
index 0000000..ee1829f
--- /dev/null
@@ -0,0 +1,497 @@
+#################################################################################
+# Copyright (c) 2018-2021, Texas Instruments Incorporated - http://www.ti.com
+# All Rights Reserved.
+#
+# Redistribution and use in source and binary forms, with or without
+# modification, are permitted provided that the following conditions are met:
+#
+# * Redistributions of source code must retain the above copyright notice, this
+#   list of conditions and the following disclaimer.
+#
+# * Redistributions in binary form must reproduce the above copyright notice,
+#   this list of conditions and the following disclaimer in the documentation
+#   and/or other materials provided with the distribution.
+#
+# * Neither the name of the copyright holder nor the names of its
+#   contributors may be used to endorse or promote products derived from
+#   this software without specific prior written permission.
+#
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
+# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
+# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
+# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
+# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
+# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
+# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+#
+#################################################################################
+
+# Modified from: https://gist.github.com/qinjian623/6aa777037534c1c1dccbb66f832e93b8
+
+import onnx
+import struct
+import torch
+import torch.nn as nn
+import torchvision as tv
+import warnings
+
+
+# enum DataType {
+#     UNDEFINED = 0;
+#     // Basic types.
+#     FLOAT = 1;   // float
+#     UINT8 = 2;   // uint8_t
+#     INT8 = 3;    // int8_t
+#     UINT16 = 4;  // uint16_t
+#     INT16 = 5;   // int16_t
+#     INT32 = 6;   // int32_t
+#     INT64 = 7;   // int64_t
+#     STRING = 8;  // string
+#     BOOL = 9;    // bool
+#
+#     // IEEE754 half-precision floating-point format (16 bits wide).
+#     // This format has 1 sign bit, 5 exponent bits, and 10 mantissa bits.
+#     FLOAT16 = 10;
+#
+#     DOUBLE = 11;
+#     UINT32 = 12;
+#     UINT64 = 13;
+#     COMPLEX64 = 14;     // complex with float32 real and imaginary components
+#     COMPLEX128 = 15;    // complex with float64 real and imaginary components
+#
+#     // Non-IEEE floating-point format based on IEEE754 single-precision
+#     // floating-point number truncated to 16 bits.
+#     // This format has 1 sign bit, 8 exponent bits, and 7 mantissa bits.
+#     BFLOAT16 = 16;
+#
+#     // Future extensions go here.
+#   }
+
+# TODO more types maybe?
+data_type_tab = {
+    1: ['f', 4],
+    2: ['B', 1],
+    3: ['b', 1],
+    4: ['H', 2],
+    5: ['h', 2],
+    6: ['i', 4],
+    7: ['q', 8],
+    10: ['e', 2],
+    11: ['d', 8],
+    12: ['I', 4],
+    13: ['Q', 8]
+}
+
+
+def empty(x):
+    return x
+
+
+# TODO pytorch only accepts 2-value list for padding.
+def _slim422(l4):
+    assert len(l4) == 4
+
+    p0, p1 = l4[::2]
+    if l4[0] == 0:  # TODO bad code
+        p0 = l4[2] // 2
+        if l4[2] == 1:
+            p0 = 1
+    if l4[1] == 0:  # TODO bad code
+        p1 = l4[3] // 2
+        if l4[3] == 1:
+            p1 = 1
+    return p0, p1
+
+
+def _check_attr(attrs, map):
+    for attr in attrs:
+        if attr.name not in map:
+            warnings.warn("Missing {} in parser's attr_map.".format(attr.name))
+
+
+def unpack_weights(initializer):
+    ret = {}
+    for i in initializer:
+        name = i.name
+        dtype = i.data_type
+        shape = list(i.dims)
+        if dtype not in data_type_tab:
+            warnings("This data type {} is not supported yet.".format(dtype))
+        fmt, size = data_type_tab[dtype]
+        if len(i.raw_data) == 0:
+            if dtype == 1:
+                data_list = i.float_data
+            elif dtype == 7:
+                data_list = i.int64_data
+            else:
+                warnings.warn("No-raw-data type {} not supported yet.".format(dtype))
+        else:
+            data_list = struct.unpack('<' + fmt * (len(i.raw_data) // size), i.raw_data)
+        t = torch.tensor(data_list)
+        if len(shape) != 0:
+            t = t.view(*shape)
+        ret[name] = t
+    return ret
+
+
+def rebuild_lrn(node, weights):
+    # size, alpha = 1e-4, beta = 0.75, k = 1.
+    rebuild_lrn.lrn_attr_map = {
+        'size': 'size',
+        'alpha': 'alpha',
+        'beta': 'beta',
+        'bias': 'k'
+    }
+    kwargs = {}
+    for att in node.attribute:
+        kwargs[rebuild_lrn.lrn_attr_map[att.name]] = att.f if att.name != 'size' else att.i
+    return nn.LocalResponseNorm(**kwargs), node.input, node.output
+
+
+def rebuild_conv(node, weights):
+    rebuild_conv.conv_attr_map = {
+        "pads": "padding",
+        "strides": "stride",
+        "kernel_shape": "kernel_size",
+        "group": "groups",
+        "dilations": "dilation"
+    }
+    assert len(node.output) == 1
+    with_bias = False
+    if len(node.input) == 3:
+        with_bias = True
+        bias_name = node.input[2]
+        bias = weights[bias_name]
+
+    weight_name = node.input[1]
+    weight = weights[weight_name]
+    in_channels = weight.shape[1]
+    out_channels = weight.shape[0]
+    kwargs = {}
+    for att in node.attribute:
+        kwargs[rebuild_conv.conv_attr_map[att.name]] = list(att.ints) if att.name != 'group' else att.i
+    if 'padding' in kwargs:
+        kwargs["padding"] = _slim422(kwargs["padding"])
+    groups = 1 if 'groups' not in kwargs else kwargs['groups']
+    in_channels *= groups
+    conv = nn.Conv2d(in_channels, out_channels, **kwargs, bias=with_bias)
+    conv.weight.data.copy_(weight)
+    if with_bias:
+        conv.bias.data.copy_(bias)
+    return conv, node.input[:1], node.output
+
+
+def rebuild_dropout(node, weights):
+    ratio = node.attribute[0].f
+    return nn.Dropout2d(p=ratio), node.input, node.output
+
+
+def rebuild_batchnormalization(node, weights):
+    rebuild_batchnormalization.bn_attr_map = {
+        "epsilon": "eps",
+        "momentum": "momentum"
+    }
+    assert len(node.input) == 5
+    assert len(node.output) == 1
+    weight = weights[node.input[1]]
+    bias = weights[node.input[2]]
+    running_mean = weights[node.input[3]]
+    running_var = weights[node.input[4]]
+    dim = weight.shape[0]
+    kwargs = {}
+    _check_attr(node.attribute, rebuild_batchnormalization.bn_attr_map)
+    for att in node.attribute:
+        if att.name in rebuild_batchnormalization.bn_attr_map:
+            kwargs[rebuild_batchnormalization.bn_attr_map[att.name]] = att.f
+
+    bn = nn.BatchNorm2d(num_features=dim)
+    bn.weight.data.copy_(weight)
+    bn.bias.data.copy_(bias)
+    bn.running_mean.data.copy_(running_mean)
+    bn.running_var.data.copy_(running_var)
+    return bn, node.input[:1], node.output
+
+
+def rebuild_relu(node, weights):
+    return nn.ReLU(), node.input, node.output
+
+
+def rebuild_clip(node, weights):
+    clip_vals = node.attribute[0].floats
+    if len(clip_vals) == 0:
+        clip_vals = node.attribute[0].ints
+        #Uncomment this to make ReLu6 based ONNX models work
+        #clip_vals = [0.0, 6.0]
+    #
+    return nn.Hardtanh(*clip_vals), node.input, node.output
+
+
+def rebuild_maxpool(node, weights):
+    rebuild_maxpool.mp_attr_map = {
+        "pads": "padding",
+        "strides": "stride",
+        "kernel_shape": "kernel_size",
+    }
+    kwargs = {}
+    for att in node.attribute:
+        kwargs[rebuild_maxpool.mp_attr_map[att.name]] = list(att.ints)
+    if 'padding' in kwargs:
+        kwargs["padding"] = _slim422(kwargs["padding"])
+    mp = nn.MaxPool2d(**kwargs)
+    return mp, node.input, node.output
+
+
+def rebuild_add(node, weights):
+    class Add(torch.nn.Module):
+        def __init__(self):
+            super().__init__()
+        def forward(self, *args):
+            return args[0] + args[1]
+    #
+    return Add(), node.input, node.output
+
+
+def rebuild_globalaveragepool(node, weights):
+    avg_pool = nn.AdaptiveAvgPool2d((1, 1))
+    return avg_pool, node.input, node.output
+
+
+def rebuild_transpose(node, weights):
+    perm = node.attribute[0].ints
+
+    def transpose(x):
+        x = x.permute(*perm)
+        return x
+    return transpose, node.input, node.output
+
+
+def rebuild_flatten(node, weights):
+    if len(node.attribute) == 0:
+        d = 1
+    else:
+        d = node.attribute[0].i
+
+    class Flatten(torch.nn.Module):
+        def __init__(self):
+            super().__init__()
+        def forward(self, x):
+            return torch.flatten(x, d)
+        #
+    return Flatten(), node.input, node.output
+
+
+def rebuild_gemm(node, weights):
+    weight = weights[node.input[1]]
+    bias = weights[node.input[2]]
+    in_feats = weight.shape[1]
+    out_feats = weight.shape[0]
+    linear = nn.Linear(in_features=in_feats, out_features=out_feats)
+    linear.weight.data.copy_(weight)
+    linear.bias.data.copy_(bias)
+    return linear, node.input[:1], node.output
+
+
+def rebuild_concat(node, weights):
+    dim = node.attribute[0].i
+
+    def concat(*inputs):
+        # for i in inputs:
+        #     print(i.shape)
+        ret = torch.cat(inputs, dim)
+        # print(ret.shape)
+        # exit()
+        return ret
+    return concat, node.input, node.output
+
+
+def rebuild_pad(node, weights):
+    mode = node.attribute[0].s
+    pads = list(node.attribute[1].ints)
+    value = node.attribute[2].f
+    assert mode == b'constant'  # TODO constant only
+    assert sum(pads[:4]) == 0  # TODO pad2d only
+    pad = nn.ConstantPad2d(pads[4:], value)
+    return pad, node.input, node.output
+
+
+def rebuild_constant(node, weights):
+    raw_data = node.attribute[0].t.raw_data
+    data_type = node.attribute[0].t.data_type
+    fmt, size = data_type_tab[data_type]
+    data = struct.unpack('<' + fmt * (len(raw_data) // size), raw_data)
+    if len(data) == 1:
+        data = data[0]
+
+    def constant():
+        return torch.tensor(data)
+    return constant, [], node.output
+
+
+def rebuild_sum(node, weights):
+    def sum(*inputs):
+        ret = inputs[0]
+        for i in inputs[1:]:
+            ret += i
+        return ret
+    return sum, node.input, node.output
+
+
+def rebuild_shape(node, weights):
+    def shape(x):
+        return torch.tensor(list(x.shape))
+    return shape, node.input, node.output
+
+
+def rebuild_gather(node, weights):
+    axis = node.attribute[0].i
+
+    def gather(x, idx):
+        return torch.gather(x, axis, idx)
+    return gather, node.input, node.output
+
+
+def _nd_unsqueeze(x, dims):
+    dims = sorted(dims)
+    for d in dims:
+        x = torch.unsqueeze(x, dim=d)
+    return x
+
+
+def rebuild_unsqueeze(node, weights):
+    axes = node.attribute[0].ints
+
+    def unsqueeze(x):
+        return _nd_unsqueeze(x, axes)
+
+    return unsqueeze, node.input, node.output
+
+
+def rebuild_mul(node, weights):
+    def mul(a, b):
+        return a * b
+    return mul, node.input, node.output
+
+
+def rebuild_softmax(node, weights):
+    def f_softmax(x):
+        return x.softmax(dim=1, dtype=torch.double).float()
+    return f_softmax, node.input, node.output
+
+
+def rebuild_reshape(node, weights):
+    def reshape(x, s):
+        data_shape = x.shape
+        onnx_shape = s.tolist()
+        pt_shape = []
+        for idx, d in enumerate(onnx_shape):
+            if d == 0:
+                pt_shape.append(data_shape[idx])
+            else:
+                pt_shape.append(d)
+        return torch.reshape(x, pt_shape)
+    return reshape, node.input, node.output
+
+
+def rebuild_averagepool(node, weights):
+    rebuild_averagepool.avg_attr_map = {
+        "pads": "padding",
+        "strides": "stride",
+        "kernel_shape": "kernel_size",
+    }
+    kwargs = {}
+
+    for att in node.attribute:
+        kwargs[rebuild_averagepool.avg_attr_map[att.name]] = list(att.ints)
+    if 'padding' in kwargs:
+        kwargs["padding"] = _slim422(kwargs["padding"])
+    ap = nn.AvgPool2d(**kwargs)
+    return ap, node.input, node.output
+
+
+def rebuild_op(node, weights):
+    op_type = node.op_type
+    return globals()['rebuild_'+op_type.lower()](node, weights)
+
+
+def construct_pytorch_nodes(graph, weights):
+    ret = []
+    for single_node in graph.node:
+        ret.append(rebuild_op(single_node, weights))
+    return ret
+
+class ONNXImportedModule(nn.Module):
+    def __init__(self, onnx_model, input_name=None):
+        super(ONNXImportedModule, self).__init__()
+        self.deps = {}
+        weights = unpack_weights(onnx_model.graph.initializer)
+        nodes = construct_pytorch_nodes(onnx_model.graph, weights)
+        for idx, (node, inputs, outputs) in enumerate(nodes):
+            if isinstance(node, nn.Module):
+                self.add_module(str(idx), node)
+            for output_name in outputs:
+                self.deps[output_name] = (str(idx), inputs)
+            #
+
+        self.input_name = onnx_model.graph.input[0].name    # TODO only you
+        self.output_name = onnx_model.graph.output[0].name  # TODO only you
+        if input_name is not None:
+            self.input_name = input_name
+
+
+    def forward(self, input):
+        inter_tensors = {}
+        inter_tensors[self.input_name] = input
+        self.resolve_deps(self.output_name, inter_tensors)
+        return inter_tensors[self.output_name]
+
+
+    def resolve_deps(self, name, inter_tensors):
+        if name in inter_tensors:
+            return
+        else:
+            op_name, deps_names = self.deps[name]
+            op = self._modules[op_name]
+            args = []
+            for deps_name in deps_names:
+                self.resolve_deps(deps_name, inter_tensors)
+                args.append(inter_tensors[deps_name])
+            result = op(*args)
+            inter_tensors[name] = result
+
+
+def import_onnx(onnx_file):
+    onnx_model = onnx.load(onnx_file)
+    reconstruct_model = ONNXImportedModule(onnx_model)
+    reconstruct_model.eval()
+    return reconstruct_model
+
+
+def test_net(original_model, onnx_file):
+    import time
+    original_model.eval()
+    onnx_model = onnx.load(onnx_file)
+    reconstruct_model = ONNXImportedModule(onnx_model)
+    reconstruct_model.eval()
+    input = torch.randn(3, 3, 224, 224)
+    s = time.time()
+    r1 = original_model(input)
+    print("Original:", time.time() - s)
+
+    s = time.time()
+    r = reconstruct_model(input)
+    print("ONNXImportedModule:", time.time() - s)
+
+    print("Max error for", onnx_file, ":", (r - r1).abs().max().item())
+
+
+def main():
+    test_net(tv.models.resnet18(True), "res18.onnx")
+    test_net(tv.models.resnet50(True), "res50.onnx")
+    test_net(tv.models.densenet121(True), "dense121.onnx")
+
+
+if __name__ == '__main__':
+    main()
index 0f3061e98115ddd861f25eaa2a83b104c05266c8..2fc1e71db0d4060674dda52078eace81c054779c 100644 (file)
@@ -1,3 +1,34 @@
+#################################################################################
+# Copyright (c) 2018-2021, Texas Instruments Incorporated - http://www.ti.com
+# All Rights Reserved.
+#
+# Redistribution and use in source and binary forms, with or without
+# modification, are permitted provided that the following conditions are met:
+#
+# * Redistributions of source code must retain the above copyright notice, this
+#   list of conditions and the following disclaimer.
+#
+# * Redistributions in binary form must reproduce the above copyright notice,
+#   this list of conditions and the following disclaimer in the documentation
+#   and/or other materials provided with the distribution.
+#
+# * Neither the name of the copyright holder nor the names of its
+#   contributors may be used to endorse or promote products derived from
+#   this software without specific prior written permission.
+#
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
+# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
+# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
+# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
+# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
+# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
+# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+#
+#################################################################################
+
 from . import lr_scheduler
 
 
index 2c9c945dccea515d0fcbe7c0b299a2f2aae04023..ecf1a4e749b9eb42e04d4c59b8567a7deb88f8ca 100644 (file)
@@ -1,12 +1,44 @@
+#################################################################################
+# Copyright (c) 2018-2021, Texas Instruments Incorporated - http://www.ti.com
+# All Rights Reserved.
+#
+# Redistribution and use in source and binary forms, with or without
+# modification, are permitted provided that the following conditions are met:
+#
+# * Redistributions of source code must retain the above copyright notice, this
+#   list of conditions and the following disclaimer.
+#
+# * Redistributions in binary form must reproduce the above copyright notice,
+#   this list of conditions and the following disclaimer in the documentation
+#   and/or other materials provided with the distribution.
+#
+# * Neither the name of the copyright holder nor the names of its
+#   contributors may be used to endorse or promote products derived from
+#   this software without specific prior written permission.
+#
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
+# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
+# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
+# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
+# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
+# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
+# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+#
+#################################################################################
+
 import numpy as np
 import torch
 
 class SchedulerWrapper(torch.optim.lr_scheduler._LRScheduler):
-    def __init__(self, scheduler_type, optimizer, epochs, start_epoch=0, warmup_epochs=5, max_iter=None, \
-                 polystep_power=1.0, milestones=None, multistep_gamma=0.5):
+    def __init__(self, scheduler_type, optimizer, epochs, start_epoch=0, warmup_epochs=5, warmup_factor=None,
+                 max_iter=None, polystep_power=1.0, milestones=None, multistep_gamma=0.5):
 
         self.scheduler_type = scheduler_type
         self.warmup_epochs = warmup_epochs
+        self.warmup_factor = warmup_factor
 
         if scheduler_type == 'step':
             lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=milestones, gamma=multistep_gamma, last_epoch=start_epoch-1)
@@ -19,9 +51,12 @@ class SchedulerWrapper(torch.optim.lr_scheduler._LRScheduler):
             raise ValueError('Unknown scheduler {}'.format(scheduler_type))
         #
         self.lr_scheduler = lr_scheduler
+        self.lr_backup = [param_group['lr'] for param_group in self.lr_scheduler.optimizer.param_groups]
         if start_epoch > 0:
+            # adjust the leraning rate to that of the start_epoch
             for step in range(start_epoch):
                 self.step()
+            #
         else:
             # to take care of first iteration and set warmup lr in param_group
             self.get_lr()
@@ -29,9 +64,13 @@ class SchedulerWrapper(torch.optim.lr_scheduler._LRScheduler):
 
 
     def get_lr(self):
-        epoch = self.lr_scheduler.last_epoch + 1
+        epoch = self.lr_scheduler.last_epoch
         if self.warmup_epochs and epoch <= self.warmup_epochs:
             lr = [(epoch * l_rate / self.warmup_epochs) for l_rate in self.lr_scheduler.base_lrs]
+            if epoch == 0 and self.warmup_factor is not None:
+                warmup_lr = [w_rate*self.warmup_factor for w_rate in self.lr_scheduler.base_lrs]
+                lr = [max(l_rate, w_rate) for l_rate, w_rate in zip(lr,warmup_lr)]
+            #
         else:
             lr = [param_group['lr'] for param_group in self.lr_scheduler.optimizer.param_groups]
         #
@@ -43,7 +82,16 @@ class SchedulerWrapper(torch.optim.lr_scheduler._LRScheduler):
 
 
     def step(self):
+        # some of the scheduler implementations in torch.option may be recursive (depends on previous lr) eg. cosine
+        # so it is necessary to restore the original lr from scheduler
+        for param_group, l_rate in zip(self.lr_scheduler.optimizer.param_groups, self.lr_backup):
+            param_group['lr'] = l_rate
+        #
+        # actual scheduler call
         self.lr_scheduler.step()
+        # backup the lr from scheduler
+        self.lr_backup = [param_group['lr'] for param_group in self.lr_scheduler.optimizer.param_groups]
+        # return the lr - warmup will be applied in this step
         return self.get_lr()
 
 
index b55c304f2634676a2a6c929b05bcaea7046089e6..629a324f96b7ca14fd8a186dcb17992a197a9919 100644 (file)
@@ -1,4 +1,34 @@
-from .quant_utils import *
+#################################################################################
+# Copyright (c) 2018-2021, Texas Instruments Incorporated - http://www.ti.com
+# All Rights Reserved.
+#
+# Redistribution and use in source and binary forms, with or without
+# modification, are permitted provided that the following conditions are met:
+#
+# * Redistributions of source code must retain the above copyright notice, this
+#   list of conditions and the following disclaimer.
+#
+# * Redistributions in binary form must reproduce the above copyright notice,
+#   this list of conditions and the following disclaimer in the documentation
+#   and/or other materials provided with the distribution.
+#
+# * Neither the name of the copyright holder nor the names of its
+#   contributors may be used to endorse or promote products derived from
+#   this software without specific prior written permission.
+#
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
+# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
+# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
+# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
+# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
+# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
+# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+#
+#################################################################################
+
 from .quant_train_module import *
 from .quant_calib_module import *
 from .quant_test_module import *
index ecaffe0c47f5648969f9261132736a457f99bac4..2142a825ed95bc60a7a0c66dfe238f5e20ee5e2f 100644 (file)
@@ -1,3 +1,34 @@
+#################################################################################
+# Copyright (c) 2018-2021, Texas Instruments Incorporated - http://www.ti.com
+# All Rights Reserved.
+#
+# Redistribution and use in source and binary forms, with or without
+# modification, are permitted provided that the following conditions are met:
+#
+# * Redistributions of source code must retain the above copyright notice, this
+#   list of conditions and the following disclaimer.
+#
+# * Redistributions in binary form must reproduce the above copyright notice,
+#   this list of conditions and the following disclaimer in the documentation
+#   and/or other materials provided with the distribution.
+#
+# * Neither the name of the copyright holder nor the names of its
+#   contributors may be used to endorse or promote products derived from
+#   this software without specific prior written permission.
+#
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
+# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
+# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
+# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
+# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
+# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
+# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+#
+#################################################################################
+
 import types
 import warnings
 import torch
index 2996f3951e5ec59a735d396431e1db8bd036ac87..fafbd0e33ab290588b1c388e7f20b86ba48ed557 100644 (file)
@@ -1,12 +1,62 @@
+#################################################################################
+# Copyright (c) 2018-2021, Texas Instruments Incorporated - http://www.ti.com
+# All Rights Reserved.
+#
+# Redistribution and use in source and binary forms, with or without
+# modification, are permitted provided that the following conditions are met:
+#
+# * Redistributions of source code must retain the above copyright notice, this
+#   list of conditions and the following disclaimer.
+#
+# * Redistributions in binary form must reproduce the above copyright notice,
+#   this list of conditions and the following disclaimer in the documentation
+#   and/or other materials provided with the distribution.
+#
+# * Neither the name of the copyright holder nor the names of its
+#   contributors may be used to endorse or promote products derived from
+#   this software without specific prior written permission.
+#
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
+# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
+# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
+# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
+# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
+# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
+# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+#
+#################################################################################
+
 import copy
 from .quant_graph_module import *
 
+
+###########################################################
+# default settings for quantization
+
+class ConstrainBiasType:
+    CONSTRAIN_BIAS_TYPE_NONE = 0
+    CONSTRAIN_BIAS_TYPE_SATURATE = 1
+    CONSTRAIN_BIAS_TYPE_REDUCE_WEIGHT_SCALE = 2
+    CONSTRAIN_BIAS_TYPE_REDUCE_FEATURE_SCALE = 3
+
+
+class QuantDefaults:
+    RANGE_SHRINK_WEIGHTS_DEFAULT = 0.0
+    POWER2_WEIGHT_RANGE_DEFAULT = True
+    POWER2_ACTIVATION_RANGE_DEFAULT = True
+    CONSTRAIN_BIAS_DEFAULT = ConstrainBiasType.CONSTRAIN_BIAS_TYPE_SATURATE
+
+
 ###########################################################
 # base module to be use for all quantization modules
 class QuantBaseModule(QuantGraphModule):
     def __init__(self, module, dummy_input, *args, bitwidth_weights=8, bitwidth_activations=8, per_channel_q=False,
                  histogram_range=True, bias_calibration=False, constrain_weights=False, constrain_bias=None,
-                 model_surgery_quantize=True, power2_weight_range=None, power2_activation_range=None, **kwargs):
+                 range_shrink_weights=None, range_shrink_activations=None,
+                 power2_weight_range=None, power2_activation_range=None, model_surgery_quantize=True, **kwargs):
         super().__init__(module)
         self.bitwidth_weights = bitwidth_weights
         self.bitwidth_activations = bitwidth_activations
@@ -14,17 +64,38 @@ class QuantBaseModule(QuantGraphModule):
         self.histogram_range = histogram_range
         self.constrain_weights = constrain_weights
         self.bias_calibration = bias_calibration
-        self.power2_weight_range = True if (power2_weight_range is None) else power2_weight_range
-        self.power2_activation_range = True if (power2_activation_range is None) else power2_activation_range
+
+        self.power2_weight_range = power2_weight_range if (power2_weight_range is not None) else \
+            QuantDefaults.POWER2_WEIGHT_RANGE_DEFAULT
+        self.power2_activation_range = power2_activation_range if (power2_activation_range is not None) else \
+            QuantDefaults.POWER2_ACTIVATION_RANGE_DEFAULT
         # range shrink - 0.0 indicates no shrink
-        self.range_shrink_percentile = (layers.PAct2.PACT2_RANGE_SHRINK if histogram_range else 0.0)
+        self.range_shrink_weights = range_shrink_weights if (range_shrink_weights is not None) else \
+            QuantDefaults.RANGE_SHRINK_WEIGHTS_DEFAULT
+        self.range_shrink_activations = range_shrink_activations if (range_shrink_activations is not None) else \
+            layers.PAct2.PACT2_RANGE_SHRINK_DEFAULT
         # constrain_bias means bias that is being added to accumulator is limited to 16bit (for 8bit quantization).
         # scale factor to be used for constrain_bias is the product of scale factors of weight and input
-        self.constrain_bias = True if (constrain_bias is None) else constrain_bias
+        self.constrain_bias = constrain_bias if (constrain_bias is not None) else \
+            QuantDefaults.CONSTRAIN_BIAS_DEFAULT
+        if self.per_channel_q and self.constrain_bias == ConstrainBiasType.CONSTRAIN_BIAS_TYPE_SATURATE:
+            warnings.warn('Per channel quantization can increase the weight scale a lot, resulting in a lot of \
+                bias saturation if constrain_bias is enabled. Too much bias saturation can hurt accuracy. \
+                Suggest to reduce weight scale by passing constrain_bias as CONSTRAIN_BIAS_TYPE_REDUCE_WEIGHT_SCALE \
+                to avoid bias saturation.')
+        #
+
         # using per_channel_q when constrain_bias is set may not be good for accuracy.
         if self.constrain_bias and self.per_channel_q:
             warnings.warn('using per_channel_q when constrain_bias is set may not be good for accuracy.')
         #
+
+        if (self.per_channel_q == 'all'):
+            assert self.constrain_bias == ConstrainBiasType.CONSTRAIN_BIAS_TYPE_NONE, \
+                f'constrain_bias must be {ConstrainBiasType.CONSTRAIN_BIAS_TYPE_NONE} \
+                when per_channel_q is all. Got {self.constrain_bias}'
+        #
+        
         # for help in debug/print
         utils.add_module_names(self)
         # put in eval mode before analyze
@@ -45,9 +116,9 @@ class QuantBaseModule(QuantGraphModule):
         # set attributes to all modules - can control the behaviour from here
         utils.apply_setattr(self, bitwidth_weights=self.bitwidth_weights, bitwidth_activations=self.bitwidth_activations,
                             histogram_range=histogram_range, bias_calibration=self.bias_calibration, per_channel_q=self.per_channel_q,
-                            range_shrink_percentile=self.range_shrink_percentile, constrain_weights=self.constrain_weights,
+                            range_shrink_weights=self.range_shrink_weights, range_shrink_activations=self.range_shrink_activations,
                             power2_weight_range=self.power2_weight_range, power2_activation_range=self.power2_activation_range,
-                            constrain_bias=self.constrain_bias)
+                            constrain_weights=self.constrain_weights, constrain_bias=self.constrain_bias)
 
     def add_activation_hooks(self):
         # add a forward hook to call the extra activation that we added
index 4ed49f7c62942f560a55906eba8005ca73cf0b25..c5687c20cec3a1966885a26ce50e5a7d1f145d51 100644 (file)
@@ -1,3 +1,34 @@
+#################################################################################
+# Copyright (c) 2018-2021, Texas Instruments Incorporated - http://www.ti.com
+# All Rights Reserved.
+#
+# Redistribution and use in source and binary forms, with or without
+# modification, are permitted provided that the following conditions are met:
+#
+# * Redistributions of source code must retain the above copyright notice, this
+#   list of conditions and the following disclaimer.
+#
+# * Redistributions in binary form must reproduce the above copyright notice,
+#   this list of conditions and the following disclaimer in the documentation
+#   and/or other materials provided with the distribution.
+#
+# * Neither the name of the copyright holder nor the names of its
+#   contributors may be used to endorse or promote products derived from
+#   this software without specific prior written permission.
+#
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
+# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
+# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
+# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
+# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
+# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
+# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+#
+#################################################################################
+
 ###########################################################
 # Approximate quantized floating point simulation with gradients.
 # Can be used for quantized training of models.
@@ -19,6 +50,7 @@ warnings.filterwarnings('ignore', category=torch.jit.TracerWarning)
 class QuantCalibrateModule(QuantTrainModule):
     def __init__(self, module, dummy_input, *args, bitwidth_weights=8, bitwidth_activations=8, per_channel_q=False,
                  histogram_range=True, bias_calibration=True, constrain_weights=None,
+                 range_shrink_weights=None, range_shrink_activations=None,
                  power2_weight_range=None, power2_activation_range=None, constrain_bias=None, lr_calib=0.05, **kwargs):
         self.weights_calibration = False
         self.lr_calib = lr_calib
@@ -29,8 +61,11 @@ class QuantCalibrateModule(QuantTrainModule):
         self.update_activation_range = True
         constrain_weights = (bias_calibration and (not per_channel_q)) if constrain_weights is None else constrain_weights
         super().__init__(module, dummy_input, *args, bitwidth_weights=bitwidth_weights, bitwidth_activations=bitwidth_activations,
-                         per_channel_q=per_channel_q, histogram_range=histogram_range, bias_calibration=bias_calibration, constrain_weights=constrain_weights,
-                         power2_weight_range=power2_weight_range, power2_activation_range=power2_activation_range, constrain_bias=constrain_bias, **kwargs)
+                         per_channel_q=per_channel_q, histogram_range=histogram_range, bias_calibration=bias_calibration,
+                         constrain_weights=constrain_weights, constrain_bias=constrain_bias,
+                         range_shrink_weights=range_shrink_weights, range_shrink_activations=range_shrink_activations,
+                         power2_weight_range=power2_weight_range, power2_activation_range=power2_activation_range,
+                         **kwargs)
         self.calib_stats = dict()
 
 
@@ -39,6 +74,8 @@ class QuantCalibrateModule(QuantTrainModule):
         with torch.no_grad():
             # counters such as num_batches_tracked are used. update them.
             self.update_counters()
+            # bitwidth_warmup
+            self.adjust_bitwidth()
 
             # backup the current state
             training = self.training
index c3433c3a549df179f0c582391e40ff3a489df6f4..6916dea1e47a628f31a7df3daac1ff6868903b91 100644 (file)
@@ -1,3 +1,34 @@
+#################################################################################
+# Copyright (c) 2018-2021, Texas Instruments Incorporated - http://www.ti.com
+# All Rights Reserved.
+#
+# Redistribution and use in source and binary forms, with or without
+# modification, are permitted provided that the following conditions are met:
+#
+# * Redistributions of source code must retain the above copyright notice, this
+#   list of conditions and the following disclaimer.
+#
+# * Redistributions in binary form must reproduce the above copyright notice,
+#   this list of conditions and the following disclaimer in the documentation
+#   and/or other materials provided with the distribution.
+#
+# * Neither the name of the copyright holder nor the names of its
+#   contributors may be used to endorse or promote products derived from
+#   this software without specific prior written permission.
+#
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
+# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
+# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
+# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
+# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
+# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
+# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+#
+#################################################################################
+
 import warnings
 import torch
 import copy
@@ -25,6 +56,9 @@ class QuantGraphModule(HookedModule):
         # quantize the input to a block (under  a certain conditions of the input was not already quantized)
         self.quantize_in = True
 
+        # whether to quantize the output prediction module or not
+        self.quantize_out = True
+
         # TBD: is this required
         # # if the original module has load_weights, add it to the quant module also
         # if hasattr(module, 'load_weights'):
@@ -138,8 +172,7 @@ class QuantGraphModule(HookedModule):
                 #
             elif qparams.quantize_in:
                 if not hasattr(module, 'activation_in'):
-                    # TODO: set range_shrink_percentile=0.0 to avoid shrinking of input range, if needed.
-                    activation_in = layers.PAct2(signed=None)
+                    activation_in = layers.PAct2(signed=None, range_shrink_activations=self.range_shrink_activations)
                     activation_in.train(self.training)
                     module.activation_in = activation_in
                 #
@@ -166,6 +199,7 @@ class QuantGraphModule(HookedModule):
         self.get_qstate().qparams = self_copy.get_qstate().qparams
 
     def _forward_analyze_modules_impl(self, inputs, *args, **kwargs):
+        self.layer_index = -1
         self.start_call()
         self.add_call_hook(self, self._analyze_modules_op)
         forward_analyze_method_name = kwargs.pop('forward_analyze_method', None)
@@ -181,6 +215,7 @@ class QuantGraphModule(HookedModule):
         return output
 
     def _analyze_modules_op(self, op, inputs, *args, **kwargs):
+        self.layer_index = self.layer_index + 1
         inputs = utils.squeeze_list2(inputs)
         self.start_node(op)
         self.add_node(op, inputs)
@@ -203,6 +238,7 @@ class QuantGraphModule(HookedModule):
             self.get_qstate().qparams[module_hash].previous_node = []
             self.get_qstate().qparams[module_hash].next_node = []
             self.get_qstate().qparams[module_hash].current_node = module_hash
+            self.get_qstate().qparams[module_hash].layer_index = self.layer_index
 
         current_node = self.get_qstate().qparams[module_hash].current_node
         for inp in inputs:
@@ -233,22 +269,35 @@ class QuantGraphModule(HookedModule):
     ################################################################
     def analyze_connections(self):
         first_module = None
-        prediction_module = None
         for module_hash, qparams in self.get_qstate().qparams.items():
             module = self.get_module(module_hash)
             if utils.is_conv_deconv_linear(module) or utils.is_normalization(module) or utils.is_activation(module):
                 first_module = module if first_module is None else first_module
-                prediction_module = module
             #
         #
+
         for module_hash, qparams in self.get_qstate().qparams.items():
             module = self.get_module(module_hash)
             is_first_module = (first_module is module)
-            is_prediction_module = (prediction_module is module)
-            self._analyse_connections_op(module_hash, module, qparams, is_first_module, is_prediction_module)
+            self._analyse_connections_op(module_hash, module, qparams, is_first_module)
+        #
+
+        last_quantize_layer_index = -1
+        for module_hash, qparams in self.get_qstate().qparams.items():
+            if self.get_qstate().qparams[module_hash].layer_index > last_quantize_layer_index and \
+                    self.get_qstate().qparams[module_hash].quantize_out:
+                last_quantize_layer_index = self.get_qstate().qparams[module_hash].layer_index
+            #
+        #
+        for module_hash, qparams in self.get_qstate().qparams.items():
+            #module = self.get_module(module_hash)
+            if self.get_qstate().qparams[module_hash].layer_index == last_quantize_layer_index and \
+                    (not self.quantize_out):
+                self.get_qstate().qparams[module_hash].quantize_out = False
+            #
         #
 
-    def _analyse_connections_op(self, module_hash, module, qparams, is_first_module, is_prediction_module):
+    def _analyse_connections_op(self, module_hash, module, qparams, is_first_module):
         previous_modules = [self.get_module(p) for p in qparams.previous_node] if len(qparams.previous_node)>0 else []
         next_modules = [self.get_module(n) for n in qparams.next_node] if len(qparams.next_node)>0 else []
 
@@ -314,7 +363,6 @@ class QuantGraphModule(HookedModule):
         qparams.is_dwconv = utils.is_dwconv(module)
         qparams.next_modules = next_modules
         qparams.is_first_module = is_first_module
-        qparams.is_prediction_module = is_prediction_module
 
 
     ################################################################
index add1384e49bb6b2403c6dc66a795d0cb4d14729d..8fac77c2b39100c4836a422807ef7dccacc1a7f7 100644 (file)
@@ -1,22 +1,47 @@
-import torch
-import math
-import copy
-import warnings
-import numpy as np
+#################################################################################
+# Copyright (c) 2018-2021, Texas Instruments Incorporated - http://www.ti.com
+# All Rights Reserved.
+#
+# Redistribution and use in source and binary forms, with or without
+# modification, are permitted provided that the following conditions are met:
+#
+# * Redistributions of source code must retain the above copyright notice, this
+#   list of conditions and the following disclaimer.
+#
+# * Redistributions in binary form must reproduce the above copyright notice,
+#   this list of conditions and the following disclaimer in the documentation
+#   and/or other materials provided with the distribution.
+#
+# * Neither the name of the copyright holder nor the names of its
+#   contributors may be used to endorse or promote products derived from
+#   this software without specific prior written permission.
+#
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
+# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
+# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
+# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
+# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
+# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
+# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+#
+#################################################################################
 
-
-########################################################################
 from .quant_train_module import *
 
 class QuantTestModule(QuantTrainModule):
     def __init__(self, module, dummy_input, *args, bitwidth_weights=8, bitwidth_activations=8, per_channel_q=False,
-                 histogram_range=True, bias_calibration=False, constrain_weights=None, model_surgery_quantize=True,
+                 histogram_range=True, bias_calibration=False, constrain_weights=False, model_surgery_quantize=True,
+                 range_shrink_weights=None, range_shrink_activations=None,
                  power2_weight_range=None, power2_activation_range=None, constrain_bias=None, **kwargs):
         super().__init__(module, dummy_input, *args, bitwidth_weights=bitwidth_weights, bitwidth_activations=bitwidth_activations,
                          per_channel_q=per_channel_q, histogram_range=histogram_range, bias_calibration=bias_calibration,
-                         constrain_weights=constrain_weights,
+                         constrain_weights=constrain_weights, constrain_bias=constrain_bias,
+                         range_shrink_weights=range_shrink_weights, range_shrink_activations=range_shrink_activations,
                          power2_weight_range=power2_weight_range, power2_activation_range=power2_activation_range,
-                         constrain_bias=constrain_bias, **kwargs)
+                         **kwargs)
         assert model_surgery_quantize == True, f'{self.__class__.__name__} does not support model_surgery_quantize=False. please use a qat or calibrated module.'
         self.eval()
 
@@ -25,495 +50,3 @@ class QuantTestModule(QuantTrainModule):
         assert mode == False, 'QuantTestModule cannot be used in train mode'
         super().train(mode)
 
-
-########################################################################
-from .quant_base_module import *
-from .quant_utils import *
-
-
-class QuantEstimateModule(QuantBaseModule):
-    '''
-    QuantEstimateModule  can be used to estimate the quantization accuracy of a float model
-    that has not gone through QAT or Calibration. However, this is an approximate method.
-    '''
-    def __init__(self, module, dummy_input, bitwidth_weights=8, bitwidth_activations=8, per_channel_q=False,
-                 histogram_range=True, range_calibration_online=False, model_surgery_quantize=True,
-                 power2_weight_range=None, power2_activation_range=None, constrain_bias=None):
-        super().__init__(module, dummy_input=dummy_input, bitwidth_weights=bitwidth_weights, bitwidth_activations=bitwidth_activations,
-                         per_channel_q=per_channel_q, histogram_range=histogram_range, bias_calibration=False,
-                         constrain_weights=False, model_surgery_quantize=model_surgery_quantize,
-                         power2_weight_range=power2_weight_range, power2_activation_range=power2_activation_range, constrain_bias=constrain_bias)
-        assert False, 'recommend to use QuantTestModule instead'
-        # whether to do online adjustment of calibration using previous frame range
-        self.range_calibration_online = range_calibration_online
-        # number of offline calibration iters. during offline calibration, current frame range is used
-        self.range_calibration_offline_iters = 25 #10
-
-        # minimum speed for range update
-        self.range_update_factor_min = 0.001 #0.1
-        # range expansion is not needed now as the ranges are not computed based on the actual floating point values.
-        # earlier it was based on quantized values - that's when the expansion was needed.
-        self.range_expansion_factor = 1.0
-
-        # set these to 0 to use faster min/max based range computation (lower accuracy) instead of histogram based range.
-        # shrink range: 0.01 means 0.01 range_shrink_percentile, not 1 range_shrink_percentile
-        self.range_shrink_percentile_activations = (layers.PAct2.PACT2_RANGE_SHRINK if histogram_range else 0)
-        # range shrinking of weight is hurting in some models
-        self.range_shrink_percentile_weights = 0 #(0.01 if histogram_range else 0)
-
-        self.idx_large_mse_for_act = 0
-
-
-    def model_surgery_quantize(self, dummy_input):
-        super().model_surgery_quantize(dummy_input)
-
-        def replace_func(op):
-            for name, m in op._modules.items():
-                if isinstance(m, layers.QAct):
-                    new_m = layers.PAct2(signed=None)
-                else:
-                    new_m = None
-                #
-                if new_m is not None:
-                    for attr in dir(m):
-                        value = getattr(m,attr)
-                        if isinstance(value,torch.Tensor) and value is not None:
-                            getattr(new_m,attr).data.copy_(value.data)
-                        elif isinstance(value,torch.nn.Module) and value is not None:
-                            setattr(new_m, attr, getattr(m,attr))
-                        elif attr in ('weight','bias','eps','clips_act','clips_w'): # add more attribute name here
-                            setattr(new_m, attr, getattr(m, attr))
-                        #
-                        new_m.train(m.training)
-                        setattr(op, name, new_m)
-                    #
-                #
-            #
-        #
-        # apply recursively
-        self.apply(replace_func)
-
-        # clear
-        self.clear_qstate()
-    #
-
-
-    def forward(self, inputs):
-        # analyze - need to merge_weights - so call analyze_graph() instead of just update_counters()
-        self.analyze_graph(inputs=inputs, force_update=True, merge_weights=True)
-
-        # batch_size = inputs[0].size(0) if utils.is_list(inputs) else inputs.size(0)
-        # if batch_size != 1:
-        #     warnings.warn('suggest (not mandatory) to set batchsize to 1 for quantized inference to simulate a realistic scenario')
-        # #
-
-        # calibration does not need gradients
-        with torch.no_grad():
-            # quantize
-            outputs = self.forward_quantize(inputs)
-            # start and new frame, copy the qparams for previous frame of inference
-            self.get_qstate().qparams_prev = self.copy_qparams(self.get_qstate().qparams, inputs)
-            # return
-            return outputs
-        #
-
-
-    def _forward_quantize_hook(self, op, *inputs_orig):
-        inputs = utils.squeeze_list2(inputs_orig)
-        self.start_node(op)
-        self.start_quantize(op)
-
-        if (self.iter_in_epoch == 0):
-            self.process_weights(op, inputs)
-        #
-        self.process_inputs(op, inputs, None)
-
-        outputs = op.__forward_orig__(*inputs_orig)
-
-        self.process_outputs(op, inputs, outputs)
-        self.finish_node(op, inputs, outputs)
-        return outputs
-    #
-
-    def forward_quantize(self, inputs):
-        self.start_call()
-        self.add_call_hook(self.module, self._forward_quantize_hook)
-        self.current_scale = 1.0
-        outputs = self.module(inputs)
-        self.remove_call_hook(self.module)
-        self.finish_call()
-        return outputs
-    #
-
-
-    # implement this in a derived class to clamp weights
-    def apply_constrain_weights(self, module):
-        pass
-
-
-    # implement this in a derived class to do bias calibration
-    def calibrate_bias(self, inputs):
-        pass
-
-
-    def start_quantize(self, op):
-        qparams = self.get_qparams(op)
-        qparams.qrange_in = []
-        qparams.qrange_out = []
-
-
-    def process_weights(self, module, inputs, outputs=None):
-        weight = module.weight if hasattr(module, 'weight') else None
-        bias = module.bias if hasattr(module, 'bias') else None
-        qparams = self.get_qparams(module)
-        if (self.bitwidth_weights is None) or (not qparams.quantize_w):
-            return
-
-        if qparams.quantize_w and weight is not None:
-            qparams.qrange_w = Dict()
-            self.quantize_weights_tensor(module, weight, qparams.qrange_w)
-        else:
-            qparams.qrange_w = None
-
-        if qparams.quantize_b and bias is not None:
-            qparams.qparams_b = Dict()
-            self.quantize_bias_tensor(module, bias, qparams.qparams_b)
-        else:
-            qparams.qparams_b = None
-
-
-    def process_inputs(self, module, inputs, outputs=None):
-        if self.bitwidth_activations is None:
-            return
-
-        inputs = self.format_tensors(inputs)
-        outputs = self.format_tensors(outputs)
-        qparams = self.get_qparams(module)
-        qparams_prev = self.get_qparams_prev(module)
-
-        # track the scale across non-modules (eg. functionals) via current_scale
-        for inp in inputs:
-            inp.scale = inp.scale  if hasattr(inp,'scale') else self.current_scale
-
-        qrange_cur = self.quantize_input_tensors(module, inputs, outputs, qparams_prev, qparams)
-
-        # create the current scale in proccess_inputs instead of process_outputs.
-        # otherwise exit condition for aggregate modules (eg. torch.nn.Sequential, Bottleneck in ResNet) will cause trouble.
-        # all the inputs scales are assumed to be aligned at this point (see align_input_tensors)
-        # any module that needs special handling needs to be considered in quantize_input_tensors / align_input_tensors.
-        has_weight_scale = (hasattr(module,'weight') and (module.weight is not None) and hasattr(module.weight,'scale'))
-        if has_weight_scale:
-            is_dw = utils.is_dwconv(module)
-            use_per_channel_q = (is_dw and self.per_channel_q is True) or (self.per_channel_q == 'all')
-            if use_per_channel_q:
-                #different scale for different channels
-                self.current_scale = [inputs[0].scale * module.weight.scale[chan] for chan in range(module.weight.shape[0])]
-            else:
-                self.current_scale = (inputs[0].scale * module.weight.scale)
-            #
-        else:
-            self.current_scale = inputs[0].scale
-        #
-
-        # update range
-        if qparams.quantize_in:
-            # in the first frame we cannot do running update. after that we can do that.
-            running_update = (qparams_prev is not None) and len(qparams_prev.qrange_in)>0
-            for idx, inp in enumerate(inputs):
-                qrange_prev = qparams_prev.qrange_in[idx] if running_update else (0,0)
-                qrange_running = self._update_activation_ranges(module, inp, running_update, qrange_cur[idx], qrange_prev)
-                qparams.qrange_in.append(qrange_running)
-
-
-    def process_outputs(self, module, inputs, outputs):
-        if self.bitwidth_activations is None:
-            return
-
-        inputs = self.format_tensors(inputs)
-        output = self.format_tensors(outputs)
-        qparams = self.get_qparams(module)
-        qparams_prev = self.get_qparams_prev(module)
-
-        # already adjusted the scale due to weights, in process_inputs
-        for idx, opt in enumerate(output):
-            opt.scale = self.current_scale
-
-        qrange_cur = self.quantize_output_tensors(module, inputs, output, qparams_prev, qparams)
-        self.viz_act(en=False, opt_q=output[0], opt_fl=inputs[0])
-        self.current_scale = output[0].scale
-
-        # update range
-        if qparams.quantize_out or qparams.unquantize_out:
-            # in the first frame we cannot do running update. after that we can do that.
-            running_update = (qparams_prev is not None) and len(qparams_prev.qrange_out)>0
-            for idx, opt in enumerate(output):
-                if isinstance(opt, (torch.LongTensor, torch.cuda.LongTensor)):
-                    continue
-                #
-                qrange_prev = qparams_prev.qrange_out[idx] if running_update else None
-                qrange_running = self._update_activation_ranges(module, opt, running_update, qrange_cur[idx], qrange_prev)
-                qparams.qrange_out.append(qrange_running)
-
-        self.unquantize_outputs(module, inputs, output, qparams_prev, qparams)
-        self.current_scale = output[0].scale
-
-
-    def compute_tensor_range(self, module, tensor_in, range_shrink_percentile):
-        if hasattr(tensor_in, 'scale') and utils.is_list(tensor_in.scale):
-            scale_inv = [(1/s) for s in tensor_in.scale]
-            tensor_scale_inv = torch.tensor(scale_inv).view(1,-1,1,1).to(tensor_in.device)
-            tensor_scaled = tensor_in * tensor_scale_inv
-            (mn, mx) = self._compute_tensor_range_noscale(module, tensor_scaled, range_shrink_percentile)
-        else:
-            scale = tensor_in.scale if hasattr(tensor_in, 'scale') else 1.0
-            (mn, mx) = self._compute_tensor_range_noscale(module, tensor_in, range_shrink_percentile)
-            (mn, mx) = (mn / scale, mx / scale)
-        #
-        return mn, mx
-
-
-    def _compute_tensor_range_noscale(self, module, tensor, range_shrink_percentile):
-        mn, mx = utils.extrema_fast(tensor.data, range_shrink_percentile)
-        return mn, mx
-
-
-    def _update_activation_ranges(self, module, tensor_in, running_update, qrange_cur, qrange_prev):
-        is_calibration = (self.iter_in_epoch < self.range_calibration_offline_iters)
-        update_activation_range = (is_calibration or self.range_calibration_online)
-        if update_activation_range:
-            # in the case of fixed range module, we do not expand the ranges
-            fixed_range_module = utils.is_fixed_range(module)
-            if fixed_range_module:
-                qrange_running = qrange_cur
-            else:
-                (mn, mx) = (float(qrange_cur.min)*self.range_expansion_factor, float(qrange_cur.max)*self.range_expansion_factor)
-                # in the first frame we cannot do running update. after that we can do that.
-                if running_update:
-                    update_factor = (1.0 / (self.iter_in_epoch + 1))
-                    update_factor = max(update_factor, self.range_update_factor_min) if self.range_update_factor_min else update_factor
-                    mn = update_factor * mn + (1 - update_factor) * qrange_prev.min
-                    mx = update_factor * mx + (1 - update_factor) * qrange_prev.max
-                #
-                qrange_running = Dict()
-                qrange_running.min = mn; qrange_running.max = mx
-            #
-        else:
-            qrange_running = qrange_prev
-
-        return qrange_running
-
-
-    def get_bitwidth_weights(self, module):
-        bitwidth_weights_last = (self.bitwidth_weights[2] if utils.is_list(self.bitwidth_weights) else self.bitwidth_weights)
-        bitwidth_weights_dw = (self.bitwidth_weights[1] if utils.is_list(self.bitwidth_weights) else self.bitwidth_weights)
-        bitwidth_weights_nodw = (self.bitwidth_weights[0] if utils.is_list(self.bitwidth_weights) else self.bitwidth_weights)
-        bitwidth_weights = bitwidth_weights_last if self.is_last_conv(module) else \
-            (bitwidth_weights_dw if utils.is_dwconv(module) else bitwidth_weights_nodw)
-        return bitwidth_weights
-
-
-    def get_bitwidth_activations(self, module):
-        bitwidth_activations_last = (self.bitwidth_activations[2] if utils.is_list(self.bitwidth_activations) else self.bitwidth_activations)
-        bitwidth_activations_dw = (self.bitwidth_activations[1] if utils.is_list(self.bitwidth_activations) else self.bitwidth_activations)
-        bitwidth_activations_nodw = (self.bitwidth_activations[0] if utils.is_list(self.bitwidth_activations) else self.bitwidth_activations)
-        bitwidth_activations = bitwidth_activations_last if self.is_last_conv(module) else \
-            (bitwidth_activations_dw if utils.is_dwconv(module) else bitwidth_activations_nodw)
-        return bitwidth_activations
-
-
-    def quantize_weights_tensor(self, module, tensor_in, qrange):
-        self.apply_constrain_weights(module)
-
-        bitwidth_weights = self.get_bitwidth_weights(module)
-        with torch.no_grad():
-            is_dw = utils.is_dwconv(module)
-            use_per_channel_q = (self.per_channel_q == 'all' or (bool(self.per_channel_q) == True and is_dw))
-            if use_per_channel_q:
-                qrange.min = []
-                qrange.max = []
-                tensor_in.scale = []
-                for chan in range(tensor_in.shape[0]):
-                    # Range
-                    mn, mx = self.compute_tensor_range(module, tensor_in[chan], range_shrink_percentile=self.range_shrink_percentile_weights)
-                    tensor_scale, clamp_limits = compute_tensor_scale(tensor_in[chan], mn, mx, bitwidth_weights, self.power2_weight_range)
-                    qrange.min.append(mn)
-                    qrange.max.append(mx)
-                    # Quantize
-                    tensor = symmetric_round_tensor(tensor_in[chan] * tensor_scale)
-                    tensor = tensor.clamp(clamp_limits[0], clamp_limits[1]).float()
-                    # Convert back to float - since this module does only simulation
-                    tensor_in[chan].data[...] = (tensor.data / tensor_scale)
-                    tensor_in.scale.append(1.0)
-                #
-            else:
-                # Range
-                mn, mx = self.compute_tensor_range(module, tensor_in, range_shrink_percentile=self.range_shrink_percentile_weights)
-                tensor_scale, clamp_limits = compute_tensor_scale(tensor_in, mn, mx, bitwidth_weights, self.power2_weight_range)
-                qrange.min = mn
-                qrange.max = mx
-                # Quantize
-                tensor = symmetric_round_tensor(tensor_in * tensor_scale)
-                tensor = tensor.clamp(clamp_limits[0], clamp_limits[1])
-                #Convert back to float - since this module does only simulation
-                tensor_in.data = (tensor.data / tensor_scale)
-                tensor_in.scale = 1.0
-
-
-    def quantize_bias_tensor(self, module, tensor_in, qparams):
-        quant_for_bias = True
-        if quant_for_bias:
-            bitwidth_weights = self.get_bitwidth_weights(module)
-
-            #use same bitwidth as weight
-            bitwidth_bias = bitwidth_weights
-            
-            mn, mx = self.compute_tensor_range(module, tensor_in, range_shrink_percentile=0.0)
-            tensor_scale, clamp_limits = compute_tensor_scale(tensor_in, mn, mx, bitwidth_bias, self.power2_weight_range)
-
-            # --
-            tensor = symmetric_round_tensor(tensor_in * tensor_scale)
-            tensor = tensor.clamp(clamp_limits[0], clamp_limits[1])
-
-            # Convert back to float - since this module does only simulation
-            tensor_in.data = (tensor.data / tensor_scale)
-            tensor_in.scale = 1.0
-        else:    
-            tensor_in.scale = 1.0
-
-
-    def quantize_input_tensors(self, module, input, output, qparams_prev, qparams):
-        qrange_cur = []
-        use_current_range = (not self.range_calibration_offline_iters) or (self.iter_in_epoch < self.range_calibration_offline_iters)
-        for idx, inp in enumerate(input):
-            if qparams.quantize_in:
-                qrange_tensor_approx = None if use_current_range else qparams_prev.qrange_in[idx]
-                qrange_tensor = self._quantize_activation(module, inp, qrange_tensor_approx)
-                qrange_cur.append(qrange_tensor)
-
-        return qrange_cur
-
-
-    def quantize_output_tensors(self, module, input, output, qparams_prev, qparams):
-        qrange_cur = []
-        use_current_range = (not self.range_calibration_offline_iters) or (self.iter_in_epoch < self.range_calibration_offline_iters)
-        for idx, opt in enumerate(output):
-            if qparams.quantize_out:
-                qrange_tensor_approx = None if use_current_range else qparams_prev.qrange_out[idx]
-                qrange_tensor = self._quantize_activation(module, opt, qrange_tensor_approx)
-                qrange_cur.append(qrange_tensor)
-
-        return qrange_cur
-
-
-    def unquantize_outputs(self, module, input, output, qparams_prev, qparams):
-        pass
-
-
-    def _quantize_activation(self, module, tensor_in, qrange):
-        bitwidth_activations = self.get_bitwidth_activations(module)
-        with torch.no_grad():
-            if qrange:
-                # after calibration, we use the range obtained from previous frame directly
-                mn = qrange.min
-                mx = qrange.max
-            else:
-                # range expansion is not required when quantizing using the current frame range (calibration)
-                # for fixed range modules, we use that range directly.
-                fixed_range_module = utils.is_fixed_range(module)
-                if fixed_range_module:
-                    op_range = utils.get_range(module)
-                    mn = op_range[0]
-                    mx = op_range[1]
-                else:
-                    mn, mx = self.compute_tensor_range(module, tensor_in, range_shrink_percentile=self.range_shrink_percentile_activations)
-
-            tensor_scale, clamp_limits = compute_tensor_scale(None, mn, mx, bitwidth_activations, True)
-            tensor = upward_round_tensor(tensor_in*tensor_scale)
-            tensor = tensor.clamp(clamp_limits[0], clamp_limits[1])
-
-            #Convert back to float - since this module does only simulation
-            tensor_in.data = tensor.data/tensor_scale
-            tensor_in.scale = 1.0
-            qrange_tensor = Dict(); qrange_tensor.min = mn; qrange_tensor.max = mx
-            return qrange_tensor
-
-
-    def wt_mse_based_clip(self, tensor_wt_fl, bitwidth_weights=8):
-        mn, mx = utils.extrema_fast(tensor_wt_fl)
-        mn = mn.cpu().numpy()
-        mx = mx.cpu().numpy()
-        mx_abs = max(abs(mn), abs(mx))
-        # print("******** New Wt Tensor Starts *******")
-        # print("mn,mx: ", mn,mx, end = ' ')
-        if mx_abs == 0:
-            best_clip_value = 0.0
-            # print("All weights 0. Dead channel !!! Check")
-        else:
-            # How many clip value needs to be searched for?
-            serach_num_clips = 100
-            step_size = mx_abs / serach_num_clips
-            # print("step_size: ", step_size)
-            min_mse = float("inf")
-            bset_clip_value = 0
-            for clip_value in np.arange(step_size, mx_abs + step_size, step_size):
-                # print("Target clip_value: {:05f}".format(clip_value), end=' ')
-                [mse, actual_clip] = self.compute_mse_for_quant(tensor_wt_fl=tensor_wt_fl, mn=-clip_value,
-                                                                mx=clip_value, bitwidth_weights=bitwidth_weights)
-                if mse < min_mse:
-                    min_mse = mse
-                    best_clip_value = actual_clip
-                # print("clip_value: mse : {:05f} : {:8.5f}".format(actual_clip, mse))
-                # print("******** Clip value Ends *******")
-            #if min_mse < mse:
-            #    print("best_clip_value: {:8.5f} min_mse : {:0.7f} clip_value: {:8.5f}  mse : {:0.7f} ".format(
-            #        best_clip_value, min_mse, actual_clip, mse))
-        # [mn,mx]
-        return [-best_clip_value, best_clip_value]
-
-
-    def compute_mse_for_quant(self, tensor_wt_fl, mn=0, mx=0, bitwidth_weights=8):
-        tensor_scale, clamp_limits = compute_tensor_scale(tensor_wt_fl, mn, mx, bitwidth_weights, self.power2_weight_range)
-
-        # print("mn : mx  {} {}".format(mn, mx))
-
-        # print("tensor_wt_fl: " , tensor_wt_fl.cpu().numpy().flatten()[0:20])
-        tensor_wt_q = symmetric_round_tensor(tensor_wt_fl * tensor_scale)
-        # print("tensor_wt_q:fl*scale " , tensor_wt_q.cpu().numpy().flatten()[0:20])
-        tensor_wt_q = tensor_wt_q.clamp(clamp_limits[0], clamp_limits[1])
-        # print("tensor_wt_q:clamp(flt*scale) " , tensor_wt_q.cpu().numpy().flatten()[0:20])
-
-        # Convert back to float - since this module does only simulation
-        tensor_wt_q.data = (tensor_wt_q.data / tensor_scale)
-        tensor_wt_q.scale = 1.0
-        # print("tensor_wt_q:Final " , tensor_wt_q.cpu().numpy().flatten()[0:20])
-
-        mse = ((tensor_wt_fl.cpu().numpy() - tensor_wt_q.cpu().numpy()) ** 2).mean(axis=None)
-        actual_clip = clamp_limits[1] / tensor_scale
-        return [mse, actual_clip]
-
-
-    def viz_act(self, en=False, opt_q=[], opt_fl=[]):
-        if not en:
-            return
-        opt_q = opt_q.cpu().numpy().flatten()
-        opt_fl = opt_fl.cpu().numpy().flatten()
-        # act_mse = ((opt_q - opt_fl)**2).mean(axis=None)
-        if True:  # (act_mse > 1E-4):
-            # print("act_mse: {:.6f}".format( act_mse))
-            if (self.idx_large_mse_for_act >= 0):
-                mn = opt_fl.min()
-                mx = opt_fl.max()
-                hist_fl = utils.hist_weight_tensor2D(x_ch=opt_fl, log=True, dir='act_study_fl',
-                                                     name='act_{:03d}_fl_{:.3f}_{:.3f}'.format(
-                                                         self.idx_large_mse_for_act, mn, mx), ch=0, en=True)
-
-                mn = opt_q.min()
-                mx = opt_q.max()
-                hist_q = utils.hist_weight_tensor2D(x_ch=opt_q, log=True, dir='act_study_q',
-                                                    name='act_{:03d}_q_{:.3f}_{:.3f}'.format(self.idx_large_mse_for_act,
-                                                                                             mn, mx), ch=0, en=True)
-                # print('hist_fl: ', hist_fl)#
-                # print('hist_q: ', hist_q)
-            self.idx_large_mse_for_act += 1
-
-
index eda0bf91131110ec57ff1325aad58446e9656602..3f58b9cb3dd8472b8639e3822c545106be745091 100644 (file)
@@ -1,3 +1,34 @@
+#################################################################################
+# Copyright (c) 2018-2021, Texas Instruments Incorporated - http://www.ti.com
+# All Rights Reserved.
+#
+# Redistribution and use in source and binary forms, with or without
+# modification, are permitted provided that the following conditions are met:
+#
+# * Redistributions of source code must retain the above copyright notice, this
+#   list of conditions and the following disclaimer.
+#
+# * Redistributions in binary form must reproduce the above copyright notice,
+#   this list of conditions and the following disclaimer in the documentation
+#   and/or other materials provided with the distribution.
+#
+# * Neither the name of the copyright holder nor the names of its
+#   contributors may be used to endorse or promote products derived from
+#   this software without specific prior written permission.
+#
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
+# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
+# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
+# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
+# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
+# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
+# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+#
+#################################################################################
+
 ###########################################################
 # Approximate quantized floating point simulation with gradients.
 # Can be used for quantized training of models.
@@ -10,7 +41,6 @@ import warnings
 
 from .. import layers
 from .. import utils
-from . import quant_utils
 from .quant_base_module import *
 
 warnings.filterwarnings('ignore', category=torch.jit.TracerWarning)
@@ -20,21 +50,29 @@ warnings.filterwarnings('ignore', category=torch.jit.TracerWarning)
 class QuantTrainModule(QuantBaseModule):
     def __init__(self, module, dummy_input, *args, bitwidth_weights=8, bitwidth_activations=8, per_channel_q=False,
                  histogram_range=True, bias_calibration=False, constrain_weights=None,
+                 range_shrink_weights=None, range_shrink_activations=None,
                  power2_weight_range=None, power2_activation_range=None, constrain_bias=None, **kwargs):
         constrain_weights = (not per_channel_q) if constrain_weights is None else constrain_weights
         super().__init__(module, dummy_input, *args, bitwidth_weights=bitwidth_weights, bitwidth_activations=bitwidth_activations,
                          per_channel_q=per_channel_q, histogram_range=histogram_range, bias_calibration=bias_calibration,
-                         constrain_weights=constrain_weights, model_surgery_quantize=True,
+                         constrain_weights=constrain_weights, constrain_bias=constrain_bias,
+                         range_shrink_weights=range_shrink_weights, range_shrink_activations=range_shrink_activations,
                          power2_weight_range=power2_weight_range, power2_activation_range=power2_activation_range,
-                         constrain_bias=constrain_bias, **kwargs)
+                         model_surgery_quantize=True, **kwargs)
 
     def forward(self, inputs, *args, **kwargs):
         # counters such as num_batches_tracked are used. update them.
         self.update_counters()
+        # bitwidth_warmup
+        self.adjust_bitwidth()
+
         # outputs
         outputs = self.module(inputs, *args, **kwargs)
         return outputs
 
+    def adjust_bitwidth(self):
+        # adjust bitwidth here, if needed, to do gradual bitwidth adjustment
+        pass
 
     def model_surgery_quantize(self, dummy_input, *args, **kwargs):
         super().model_surgery_quantize(dummy_input, *args, **kwargs)
@@ -60,21 +98,21 @@ class QuantTrainModule(QuantBaseModule):
                 elif isinstance(m, layers.PAct2):
                     new_m = QuantTrainPAct2(inplace=m.inplace, signed=m.signed, clip_range=m.clip_range,
                                              bitwidth_weights=self.bitwidth_weights, bitwidth_activations=self.bitwidth_activations,
-                                            per_channel_q=self.per_channel_q, range_shrink_percentile=self.range_shrink_percentile,
+                                            per_channel_q=self.per_channel_q, range_shrink_activations=self.range_shrink_activations,
                                             power2_weight_range=self.power2_weight_range, power2_activation_range=self.power2_activation_range)
                 elif isinstance(m, layers.QAct):
                     new_m = QuantTrainPAct2(signed=None, bitwidth_weights=self.bitwidth_weights, bitwidth_activations=self.bitwidth_activations,
-                                             per_channel_q=self.per_channel_q, range_shrink_percentile=self.range_shrink_percentile,
+                                             per_channel_q=self.per_channel_q, range_shrink_activations=self.range_shrink_activations,
                                             power2_weight_range=self.power2_weight_range, power2_activation_range=self.power2_activation_range)
                 elif isinstance(m, (torch.nn.ReLU, torch.nn.ReLU6)):
                     new_m = QuantTrainPAct2(signed=False, bitwidth_weights=self.bitwidth_weights, bitwidth_activations=self.bitwidth_activations,
-                                             per_channel_q=self.per_channel_q, range_shrink_percentile=self.range_shrink_percentile,
+                                             per_channel_q=self.per_channel_q, range_shrink_activations=self.range_shrink_activations,
                                             power2_weight_range=self.power2_weight_range, power2_activation_range=self.power2_activation_range)
                 else:
                     new_m = None
                 #
                 if new_m is not None:
-                    copy_attr_list = ('weight', 'bias', 'eps', 'clips_act', 'clips_w', 'range_shrink_percentile')
+                    copy_attr_list = ('weight', 'bias', 'eps', 'clips_act', 'clips_w', 'range_shrink_activations')
                     for attr in dir(m):
                         value = getattr(m,attr)
                         if isinstance(value,torch.Tensor) and value is not None:
@@ -99,8 +137,6 @@ class QuantTrainModule(QuantBaseModule):
     #
 
 
-
-
 ###########################################################
 class QuantTrainParams:
     pass
@@ -215,8 +251,8 @@ class QuantTrainLinear(torch.nn.Linear):
         y.qparams = qparams
         return y
     #
-       
-       
+
+
 ###########################################################
 class QuantTrainBatchNorm2d(torch.nn.BatchNorm2d):
     def __init__(self, *args, **kwargs):
@@ -250,9 +286,10 @@ class QuantTrainBatchNorm2d(torch.nn.BatchNorm2d):
 # fake quantized PAct2 for training
 class QuantTrainPAct2(layers.PAct2):
     def __init__(self, inplace=False, signed=False, clip_range=None, bitwidth_weights=None, bitwidth_activations=None,
-                 per_channel_q=False, range_shrink_percentile=layers.PAct2.PACT2_RANGE_SHRINK, power2_weight_range=True,
+                 per_channel_q=False, range_shrink_activations=layers.PAct2.PACT2_RANGE_SHRINK_DEFAULT, power2_weight_range=True,
                  power2_activation_range=True):
-        super().__init__(inplace=inplace, signed=signed, clip_range=clip_range, range_shrink_percentile=range_shrink_percentile,
+        super().__init__(inplace=inplace, signed=signed, clip_range=clip_range,
+                         range_shrink_activations=range_shrink_activations,
                          power2_activation_range=power2_activation_range)
 
         self.bitwidth_weights = bitwidth_weights
@@ -267,11 +304,7 @@ class QuantTrainPAct2(layers.PAct2):
         #   (as yq is directly propagated then). Uses: PropagateQuantTensorQTE
         self.quantize_backward_type = 'ste'
 
-        # weight shrinking is done by clamp weights - set this factor to zero.
-        # this must me zero - as in pact we do not have the actual weight param, but just a temporary tensor
-        # so any clipping we do here is not stored int he weight params
-        self.range_shrink_weights = 0.0
-        self.round_dither = 0.0
+        self.range_shrink_weights = None
         self.update_activation_range = True
         self.quantize_enable = True
         self.quantize_weights = True
@@ -280,15 +313,17 @@ class QuantTrainPAct2(layers.PAct2):
         self.constrain_bias = None
         self.constrain_weights = True
         self.bias_calibration = False
-        # start bias constrain at this iteration
-        self.constrain_bias_start_iter = 75
+        # constraining of weights at this iteration
+        self.constrain_weights_iter = 0
+        # start bias constraint at this iteration
+        self.constrain_bias_start_iter = 85
         # storing of weights at this iteration
-        self.store_weights_iter = 0
-
+        self.store_weights_iter = 0 #85
 
     def forward(self, x):
         assert (self.bitwidth_weights is not None) and (self.bitwidth_activations is not None), \
                         'bitwidth_weights and bitwidth_activations must not be None'
+
         # the pact range update happens here - but range clipping depends on quantize_enable
         y = super().forward(x, update_activation_range=self.update_activation_range, enable=self.quantize_enable)
 
@@ -335,10 +370,11 @@ class QuantTrainPAct2(layers.PAct2):
         #
 
         if (self.quantize_enable and self.quantize_activations):
-            clip_min, clip_max, scale, scale_inv = self.get_clips_scale_act()
+            use_per_channel_q = (self.per_channel_q == 'all')
+            clip_min, clip_max, scale, scale_inv = self.get_clips_scale_act(xq, use_per_channel_q=use_per_channel_q)
             width_min, width_max = self.get_widths_act()
             # no need to call super().forward here as clipping with width_min/windth_max-1 after scaling has the same effect.
-            yq = layers.quantize_dequantize_g(xq, scale, width_min, width_max-1, self.power2_activation_range, 'round_up')
+            yq = layers.quantize_dequantize_g(xq, scale, width_min, width_max-1, self.power2_activation_range, 1, 'round_up')
         else:
             yq = super().forward(xq, update_activation_range=False, enable=True)
         #
@@ -358,15 +394,24 @@ class QuantTrainPAct2(layers.PAct2):
 
 
     def apply_constrain_weights(self, merged_weight):
-        return quant_utils.constrain_weight(merged_weight)
+        return utils.constrain_weight(merged_weight)
+
+
+    def reduce_quantize_weight_scale(self):
+        return (self.constrain_bias == ConstrainBiasType.CONSTRAIN_BIAS_TYPE_REDUCE_WEIGHT_SCALE)
 
 
     def merge_quantize_weights(self, qparams, conv, bn):
         num_batches_tracked = int(self.num_batches_tracked)
-        is_constrain_weights_iter = self.training and (num_batches_tracked == self.store_weights_iter)
+        # constrain/clip weights to reduce the dynamic range of weights
+        is_constrain_weights_iter = self.training and (num_batches_tracked == self.constrain_weights_iter)
+        # store weights once in training after constraining
         is_store_weights_iter = self.training and (num_batches_tracked == self.store_weights_iter)
+        # note we do not modify bias here - but rather the weight scale so that the bias doesn't overflow after scaling.
+        # weight scale adjustment according to bias constraint needs to happen for train and val
+        is_constrain_bias_iter = (not self.training) or (num_batches_tracked >= self.constrain_bias_start_iter)
+        # store the constrained bias if needed
         is_store_bias_iter = self.training and (num_batches_tracked == self.constrain_bias_start_iter)
-        is_constrain_bias_iter = self.training and (num_batches_tracked >= self.constrain_bias_start_iter)
 
         # merge weight and bias (if possible) across layers
         if conv is not None and bn is not None:
@@ -404,58 +449,66 @@ class QuantTrainPAct2(layers.PAct2):
             merged_weight = 0.0
             merged_bias = 0.0
         #
-
         # quantize weight and bias
         if (conv is not None):
+            is_dw = utils.is_dwconv(conv)
+            is_deconv = utils.is_deconv(conv)
+            use_per_channel_q = (is_dw and self.per_channel_q is True) or (self.per_channel_q == 'all')
+            # quantize the bias
+            if (self.quantize_enable and self.quantize_bias):
+                bias_width_min, bias_width_max = self.get_widths_bias()
+                bias_clip_min, bias_clip_max, bias_scale2, bias_scale_inv2 = self.get_clips_scale_bias(merged_bias)
+                power2_bias_range = (self.power2_weight_range and self.power2_activation_range)
+                merged_bias = layers.quantize_dequantize_g(merged_bias, bias_scale2, bias_width_min, bias_width_max - 1,
+                                                           power2_bias_range, 0, 'round_sym')
+            # #
+            # quantize the weights
             if (self.quantize_enable and self.quantize_weights):
+                # clip/constrain the weights
                 if self.constrain_weights and is_constrain_weights_iter:
                     with torch.no_grad():
                         # clamp merged weights, invert the bn and copy to conv weight
                         constrained_weight = self.apply_constrain_weights(merged_weight.data)
                         merged_weight.data.copy_(constrained_weight.data)
-                        # store clipped weight after inverting bn - not really needed as there is a saving below as well
-                        # conv.weight.data.copy_(merged_weight.data * merged_scale_inv)
+                    #
+                #
+                # get the weight scale values (multiple values in case of use_per_channel_q)
+                weight_clip_min, weight_clip_max, weight_scale2, weight_scale_inv2 = self.get_clips_scale_w(merged_weight,
+                                        use_per_channel_q=use_per_channel_q, is_deconv=is_deconv)
+
+                # in some cases, bias quantization can have additional restrictions if for example,
+                # bias that is being added to accumulator is limited to 16bit.
+                if self.quantize_enable and self.constrain_bias and is_constrain_bias_iter and self.reduce_quantize_weight_scale():
+                    # use the bias to determine the bias scale allowed due to additional joint constrains
+                    clips_scale_joint = self.get_clips_scale_joint(merged_bias)
+                    # get the input scale if it is available
+                    clips_scale_input = self.get_clips_scale_input(qparams)
+                    if clips_scale_input is not None:
+                        # scale factor to be used for bias is the product of scale factors of weight and input
+                        # using input_scale, work backwards and find the maximum allowed weight scale
+                        scale2_joint = clips_scale_joint[2]
+                        scale2_input = clips_scale_input[2]
+                        scale2_input = torch.clamp(scale2_input, min=self.eps)
+                        scale2_weight_max_joint = scale2_joint / scale2_input
+                        # limit the weight scale to maximum allowed weight scale
+                        weight_scale2 = torch.min(weight_scale2, scale2_weight_max_joint)
+                        weight_scale_inv2 = weight_scale2.pow(-1)
                     #
                 #
 
-                is_dw = utils.is_dwconv(conv)
-                is_deconv = utils.is_deconv(conv)
-                use_per_channel_q = (is_dw and self.per_channel_q is True) or (self.per_channel_q == 'all')
-                clip_min, clip_max, scale2, scale_inv2 = self.get_clips_scale_w(merged_weight,
-                                                use_per_channel_q=use_per_channel_q, is_deconv=is_deconv)
+                # do fake quantization of weights
                 width_min, width_max = self.get_widths_w()
-                merged_weight = layers.quantize_dequantize_g(merged_weight, scale2, width_min, width_max-1,
-                                                             self.power2_weight_range, 'round_sym')
-            #
-
-            if (self.quantize_enable and self.quantize_bias):
-                bias_width_min, bias_width_max = self.get_widths_bias()
-                bias_clip_min, bias_clip_max, bias_scale2, bias_scale_inv2 = self.get_clips_scale_bias(merged_bias)
-                power2_bias_range = (self.power2_weight_range and self.power2_activation_range)
-                merged_bias = layers.quantize_dequantize_g(merged_bias, bias_scale2, bias_width_min, bias_width_max - 1,
-                                                           power2_bias_range, 'round_sym')
-            #
-
-            # in some cases, bias quantization can have additional restrictions if for example,
-            # bias that is being added to accumulator is limited to 16bit.
-            # scale factor to be used for bias is the product of scale factors of weight and input
-            if self.quantize_enable and self.constrain_bias and is_constrain_bias_iter:
-                clips_scale_joint = self.get_clips_scale_joint(qparams, merged_weight, merged_bias,
-                                                use_per_channel_q=use_per_channel_q, is_deconv=is_deconv)
-                if clips_scale_joint is not None:
-                    bias_width_min, bias_width_max = self.get_widths_joint()
-                    bias_clip_min, bias_clip_max, bias_scale2, bias_scale_inv2 = clips_scale_joint
-                    power2_bias_range = (self.power2_weight_range and self.power2_activation_range)
-                    merged_bias = layers.quantize_dequantize_g(merged_bias, bias_scale2, bias_width_min,
-                                                  bias_width_max - 1, power2_bias_range, 'round_sym')
-                #
+                per_channel_q_axis = 1 if is_deconv else 0
+                merged_weight = layers.quantize_dequantize_g(merged_weight, weight_scale2, width_min, width_max-1,
+                                                             self.power2_weight_range, per_channel_q_axis, 'round_sym')
             #
-
             # invert the bn operation and store weights/bias
             if self.quantize_enable and self.quantize_weights and is_store_weights_iter:
                 conv.weight.data.copy_(merged_weight.data * merged_scale_inv)
             #
-            if self.quantize_enable and self.quantize_bias and is_store_bias_iter:
+            # store the constrained bias if needed
+            if self.quantize_enable and self.quantize_bias and is_store_bias_iter and \
+                    self.constrain_bias == ConstrainBiasType.CONSTRAIN_BIAS_TYPE_SATURATE:
                 if conv.bias is not None:
                     if bn is not None:
                         conv_bias = (merged_bias - bn_bias) * merged_scale_inv.view(-1) + bn.running_mean
@@ -488,7 +541,7 @@ class QuantTrainPAct2(layers.PAct2):
         clip_max = torch.max(torch.abs(w_min), torch.abs(w_max))
         clip_max = torch.clamp(clip_max, min=self.eps)
         # in range learning mode + training - this power2_weight_range is taken care in the quantize function
-        use_power2 = (self.power2_weight_range and (not (self.PACT2_RANGE_LEARN and self.training)))
+        use_power2 = (self.power2_weight_range and (not (self.PACT2_RANGE_LEARN_MODE and self.training)))
         clip_max2 = layers.functional.ceil2_g(clip_max) if use_power2 else clip_max
         clip_min2 = -clip_max2
         return (clip_min2, clip_max2)
@@ -544,13 +597,33 @@ class QuantTrainPAct2(layers.PAct2):
         return width_min, width_max
 
 
-    def get_clips_scale_act(self):
+    def get_clips_scale_act(self, tensor=None, use_per_channel_q=False):
         clip_min, clip_max = self.get_clips_act()
         width_min, width_max = self.get_widths_act()
         scale2 = width_max / clip_max
         scale2 = torch.clamp(scale2, min=self.eps)
         scale_inv2 = scale2.pow(-1.0)
-        return (clip_min, clip_max, scale2, scale_inv2)
+        if not use_per_channel_q:
+            return (clip_min, clip_max, scale2, scale_inv2)
+        #
+        # the remaining part of the function is only used in the case of use_per_channel_q
+        # compute the per-channel weight scale.
+        # restrict the weight scale factor of a channel from becoming extremely large.
+        scale_factor_ratio_max = 256  # None
+        channels = int(tensor.size(1))
+        scale2_array = torch.zeros(1, channels, 1, 1).to(tensor.device)
+        scale_inv2_array = torch.zeros(1, channels, 1, 1).to(tensor.device)
+        for chan_id in range(channels):
+            tensor_channel = tensor[:, chan_id, ...]
+            _, _, scale2_value, scale_inv2_value = self.get_clips_scale_act(tensor_channel)
+            scale2_value = torch.min(scale2_value, scale2 * scale_factor_ratio_max) \
+                if (scale_factor_ratio_max is not None) else scale2_value
+            scale2_value = torch.clamp(scale2_value, min=self.eps)
+            scale_inv2_value = scale2_value.pow(-1.0)
+            scale2_array[0, chan_id, 0, 0] = scale2_value
+            scale_inv2_array[0, chan_id, 0, 0] = scale_inv2_value
+        #
+        return (clip_min, clip_max, scale2_array, scale_inv2_array)
 
 
     ###########################################################
@@ -572,8 +645,18 @@ class QuantTrainPAct2(layers.PAct2):
 
 
     ###########################################################
+    def get_clips_scale_joint(self, tensor):
+        clip_min, clip_max = self.get_clips_bias(tensor)
+        width_min, width_max = self.get_widths_joint()
+        scale2 = (width_max / clip_max)
+        scale2 = torch.clamp(scale2, min=self.eps)
+        scale_inv2 = scale2.pow(-1.0)
+        return (clip_min, clip_max, scale2, scale_inv2)
+
+
     def get_widths_joint(self):
-        bw = (2*self.bitwidth_weights-1)
+        bw = (4*self.bitwidth_weights-1) if (self.constrain_bias == ConstrainBiasType.CONSTRAIN_BIAS_TYPE_NONE) \
+            else (2*self.bitwidth_weights-1)
         width_max = np.power(2.0, bw)
         width_min = -width_max
         return (width_min, width_max)
@@ -608,27 +691,3 @@ class QuantTrainPAct2(layers.PAct2):
         #
 
 
-    def get_clips_scale_joint(self, qparams, weights, bias, use_per_channel_q=False, is_deconv=False):
-        clips_scale_input = self.get_clips_scale_input(qparams)
-        if clips_scale_input is None:
-            return None
-        #
-        clip_min_input, clip_max_input, scale2_input, scale_inv2_input = clips_scale_input
-        clip_min_bias, clip_max_bias, scale2_bias, scale_inv2_bias = self.get_clips_scale_bias(bias)
-        if not use_per_channel_q:
-            clip_min_w, clip_max_w, scale2_w, scale_inv2_w = self.get_clips_scale_w(weights)
-            return (clip_min_bias, clip_max_bias, scale2_w * scale2_input, scale_inv2_w * scale_inv2_input)
-        #
-        # the remaining part of the function is only used in the case of use_per_channel_q
-        channels = int(weights.size(1)) if is_deconv else int(weights.size(0))
-        clip_min_w, clip_max_w, scale2_w, scale_inv2_w = self.get_clips_scale_w(weights,
-                                use_per_channel_q=use_per_channel_q, is_deconv=is_deconv)
-        scale2_joint_array = torch.zeros(channels).to(weights.device)
-        scale_joint_inv2_array = torch.zeros(channels).to(weights.device)
-        for chan_id in range(channels):
-            scale_value = scale2_w[0, chan_id, 0, 0] if is_deconv else scale2_w[chan_id, 0, 0, 0]
-            scale_inv_value = scale_inv2_w[0, chan_id, 0, 0] if is_deconv else scale_inv2_w[chan_id, 0, 0, 0]
-            scale2_joint_array[chan_id] = scale_value * scale2_input
-            scale_joint_inv2_array[chan_id] = scale_inv_value * scale_inv2_input
-        #
-        return (clip_min_bias, clip_max_bias, scale2_joint_array, scale_joint_inv2_array)
diff --git a/modules/pytorch_jacinto_ai/xnn/quantize_torch_internal/__init__.py b/modules/pytorch_jacinto_ai/xnn/quantize_torch_internal/__init__.py
new file mode 100644 (file)
index 0000000..7baf751
--- /dev/null
@@ -0,0 +1,36 @@
+#################################################################################
+# Copyright (c) 2018-2021, Texas Instruments Incorporated - http://www.ti.com
+# All Rights Reserved.
+#
+# Redistribution and use in source and binary forms, with or without
+# modification, are permitted provided that the following conditions are met:
+#
+# * Redistributions of source code must retain the above copyright notice, this
+#   list of conditions and the following disclaimer.
+#
+# * Redistributions in binary form must reproduce the above copyright notice,
+#   this list of conditions and the following disclaimer in the documentation
+#   and/or other materials provided with the distribution.
+#
+# * Neither the name of the copyright holder nor the names of its
+#   contributors may be used to endorse or promote products derived from
+#   this software without specific prior written permission.
+#
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
+# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
+# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
+# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
+# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
+# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
+# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+#
+#################################################################################
+
+from .quant_torch_eagercalib_module import *
+from .quant_torch_eagertrain_module import *
+from .quant_torch_scriptcalib_module import *
+from .quant_torch_eagerdistill_module import *
+
diff --git a/modules/pytorch_jacinto_ai/xnn/quantize_torch_internal/quant_torch_base_module.py b/modules/pytorch_jacinto_ai/xnn/quantize_torch_internal/quant_torch_base_module.py
new file mode 100644 (file)
index 0000000..f2dd64b
--- /dev/null
@@ -0,0 +1,321 @@
+#################################################################################
+# Copyright (c) 2018-2021, Texas Instruments Incorporated - http://www.ti.com
+# All Rights Reserved.
+#
+# Redistribution and use in source and binary forms, with or without
+# modification, are permitted provided that the following conditions are met:
+#
+# * Redistributions of source code must retain the above copyright notice, this
+#   list of conditions and the following disclaimer.
+#
+# * Redistributions in binary form must reproduce the above copyright notice,
+#   this list of conditions and the following disclaimer in the documentation
+#   and/or other materials provided with the distribution.
+#
+# * Neither the name of the copyright holder nor the names of its
+#   contributors may be used to endorse or promote products derived from
+#   this software without specific prior written permission.
+#
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
+# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
+# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
+# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
+# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
+# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
+# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+#
+#################################################################################
+
+###########################################################
+# Approximate quantized floating point simulation with gradients.
+# Can be used for quantized training of models.
+###########################################################
+import copy
+import torch
+
+from .. import layers
+from .. import utils
+from .quant_torch_qconfig import *
+from .quant_torch_qconfig_qat import *
+
+###########################################################
+class QuantTorchBaseModule(torch.nn.Module):
+    def __init__(self, module, dummy_input, *args,  backend='fbgemm', symmetric=True, per_channel_q=False, #'depthwise'
+                 with_fakequantize=True, is_qat=False, power2_weight_range=True, power2_activation_range=True,
+                 histogram=False, constrain_weights=False, freeze_bn=False, clamp_params=False, **kwargs):
+        super().__init__()
+        self.dummy_input = dummy_input
+        self.backend = backend
+        self.with_fakequantize = with_fakequantize
+        self.is_qat = is_qat
+        self.histogram = histogram
+        self.symmetric = symmetric
+        # check if we have set the depthwise_only mode for per_channel quantization
+        self.per_channel_q_depthwise_only = (per_channel_q == 'depthwise')
+        # the following does not include per_channel being used only for depthwise - it will be handled elsewhere
+        self.per_channel_q = (per_channel_q is True)
+        self.power2_weight_range = power2_weight_range
+        self.power2_activation_range = power2_activation_range
+        self.constrain_weights = (not per_channel_q) if constrain_weights is None else constrain_weights
+        self.freeze_bn = freeze_bn
+        self.clamp_params = clamp_params
+        self.module = module
+        self.module.quant_in = torch.quantization.QuantStub()
+        self.module.dequant_out = torch.quantization.DeQuantStub()
+
+    def fuse_model(self, inplace=True):
+        if self.is_qat or self.with_fakequantize:
+            # fuse in train mode for QAT to retain BNs
+            self.train()
+        else:
+            # fuse in eval mode to merge BNs upfront - typically used in PTQ
+            self.eval()
+        #
+        if hasattr(self.module, 'fuse_model'):
+            self.module.fuse_model()
+        else:
+            device = next(self.module.parameters()).device
+            dummy_input = self.dummy_input.to(device=device)
+            fuse_list = self._get_fuse_list(self.module, dummy_input)
+            self.module = torch.quantization.fuse_modules(self.module, fuse_list, inplace=inplace)
+        #
+        for p in self.modules():
+            for n, m in p.named_children():
+                if isinstance(m, layers.AddBlock):
+                    setattr(p, n, layers.FloatFunctionalBlock('add'))
+                elif isinstance(m, layers.MultBlock):
+                    setattr(p, n, layers.FloatFunctionalBlock('mul'))
+                elif isinstance(m, layers.CatBlock):
+                    setattr(p, n, layers.FloatFunctionalBlock('cat'))
+                #
+            #
+        #
+        return self.module
+
+    def prepare(self):
+        torch.backends.quantized.engine = self.backend
+        qconfig_func = get_custom_qconfig_qat if self.is_qat \
+            else (get_custom_qconfig_with_fakequantize if self.with_fakequantize else get_custom_qconfig)
+        qconfig_args = dict(histogram=self.histogram, symmetric=self.symmetric, per_channel=self.per_channel_q,
+                            power2_weight_range=self.power2_weight_range,
+                            power2_activation_range=self.power2_activation_range)
+        self.module.qconfig = qconfig_func(**qconfig_args)
+        if self.with_fakequantize or self.is_qat:
+            torch.quantization.prepare_qat(self.module, inplace=True)
+        else:
+            torch.quantization.prepare(self.module, inplace=True)
+        #
+        if self.per_channel_q_depthwise_only:
+            self._force_per_channel_depthwise_only(qconfig_func, qconfig_args)
+        #
+        if self.clamp_params and (not self.constrain_weights):
+            self.clamp_params_backup()
+        #
+        if self.constrain_weights:
+            self.apply_constrain_weights()
+        #
+        return
+
+    def _force_per_channel_depthwise_only(self, qconfig_func, qconfig_args):
+        for m in self.modules():
+            per_channel_modules = (torch.nn.Conv2d,
+                torch.nn.intrinsic.ConvBnReLU2d, torch.nn.intrinsic.ConvBn2d, torch.nn.intrinsic.ConvReLU2d,
+                torch.nn.intrinsic.qat.ConvBnReLU2d, torch.nn.intrinsic.qat.ConvBn2d, torch.nn.intrinsic.qat.ConvReLU2d)
+            if isinstance(m, per_channel_modules) and m.weight.size()[1] == 1 and hasattr(m, 'qconfig'):
+                qconfig_args_depthwise = copy.deepcopy(qconfig_args)
+                qconfig_args_depthwise.update(dict(per_channel=True))
+                m.qconfig = qconfig_func(**qconfig_args_depthwise)
+                if hasattr(m, 'weight_fake_quant'):
+                    m.weight_fake_quant = m.qconfig.weight()
+                #
+                if hasattr(m, 'activation_post_process'):
+                    m.activation_post_process = m.qconfig.activation()
+                #
+            #
+        #
+        return
+
+    def forward(self, inputs, *args, **kwargs):
+        # freeze batchnorms in the model. clamp_params also need this freezing
+        if self.freeze_bn or self.clamp_params:
+            self.freeze_model(freeze_bn_stats=True)
+        #
+        inputs = self.module.quant_in(inputs)
+        outputs = self.module(inputs, *args, **kwargs)
+        outputs = self.module.dequant_out(outputs)
+        # clamp the weights to a few quantization delta of the original weights for faster convergence
+        # but if constrain_weights is used, we cannot clamp as the weights are significantly modified from the original
+        if self.clamp_params and (not self.constrain_weights):
+            for n, m in self.module.named_modules():
+                self.clamp_module_with_delta(n, m)
+            #
+        #
+        return outputs
+
+    def convert(self):
+        torch.quantization.convert(self.module, inplace=True)
+
+    def _get_fuse_list(self, module, dummy_input):
+        for name, m in module.named_modules():
+            m.__track_modules_name__ = name
+        #
+        def _track_modules1(m, inp, oup):
+            prev_module = inp.__track_modules_m__[-1] if hasattr(inp, '__track_modules_m__') else None
+            if prev_module is not None:
+                if hasattr(prev_module, '__track_modules_next__'):
+                    prev_module.__track_modules_next__.append(m)
+                else:
+                    prev_module.__track_modules_next__ = [m]
+                #
+                if hasattr(m, '__track_modules_prev__'):
+                    m.__track_modules_prev__.append(prev_module)
+                else:
+                    m.__track_modules_prev__ = [prev_module]
+                #
+            #
+            if hasattr(oup, '__track_modules_m__'):
+                oup.__track_modules_m__.append(m)
+            else:
+                oup.__track_modules_m__ = [m]
+            #
+        #
+        def _track_modules(m, inp, oup):
+            inp = inp if isinstance(inp, (list,tuple)) else [inp]
+            oup = inp if isinstance(oup, (list,tuple)) else [oup]
+            for input in inp:
+                for output in oup:
+                    _track_modules1(m, input, output)
+                #
+            #
+        #
+        for m in module.modules():
+            m.__track_modules_m_hook__ = m.register_forward_hook(_track_modules)
+        #
+        module(dummy_input)
+        # analyze
+        fuse_list = []
+        for m in module.modules():
+            if isinstance(m, (torch.nn.Conv2d, torch.nn.Linear)):
+                m_next = None
+                m_next2 = None
+                if hasattr(m, '__track_modules_next__') and len(m.__track_modules_next__) == 1:
+                    m_next = m.__track_modules_next__[-1]
+                    if hasattr(m_next, '__track_modules_next__') and len(m_next.__track_modules_next__) == 1:
+                        m_next2 = m_next.__track_modules_next__[-1]
+                    #
+                #
+                if isinstance(m_next, torch.nn.BatchNorm2d) and isinstance(m_next2, (torch.nn.ReLU,torch.nn.ReLU6)):
+                    fuse_list.append([m.__track_modules_name__, m_next.__track_modules_name__, m_next2.__track_modules_name__])
+                elif isinstance(m_next, torch.nn.BatchNorm2d):
+                    fuse_list.append([m.__track_modules_name__, m_next.__track_modules_name__])
+                elif isinstance(m_next, (torch.nn.ReLU,torch.nn.ReLU6)):
+                    fuse_list.append([m.__track_modules_name__, m_next.__track_modules_name__])
+                #
+            # elif isinstance(m, layers.FloatFunctionalBlock):
+            #     if isinstance(m_next, (torch.nn.ReLU,torch.nn.ReLU6)):
+            #         fuse_list.append([m.__track_modules_name__, m_next.__track_modules_name__])
+            #     #
+            # #
+        #
+        for m in module.modules():
+            if hasattr(m, '__track_modules_m_hook__'):
+                m.__track_modules_m_hook__.remove()
+                del m.__track_modules_m_hook__
+            #
+            if hasattr(m, '__track_modules_m__'):
+                del m.__track_modules_m__
+            #
+            if hasattr(m, '__track_modules_prev__'):
+                del m.__track_modules_prev__
+            #
+            if hasattr(m, '__track_modules_next__'):
+                del m.__track_modules_next__
+            #
+            if hasattr(m, '__track_modules_name__'):
+                del m.__track_modules_name__
+            #
+        #
+        return fuse_list
+
+    def freeze_model(self, disable_observer=False, freeze_bn_stats=True):
+        if disable_observer:
+            # Freeze quantizer parameters
+            self.module.apply(torch.quantization.disable_observer)
+        #
+        if freeze_bn_stats:
+            # Freeze batch norm mean and variance estimates
+            self.module.apply(torch.nn.intrinsic.qat.freeze_bn_stats)
+        #
+
+    def apply_constrain_weights(self):
+        for n, m in self.module.named_modules():
+            if isinstance(m, (torch.nn.intrinsic.ConvBn2d, torch.nn.intrinsic.ConvBnReLU2d)):
+                running_std = torch.sqrt(m.bn.running_var + m.bn.eps)
+                scale_factor = m[1].weight / running_std
+                scaled_weight = m[0].weight * scale_factor.reshape([-1, 1, 1, 1])
+                clamped_weight = utils.constrain_weight(scaled_weight)
+                unscaled_weight = clamped_weight / scale_factor.reshape([-1, 1, 1, 1])
+                m.weight.data.copy_(unscaled_weight)
+            elif isinstance(m, (torch.nn.intrinsic.qat.ConvBn2d, torch.nn.intrinsic.qat.ConvBnReLU2d)):
+                running_std = torch.sqrt(m.bn.running_var + m.bn.eps)
+                scale_factor = m.bn.weight / running_std
+                scaled_weight = m.weight * scale_factor.reshape([-1, 1, 1, 1])
+                clamped_weight = utils.constrain_weight(scaled_weight)
+                unscaled_weight = clamped_weight / scale_factor.reshape([-1, 1, 1, 1])
+                m.weight.data.copy_(unscaled_weight)
+            elif isinstance(m, torch.nn.Conv2d):
+                clamped_weight = utils.constrain_weight(m.weight)
+                m.weight.data.copy_(clamped_weight)
+            #
+        #
+
+    def clamp_params_backup(self):
+        self.parameters_backup = dict()
+        for n, m in self.module.named_modules():
+            if isinstance(m, (torch.nn.intrinsic.ConvBn2d, torch.nn.intrinsic.ConvBnReLU2d)):
+                running_std = torch.sqrt(m.bn.running_var + m.bn.eps)
+                scale_factor = m[1].weight / running_std
+                scaled_weight = m[0].weight * scale_factor.reshape([-1, 1, 1, 1])
+                self.parameters_backup[n] = scaled_weight
+            elif isinstance(m, (torch.nn.intrinsic.qat.ConvBn2d, torch.nn.intrinsic.qat.ConvBnReLU2d)):
+                running_std = torch.sqrt(m.bn.running_var + m.bn.eps)
+                scale_factor = m.bn.weight / running_std
+                scaled_weight = m.weight * scale_factor.reshape([-1, 1, 1, 1])
+                self.parameters_backup[n] = scaled_weight
+            elif isinstance(m, torch.nn.Conv2d):
+                self.parameters_backup[n] = m.weight
+            #
+
+    def clamp_module_with_delta(self, n, m):
+        if isinstance(m, (torch.nn.intrinsic.ConvBn2d, torch.nn.intrinsic.ConvBnReLU2d)):
+            running_std = torch.sqrt(m.bn.running_var + m.bn.eps)
+            scale_factor = m[1].weight / running_std
+            scaled_weight = m[0].weight * scale_factor.reshape([-1, 1, 1, 1])
+            scaled_weight_start = self.parameters_backup[n]
+            clamped_weight = self.clamp_param_with_delta(scaled_weight, scaled_weight_start)
+            unscaled_weight = clamped_weight / scale_factor.reshape([-1, 1, 1, 1])
+            m[0].weight.data.copy_(unscaled_weight)
+        elif isinstance(m, (torch.nn.intrinsic.qat.ConvBn2d, torch.nn.intrinsic.qat.ConvBnReLU2d)):
+            running_std = torch.sqrt(m.bn.running_var + m.bn.eps)
+            scale_factor = m.bn.weight / running_std
+            scaled_weight = m.weight * scale_factor.reshape([-1, 1, 1, 1])
+            scaled_weight_start = self.parameters_backup[n]
+            clamped_weight = self.clamp_param_with_delta(scaled_weight, scaled_weight_start)
+            unscaled_weight = clamped_weight / scale_factor.reshape([-1, 1, 1, 1])
+            m.weight.data.copy_(unscaled_weight)
+        elif isinstance(m, torch.nn.Conv2d):
+            weight_start = self.parameters_backup[n]
+            clamped_weight = self.clamp_param_with_delta(m.weight, weight_start)
+            m.weight.data.copy_(clamped_weight)
+        #
+
+    def clamp_param_with_delta(self, p, p_start):
+        # clamp the weights within a few quatization delta step of the original weights
+        # weight is a signed quantity, so 1.0/128.0 is one quantization delta
+        p_max = torch.max(torch.abs(p_start.data))
+        p_delta = p_max * 2.0 / 128.0
+        p_new = torch.min(torch.max(p.data, p_start.data - p_delta), p_start.data + p_delta)
+        return p_new
\ No newline at end of file
diff --git a/modules/pytorch_jacinto_ai/xnn/quantize_torch_internal/quant_torch_eagercalib_module.py b/modules/pytorch_jacinto_ai/xnn/quantize_torch_internal/quant_torch_eagercalib_module.py
new file mode 100644 (file)
index 0000000..0f5db77
--- /dev/null
@@ -0,0 +1,56 @@
+#################################################################################
+# Copyright (c) 2018-2021, Texas Instruments Incorporated - http://www.ti.com
+# All Rights Reserved.
+#
+# Redistribution and use in source and binary forms, with or without
+# modification, are permitted provided that the following conditions are met:
+#
+# * Redistributions of source code must retain the above copyright notice, this
+#   list of conditions and the following disclaimer.
+#
+# * Redistributions in binary form must reproduce the above copyright notice,
+#   this list of conditions and the following disclaimer in the documentation
+#   and/or other materials provided with the distribution.
+#
+# * Neither the name of the copyright holder nor the names of its
+#   contributors may be used to endorse or promote products derived from
+#   this software without specific prior written permission.
+#
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
+# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
+# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
+# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
+# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
+# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
+# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+#
+#################################################################################
+
+###########################################################
+# Approximate quantized floating point simulation with gradients.
+# Can be used for quantized training of models.
+###########################################################
+
+import copy
+import torch
+
+from .quant_torch_qconfig import *
+from .quant_torch_base_module import *
+
+#warnings.filterwarnings('ignore', category=torch.jit.TracerWarning)
+
+###########################################################
+class QuantTorchEagerCalibrateModule(QuantTorchBaseModule):
+    def __init__(self, module, dummy_input, *args, histogram=True, with_fakequantize=True, **kwargs):
+        '''
+        Quantize after collecting the ranges.
+        If with_fakequantize is set, the range collection will be with quantization
+        of weights and activations - this is seen to help accuracy.
+        '''
+        super().__init__(module, dummy_input, *args,  histogram=histogram, with_fakequantize=with_fakequantize,
+                         constrain_weights=False, **kwargs)
+
+
diff --git a/modules/pytorch_jacinto_ai/xnn/quantize_torch_internal/quant_torch_eagerdistill_module.py b/modules/pytorch_jacinto_ai/xnn/quantize_torch_internal/quant_torch_eagerdistill_module.py
new file mode 100644 (file)
index 0000000..b25fc8b
--- /dev/null
@@ -0,0 +1,145 @@
+#################################################################################
+# Copyright (c) 2018-2021, Texas Instruments Incorporated - http://www.ti.com
+# All Rights Reserved.
+#
+# Redistribution and use in source and binary forms, with or without
+# modification, are permitted provided that the following conditions are