]> Gitweb @ Texas Instruments - Open Source Git Repositories - git.TI.com/gitweb - jacinto-ai/pytorch-jacinto-ai-devkit.git/commitdiff
cleanedup STE for QAT. Added RegNetX models
authorManu Mathew <a0393608@ti.com>
Thu, 6 Aug 2020 19:03:01 +0000 (00:33 +0530)
committerManu Mathew <a0393608@ti.com>
Thu, 6 Aug 2020 19:04:18 +0000 (00:34 +0530)
41 files changed:
docs/Image_Classification.md
docs/Semantic_Segmentation.md
modules/pytorch_jacinto_ai/engine/train_classification.py
modules/pytorch_jacinto_ai/engine/train_pixel2pixel.py
modules/pytorch_jacinto_ai/xnn/layers/__init__.py
modules/pytorch_jacinto_ai/xnn/layers/activation.py
modules/pytorch_jacinto_ai/xnn/layers/function.py
modules/pytorch_jacinto_ai/xnn/layers/functional.py
modules/pytorch_jacinto_ai/xnn/layers/quant_ste.py [new file with mode: 0644]
modules/pytorch_jacinto_ai/xnn/layers/resize_blocks.py
modules/pytorch_jacinto_ai/xnn/layers/rf_blocks.py
modules/pytorch_jacinto_ai/xnn/quantize/quant_base_module.py
modules/pytorch_jacinto_ai/xnn/quantize/quant_calib_module.py
modules/pytorch_jacinto_ai/xnn/quantize/quant_graph_module.py
modules/pytorch_jacinto_ai/xnn/quantize/quant_test_module.py
modules/pytorch_jacinto_ai/xnn/quantize/quant_train_module.py
modules/pytorch_jacinto_ai/xnn/utils/__init__.py
modules/pytorch_jacinto_ai/xnn/utils/count_flops.py
modules/pytorch_jacinto_ai/xnn/utils/data_utils.py [new file with mode: 0644]
modules/pytorch_jacinto_ai/xnn/utils/depth_utils.py [new file with mode: 0644]
modules/pytorch_jacinto_ai/xnn/utils/function_utils.py [new file with mode: 0644]
modules/pytorch_jacinto_ai/xnn/utils/hist_utils.py [new file with mode: 0644]
modules/pytorch_jacinto_ai/xnn/utils/load_weights.py
modules/pytorch_jacinto_ai/xnn/utils/range_utils.py [new file with mode: 0644]
modules/pytorch_jacinto_ai/xnn/utils/tensor_utils.py
modules/pytorch_jacinto_ai/xvision/models/classification/__init__.py
modules/pytorch_jacinto_ai/xvision/models/multi_input_net.py
modules/pytorch_jacinto_ai/xvision/models/pixel2pixel/deeplabv3lite.py
modules/pytorch_jacinto_ai/xvision/models/pixel2pixel/fpnlite_pixel2pixel.py
modules/pytorch_jacinto_ai/xvision/models/pixel2pixel/pixel2pixelnet_utils.py
modules/pytorch_jacinto_ai/xvision/models/pixel2pixel/unetlite_pixel2pixel.py
modules/pytorch_jacinto_ai/xvision/models/regnet.py [new file with mode: 0644]
modules/pytorch_jacinto_ai/xvision/transforms/image_transform_utils.py
modules/pytorch_jacinto_ai/xvision/transforms/image_transforms.py
modules/pytorch_jacinto_ai/xvision/transforms/transforms.py
requirements.txt
run_classification.sh
run_quantization.sh
run_segmentation.sh
scripts/train_classification_main.py
scripts/train_segmentation_main.py

index 41716ae5a193ebfa7f6e4d3bcce14a4ff1d81b52..bac12dd5deb0824231bdd2c0120b0b57941e970f 100644 (file)
     rm ./valprep.sh
     ```
 
     rm ./valprep.sh
     ```
 
-* Training can be started by the following command from the base folder of the repository:<br>
+* Training with **MobileNetV2** model can be started by the following command from the base folder of the repository:<br>
     ```
     python ./scripts/train_classification_main.py --dataset_name image_folder_classification --model_name mobilenetv2_tv_x1 --data_path ./data/datasets/image_folder_classification
     ```
 
     ```
     python ./scripts/train_classification_main.py --dataset_name image_folder_classification --model_name mobilenetv2_tv_x1 --data_path ./data/datasets/image_folder_classification
     ```
 
-* Training with ResNet50:<br>
+* Training with **ResNet50** model:<br>
     ```
     python ./scripts/train_classification_main.py --dataset_name image_folder_classification --model_name resnet50_x1 --data_path ./data/datasets/image_folder_classification
     ```
 
     ```
     python ./scripts/train_classification_main.py --dataset_name image_folder_classification --model_name resnet50_x1 --data_path ./data/datasets/image_folder_classification
     ```
 
+* Training with **RegNet800MF model and BGR image input transform**:<br>
+    ```python ./scripts/train_classification_main.py --dataset_name image_folder_classification --model_name regnetx800mf_x1 --data_path ./data/datasets/image_folder_classification --input_channel_reverse True --image_mean 103.53 116.28 123.675 --image_scale 0.017429 0.017507 0.017125
+    ```
+
 * If the dataset is in a different location, it can be specified by the --data_path option, but dataset_name must be *image_folder_classification* for folder based classification.
 
 
 * If the dataset is in a different location, it can be specified by the --data_path option, but dataset_name must be *image_folder_classification* for folder based classification.
 
 
 
 * ImageNet classification results are as follows:
 
 
 * ImageNet classification results are as follows:
 
-|Dataset  |Mode Name     |Resize Resolution|Crop Resolution|Complexity (GigaMACS)|Top1 Accuracy% |Model Configuration Name|
-|---------|----------    |-----------      |----------     |--------             |--------       |------------------------|
-|ImageNet |MobileNetV1   |256x256          |224x224        |0.568                |**71.83**      |mobilenetv1_x1          |
-|ImageNet |MobileNetV2   |256x256          |224x224        |0.296                |**72.13**      |mobilenetv2_tv_x1       |
-|ImageNet |ResNet50-0.5  |256x256          |224x224        |1.051                |**72.05**      |resnet50_xp5            |
+|Dataset  |Mode Name          |Resize Resolution|Crop Resolution|Complexity (GigaMACS)|Top1 Accuracy% |Model Configuration Name|
+|---------|----------         |-----------      |----------     |--------             |--------       |------------------------|
+|ImageNet |MobileNetV1        |256x256          |224x224        |0.568                |**71.83**      |mobilenetv1_x1          |
+|ImageNet |MobileNetV2        |256x256          |224x224        |0.296                |**72.13**      |mobilenetv2_tv_x1       |
+|ImageNet |ResNet50-0.5       |256x256          |224x224        |1.051                |**72.05**      |resnet50_xp5            |
+|ImageNet |**RegNetX800MF**   |256x256          |224x224        |0.800                |               |regnetx800mf_x1         |
 |.
 |.
-|ImageNet |MobileNetV1[1]|256x256          |224x224        |0.569                |70.60          |                        |
-|ImageNet |MobileNetV2[2]|256x256          |224x224        |0.300                |72.00          |                        |
-|ImageNet |ResNet50[3]   |256x256          |224x224        |4.087                |76.15          |                        |
+|ImageNet |MobileNetV1[1]     |256x256          |224x224        |0.569                |70.60          |                        |
+|ImageNet |MobileNetV2[2]     |256x256          |224x224        |0.300                |72.00          |                        |
+|ImageNet |ResNet50[3]        |256x256          |224x224        |4.087                |76.15          |                        |
+|ImageNet |**RegNetX800MF**[4]|256x256          |224x224        |0.800                |**75.2**       |                        |
+|ImageNet |RegNetX1.6F[4]     |256x256          |224x224        |1.6                  |**77.0**       |                        |
+
+
+Notes:
+- As can be seen from the table, the models included in this repository provide a good Accuracy/Complexity tradeoff. 
+- However, the Complexity (in GigaMACS) does not always indicate the speed of inference on an embedded device. We have to also consider the fact that regular convolutions and Grouped convolutions are typically more efficient in utilizing the available compute resources (as they have more compute per data trasnsfer) compared to Depthwise convolutions.
+- Hence, although the MobileNetV2 based models may have less GigaMACS compared to the RegNetX models, the RegNetX based models may not be slower in practice. Overall, RegNetX based models are highly recommend as they strike a good balance between Complexity (GigaMACS), Compute Efficiency (more compute per data transfer) and easiness of Quantization.
 
 
 ## Referrences
 
 [1] MobileNets: Efficient Convolutional Neural Networks for Mobile Vision Applications, Howard AG, Zhu M, Chen B, Kalenichenko D, Wang W, Weyand T, Andreetto M, Adam H, arXiv:1704.04861, 2017
 
 
 ## Referrences
 
 [1] MobileNets: Efficient Convolutional Neural Networks for Mobile Vision Applications, Howard AG, Zhu M, Chen B, Kalenichenko D, Wang W, Weyand T, Andreetto M, Adam H, arXiv:1704.04861, 2017
+
 [2] MobileNetV2: Inverted Residuals and Linear Bottlenecks, Sandler M, Howard A, Zhu M, Zhmoginov A, Chen LC. arXiv preprint. arXiv:1801.04381, 2018.
 [2] MobileNetV2: Inverted Residuals and Linear Bottlenecks, Sandler M, Howard A, Zhu M, Zhmoginov A, Chen LC. arXiv preprint. arXiv:1801.04381, 2018.
-[3] PyTorch TorchVision Model Zoo: https://pytorch.org/docs/stable/torchvision/models.html
\ No newline at end of file
+
+[3] PyTorch TorchVision Model Zoo: https://pytorch.org/docs/stable/torchvision/models.html
+
+[4] Designing Network Design Spaces, Ilija Radosavovic Raj Prateek Kosaraju Ross Girshick Kaiming He Piotr DollarĀ“, Facebook AI Research (FAIR), https://arxiv.org/pdf/2003.13678.pdf, https://github.com/facebookresearch/pycls
index 3f83f86715bae9d1cb2f34efc0f7e3bcb912a117..c0dda137ec62df4b7baa2c881ce150506812391e 100644 (file)
@@ -45,17 +45,22 @@ Whether to use multiple inputs or how many decoders to use are fully configurabl
 ## Training
 * These examples use two gpus because we use slightly higher accuracy when we restricted the number of GPUs used. 
 
 ## Training
 * These examples use two gpus because we use slightly higher accuracy when we restricted the number of GPUs used. 
 
-* **Cityscapes Segmentation Training** can be done as follows:<br>
+* **Cityscapes Segmentation Training** with MobileNetV2 backbone and DeeplabV3Lite decoder can be done as follows:<br>
     ```
     python ./scripts/train_segmentation_main.py --model_name deeplabv3lite_mobilenetv2_tv --dataset_name cityscapes_segmentation --data_path ./data/datasets/cityscapes/data --img_resize 384 768 --output_size 1024 2048 --pretrained https://download.pytorch.org/models/mobilenet_v2-b0353104.pth --gpus 0 1
   ```
     ```
     python ./scripts/train_segmentation_main.py --model_name deeplabv3lite_mobilenetv2_tv --dataset_name cityscapes_segmentation --data_path ./data/datasets/cityscapes/data --img_resize 384 768 --output_size 1024 2048 --pretrained https://download.pytorch.org/models/mobilenet_v2-b0353104.pth --gpus 0 1
   ```
-  
-  * It is possible to use a different image size. For example, we trained for 1536x768 resolution by the following. (We used a smaller crop size compared to the image resize resolution to reduce GPU memory usage). <br>
+
+* Cityscapes Segmentation Training with **RegNet800MF backbone and FPN decoder** can be done as follows:<br>
+    ```
+    python ./scripts/train_segmentation_main.py --dataset_name cityscapes_segmentation --model_name fpnlite_pixel2pixel_aspp_regnetx800mf --data_path ./data/datasets/cityscapes/data --img_resize 384 768 --output_size 1024 2048 --gpus 0 1 --pretrained https://dl.fbaipublicfiles.com/pycls/dds_baselines/160906036/RegNetX-800MF_dds_8gpu.pyth
+  ```
+
+  * It is possible to use a **different image size**. For example, we trained for 1536x768 resolution by the following. (We used a smaller crop size compared to the image resize resolution to reduce GPU memory usage). <br>
     ```
     python ./scripts/train_segmentation_main.py --model_name deeplabv3lite_mobilenetv2_tv --dataset_name cityscapes_segmentation --data_path ./data/datasets/cityscapes/data --img_resize 768 1536 --rand_crop 512 1024 --output_size 1024 2048 --pretrained https://download.pytorch.org/models/mobilenet_v2-b0353104.pth --gpus 0 1
     ```
 
     ```
     python ./scripts/train_segmentation_main.py --model_name deeplabv3lite_mobilenetv2_tv --dataset_name cityscapes_segmentation --data_path ./data/datasets/cityscapes/data --img_resize 768 1536 --rand_crop 512 1024 --output_size 1024 2048 --pretrained https://download.pytorch.org/models/mobilenet_v2-b0353104.pth --gpus 0 1
     ```
 
-* Train FPNPixel2Pixel model at 1536x768 resolution (use 1024x512 crop to reduce memory usage):<br>
+* Train **FPNPixel2Pixel model at 1536x768 resolution** (use 1024x512 crop to reduce memory usage):<br>
     ```
     python ./scripts/train_segmentation_main.py --model_name fpnlite_pixel2pixel_aspp_mobilenetv2_tv --dataset_name cityscapes_segmentation --data_path ./data/datasets/cityscapes/data --img_resize 768 1536 --rand_crop 512 1024 --output_size 1024 2048 --pretrained https://download.pytorch.org/models/mobilenet_v2-b0353104.pth --gpus 0 1
     ```
     ```
     python ./scripts/train_segmentation_main.py --model_name fpnlite_pixel2pixel_aspp_mobilenetv2_tv --dataset_name cityscapes_segmentation --data_path ./data/datasets/cityscapes/data --img_resize 768 1536 --rand_crop 512 1024 --output_size 1024 2048 --pretrained https://download.pytorch.org/models/mobilenet_v2-b0353104.pth --gpus 0 1
     ```
@@ -95,18 +100,29 @@ Inference can be done as follows (fill in the path to the pretrained model):<br>
 |Cityscapes |FPNLitePixel2Pixel with DWASPP|MobileNetV2    |32             |1536x768   |**15.37**            |**74.98** |**fpnlite_pixel2pixel_aspp_mobilenetv2_tv**   |
 |Cityscapes |FPNLitePixel2Pixel with DWASPP|FD-ResNet50    |64             |1536x768   |30.91                |-         |fpnlite_pixel2pixel_aspp_resnet50_fd          |
 |Cityscapes |FPNLitePixel2Pixel with DWASPP|ResNet50       |32             |1536x768   |114.42               |-         |fpnlite_pixel2pixel_aspp_resnet50             |
 |Cityscapes |FPNLitePixel2Pixel with DWASPP|MobileNetV2    |32             |1536x768   |**15.37**            |**74.98** |**fpnlite_pixel2pixel_aspp_mobilenetv2_tv**   |
 |Cityscapes |FPNLitePixel2Pixel with DWASPP|FD-ResNet50    |64             |1536x768   |30.91                |-         |fpnlite_pixel2pixel_aspp_resnet50_fd          |
 |Cityscapes |FPNLitePixel2Pixel with DWASPP|ResNet50       |32             |1536x768   |114.42               |-         |fpnlite_pixel2pixel_aspp_resnet50             |
+|.
+|Cityscapes |DeepLabV3Lite GroupedConvASPP |RegNet800MF [9]|32             |768x384    |**11.19**            |**70.22** |**deeplav3lite_pixel2pixel_aspp_regnetx800mf**|
+|Cityscapes |DeepLabV3Lite GroupedConvASPP |RegNet800MF [9]|32             |768x384    |**7.29*              |          |**fpnlite_pixel2pixel_aspp_regnetx800mf**     |
+|Cityscapes |DeepLabV3Lite GroupedConvASPP |RegNet800MF [9]|32             |768x384    |**6.09**             |          |**unetlite_pixel2pixel_aspp_regnetx800mf**    |
 
 
-|Dataset    |Mode Architecture         |Backbone Model |Backbone Stride|Resolution |Complexity (GigaMACS)|MeanIoU%  |Model Configuration Name                  |
-|---------  |----------                |-----------    |-------------- |-----------|--------             |----------|----------------------------------------  |
-|Cityscapes |ERFNet[[4]]               |-              |-              |1024x512   |27.705               |69.7      |-                                         |
-|Cityscapes |SwiftNetMNV2[[5]]         |MobileNetV2    |-              |2048x1024  |41.0                 |75.3      |-                                         |
-|Cityscapes |DeepLabV3Plus[[6,7]]      |MobileNetV2    |16             |           |21.27                |70.71     |-                                         |
-|Cityscapes |DeepLabV3Plus[[6,7]]      |Xception65     |16             |           |418.64               |78.79     |-                                         |
+
+For comparison, here we list a few models from the literature:
+
+|Dataset    |Mode Architecture             |Backbone Model |Backbone Stride|Resolution |Complexity (GigaMACS)|MeanIoU%  |Model Configuration Name                     |
+|---------  |----------                    |-----------    |-------------- |-----------|--------             |----------|------------------------|
+|Cityscapes |ERFNet [4]                    |-              |-              |1024x512   |27.705               |69.7      |-                       |
+|Cityscapes |SwiftNetMNV2 [5]              |MobileNetV2    |-              |2048x1024  |41.0                 |75.3      |-                       |
+|Cityscapes |DeepLabV3Plus [6,7]           |MobileNetV2    |16             |           |21.27                |70.71     |-                       |
+|Cityscapes |DeepLabV3Plus [6,7]           |Xception65     |16             |           |418.64               |78.79     |-                       |
 
 Notes:
 - The suffix **'Lite'** in the model names indicates complexity optimizations in the Decoder part of the model - especially the use of DepthWise Separable Convolutions instead of regular convolutions.
 - (\*2\*2) in the above table represents two additional Depthwise Separable Convolutions with strides (at the end of the backbone encoder). 
 - FD-MobileNetV2 Backbone uses a stride of 64 (this is used in some rows of the above table) and is achieved by Fast Downsampling Strategy [8]
 
 Notes:
 - The suffix **'Lite'** in the model names indicates complexity optimizations in the Decoder part of the model - especially the use of DepthWise Separable Convolutions instead of regular convolutions.
 - (\*2\*2) in the above table represents two additional Depthwise Separable Convolutions with strides (at the end of the backbone encoder). 
 - FD-MobileNetV2 Backbone uses a stride of 64 (this is used in some rows of the above table) and is achieved by Fast Downsampling Strategy [8]
+- As can be seen from the table, the models included in this repository provide a good Accuracy/Complexity tradeoff. 
+- However, the Complexity (in GigaMACS) does not always indicate the speed of inference on an embedded device. We have to also consider the fact that regular convolutions and Grouped convolutions are typically more efficient in utilizing the available compute resources (as they have more compute per data trasnsfer) compared to Depthwise convolutions.
+- Hence, although the MobileNetV2 based models may have less GigaMACS compared to the RegNetX models, the RegNetX based models may not be slower in practice. Overall, RegNetX based models are highly recommend as they strike a good balance between Complexity (GigaMACS), Compute Efficiency (more compute per data transfer) and easiness of Quantization.
+
 
 ## References
 [1]The Cityscapes Dataset for Semantic Urban Scene Understanding, Marius Cordts, Mohamed Omran, Sebastian Ramos, Timo Rehfeld, Markus Enzweiler, Rodrigo Benenson, Uwe Franke, Stefan Roth, Bernt Schiele, CVPR 2016, https://www.cityscapes-dataset.com/
 
 ## References
 [1]The Cityscapes Dataset for Semantic Urban Scene Understanding, Marius Cordts, Mohamed Omran, Sebastian Ramos, Timo Rehfeld, Markus Enzweiler, Rodrigo Benenson, Uwe Franke, Stefan Roth, Bernt Schiele, CVPR 2016, https://www.cityscapes-dataset.com/
@@ -127,4 +143,6 @@ International Journal of Computer Vision, 88(2), 303-338, 2010, http://host.robo
 
 [8] FD-MobileNet: Improved MobileNet with a Fast Downsampling Strategy, Zheng Qin, Zhaoning Zhang, Xiaotao Chen, Yuxing Peng - https://arxiv.org/abs/1802.03750)
 
 
 [8] FD-MobileNet: Improved MobileNet with a Fast Downsampling Strategy, Zheng Qin, Zhaoning Zhang, Xiaotao Chen, Yuxing Peng - https://arxiv.org/abs/1802.03750)
 
+[9] Designing Network Design Spaces, Ilija Radosavovic Raj Prateek Kosaraju Ross Girshick Kaiming He Piotr DollarĀ“, Facebook AI Research (FAIR), https://arxiv.org/pdf/2003.13678.pdf, https://github.com/facebookresearch/pycls
+
 
 
index 06d418bb3460fb2c345a72f5b54c2ed105e6d9bd..bc208a4f6e4d0d88b69ed9cd589b898713d2d7bf 100644 (file)
@@ -35,9 +35,9 @@ def get_config():
     args.model_config.num_tiles_x = int(1)
     args.model_config.num_tiles_y = int(1)
     args.model_config.en_make_divisible_by8 = True
     args.model_config.num_tiles_x = int(1)
     args.model_config.num_tiles_y = int(1)
     args.model_config.en_make_divisible_by8 = True
-
     args.model_config.input_channels = 3                # num input channels
 
     args.model_config.input_channels = 3                # num input channels
 
+    args.input_channel_reverse = False                  # rgb to bgr
     args.data_path = './data/datasets/ilsvrc'           # path to dataset
     args.model_name = 'mobilenetv2_tv_x1'     # model architecture'
     args.model = None                                   #if mdoel is crated externaly 
     args.data_path = './data/datasets/ilsvrc'           # path to dataset
     args.model_name = 'mobilenetv2_tv_x1'     # model architecture'
     args.model = None                                   #if mdoel is crated externaly 
@@ -265,7 +265,12 @@ def main(args):
 
     # load pretrained
     if pretrained_data is not None and not is_onnx_model:
 
     # load pretrained
     if pretrained_data is not None and not is_onnx_model:
-        xnn.utils.load_weights(get_model_orig(model), pretrained=pretrained_data, change_names_dict=change_names_dict)
+        model_orig = get_model_orig(model)
+        if hasattr(model_orig, 'load_weights'):
+            model_orig.load_weights(pretrained=pretrained_data, change_names_dict=change_names_dict)
+        else:
+            xnn.utils.load_weights(model_orig, pretrained=pretrained_data, change_names_dict=change_names_dict)
+        #
     #
     
     #################################################
     #
     
     #################################################
@@ -751,10 +756,12 @@ def get_train_transform(args):
     normalize = xvision.transforms.NormalizeMeanScale(mean=args.image_mean, scale=args.image_scale) \
         if (args.image_mean is not None and args.image_scale is not None) else None
     multi_color_transform = xvision.transforms.MultiColor(args.multi_color_modes) if (args.multi_color_modes is not None) else None
     normalize = xvision.transforms.NormalizeMeanScale(mean=args.image_mean, scale=args.image_scale) \
         if (args.image_mean is not None and args.image_scale is not None) else None
     multi_color_transform = xvision.transforms.MultiColor(args.multi_color_modes) if (args.multi_color_modes is not None) else None
+    reverse_channels = xvision.transforms.ReverseChannels() if args.input_channel_reverse else None
 
     train_resize_crop_transform = xvision.transforms.RandomResizedCrop(size=args.img_crop, scale=args.rand_scale) \
         if args.rand_scale else xvision.transforms.RandomCrop(size=args.img_crop)
 
     train_resize_crop_transform = xvision.transforms.RandomResizedCrop(size=args.img_crop, scale=args.rand_scale) \
         if args.rand_scale else xvision.transforms.RandomCrop(size=args.img_crop)
-    train_transform = xvision.transforms.Compose([train_resize_crop_transform,
+    train_transform = xvision.transforms.Compose([reverse_channels,
+                                                 train_resize_crop_transform,
                                                  xvision.transforms.RandomHorizontalFlip(),
                                                  multi_color_transform,
                                                  xvision.transforms.ToFloat(),
                                                  xvision.transforms.RandomHorizontalFlip(),
                                                  multi_color_transform,
                                                  xvision.transforms.ToFloat(),
@@ -766,10 +773,12 @@ def get_validation_transform(args):
     normalize = xvision.transforms.NormalizeMeanScale(mean=args.image_mean, scale=args.image_scale) \
         if (args.image_mean is not None and args.image_scale is not None) else None
     multi_color_transform = xvision.transforms.MultiColor(args.multi_color_modes) if (args.multi_color_modes is not None) else None
     normalize = xvision.transforms.NormalizeMeanScale(mean=args.image_mean, scale=args.image_scale) \
         if (args.image_mean is not None and args.image_scale is not None) else None
     multi_color_transform = xvision.transforms.MultiColor(args.multi_color_modes) if (args.multi_color_modes is not None) else None
+    reverse_channels = xvision.transforms.ReverseChannels() if args.input_channel_reverse else None
 
     # pass tuple to Resize() to resize to exact size without respecting aspect ratio (typical caffe style)
     val_resize_crop_transform = xvision.transforms.Resize(size=args.img_resize) if args.img_resize else xvision.transforms.Bypass()
 
     # pass tuple to Resize() to resize to exact size without respecting aspect ratio (typical caffe style)
     val_resize_crop_transform = xvision.transforms.Resize(size=args.img_resize) if args.img_resize else xvision.transforms.Bypass()
-    val_transform = xvision.transforms.Compose([val_resize_crop_transform,
+    val_transform = xvision.transforms.Compose([reverse_channels,
+                                               val_resize_crop_transform,
                                                xvision.transforms.CenterCrop(size=args.img_crop),
                                                multi_color_transform,
                                                xvision.transforms.ToFloat(),
                                                xvision.transforms.CenterCrop(size=args.img_crop),
                                                multi_color_transform,
                                                xvision.transforms.ToFloat(),
index d87a850117a12ae66fd829e2752a2d2fd95b1945..bd031238d63a9297672ed19c622517d9a05d6314 100644 (file)
@@ -57,6 +57,7 @@ def get_config():
     args.model_name = 'deeplabv2lite_mobilenetv2'       # model architecture, overwritten if pretrained is specified
     args.dataset_name = 'cityscapes_segmentation'       # dataset type
     args.transforms = None                              # the transforms itself can be given from outside
     args.model_name = 'deeplabv2lite_mobilenetv2'       # model architecture, overwritten if pretrained is specified
     args.dataset_name = 'cityscapes_segmentation'       # dataset type
     args.transforms = None                              # the transforms itself can be given from outside
+    args.input_channel_reverse = False                  # reverse input channels, for example RGB to BGR
 
     args.data_path = './data/cityscapes'                # 'path to dataset'
     args.save_path = None                               # checkpoints save path
 
     args.data_path = './data/cityscapes'                # 'path to dataset'
     args.save_path = None                               # checkpoints save path
@@ -376,9 +377,15 @@ def main(args):
 
     # load pretrained model
     if pretrained_data is not None and not is_onnx_model:
 
     # load pretrained model
     if pretrained_data is not None and not is_onnx_model:
+        model_orig = get_model_orig(model)
         for (p_data,p_file) in zip(pretrained_data, pretrained_files):
             print("=> using pretrained weights from: {}".format(p_file))
         for (p_data,p_file) in zip(pretrained_data, pretrained_files):
             print("=> using pretrained weights from: {}".format(p_file))
-            xnn.utils.load_weights(get_model_orig(model), pretrained=p_data, change_names_dict=change_names_dict)
+            if hasattr(model_orig, 'load_weights'):
+                model_orig.load_weights(pretrained=p_data, change_names_dict=change_names_dict)
+            else:
+                xnn.utils.load_weights(get_model_orig(model), pretrained=p_data, change_names_dict=change_names_dict)
+            #
+        #
     #
 
     #################################################
     #
 
     #################################################
@@ -1129,11 +1136,13 @@ def get_train_transform(args):
     image_scale = np.array(args.image_scale, dtype=np.float32)
     image_prenorm = xvision.transforms.image_transforms.NormalizeMeanScale(mean=image_mean, scale=image_scale) if args.image_prenorm else None
     image_postnorm = xvision.transforms.image_transforms.NormalizeMeanScale(mean=image_mean, scale=image_scale) if (not image_prenorm) else None
     image_scale = np.array(args.image_scale, dtype=np.float32)
     image_prenorm = xvision.transforms.image_transforms.NormalizeMeanScale(mean=image_mean, scale=image_scale) if args.image_prenorm else None
     image_postnorm = xvision.transforms.image_transforms.NormalizeMeanScale(mean=image_mean, scale=image_scale) if (not image_prenorm) else None
+    reverse_channels = xvision.transforms.image_transforms.ReverseImageChannels() if args.input_channel_reverse else None
 
     # crop size used only for training
     image_train_output_scaling = xvision.transforms.image_transforms.Scale(args.rand_resize, target_size=args.rand_output_size, is_flow=args.is_flow) \
         if (args.rand_output_size is not None and args.rand_output_size != args.rand_resize) else None
     train_transform = xvision.transforms.image_transforms.Compose([
 
     # crop size used only for training
     image_train_output_scaling = xvision.transforms.image_transforms.Scale(args.rand_resize, target_size=args.rand_output_size, is_flow=args.is_flow) \
         if (args.rand_output_size is not None and args.rand_output_size != args.rand_resize) else None
     train_transform = xvision.transforms.image_transforms.Compose([
+        reverse_channels,
         image_prenorm,
         xvision.transforms.image_transforms.AlignImages(),
         xvision.transforms.image_transforms.MaskTarget(args.target_mask, 0),
         image_prenorm,
         xvision.transforms.image_transforms.AlignImages(),
         xvision.transforms.image_transforms.MaskTarget(args.target_mask, 0),
@@ -1156,9 +1165,11 @@ def get_validation_transform(args):
     image_scale = np.array(args.image_scale, dtype=np.float32)
     image_prenorm = xvision.transforms.image_transforms.NormalizeMeanScale(mean=image_mean, scale=image_scale) if args.image_prenorm else None
     image_postnorm = xvision.transforms.image_transforms.NormalizeMeanScale(mean=image_mean, scale=image_scale) if (not image_prenorm) else None
     image_scale = np.array(args.image_scale, dtype=np.float32)
     image_prenorm = xvision.transforms.image_transforms.NormalizeMeanScale(mean=image_mean, scale=image_scale) if args.image_prenorm else None
     image_postnorm = xvision.transforms.image_transforms.NormalizeMeanScale(mean=image_mean, scale=image_scale) if (not image_prenorm) else None
+    reverse_channels = xvision.transforms.image_transforms.ReverseImageChannels() if args.input_channel_reverse else None
 
     # prediction is resized to output_size before evaluation.
     val_transform = xvision.transforms.image_transforms.Compose([
 
     # prediction is resized to output_size before evaluation.
     val_transform = xvision.transforms.image_transforms.Compose([
+        reverse_channels,
         image_prenorm,
         xvision.transforms.image_transforms.AlignImages(),
         xvision.transforms.image_transforms.MaskTarget(args.target_mask, 0),
         image_prenorm,
         xvision.transforms.image_transforms.AlignImages(),
         xvision.transforms.image_transforms.MaskTarget(args.target_mask, 0),
index a623edd0a1c4d01a6c3febfb5cda2ff9049f774c..464dbb0e46f7001c749f4bd6787ac4062e945fa6 100644 (file)
@@ -13,6 +13,8 @@ from .rf_blocks import *
 
 from .model_utils import *
 
 
 from .model_utils import *
 
+from .quant_ste import *
+
 # optional/experimental
 try:
     from .blocks_internal import *
 # optional/experimental
 try:
     from .blocks_internal import *
index fafb6af71b68430bc983883cfaae1026e9df2373..d89e6eea496517cbc2c320905e891abf2ba20b15 100644 (file)
@@ -15,7 +15,7 @@ class PAct2(torch.nn.Module):
     PACT2_RANGE_INIT = 8.0      # this is the starting range
     PACT2_RANGE_EXPANSION = 1.0 # expand the calculated range for margin
 
     PACT2_RANGE_INIT = 8.0      # this is the starting range
     PACT2_RANGE_EXPANSION = 1.0 # expand the calculated range for margin
 
-    def __init__(self, inplace=False, signed=None, percentile_range_shrink=PACT2_RANGE_SHRINK, clip_range=None,
+    def __init__(self, inplace=False, signed=None, range_shrink_percentile=PACT2_RANGE_SHRINK, clip_range=None,
                  power2_activation_range=True, **kwargs):
         super().__init__()
         if (clip_range is not None) and (signed is not None):
                  power2_activation_range=True, **kwargs):
         super().__init__()
         if (clip_range is not None) and (signed is not None):
@@ -24,12 +24,13 @@ class PAct2(torch.nn.Module):
         self.inplace = inplace
         self.clip_range = clip_range
         self.signed = signed if (clip_range is None) else (clip_range[0]<0.0)
         self.inplace = inplace
         self.clip_range = clip_range
         self.signed = signed if (clip_range is None) else (clip_range[0]<0.0)
-        self.percentile_range_shrink = percentile_range_shrink # range shrinking factor
+        self.range_shrink_percentile = range_shrink_percentile # range shrinking factor
         self.fixed_range = (clip_range is not None)
         self.learn_range = (self.PACT2_RANGE_LEARN and (not self.fixed_range))
         self.eps = np.power(2.0, -16.0)
         self.power2_activation_range = power2_activation_range   # power of 2 ranges
         self.log_base = None # 2.0  # log is used only in learned mode if log_base is not None
         self.fixed_range = (clip_range is not None)
         self.learn_range = (self.PACT2_RANGE_LEARN and (not self.fixed_range))
         self.eps = np.power(2.0, -16.0)
         self.power2_activation_range = power2_activation_range   # power of 2 ranges
         self.log_base = None # 2.0  # log is used only in learned mode if log_base is not None
+        self.range_estimator = None
 
         # any validation before at-least one iteration of training wll use default clip values.
         clip_init = max(abs(np.array(clip_range))) if (clip_range is not None) else self.PACT2_RANGE_INIT
 
         # any validation before at-least one iteration of training wll use default clip values.
         clip_init = max(abs(np.array(clip_range))) if (clip_range is not None) else self.PACT2_RANGE_INIT
@@ -39,19 +40,24 @@ class PAct2(torch.nn.Module):
             clip_signed_log = self.convert_to_log(torch.tensor(clip_init2))
             default_clips = (-clip_signed_log, clip_signed_log) \
                 if (self.signed == True or self.signed is None) else (0.0, clip_signed_log)
             clip_signed_log = self.convert_to_log(torch.tensor(clip_init2))
             default_clips = (-clip_signed_log, clip_signed_log) \
                 if (self.signed == True or self.signed is None) else (0.0, clip_signed_log)
-            self.register_parameter('clips_act', torch.nn.Parameter(torch.tensor(default_clips)))
+            self.register_parameter('clips_act', torch.nn.Parameter(torch.tensor(default_clips, dtype=torch.float32)))
             # Initially ranges will be dominated by running average, but eventually the update factor becomes too small.
             # Then the backprop updates will have dominance.
             self.range_update_factor_min = 0.0
             # Initially ranges will be dominated by running average, but eventually the update factor becomes too small.
             # Then the backprop updates will have dominance.
             self.range_update_factor_min = 0.0
-            self.register_buffer('num_batches_tracked', torch.tensor(-1.0))
+            self.register_buffer('num_batches_tracked', torch.tensor(-1.0, dtype=torch.float32))
         else:
             default_clips = (-clip_init2, clip_init2) \
                 if (self.signed == True or self.signed is None) else (0.0, clip_init2)
         else:
             default_clips = (-clip_init2, clip_init2) \
                 if (self.signed == True or self.signed is None) else (0.0, clip_init2)
-            self.register_buffer('clips_act', torch.tensor(default_clips))
+            self.register_buffer('clips_act', torch.tensor(default_clips, dtype=torch.float32))
             # range_update_factor_min is the lower bound for exponential update factor.
             # using 0.0 will freeze the ranges, since the update_factor becomes too small after some time
             self.range_update_factor_min = 0.001
             # range_update_factor_min is the lower bound for exponential update factor.
             # using 0.0 will freeze the ranges, since the update_factor becomes too small after some time
             self.range_update_factor_min = 0.001
-            self.register_buffer('num_batches_tracked', torch.tensor(-1.0))
+            self.register_buffer('num_batches_tracked', torch.tensor(-1.0, dtype=torch.float32))
+
+            if utils.has_range_estimator:
+                self.range_estimator = utils.RangeEstimator(range_shrink_percentile=range_shrink_percentile,
+                                                            range_update_factor_min=self.range_update_factor_min)
+            #
         #
 
 
         #
 
 
@@ -61,7 +67,7 @@ class PAct2(torch.nn.Module):
             # even in learn_range mode - do this for a few iterations to get a good starting point
             if (not self.fixed_range) and ((not self.learn_range) or (self.num_batches_tracked < 100)):
                 with torch.no_grad():
             # even in learn_range mode - do this for a few iterations to get a good starting point
             if (not self.fixed_range) and ((not self.learn_range) or (self.num_batches_tracked < 100)):
                 with torch.no_grad():
-                    self.update_scale_act(x.data)
+                    self.update_clips_act(x.data)
                 #
             #
         #
                 #
             #
         #
@@ -94,16 +100,21 @@ class PAct2(torch.nn.Module):
         return utils.signed_pow(x, self.log_base)
 
 
         return utils.signed_pow(x, self.log_base)
 
 
-    def update_scale_act(self, x):
-        # compute the new scale
-        x_min, x_max = utils.extrema_fast(x, percentile_range_shrink=self.percentile_range_shrink)
-        x_min, x_max = self.convert_to_log(x_min*self.PACT2_RANGE_EXPANSION), self.convert_to_log(x_max*self.PACT2_RANGE_EXPANSION)
-        # exponential update factor
-        update_factor = 1.0 / float(self.num_batches_tracked if self.num_batches_tracked else 1.0)
-        update_factor = max(update_factor, self.range_update_factor_min)
-        # exponential moving average update
-        self.clips_act[0].data.mul_(1.0-update_factor).add_(x_min * update_factor)
-        self.clips_act[1].data.mul_(1.0-update_factor).add_(x_max * update_factor)
+    def update_clips_act(self, x):
+        if self.learn_range or (self.range_estimator is None):
+            x_min, x_max = utils.extrema_fast(x, range_shrink_percentile=self.range_shrink_percentile)
+            x_min, x_max = self.convert_to_log(x_min*self.PACT2_RANGE_EXPANSION), self.convert_to_log(x_max*self.PACT2_RANGE_EXPANSION)
+            # exponential update factor
+            update_factor = 1.0 / float(self.num_batches_tracked if self.num_batches_tracked else 1.0)
+            update_factor = max(update_factor, self.range_update_factor_min)
+            # exponential moving average update
+            self.clips_act[0].data.mul_(1.0-update_factor).add_(x_min * update_factor)
+            self.clips_act[1].data.mul_(1.0-update_factor).add_(x_max * update_factor)
+        else:
+            mn, mx = self.range_estimator(x)
+            self.clips_act[0].data.fill_(mn)
+            self.clips_act[1].data.fill_(mx)
+        #
 
 
     def get_clips_act(self):
 
 
     def get_clips_act(self):
index 509d67bfaeee303eb75ab7918b06e90f2a940f50..97095ee4afffc5e5e8050ebfa3a684cddb725aaa 100644 (file)
@@ -1,6 +1,7 @@
 import torch
 import numpy as np
 
 import torch
 import numpy as np
 
+###################################################
 # round with gradient - assumed to have unit gradient.
 # note: this rounds towards even (bankers round / unbiased round - like numpy / pytorch).
 class RoundG(torch.autograd.Function):
 # round with gradient - assumed to have unit gradient.
 # note: this rounds towards even (bankers round / unbiased round - like numpy / pytorch).
 class RoundG(torch.autograd.Function):
@@ -124,21 +125,23 @@ class Floor2G(torch.autograd.Function):
         return g.op("Floor2G", x)
 
 
         return g.op("Floor2G", x)
 
 
-# quantize - use power2 only in forward. use numerical gradient as analytical expression is difficult.
 class QuantizeDequantizeG(torch.autograd.Function):
     @staticmethod
     def forward(ctx, x, scale_tensor, width_min, width_max, power2, round_type='round_up'):
 class QuantizeDequantizeG(torch.autograd.Function):
     @staticmethod
     def forward(ctx, x, scale_tensor, width_min, width_max, power2, round_type='round_up'):
+        # apply quantization
         y, x_scaled_round = QuantizeDequantizeG.quant_dequant(ctx, x, scale_tensor, width_min, width_max, power2, round_type)
         # save for backward
         ctx.save_for_backward(x, scale_tensor, x_scaled_round)
         ctx.width_min, ctx.width_max, ctx.power2, ctx.round_type = width_min, width_max, power2, round_type
         return y
 
         y, x_scaled_round = QuantizeDequantizeG.quant_dequant(ctx, x, scale_tensor, width_min, width_max, power2, round_type)
         # save for backward
         ctx.save_for_backward(x, scale_tensor, x_scaled_round)
         ctx.width_min, ctx.width_max, ctx.power2, ctx.round_type = width_min, width_max, power2, round_type
         return y
 
+
     @staticmethod
     def quant_dequant(ctx, x, scale_tensor, width_min, width_max, power2, round_type='round_up'):
         # clip values need ceil2 and scale values need floor2
         scale_tensor = Floor2G.apply(scale_tensor) if power2 else scale_tensor
         x_scaled = (x * scale_tensor)
     @staticmethod
     def quant_dequant(ctx, x, scale_tensor, width_min, width_max, power2, round_type='round_up'):
         # clip values need ceil2 and scale values need floor2
         scale_tensor = Floor2G.apply(scale_tensor) if power2 else scale_tensor
         x_scaled = (x * scale_tensor)
+
         # round
         if round_type == 'round_up':    # typically for activations
             rand_val = 0.5
         # round
         if round_type == 'round_up':    # typically for activations
             rand_val = 0.5
@@ -150,13 +153,16 @@ class QuantizeDequantizeG(torch.autograd.Function):
             x_scaled_round = torch.round(x_scaled)
         #
         # invert the scale
             x_scaled_round = torch.round(x_scaled)
         #
         # invert the scale
-        scale_inv = 1.0/scale_tensor
+        scale_inv = scale_tensor.pow(-1.0)
         # clamp
         # clamp
-        y = torch.clamp(x_scaled_round, width_min, width_max)*scale_inv
+        x_clamp = torch.clamp(x_scaled_round, width_min, width_max)
+        y = x_clamp * scale_inv
         return y, x_scaled_round
 
         return y, x_scaled_round
 
+
     @staticmethod
     def backward(ctx, dy):
     @staticmethod
     def backward(ctx, dy):
+        # use numerical gradient as analytical expression is difficult.
         # this includes gradient of scale, round and clip
         x, scale_tensor, x_scaled_round = ctx.saved_tensors
         width_min, width_max, power2, round_type = ctx.width_min, ctx.width_max, ctx.power2, ctx.round_type
         # this includes gradient of scale, round and clip
         x, scale_tensor, x_scaled_round = ctx.saved_tensors
         width_min, width_max, power2, round_type = ctx.width_min, ctx.width_max, ctx.power2, ctx.round_type
@@ -182,6 +188,7 @@ class QuantizeDequantizeG(torch.autograd.Function):
         # return
         return dx, ds, None, None, None, None
 
         # return
         return dx, ds, None, None, None, None
 
+
     @staticmethod
     def symbolic(g,  x, scale_tensor, width_min, width_max, power2, round_type='round_up'):
         return g.op("QuantizeDequantizeG",  x, scale_tensor)
\ No newline at end of file
     @staticmethod
     def symbolic(g,  x, scale_tensor, width_min, width_max, power2, round_type='round_up'):
         return g.op("QuantizeDequantizeG",  x, scale_tensor)
\ No newline at end of file
index 996b4c92457ab0b634c4b53ef512f93480d9e889..16f1f3da0e1eea1e4c60b8f9c388bef4536fdbe2 100644 (file)
@@ -1,17 +1,25 @@
+import functools
 import torch
 from . import function
 import torch
 from . import function
+from . import quant_ste
 
 
 ###################################################
 
 
 ###################################################
-round_g = function.RoundG.apply
-round_sym_g = function.RoundSymG.apply
-round_up_g = function.RoundUpG.apply
-round2_g = function.Round2G.apply
-
-ceil_g = function.CeilG.apply
-ceil2_g = function.Ceil2G.apply
-
-quantize_dequantize_g = function.QuantizeDequantizeG.apply
+round_g = quant_ste.PropagateQuantTensorSTE(function.RoundG.apply)
+round_sym_g = quant_ste.PropagateQuantTensorSTE(function.RoundSymG.apply)
+round_up_g = quant_ste.PropagateQuantTensorSTE(function.RoundUpG.apply)
+round2_g = quant_ste.PropagateQuantTensorSTE(function.Round2G.apply)
+ceil_g = quant_ste.PropagateQuantTensorSTE(function.CeilG.apply)
+ceil2_g = quant_ste.PropagateQuantTensorSTE(function.Ceil2G.apply)
+
+# This line with PropagateQuantTensorSTE is optional: using PropagateQuantTensorSTE will cause
+# backward method of QuantizeDequantizeG to be skipped
+# Replace with PropagateQuantTensorQTE to: allow gradient to flow back through the backward method
+# Note: QTE here has effect only if QTE is used in forward() QuantTrainPAct2 in quant_train_modules.py
+# by using quantize_backward_type = 'qte' in QuantTrainPAct2
+# TODO: when using QTE here, we need to register this OP for ONNX export to work
+# and even then the exported model may not be clean.
+quantize_dequantize_g = quant_ste.PropagateQuantTensorSTE(function.QuantizeDequantizeG.apply)
 
 
 ###################################################
 
 
 ###################################################
diff --git a/modules/pytorch_jacinto_ai/xnn/layers/quant_ste.py b/modules/pytorch_jacinto_ai/xnn/layers/quant_ste.py
new file mode 100644 (file)
index 0000000..d3ab4b1
--- /dev/null
@@ -0,0 +1,51 @@
+import torch
+
+###################################################
+class PropagateQuantTensor(object):
+    QUANTIZED_THROUGH_ESTIMATION = 0
+    STRAIGHT_THROUGH_ESTIMATION = 1
+    NO_QUANTIZATION_ESTIMATION = 2
+
+    def __init__(self, func, quant_estimation_type, copy_always=True):
+        self.func = func
+        self.quant_estimation_type = quant_estimation_type
+        self.copy_always = copy_always
+
+    def __call__(self, x, *args, **kwargs):
+        if self.func is not None:
+            y = self.func(x, *args, **kwargs)
+        else:
+            y = args[0]
+        #
+        if self.quant_estimation_type == self.QUANTIZED_THROUGH_ESTIMATION:
+            # backprop will flow through the backward function
+            # as the quantized tensor is forwarded as it is.
+            return y
+        elif self.quant_estimation_type == self.STRAIGHT_THROUGH_ESTIMATION:
+            # forward the quantized data, but in the container x_copy
+            # here a copy of x is made so that the original is not altered (x is a reference in Python).
+            # the backward will directly flow though x_copy and to x instead of going through y
+            # copy_always can be set to False avoid this copy if possible.
+            x_copy = x.clone() if self.copy_always or isinstance(x, torch.nn.Parameter) else x
+            x_copy.data = y.data
+            return x_copy
+        elif self.quant_estimation_type == self.NO_QUANTIZATION_ESTIMATION:
+            # beware! no quantization performed in this case.
+            return x
+        else:
+            assert False, f'unknown quant_estimation_type {self.quant_estimation_type}'
+
+
+class PropagateQuantTensorQTE(PropagateQuantTensor):
+    def __init__(self, func):
+        # QUANTIZED_THROUGH_ESTIMATION: backprop will flow through
+        # the backward functions wrapped in this class
+        super().__init__(func, quant_estimation_type=PropagateQuantTensor.QUANTIZED_THROUGH_ESTIMATION)
+
+
+class PropagateQuantTensorSTE(PropagateQuantTensor):
+    def __init__(self, func):
+        # STRAIGHT_THROUGH_ESTIMATION: backprop will NOT flow through
+        # the backward functions wrapped in this class
+        super().__init__(func, quant_estimation_type=PropagateQuantTensor.STRAIGHT_THROUGH_ESTIMATION)
+
index 8bab95da30ba8f0bcc0da3d1853e6fe7b0de9edb..8a02a21634e34297db3dbe7c94d065d224ddc070 100644 (file)
@@ -88,7 +88,7 @@ def UpsampleWith(input_channels=None, output_channels=None, upstride=None, inter
             upsample = [ResizeWith(scale_factor=upstride, mode=interpolation_mode),
                         ConvDWSepNormAct2d(input_channels, output_channels, kernel_size=int(upstride * 1.5 + 1),
                                       normalization=normalization, activation=activation)]
             upsample = [ResizeWith(scale_factor=upstride, mode=interpolation_mode),
                         ConvDWSepNormAct2d(input_channels, output_channels, kernel_size=int(upstride * 1.5 + 1),
                                       normalization=normalization, activation=activation)]
-        elif interpolation_type == 'subpixel_conv':
+        elif (interpolation_type == 'subpixel_conv' or interpolation_type == 'pixel_shuffle'):
             upsample = [ConvDWSepNormAct2d(input_channels, output_channels*upstride*upstride, kernel_size=int(upstride + 1),
                                       normalization=normalization, activation=activation),
                         torch.nn.PixelShuffle(upscale_factor=int(upstride))]
             upsample = [ConvDWSepNormAct2d(input_channels, output_channels*upstride*upstride, kernel_size=int(upstride + 1),
                                       normalization=normalization, activation=activation),
                         torch.nn.PixelShuffle(upscale_factor=int(upstride))]
index b6399bbdd6c3aa1b9aefe202c82e28d108dcaa58..368a077108d6b3181df3f9a3e5fafb5068dca177 100644 (file)
@@ -58,7 +58,8 @@ class ASPPBlock(torch.nn.Module):
 # this is called a lite block because the dilated convolutions use
 # ConvDWNormAct2d instead of ConvDWSepNormAct2d
 class DWASPPLiteBlock(torch.nn.Module):
 # this is called a lite block because the dilated convolutions use
 # ConvDWNormAct2d instead of ConvDWSepNormAct2d
 class DWASPPLiteBlock(torch.nn.Module):
-    def __init__(self, in_chs, aspp_chs, out_chs, dilation=(6, 12, 18), groups=1, avg_pool=False, activation=DefaultAct2d, linear_dw=False):
+    def __init__(self, in_chs, aspp_chs, out_chs, dilation=(6, 12, 18), groups=1, group_size_dw=None, avg_pool=False,
+                 activation=DefaultAct2d, linear_dw=False):
         super().__init__()
 
         self.aspp_chs = aspp_chs
         super().__init__()
 
         self.aspp_chs = aspp_chs
@@ -73,9 +74,15 @@ class DWASPPLiteBlock(torch.nn.Module):
         self.conv1x1 = ConvNormAct2d(in_chs, aspp_chs, kernel_size=1, activation=activation)
         normalizations_dw = ((not linear_dw), True)
         activations_dw = (False if linear_dw else activation, activation)
         self.conv1x1 = ConvNormAct2d(in_chs, aspp_chs, kernel_size=1, activation=activation)
         normalizations_dw = ((not linear_dw), True)
         activations_dw = (False if linear_dw else activation, activation)
-        self.aspp_bra1 = ConvDWSepNormAct2d(in_chs, aspp_chs, kernel_size=3, dilation=dilation[0], normalization=normalizations_dw, activation=activations_dw, groups=groups)
-        self.aspp_bra2 = ConvDWSepNormAct2d(in_chs, aspp_chs, kernel_size=3, dilation=dilation[1], normalization=normalizations_dw, activation=activations_dw, groups=groups)
-        self.aspp_bra3 = ConvDWSepNormAct2d(in_chs, aspp_chs, kernel_size=3, dilation=dilation[2], normalization=normalizations_dw, activation=activations_dw, groups=groups)
+        self.aspp_bra1 = ConvDWSepNormAct2d(in_chs, aspp_chs, kernel_size=3, dilation=dilation[0],
+                                            normalization=normalizations_dw, activation=activations_dw,
+                                            groups=groups, group_size_dw=group_size_dw)
+        self.aspp_bra2 = ConvDWSepNormAct2d(in_chs, aspp_chs, kernel_size=3, dilation=dilation[1],
+                                            normalization=normalizations_dw, activation=activations_dw,
+                                            groups=groups, group_size_dw=group_size_dw)
+        self.aspp_bra3 = ConvDWSepNormAct2d(in_chs, aspp_chs, kernel_size=3, dilation=dilation[2],
+                                            normalization=normalizations_dw, activation=activations_dw,
+                                            groups=groups, group_size_dw=group_size_dw)
 
         self.dropout = torch.nn.Dropout2d(p=0.2, inplace=True)   
         self.aspp_out = ConvNormAct2d(self.last_chns, out_chs, kernel_size=1, groups=1, activation=activation)
 
         self.dropout = torch.nn.Dropout2d(p=0.2, inplace=True)   
         self.aspp_out = ConvNormAct2d(self.last_chns, out_chs, kernel_size=1, groups=1, activation=activation)
index 11467d1f2495b0193ef235f47a9fe6607f668ccb..2996f3951e5ec59a735d396431e1db8bd036ac87 100644 (file)
@@ -2,12 +2,6 @@ import copy
 from .quant_graph_module import *
 
 ###########################################################
 from .quant_graph_module import *
 
 ###########################################################
-class QuantEstimationType:
-    QUANTIZED_THROUGH_ESTIMATION = 0
-    STRAIGHT_THROUGH_ESTIMATION = 1
-    ALPHA_BLENDING_ESTIMATION = 2
-
-
 # base module to be use for all quantization modules
 class QuantBaseModule(QuantGraphModule):
     def __init__(self, module, dummy_input, *args, bitwidth_weights=8, bitwidth_activations=8, per_channel_q=False,
 # base module to be use for all quantization modules
 class QuantBaseModule(QuantGraphModule):
     def __init__(self, module, dummy_input, *args, bitwidth_weights=8, bitwidth_activations=8, per_channel_q=False,
@@ -19,12 +13,18 @@ class QuantBaseModule(QuantGraphModule):
         self.per_channel_q = per_channel_q
         self.histogram_range = histogram_range
         self.constrain_weights = constrain_weights
         self.per_channel_q = per_channel_q
         self.histogram_range = histogram_range
         self.constrain_weights = constrain_weights
-        self.constrain_bias = True if (constrain_bias is None) else constrain_bias
         self.bias_calibration = bias_calibration
         self.power2_weight_range = True if (power2_weight_range is None) else power2_weight_range
         self.power2_activation_range = True if (power2_activation_range is None) else power2_activation_range
         # range shrink - 0.0 indicates no shrink
         self.bias_calibration = bias_calibration
         self.power2_weight_range = True if (power2_weight_range is None) else power2_weight_range
         self.power2_activation_range = True if (power2_activation_range is None) else power2_activation_range
         # range shrink - 0.0 indicates no shrink
-        self.percentile_range_shrink = (layers.PAct2.PACT2_RANGE_SHRINK if histogram_range else 0.0)
+        self.range_shrink_percentile = (layers.PAct2.PACT2_RANGE_SHRINK if histogram_range else 0.0)
+        # constrain_bias means bias that is being added to accumulator is limited to 16bit (for 8bit quantization).
+        # scale factor to be used for constrain_bias is the product of scale factors of weight and input
+        self.constrain_bias = True if (constrain_bias is None) else constrain_bias
+        # using per_channel_q when constrain_bias is set may not be good for accuracy.
+        if self.constrain_bias and self.per_channel_q:
+            warnings.warn('using per_channel_q when constrain_bias is set may not be good for accuracy.')
+        #
         # for help in debug/print
         utils.add_module_names(self)
         # put in eval mode before analyze
         # for help in debug/print
         utils.add_module_names(self)
         # put in eval mode before analyze
@@ -45,7 +45,7 @@ class QuantBaseModule(QuantGraphModule):
         # set attributes to all modules - can control the behaviour from here
         utils.apply_setattr(self, bitwidth_weights=self.bitwidth_weights, bitwidth_activations=self.bitwidth_activations,
                             histogram_range=histogram_range, bias_calibration=self.bias_calibration, per_channel_q=self.per_channel_q,
         # set attributes to all modules - can control the behaviour from here
         utils.apply_setattr(self, bitwidth_weights=self.bitwidth_weights, bitwidth_activations=self.bitwidth_activations,
                             histogram_range=histogram_range, bias_calibration=self.bias_calibration, per_channel_q=self.per_channel_q,
-                            percentile_range_shrink=self.percentile_range_shrink, constrain_weights=self.constrain_weights,
+                            range_shrink_percentile=self.range_shrink_percentile, constrain_weights=self.constrain_weights,
                             power2_weight_range=self.power2_weight_range, power2_activation_range=self.power2_activation_range,
                             constrain_bias=self.constrain_bias)
 
                             power2_weight_range=self.power2_weight_range, power2_activation_range=self.power2_activation_range,
                             constrain_bias=self.constrain_bias)
 
index a36adccaaaa98fc0b3672409623a56a8ca9932ef..4ed49f7c62942f560a55906eba8005ca73cf0b25 100644 (file)
@@ -31,6 +31,7 @@ class QuantCalibrateModule(QuantTrainModule):
         super().__init__(module, dummy_input, *args, bitwidth_weights=bitwidth_weights, bitwidth_activations=bitwidth_activations,
                          per_channel_q=per_channel_q, histogram_range=histogram_range, bias_calibration=bias_calibration, constrain_weights=constrain_weights,
                          power2_weight_range=power2_weight_range, power2_activation_range=power2_activation_range, constrain_bias=constrain_bias, **kwargs)
         super().__init__(module, dummy_input, *args, bitwidth_weights=bitwidth_weights, bitwidth_activations=bitwidth_activations,
                          per_channel_q=per_channel_q, histogram_range=histogram_range, bias_calibration=bias_calibration, constrain_weights=constrain_weights,
                          power2_weight_range=power2_weight_range, power2_activation_range=power2_activation_range, constrain_bias=constrain_bias, **kwargs)
+        self.calib_stats = dict()
 
 
     def forward(self, inputs, *args, **kwargs):
 
 
     def forward(self, inputs, *args, **kwargs):
@@ -77,7 +78,8 @@ class QuantCalibrateModule(QuantTrainModule):
         outputs = self.forward_float(inputs, *args, **kwargs)
         # Then adjust weights/bias so that the quantized output matches float output
         outputs = self.forward_quantized(inputs, *args, **kwargs)
         outputs = self.forward_float(inputs, *args, **kwargs)
         # Then adjust weights/bias so that the quantized output matches float output
         outputs = self.forward_quantized(inputs, *args, **kwargs)
-
+        # not needed outside - clear
+        self.calib_stats = dict()
         return outputs
 
 
         return outputs
 
 
@@ -107,13 +109,15 @@ class QuantCalibrateModule(QuantTrainModule):
         reduce_dims = [0, 2, 3] if len(output.size()) == 4 else ([0] if len(output.size()) == 2 else None)
 
         bias = op.bias if hasattr(op, 'bias') else None
         reduce_dims = [0, 2, 3] if len(output.size()) == 4 else ([0] if len(output.size()) == 2 else None)
 
         bias = op.bias if hasattr(op, 'bias') else None
+        output_mean = output_std = None
         if (self.bias_calibration and bias is not None):
         if (self.bias_calibration and bias is not None):
-            op.__output_mean_orig__ = torch.mean(output, dim=reduce_dims).data
+            output_mean = torch.mean(output, dim=reduce_dims).data
         #
 
         if self.weights_calibration and utils.is_conv_deconv(op):
         #
 
         if self.weights_calibration and utils.is_conv_deconv(op):
-            op.__output_std_orig__ = torch.std(output, dim=reduce_dims).data
+            output_std = torch.std(output, dim=reduce_dims).data
         #
         #
+        self.calib_stats[op] = dict(mean=output_mean, std=output_std)
         return outputs
     #
 
         return outputs
     #
 
@@ -137,10 +141,12 @@ class QuantCalibrateModule(QuantTrainModule):
             output = output[0]
 
         bias = op.bias if hasattr(op, 'bias') else None
             output = output[0]
 
         bias = op.bias if hasattr(op, 'bias') else None
+        output_mean_float = self.calib_stats[op]['mean']
+        output_std_float = self.calib_stats[op]['std']
         if self.bias_calibration and bias is not None:
             reduce_dims = [0,2,3] if len(output.size()) == 4 else ([0] if len(output.size()) == 2 else None)
             output_mean = torch.mean(output, dim=reduce_dims).data
         if self.bias_calibration and bias is not None:
             reduce_dims = [0,2,3] if len(output.size()) == 4 else ([0] if len(output.size()) == 2 else None)
             output_mean = torch.mean(output, dim=reduce_dims).data
-            output_delta = op.__output_mean_orig__ - output_mean
+            output_delta = output_mean_float - output_mean
             output_delta = output_delta * self.calibration_factor
             bias.data += (output_delta)
         #
             output_delta = output_delta * self.calibration_factor
             bias.data += (output_delta)
         #
@@ -150,7 +156,7 @@ class QuantCalibrateModule(QuantTrainModule):
             weight = op.weight
             reduce_dims = [0, 2, 3] if len(output.size()) == 4 else ([0] if len(output.size()) == 2 else None)
             output_std = torch.std(output, dim=reduce_dims).data
             weight = op.weight
             reduce_dims = [0, 2, 3] if len(output.size()) == 4 else ([0] if len(output.size()) == 2 else None)
             output_std = torch.std(output, dim=reduce_dims).data
-            output_ratio = (op.__output_std_orig__ + eps) / (output_std + eps)
+            output_ratio = (output_std_float + eps) / (output_std + eps)
             channels = output.size(1)
             output_ratio = output_ratio.view(channels, 1, 1, 1) if len(weight.data.size()) > 1 else output_ratio
             output_ratio = torch.pow(output_ratio, self.calibration_factor)
             channels = output.size(1)
             output_ratio = output_ratio.view(channels, 1, 1, 1) if len(weight.data.size()) > 1 else output_ratio
             output_ratio = torch.pow(output_ratio, self.calibration_factor)
index 2a5b60d94833ee2a58a3e0abee85aa930a91dffd..c3433c3a549df179f0c582391e40ff3a489df6f4 100644 (file)
@@ -22,6 +22,9 @@ class QuantGraphModule(HookedModule):
         # this block is not quantized. Also if the next block is this, current block is not quantized
         self.ignore_out_blocks = (layers.NoQAct,torch.nn.Dropout2d)
 
         # this block is not quantized. Also if the next block is this, current block is not quantized
         self.ignore_out_blocks = (layers.NoQAct,torch.nn.Dropout2d)
 
+        # quantize the input to a block (under  a certain conditions of the input was not already quantized)
+        self.quantize_in = True
+
         # TBD: is this required
         # # if the original module has load_weights, add it to the quant module also
         # if hasattr(module, 'load_weights'):
         # TBD: is this required
         # # if the original module has load_weights, add it to the quant module also
         # if hasattr(module, 'load_weights'):
@@ -135,7 +138,7 @@ class QuantGraphModule(HookedModule):
                 #
             elif qparams.quantize_in:
                 if not hasattr(module, 'activation_in'):
                 #
             elif qparams.quantize_in:
                 if not hasattr(module, 'activation_in'):
-                    # TODO: set percentile_range_shrink=0.0 to avoid shrinking of input range, if needed.
+                    # TODO: set range_shrink_percentile=0.0 to avoid shrinking of input range, if needed.
                     activation_in = layers.PAct2(signed=None)
                     activation_in.train(self.training)
                     module.activation_in = activation_in
                     activation_in = layers.PAct2(signed=None)
                     activation_in.train(self.training)
                     module.activation_in = activation_in
@@ -298,8 +301,8 @@ class QuantGraphModule(HookedModule):
             is_input_quantized = False
         #
 
             is_input_quantized = False
         #
 
-        quantize_in = utils.is_conv_deconv_linear(module) and not is_input_quantized and \
-                      not is_input_ignored and is_first_module
+        quantize_in = self.quantize_in and utils.is_conv_deconv_linear(module) and (not is_input_quantized) and \
+                      (not is_input_ignored) and is_first_module
         qparams.quantize_w = utils.is_conv_deconv_linear(module)                                # all conv/deconv layers will be quantized
         qparams.quantize_b = utils.is_conv_deconv_linear(module)                                # all conv/deconv layers will be quantized
         qparams.quantize_out = quantize_out                                                     # selectively quantize output
         qparams.quantize_w = utils.is_conv_deconv_linear(module)                                # all conv/deconv layers will be quantized
         qparams.quantize_b = utils.is_conv_deconv_linear(module)                                # all conv/deconv layers will be quantized
         qparams.quantize_out = quantize_out                                                     # selectively quantize output
index bc5e3d5a04af8e12ab1654d77c3ab974fb2758f7..add1384e49bb6b2403c6dc66a795d0cb4d14729d 100644 (file)
@@ -12,7 +12,6 @@ class QuantTestModule(QuantTrainModule):
     def __init__(self, module, dummy_input, *args, bitwidth_weights=8, bitwidth_activations=8, per_channel_q=False,
                  histogram_range=True, bias_calibration=False, constrain_weights=None, model_surgery_quantize=True,
                  power2_weight_range=None, power2_activation_range=None, constrain_bias=None, **kwargs):
     def __init__(self, module, dummy_input, *args, bitwidth_weights=8, bitwidth_activations=8, per_channel_q=False,
                  histogram_range=True, bias_calibration=False, constrain_weights=None, model_surgery_quantize=True,
                  power2_weight_range=None, power2_activation_range=None, constrain_bias=None, **kwargs):
-        constrain_weights = (not per_channel_q) if constrain_weights is None else constrain_weights
         super().__init__(module, dummy_input, *args, bitwidth_weights=bitwidth_weights, bitwidth_activations=bitwidth_activations,
                          per_channel_q=per_channel_q, histogram_range=histogram_range, bias_calibration=bias_calibration,
                          constrain_weights=constrain_weights,
         super().__init__(module, dummy_input, *args, bitwidth_weights=bitwidth_weights, bitwidth_activations=bitwidth_activations,
                          per_channel_q=per_channel_q, histogram_range=histogram_range, bias_calibration=bias_calibration,
                          constrain_weights=constrain_weights,
@@ -57,10 +56,10 @@ class QuantEstimateModule(QuantBaseModule):
         self.range_expansion_factor = 1.0
 
         # set these to 0 to use faster min/max based range computation (lower accuracy) instead of histogram based range.
         self.range_expansion_factor = 1.0
 
         # set these to 0 to use faster min/max based range computation (lower accuracy) instead of histogram based range.
-        # shrink range: 0.01 means 0.01 percentile_range_shrink, not 1 percentile_range_shrink
-        self.percentile_range_shrink_activations = (layers.PAct2.PACT2_RANGE_SHRINK if histogram_range else 0)
+        # shrink range: 0.01 means 0.01 range_shrink_percentile, not 1 range_shrink_percentile
+        self.range_shrink_percentile_activations = (layers.PAct2.PACT2_RANGE_SHRINK if histogram_range else 0)
         # range shrinking of weight is hurting in some models
         # range shrinking of weight is hurting in some models
-        self.percentile_range_shrink_weights = 0 #(0.01 if histogram_range else 0)
+        self.range_shrink_percentile_weights = 0 #(0.01 if histogram_range else 0)
 
         self.idx_large_mse_for_act = 0
 
 
         self.idx_large_mse_for_act = 0
 
@@ -259,22 +258,22 @@ class QuantEstimateModule(QuantBaseModule):
         self.current_scale = output[0].scale
 
 
         self.current_scale = output[0].scale
 
 
-    def compute_tensor_range(self, module, tensor_in, percentile_range_shrink):
+    def compute_tensor_range(self, module, tensor_in, range_shrink_percentile):
         if hasattr(tensor_in, 'scale') and utils.is_list(tensor_in.scale):
             scale_inv = [(1/s) for s in tensor_in.scale]
             tensor_scale_inv = torch.tensor(scale_inv).view(1,-1,1,1).to(tensor_in.device)
             tensor_scaled = tensor_in * tensor_scale_inv
         if hasattr(tensor_in, 'scale') and utils.is_list(tensor_in.scale):
             scale_inv = [(1/s) for s in tensor_in.scale]
             tensor_scale_inv = torch.tensor(scale_inv).view(1,-1,1,1).to(tensor_in.device)
             tensor_scaled = tensor_in * tensor_scale_inv
-            (mn, mx) = self._compute_tensor_range_noscale(module, tensor_scaled, percentile_range_shrink)
+            (mn, mx) = self._compute_tensor_range_noscale(module, tensor_scaled, range_shrink_percentile)
         else:
             scale = tensor_in.scale if hasattr(tensor_in, 'scale') else 1.0
         else:
             scale = tensor_in.scale if hasattr(tensor_in, 'scale') else 1.0
-            (mn, mx) = self._compute_tensor_range_noscale(module, tensor_in, percentile_range_shrink)
+            (mn, mx) = self._compute_tensor_range_noscale(module, tensor_in, range_shrink_percentile)
             (mn, mx) = (mn / scale, mx / scale)
         #
         return mn, mx
 
 
             (mn, mx) = (mn / scale, mx / scale)
         #
         return mn, mx
 
 
-    def _compute_tensor_range_noscale(self, module, tensor, percentile_range_shrink):
-        mn, mx = utils.extrema_fast(tensor.data, percentile_range_shrink)
+    def _compute_tensor_range_noscale(self, module, tensor, range_shrink_percentile):
+        mn, mx = utils.extrema_fast(tensor.data, range_shrink_percentile)
         return mn, mx
 
 
         return mn, mx
 
 
@@ -335,7 +334,7 @@ class QuantEstimateModule(QuantBaseModule):
                 tensor_in.scale = []
                 for chan in range(tensor_in.shape[0]):
                     # Range
                 tensor_in.scale = []
                 for chan in range(tensor_in.shape[0]):
                     # Range
-                    mn, mx = self.compute_tensor_range(module, tensor_in[chan], percentile_range_shrink=self.percentile_range_shrink_weights)
+                    mn, mx = self.compute_tensor_range(module, tensor_in[chan], range_shrink_percentile=self.range_shrink_percentile_weights)
                     tensor_scale, clamp_limits = compute_tensor_scale(tensor_in[chan], mn, mx, bitwidth_weights, self.power2_weight_range)
                     qrange.min.append(mn)
                     qrange.max.append(mx)
                     tensor_scale, clamp_limits = compute_tensor_scale(tensor_in[chan], mn, mx, bitwidth_weights, self.power2_weight_range)
                     qrange.min.append(mn)
                     qrange.max.append(mx)
@@ -348,7 +347,7 @@ class QuantEstimateModule(QuantBaseModule):
                 #
             else:
                 # Range
                 #
             else:
                 # Range
-                mn, mx = self.compute_tensor_range(module, tensor_in, percentile_range_shrink=self.percentile_range_shrink_weights)
+                mn, mx = self.compute_tensor_range(module, tensor_in, range_shrink_percentile=self.range_shrink_percentile_weights)
                 tensor_scale, clamp_limits = compute_tensor_scale(tensor_in, mn, mx, bitwidth_weights, self.power2_weight_range)
                 qrange.min = mn
                 qrange.max = mx
                 tensor_scale, clamp_limits = compute_tensor_scale(tensor_in, mn, mx, bitwidth_weights, self.power2_weight_range)
                 qrange.min = mn
                 qrange.max = mx
@@ -368,7 +367,7 @@ class QuantEstimateModule(QuantBaseModule):
             #use same bitwidth as weight
             bitwidth_bias = bitwidth_weights
             
             #use same bitwidth as weight
             bitwidth_bias = bitwidth_weights
             
-            mn, mx = self.compute_tensor_range(module, tensor_in, percentile_range_shrink=0.0)
+            mn, mx = self.compute_tensor_range(module, tensor_in, range_shrink_percentile=0.0)
             tensor_scale, clamp_limits = compute_tensor_scale(tensor_in, mn, mx, bitwidth_bias, self.power2_weight_range)
 
             # --
             tensor_scale, clamp_limits = compute_tensor_scale(tensor_in, mn, mx, bitwidth_bias, self.power2_weight_range)
 
             # --
@@ -426,7 +425,7 @@ class QuantEstimateModule(QuantBaseModule):
                     mn = op_range[0]
                     mx = op_range[1]
                 else:
                     mn = op_range[0]
                     mx = op_range[1]
                 else:
-                    mn, mx = self.compute_tensor_range(module, tensor_in, percentile_range_shrink=self.percentile_range_shrink_activations)
+                    mn, mx = self.compute_tensor_range(module, tensor_in, range_shrink_percentile=self.range_shrink_percentile_activations)
 
             tensor_scale, clamp_limits = compute_tensor_scale(None, mn, mx, bitwidth_activations, True)
             tensor = upward_round_tensor(tensor_in*tensor_scale)
 
             tensor_scale, clamp_limits = compute_tensor_scale(None, mn, mx, bitwidth_activations, True)
             tensor = upward_round_tensor(tensor_in*tensor_scale)
index 3c3cfedf1395e1065e9fcd5e8bf4c3aa0db6e368..7bcd5417c89b8d9defb5269ad458fdfa63454fd8 100644 (file)
@@ -60,21 +60,21 @@ class QuantTrainModule(QuantBaseModule):
                 elif isinstance(m, layers.PAct2):
                     new_m = QuantTrainPAct2(inplace=m.inplace, signed=m.signed, clip_range=m.clip_range,
                                              bitwidth_weights=self.bitwidth_weights, bitwidth_activations=self.bitwidth_activations,
                 elif isinstance(m, layers.PAct2):
                     new_m = QuantTrainPAct2(inplace=m.inplace, signed=m.signed, clip_range=m.clip_range,
                                              bitwidth_weights=self.bitwidth_weights, bitwidth_activations=self.bitwidth_activations,
-                                            per_channel_q=self.per_channel_q, percentile_range_shrink=self.percentile_range_shrink,
+                                            per_channel_q=self.per_channel_q, range_shrink_percentile=self.range_shrink_percentile,
                                             power2_weight_range=self.power2_weight_range, power2_activation_range=self.power2_activation_range)
                 elif isinstance(m, layers.QAct):
                     new_m = QuantTrainPAct2(signed=None, bitwidth_weights=self.bitwidth_weights, bitwidth_activations=self.bitwidth_activations,
                                             power2_weight_range=self.power2_weight_range, power2_activation_range=self.power2_activation_range)
                 elif isinstance(m, layers.QAct):
                     new_m = QuantTrainPAct2(signed=None, bitwidth_weights=self.bitwidth_weights, bitwidth_activations=self.bitwidth_activations,
-                                             per_channel_q=self.per_channel_q, percentile_range_shrink=self.percentile_range_shrink,
+                                             per_channel_q=self.per_channel_q, range_shrink_percentile=self.range_shrink_percentile,
                                             power2_weight_range=self.power2_weight_range, power2_activation_range=self.power2_activation_range)
                 elif isinstance(m, (torch.nn.ReLU, torch.nn.ReLU6)):
                     new_m = QuantTrainPAct2(signed=False, bitwidth_weights=self.bitwidth_weights, bitwidth_activations=self.bitwidth_activations,
                                             power2_weight_range=self.power2_weight_range, power2_activation_range=self.power2_activation_range)
                 elif isinstance(m, (torch.nn.ReLU, torch.nn.ReLU6)):
                     new_m = QuantTrainPAct2(signed=False, bitwidth_weights=self.bitwidth_weights, bitwidth_activations=self.bitwidth_activations,
-                                             per_channel_q=self.per_channel_q, percentile_range_shrink=self.percentile_range_shrink,
+                                             per_channel_q=self.per_channel_q, range_shrink_percentile=self.range_shrink_percentile,
                                             power2_weight_range=self.power2_weight_range, power2_activation_range=self.power2_activation_range)
                 else:
                     new_m = None
                 #
                 if new_m is not None:
                                             power2_weight_range=self.power2_weight_range, power2_activation_range=self.power2_activation_range)
                 else:
                     new_m = None
                 #
                 if new_m is not None:
-                    copy_attr_list = ('weight', 'bias', 'eps', 'clips_act', 'clips_w', 'percentile_range_shrink')
+                    copy_attr_list = ('weight', 'bias', 'eps', 'clips_act', 'clips_w', 'range_shrink_percentile')
                     for attr in dir(m):
                         value = getattr(m,attr)
                         if isinstance(value,torch.Tensor) and value is not None:
                     for attr in dir(m):
                         value = getattr(m,attr)
                         if isinstance(value,torch.Tensor) and value is not None:
@@ -250,8 +250,9 @@ class QuantTrainBatchNorm2d(torch.nn.BatchNorm2d):
 # fake quantized PAct2 for training
 class QuantTrainPAct2(layers.PAct2):
     def __init__(self, inplace=False, signed=False, clip_range=None, bitwidth_weights=None, bitwidth_activations=None,
 # fake quantized PAct2 for training
 class QuantTrainPAct2(layers.PAct2):
     def __init__(self, inplace=False, signed=False, clip_range=None, bitwidth_weights=None, bitwidth_activations=None,
-                 per_channel_q=False, percentile_range_shrink=layers.PAct2.PACT2_RANGE_SHRINK, power2_weight_range=True, power2_activation_range=True):
-        super().__init__(inplace=inplace, signed=signed, clip_range=clip_range, percentile_range_shrink=percentile_range_shrink,
+                 per_channel_q=False, range_shrink_percentile=layers.PAct2.PACT2_RANGE_SHRINK, power2_weight_range=True,
+                 power2_activation_range=True):
+        super().__init__(inplace=inplace, signed=signed, clip_range=clip_range, range_shrink_percentile=range_shrink_percentile,
                          power2_activation_range=power2_activation_range)
 
         self.bitwidth_weights = bitwidth_weights
                          power2_activation_range=power2_activation_range)
 
         self.bitwidth_weights = bitwidth_weights
@@ -259,6 +260,13 @@ class QuantTrainPAct2(layers.PAct2):
         self.per_channel_q = per_channel_q
         self.power2_weight_range = power2_weight_range
 
         self.per_channel_q = per_channel_q
         self.power2_weight_range = power2_weight_range
 
+        # quantize_backward_type can be 'ste' or 'qte'
+        # - quantize_backward_type == 'ste' will cause backward to happen using unquantized weights/bias
+        #   (as the contents of yq is propagated inside the tensor y). Uses: PropagateQuantTensorSTE
+        # - quantize_backward_type == 'qte'  allow gradient to flow back through conv using the quantized weights/bias
+        #   (as yq is directly propagated then). Uses: PropagateQuantTensorQTE
+        self.quantize_backward_type = 'ste'
+
         # weight shrinking is done by clamp weights - set this factor to zero.
         # this must me zero - as in pact we do not have the actual weight param, but just a temporary tensor
         # so any clipping we do here is not stored int he weight params
         # weight shrinking is done by clamp weights - set this factor to zero.
         # this must me zero - as in pact we do not have the actual weight param, but just a temporary tensor
         # so any clipping we do here is not stored int he weight params
@@ -272,20 +280,10 @@ class QuantTrainPAct2(layers.PAct2):
         self.constrain_bias = None
         self.constrain_weights = True
         self.bias_calibration = False
         self.constrain_bias = None
         self.constrain_weights = True
         self.bias_calibration = False
-        # do joint quantization only after the activation range has stabilized reasonably.
-        self.constrain_bias_start_iter = 75
-
-        # set to STRAIGHT_THROUGH_ESTIMATION or ALPHA_BLENDING_ESTIMATION
-        # For a comparison of STE and ABE, read:
-        # Learning low-precision neural networks without Straight-Through Estimator (STE):
-        # https://arxiv.org/pdf/1903.01061.pdf
-        self.quantized_estimation_type = QuantEstimationType.STRAIGHT_THROUGH_ESTIMATION
-        self.alpha_blending_estimation_factor = 0.5
-
-        if (layers.PAct2.PACT2_RANGE_LEARN):
-            assert self.quantized_estimation_type != QuantEstimationType.STRAIGHT_THROUGH_ESTIMATION, \
-                'straight through estimation should not used when PACT clip values are being learned as it doesnt backpropagate gradients though quantization'
-        #
+        # start bias constrain at this iteration
+        self.constrain_bias_start_iter = 0
+        # start storing of weights at this iteration
+        self.store_weights_iter = 0
 
 
     def forward(self, x):
 
 
     def forward(self, x):
@@ -340,39 +338,22 @@ class QuantTrainPAct2(layers.PAct2):
             clip_min, clip_max, scale, scale_inv = self.get_clips_scale_act()
             width_min, width_max = self.get_widths_act()
             # no need to call super().forward here as clipping with width_min/windth_max-1 after scaling has the same effect.
             clip_min, clip_max, scale, scale_inv = self.get_clips_scale_act()
             width_min, width_max = self.get_widths_act()
             # no need to call super().forward here as clipping with width_min/windth_max-1 after scaling has the same effect.
-            # currently the gradient is set to 1 in round_up_g. shouldn't the gradient be (y-x)?
-            # see eqn 6 in https://arxiv.org/pdf/1903.08066v2.pdf
-            # yq = layers.clamp_g(layers.round_up_g(xq * scale), width_min, width_max-1, self.training) * scale_inv
             yq = layers.quantize_dequantize_g(xq, scale, width_min, width_max-1, self.power2_activation_range, 'round_up')
         else:
             yq = super().forward(xq, update_activation_range=False, enable=True)
         #
 
             yq = layers.quantize_dequantize_g(xq, scale, width_min, width_max-1, self.power2_activation_range, 'round_up')
         else:
             yq = super().forward(xq, update_activation_range=False, enable=True)
         #
 
-        if (self.quantized_estimation_type == QuantEstimationType.STRAIGHT_THROUGH_ESTIMATION):
-            # replace the float output with quantized version
-            # the entire weight merging and quantization process is bypassed in the forward pass
-            # however, the backward gradients flow through only the float path - this is called straight through estimation (STE)
-            with torch.no_grad():
-                y.data.copy_(yq.data)
-            #
-        elif (self.quantized_estimation_type == QuantEstimationType.ALPHA_BLENDING_ESTIMATION):
-            if self.training:
-                # TODO: vary the alpha blending factor over the epochs
-                y = y * (1.0-self.alpha_blending_estimation_factor) + yq * self.alpha_blending_estimation_factor
-            else:
-                y = yq
-            #
-        elif (self.quantized_estimation_type == QuantEstimationType.QUANTIZED_THROUGH_ESTIMATION):
-            # pass on the quantized output - the backward gradients also flow through quantization.
-            # however, note the gradients of round and ceil operators are forced to be unity (1.0).
-            y = yq
+        if self.quantize_backward_type == 'ste':
+            yq = layers.PropagateQuantTensorSTE(None)(y, yq)
+        elif self.quantize_backward_type == 'qte':
+            yq = layers.PropagateQuantTensorQTE(None)(y, yq)
         else:
         else:
-            assert False, f'unsupported value for quantized_estimation_type: {self.quantized_estimation_type}'
+            assert False, 'quantize_backward_type must be one of ste or qte'
         #
 
         # pass on the clips to be used in the next quantization
         #
 
         # pass on the clips to be used in the next quantization
-        y.clips_act = self.get_clips_act()
-        return y
+        yq.clips_act = self.get_clips_act()
+        return yq
     #
 
 
     #
 
 
@@ -382,10 +363,10 @@ class QuantTrainPAct2(layers.PAct2):
 
     def merge_quantize_weights(self, qparams, conv, bn):
         num_batches_tracked = int(self.num_batches_tracked)
 
     def merge_quantize_weights(self, qparams, conv, bn):
         num_batches_tracked = int(self.num_batches_tracked)
-        is_constrain_weights_iter = self.training and (num_batches_tracked == 0)
-        is_store_weights_iter = self.training and (num_batches_tracked == 0)
-        is_constrain_bias_iter = self.training and (num_batches_tracked>=self.constrain_bias_start_iter)
-        is_store_bias_iter = self.training and (num_batches_tracked==self.constrain_bias_start_iter)
+        is_constrain_weights_iter = self.training and (num_batches_tracked == self.store_weights_iter)
+        is_store_weights_iter = self.training and (num_batches_tracked == self.store_weights_iter)
+        is_store_bias_iter = self.training and (num_batches_tracked == self.constrain_bias_start_iter)
+        is_constrain_bias_iter = self.training and (num_batches_tracked >= self.constrain_bias_start_iter)
 
         # merge weight and bias (if possible) across layers
         if conv is not None and bn is not None:
 
         # merge weight and bias (if possible) across layers
         if conv is not None and bn is not None:
@@ -438,41 +419,35 @@ class QuantTrainPAct2(layers.PAct2):
                 #
 
                 is_dw = utils.is_dwconv(conv)
                 #
 
                 is_dw = utils.is_dwconv(conv)
+                is_deconv = utils.is_deconv(conv)
                 use_per_channel_q = (is_dw and self.per_channel_q is True) or (self.per_channel_q == 'all')
                 use_per_channel_q = (is_dw and self.per_channel_q is True) or (self.per_channel_q == 'all')
-                if use_per_channel_q:
-                    channels = int(merged_weight.size(0))
-                    scale2 = torch.zeros(channels,1,1,1).to(merged_weight.device)
-                    scale_inv2 = torch.zeros(channels,1,1,1).to(merged_weight.device)
-                    for chan_id in range(channels):
-                        clip_min, clip_max, scale2_value, scale_inv2_value = self.get_clips_scale_w(merged_weight[chan_id])
-                        scale2[chan_id,0,0,0] = scale2_value
-                        scale_inv2[chan_id,0,0,0] = scale_inv2_value
-                    #
-                else:
-                    clip_min, clip_max, scale2, scale_inv2 = self.get_clips_scale_w(merged_weight)
-                #
+                clip_min, clip_max, scale2, scale_inv2 = self.get_clips_scale_w(merged_weight,
+                                                use_per_channel_q=use_per_channel_q, is_deconv=is_deconv)
                 width_min, width_max = self.get_widths_w()
                 width_min, width_max = self.get_widths_w()
-                # merged_weight = layers.clamp_g(layers.round_sym_g(merged_weight * scale2), width_min, width_max-1, self.training) * scale_inv2
-                merged_weight = layers.quantize_dequantize_g(merged_weight, scale2, width_min, width_max-1, self.power2_weight_range, 'round_sym')
+                merged_weight = layers.quantize_dequantize_g(merged_weight, scale2, width_min, width_max-1,
+                                                             self.power2_weight_range, 'round_sym')
             #
 
             if (self.quantize_enable and self.quantize_bias):
                 bias_width_min, bias_width_max = self.get_widths_bias()
                 bias_clip_min, bias_clip_max, bias_scale2, bias_scale_inv2 = self.get_clips_scale_bias(merged_bias)
                 power2_bias_range = (self.power2_weight_range and self.power2_activation_range)
             #
 
             if (self.quantize_enable and self.quantize_bias):
                 bias_width_min, bias_width_max = self.get_widths_bias()
                 bias_clip_min, bias_clip_max, bias_scale2, bias_scale_inv2 = self.get_clips_scale_bias(merged_bias)
                 power2_bias_range = (self.power2_weight_range and self.power2_activation_range)
-                merged_bias = layers.quantize_dequantize_g(merged_bias, bias_scale2, bias_width_min, bias_width_max - 1, power2_bias_range, 'round_sym')
+                merged_bias = layers.quantize_dequantize_g(merged_bias, bias_scale2, bias_width_min, bias_width_max - 1,
+                                                           power2_bias_range, 'round_sym')
             #
 
             # in some cases, bias quantization can have additional restrictions if for example,
             # bias that is being added to accumulator is limited to 16bit.
             # scale factor to be used for bias is the product of scale factors of weight and input
             if self.quantize_enable and self.constrain_bias and is_constrain_bias_iter:
             #
 
             # in some cases, bias quantization can have additional restrictions if for example,
             # bias that is being added to accumulator is limited to 16bit.
             # scale factor to be used for bias is the product of scale factors of weight and input
             if self.quantize_enable and self.constrain_bias and is_constrain_bias_iter:
-                clips_scale_joint = self.get_clips_scale_joint(qparams, merged_weight, merged_bias)
+                clips_scale_joint = self.get_clips_scale_joint(qparams, merged_weight, merged_bias,
+                                                use_per_channel_q=use_per_channel_q, is_deconv=is_deconv)
                 if clips_scale_joint is not None:
                     bias_width_min, bias_width_max = self.get_widths_joint()
                     bias_clip_min, bias_clip_max, bias_scale2, bias_scale_inv2 = clips_scale_joint
                     power2_bias_range = (self.power2_weight_range and self.power2_activation_range)
                 if clips_scale_joint is not None:
                     bias_width_min, bias_width_max = self.get_widths_joint()
                     bias_clip_min, bias_clip_max, bias_scale2, bias_scale_inv2 = clips_scale_joint
                     power2_bias_range = (self.power2_weight_range and self.power2_activation_range)
-                    merged_bias = layers.quantize_dequantize_g(merged_bias, bias_scale2, bias_width_min, bias_width_max - 1, power2_bias_range, 'round_sym')
+                    merged_bias = layers.quantize_dequantize_g(merged_bias, bias_scale2, bias_width_min,
+                                                  bias_width_max - 1, power2_bias_range, 'round_sym')
                 #
             #
 
                 #
             #
 
@@ -509,7 +484,7 @@ class QuantTrainPAct2(layers.PAct2):
 
     def get_clips_w(self, tensor):
         # find the clip values
 
     def get_clips_w(self, tensor):
         # find the clip values
-        w_min, w_max = utils.extrema_fast(tensor.data, percentile_range_shrink=self.range_shrink_weights)
+        w_min, w_max = utils.extrema_fast(tensor.data, range_shrink_percentile=self.range_shrink_weights)
         clip_max = torch.max(torch.abs(w_min), torch.abs(w_max))
         clip_max = torch.clamp(clip_max, min=self.eps)
         # in range learning mode + training - this power2_weight_range is taken care in the quantize function
         clip_max = torch.max(torch.abs(w_min), torch.abs(w_max))
         clip_max = torch.clamp(clip_max, min=self.eps)
         # in range learning mode + training - this power2_weight_range is taken care in the quantize function
@@ -519,13 +494,41 @@ class QuantTrainPAct2(layers.PAct2):
         return (clip_min2, clip_max2)
 
 
         return (clip_min2, clip_max2)
 
 
-    def get_clips_scale_w(self, weight):
+    def get_clips_scale_w(self, weight, use_per_channel_q=False, is_deconv=False):
         clip_min, clip_max = self.get_clips_w(weight)
         width_min, width_max = self.get_widths_w()
         scale2 = (width_max / clip_max)
         scale2 = torch.clamp(scale2, min=self.eps)
         scale_inv2 = scale2.pow(-1.0)
         clip_min, clip_max = self.get_clips_w(weight)
         width_min, width_max = self.get_widths_w()
         scale2 = (width_max / clip_max)
         scale2 = torch.clamp(scale2, min=self.eps)
         scale_inv2 = scale2.pow(-1.0)
-        return (clip_min, clip_max, scale2, scale_inv2)
+        if not use_per_channel_q:
+            return (clip_min, clip_max, scale2, scale_inv2)
+        #
+        # the remaining part of the function is only used in the case of use_per_channel_q
+        # compute the per-channel weight scale.
+        # restrict the weight scale factor of a channel from becoming extremely large.
+        scale_factor_ratio_max = 256 #None
+        channels = int(weight.size(1)) if is_deconv else int(weight.size(0))
+        scale2_array = torch.zeros(1, channels, 1, 1).to(weight.device) if is_deconv else \
+            torch.zeros(channels, 1, 1, 1).to(weight.device)
+        scale_inv2_array = torch.zeros(1, channels, 1, 1).to(weight.device) if is_deconv else \
+            torch.zeros(channels, 1, 1, 1).to(weight.device)
+        for chan_id in range(channels):
+            weight_channel = weight[:,chan_id,...] if is_deconv else weight[chan_id]
+            _, _, scale2_value, scale_inv2_value = self.get_clips_scale_w(weight_channel)
+            scale2_value = torch.min(scale2_value, scale2*scale_factor_ratio_max) \
+                if (scale_factor_ratio_max is not None) else scale2_value
+            scale2_value = torch.clamp(scale2_value, min=self.eps)
+            scale_inv2_value = scale2_value.pow(-1.0)
+            if is_deconv:
+                scale2_array[0, chan_id, 0, 0] = scale2_value
+                scale_inv2_array[0, chan_id, 0, 0] = scale_inv2_value
+            else:
+                scale2_array[chan_id, 0, 0, 0] = scale2_value
+                scale_inv2_array[chan_id, 0, 0, 0] = scale_inv2_value
+            #
+        #
+        return (clip_min, clip_max, scale2_array, scale_inv2_array)
+
 
     ###########################################################
     def get_widths_act(self):
 
     ###########################################################
     def get_widths_act(self):
@@ -570,7 +573,7 @@ class QuantTrainPAct2(layers.PAct2):
 
     ###########################################################
     def get_widths_joint(self):
 
     ###########################################################
     def get_widths_joint(self):
-        bw = (2*self.bitwidth_weights - 1)
+        bw = (2*self.bitwidth_weights-1)
         width_max = np.power(2.0, bw)
         width_min = -width_max
         return (width_min, width_max)
         width_max = np.power(2.0, bw)
         width_min = -width_max
         return (width_min, width_max)
@@ -605,13 +608,27 @@ class QuantTrainPAct2(layers.PAct2):
         #
 
 
         #
 
 
-    def get_clips_scale_joint(self, qparams, weights, bias):
+    def get_clips_scale_joint(self, qparams, weights, bias, use_per_channel_q=False, is_deconv=False):
         clips_scale_input = self.get_clips_scale_input(qparams)
         clips_scale_input = self.get_clips_scale_input(qparams)
-        if clips_scale_input is not None:
-            clip_min_input, clip_max_input, scale2_input, scale_inv2_input = clips_scale_input
-            clip_min_w, clip_max_w, scale2_w, scale_inv2_w = self.get_clips_scale_w(weights)
-            clip_min_bias, clip_max_bias, scale2_bias, scale_inv2_bias = self.get_clips_scale_bias(bias)
-            return (clip_min_bias, clip_max_bias, scale2_w*scale2_input, scale_inv2_w*scale_inv2_input)
-        else:
+        if clips_scale_input is None:
             return None
             return None
-        #
\ No newline at end of file
+        #
+        clip_min_input, clip_max_input, scale2_input, scale_inv2_input = clips_scale_input
+        clip_min_bias, clip_max_bias, scale2_bias, scale_inv2_bias = self.get_clips_scale_bias(bias)
+        if not use_per_channel_q:
+            clip_min_w, clip_max_w, scale2_w, scale_inv2_w = self.get_clips_scale_w(weights)
+            return (clip_min_bias, clip_max_bias, scale2_w * scale2_input, scale_inv2_w * scale_inv2_input)
+        #
+        # the remaining part of the function is only used in the case of use_per_channel_q
+        channels = int(weights.size(1)) if is_deconv else int(weights.size(0))
+        clip_min_w, clip_max_w, scale2_w, scale_inv2_w = self.get_clips_scale_w(weights,
+                                use_per_channel_q=use_per_channel_q, is_deconv=is_deconv)
+        scale2_joint_array = torch.zeros(channels).to(weights.device)
+        scale_joint_inv2_array = torch.zeros(channels).to(weights.device)
+        for chan_id in range(channels):
+            scale_value = scale2_w[0, chan_id, 0, 0] if is_deconv else scale2_w[chan_id, 0, 0, 0]
+            scale_inv_value = scale_inv2_w[0, chan_id, 0, 0] if is_deconv else scale_inv2_w[chan_id, 0, 0, 0]
+            scale2_joint_array[chan_id] = scale_value * scale2_input
+            scale_joint_inv2_array[chan_id] = scale_inv_value * scale_inv2_input
+        #
+        return (clip_min_bias, clip_max_bias, scale2_joint_array, scale_joint_inv2_array)
index 2532bfe9854b47c8e1415f4d10485eb9a5ce9231..6ee2e8d69b2844713f15289a72a3e7e4900b344e 100644 (file)
@@ -1,17 +1,27 @@
 from .print_utils import *
 from .print_utils import *
-from .util_functions import *
-from .utils_data import *
+from .function_utils import *
+from .data_utils import *
 from .load_weights import *
 from .tensor_utils import *
 from .logger import *
 from .load_weights import *
 from .tensor_utils import *
 from .logger import *
-from .utils_hist import *
+from .hist_utils import *
 from .attr_dict import *
 from .weights_utils import *
 from .image_utils import *
 from .module_utils import *
 from .count_flops import forward_count_flops
 from .bn_utils import *
 from .attr_dict import *
 from .weights_utils import *
 from .image_utils import *
 from .module_utils import *
 from .count_flops import forward_count_flops
 from .bn_utils import *
+from .range_utils import *
+
 try: from .tensor_utils_internal import *
 except: pass
 try: from .tensor_utils_internal import *
 except: pass
-try: from .utils_export_internal import *
+
+try: from .export_utils_internal import *
 except: pass
 except: pass
+
+try:
+    from .range_estimator_internal import *
+    has_range_estimator = True
+except:
+    has_range_estimator = False
+    pass
index bf63abbcfd5a7e3c795c33d1019c2e0aac4848be..abe6d22c39584d5b7ed0885dfbe10de2f062e701 100644 (file)
@@ -17,6 +17,18 @@ def forward_count_flops(model, inp):
     _remove_hook(model)
     return num_flops
 
     _remove_hook(model)
     return num_flops
 
+def fw_bw_count_flops(model, inp):
+    _add_hook_fw_bw(model, _count_fw_bw_flops_func)
+    _ = model(inp)
+    num_fw_bw_flops = 0
+    op_vol = 0
+    for m in model.modules():
+        num_fw_bw_flops += m.__num_fw_bw_flops__
+        op_vol += m.__op_vol__
+    #
+    _remove_hook_fw_bw(model)
+    return num_fw_bw_flops, op_vol
+
 
 def count_params(model):
     layer_params = [p.numel() for p in model.parameters if p.requires_grad]
 
 def count_params(model):
     layer_params = [p.numel() for p in model.parameters if p.requires_grad]
@@ -42,7 +54,29 @@ def _count_flops_func(m, inp, out):
     else:
         m.__num_flops__ = 0
     #
     else:
         m.__num_flops__ = 0
     #
-
+def _count_fw_bw_flops_func(m, inp, out):
+    # trained calibration/quantization can do model surgery and return extra outputs - ignroe them
+    if isinstance(out, (list,tuple)):
+        out = out[0]
+    #
+    if module_utils.is_conv_deconv(m):
+        num_pixels = (out.shape[2] * out.shape[3])
+        # Note: channels_in taken from weight shape is already divided by m.groups - no need to divide again
+        channels_out, channels_in, kernel_height, kernel_width = m.weight.shape
+        macs_per_pixel = (channels_out * channels_in *  kernel_height * kernel_width)
+        num_flops = 2 * macs_per_pixel * num_pixels
+        if hasattr(m, 'bias') and (m.bias is not None):
+            num_flops += m.weight.shape[0]
+        #1 for fw + 2 for bw
+        m.__num_fw_bw_flops__ = num_flops*3
+        m.__op_vol__ = channels_out*(out.shape[2] * out.shape[3])
+    elif module_utils.is_bn(m):
+        channels_out = m.weight.shape[0]
+        out_vol = channels_out*(out.shape[2] * out.shape[3])
+        num_flops = 2 * out_vol
+        # 1 for fw + 2 for bw + 1 for recomputing BN to save memory
+        m.__num_fw_bw_flops__ = num_flops*4
+        m.__op_vol__ = 0
 
 def _add_hook(module, hook_func):
     for m in module.modules():
 
 def _add_hook(module, hook_func):
     for m in module.modules():
@@ -50,6 +84,13 @@ def _add_hook(module, hook_func):
         m.__num_flops__ = 0
 
 
         m.__num_flops__ = 0
 
 
+def _add_hook_fw_bw(module, hook_func):
+    for m in module.modules():
+        m.__count_fw_bw_flops_hook__ = m.register_forward_hook(hook_func)
+        m.__num_fw_bw_flops__ = 0
+        m.__op_vol__ = 0
+
+
 def _remove_hook(module):
     for m in module.modules():
         if hasattr(m, '__count_flops_hook__'):
 def _remove_hook(module):
     for m in module.modules():
         if hasattr(m, '__count_flops_hook__'):
@@ -59,3 +100,11 @@ def _remove_hook(module):
             del m.__num_flops__
         #
 
             del m.__num_flops__
         #
 
+def _remove_hook_fw_bw(module):
+    for m in module.modules():
+        if hasattr(m, '__count_fw_bw_flops_hook__'):
+            m.__count_fw_bw_flops_hook__.remove()
+        #
+        if hasattr(m, '__num_fw_bw_flops__'):
+            del m.__num_fw_bw_flops__
+
diff --git a/modules/pytorch_jacinto_ai/xnn/utils/data_utils.py b/modules/pytorch_jacinto_ai/xnn/utils/data_utils.py
new file mode 100644 (file)
index 0000000..f0313f2
--- /dev/null
@@ -0,0 +1,284 @@
+# from torchvision.datasets
+import os
+import os.path
+import hashlib
+import gzip
+import errno
+import tarfile
+import zipfile
+
+import torch
+from torch.utils.model_zoo import tqdm
+
+
+def gen_bar_updater():
+    pbar = tqdm(total=None)
+
+    def bar_update(count, block_size, total_size):
+        if pbar.total is None and total_size:
+            pbar.total = total_size
+        progress_bytes = count * block_size
+        pbar.update(progress_bytes - pbar.n)
+
+    return bar_update
+
+
+def calculate_md5(fpath, chunk_size=1024 * 1024):
+    md5 = hashlib.md5()
+    with open(fpath, 'rb') as f:
+        for chunk in iter(lambda: f.read(chunk_size), b''):
+            md5.update(chunk)
+    return md5.hexdigest()
+
+
+def check_md5(fpath, md5, **kwargs):
+    return md5 == calculate_md5(fpath, **kwargs)
+
+
+def check_integrity(fpath, md5=None):
+    if not os.path.isfile(fpath):
+        return False
+    if md5 is None:
+        return True
+    return check_md5(fpath, md5)
+
+
+def makedir_exist_ok(dirpath):
+    """
+    Python2 support for os.makedirs(.., exist_ok=True)
+    """
+    try:
+        os.makedirs(dirpath)
+    except OSError as e:
+        if e.errno == errno.EEXIST:
+            pass
+        else:
+            raise
+
+
+def download_url(url, root, filename=None, md5=None):
+    """Download a file from a url and place it in root.
+
+    Args:
+        url (str): URL to download file from
+        root (str): Directory to place downloaded file in
+        filename (str, optional): Name to save the file under. If None, use the basename of the URL
+        md5 (str, optional): MD5 checksum of the download. If None, do not check
+    """
+    from six.moves import urllib
+
+    root = os.path.expanduser(root)
+    if not filename:
+        filename = os.path.basename(url)
+    fpath = os.path.join(root, filename)
+
+    makedir_exist_ok(root)
+
+    # downloads file
+    if check_integrity(fpath, md5):
+        print('Using downloaded and verified file: ' + fpath)
+    else:
+        try:
+            print('Downloading ' + url + ' to ' + fpath)
+            urllib.request.urlretrieve(
+                url, fpath,
+                reporthook=gen_bar_updater()
+            )
+        except (urllib.error.URLError, IOError) as e:
+            if url[:5] == 'https':
+                url = url.replace('https:', 'http:')
+                print('Failed download. Trying https -> http instead.'
+                      ' Downloading ' + url + ' to ' + fpath)
+                urllib.request.urlretrieve(
+                    url, fpath,
+                    reporthook=gen_bar_updater()
+                )
+            else:
+                raise e
+
+    return fpath
+
+
+def list_dir(root, prefix=False):
+    """List all directories at a given root
+
+    Args:
+        root (str): Path to directory whose folders need to be listed
+        prefix (bool, optional): If true, prepends the path to each result, otherwise
+            only returns the name of the directories found
+    """
+    root = os.path.expanduser(root)
+    directories = list(
+        filter(
+            lambda p: os.path.isdir(os.path.join(root, p)),
+            os.listdir(root)
+        )
+    )
+
+    if prefix is True:
+        directories = [os.path.join(root, d) for d in directories]
+
+    return directories
+
+
+def list_files(root, suffix, prefix=False):
+    """List all files ending with a suffix at a given root
+
+    Args:
+        root (str): Path to directory whose folders need to be listed
+        suffix (str or tuple): Suffix of the files to match, e.g. '.png' or ('.jpg', '.png').
+            It uses the Python "str.endswith" method and is passed directly
+        prefix (bool, optional): If true, prepends the path to each result, otherwise
+            only returns the name of the files found
+    """
+    root = os.path.expanduser(root)
+    files = list(
+        filter(
+            lambda p: os.path.isfile(os.path.join(root, p)) and p.endswith(suffix),
+            os.listdir(root)
+        )
+    )
+
+    if prefix is True:
+        files = [os.path.join(root, d) for d in files]
+
+    return files
+
+
+def download_file_from_google_drive(file_id, root, filename=None, md5=None):
+    """Download a Google Drive file from  and place it in root.
+
+    Args:
+        file_id (str): id of file to be downloaded
+        root (str): Directory to place downloaded file in
+        filename (str, optional): Name to save the file under. If None, use the id of the file.
+        md5 (str, optional): MD5 checksum of the download. If None, do not check
+    """
+    # Based on https://stackoverflow.com/questions/38511444/python-download-files-from-google-drive-using-url
+    import requests
+    url = "https://docs.google.com/uc?export=download"
+
+    root = os.path.expanduser(root)
+    if not filename:
+        filename = file_id
+    fpath = os.path.join(root, filename)
+
+    makedir_exist_ok(root)
+
+    if os.path.isfile(fpath) and check_integrity(fpath, md5):
+        print('Using downloaded and verified file: ' + fpath)
+    else:
+        session = requests.Session()
+
+        response = session.get(url, params={'id': file_id}, stream=True)
+        token = _get_confirm_token(response)
+
+        if token:
+            params = {'id': file_id, 'confirm': token}
+            response = session.get(url, params=params, stream=True)
+
+        _save_response_content(response, fpath)
+
+
+def _get_confirm_token(response):
+    for key, value in response.cookies.items():
+        if key.startswith('download_warning'):
+            return value
+
+    return None
+
+
+def _save_response_content(response, destination, chunk_size=32768):
+    with open(destination, "wb") as f:
+        pbar = tqdm(total=None)
+        progress = 0
+        for chunk in response.iter_content(chunk_size):
+            if chunk:  # filter out keep-alive new chunks
+                f.write(chunk)
+                progress += len(chunk)
+                pbar.update(progress - pbar.n)
+        pbar.close()
+
+
+def _is_tar(filename):
+    return filename.endswith(".tar")
+
+
+def _is_targz(filename):
+    return filename.endswith(".tar.gz")
+
+
+def _is_gzip(filename):
+    return filename.endswith(".gz") and not filename.endswith(".tar.gz")
+
+
+def _is_zip(filename):
+    return filename.endswith(".zip")
+
+
+def extract_archive(from_path, to_path=None, remove_finished=False):
+    if to_path is None:
+        to_path = os.path.dirname(from_path)
+
+    if _is_tar(from_path):
+        with tarfile.open(from_path, 'r') as tar:
+            tar.extractall(path=to_path)
+    elif _is_targz(from_path):
+        with tarfile.open(from_path, 'r:gz') as tar:
+            tar.extractall(path=to_path)
+    elif _is_gzip(from_path):
+        to_path = os.path.join(to_path, os.path.splitext(os.path.basename(from_path))[0])
+        with open(to_path, "wb") as out_f, gzip.GzipFile(from_path) as zip_f:
+            out_f.write(zip_f.read())
+    elif _is_zip(from_path):
+        with zipfile.ZipFile(from_path, 'r') as z:
+            z.extractall(to_path)
+    else:
+        raise ValueError("Extraction of {} not supported".format(from_path))
+
+    if remove_finished:
+        os.remove(from_path)
+
+
+def download_and_extract_archive(url, download_root, extract_root=None, filename=None,
+                                 md5=None, remove_finished=False):
+    download_root = os.path.expanduser(download_root)
+    if extract_root is None:
+        extract_root = download_root
+    if not filename:
+        filename = os.path.basename(url)
+
+    download_url(url, download_root, filename, md5)
+
+    archive = os.path.join(download_root, filename)
+    print("Extracting {} to {}".format(archive, extract_root))
+    extract_archive(archive, extract_root, remove_finished)
+
+
+def iterable_to_str(iterable):
+    return "'" + "', '".join([str(item) for item in iterable]) + "'"
+
+
+def verify_str_arg(value, arg=None, valid_values=None, custom_msg=None):
+    if not isinstance(value, torch._six.string_classes):
+        if arg is None:
+            msg = "Expected type str, but got type {type}."
+        else:
+            msg = "Expected type str for argument {arg}, but got type {type}."
+        msg = msg.format(type=type(value), arg=arg)
+        raise ValueError(msg)
+
+    if valid_values is None:
+        return value
+
+    if value not in valid_values:
+        if custom_msg is not None:
+            msg = custom_msg
+        else:
+            msg = ("Unknown value '{value}' for argument {arg}. "
+                   "Valid values are {{{valid_values}}}.")
+            msg = msg.format(value=value, arg=arg,
+                             valid_values=iterable_to_str(valid_values))
+        raise ValueError(msg)
+
+    return value
diff --git a/modules/pytorch_jacinto_ai/xnn/utils/depth_utils.py b/modules/pytorch_jacinto_ai/xnn/utils/depth_utils.py
new file mode 100644 (file)
index 0000000..c630286
--- /dev/null
@@ -0,0 +1,76 @@
+import sys
+import numpy as np
+import cv2
+import torch
+import os
+import matplotlib.pyplot as plt
+import math
+
+# FIX_ME:SN move to utils
+def fish_cord_to_rect(dx_f=0, dy_f=0, theta_to_r_rect=[], params=[]):
+    f_x = params.cam_info_K[0]
+    f = f_x
+    r_f = np.sqrt(dx_f * dx_f + dy_f * dy_f)
+    calib_idx = int((r_f - params.fisheye_r_start) / params.fisheye_r_step)
+    calib_idx = max(0, min(calib_idx, len(theta_to_r_rect) - 2))
+    # getting theta using interpolation from theta_to_r_rect and then r
+    r_f_lower = params.fisheye_r_start + calib_idx * params.fisheye_r_step
+    theta = theta_to_r_rect[calib_idx] + (theta_to_r_rect[calib_idx + 1] - theta_to_r_rect[calib_idx]) * (
+            r_f - r_f_lower) / params.fisheye_r_step
+
+    r_rect = f * abs(math.tan(theta * np.pi / 180.0))
+    if (r_f != 0.0):
+        dx_r = dx_f * r_rect / r_f
+        dy_r = dy_f * r_rect / r_f
+    else:
+        dx_r = 0.0
+        dy_r = 0.0
+
+    return dx_r, dy_r
+
+
+def ZtoAbsYForaPixel(Zc=0.0, y_f=0.0, x_f=0.0, params=[]):
+    cx_f = params.cam_info_K[2]
+    cy_f = params.cam_info_K[5]
+    f_x = params.cam_info_K[0]
+    f_y = params.cam_info_K[4]
+
+    dy_f = y_f - cy_f
+    dx_f = x_f - cx_f
+
+    dx_r, dy_r = fish_cord_to_rect(dx_f=dx_f, dy_f=dy_f, theta_to_r_rect=r_fish_to_theta_rect, params=params)
+    Yc = dy_r * Zc / f_y
+    Xc = dx_r * Zc / f_x
+    M_c_l_r = np.array([[0.0046, 1.0000, -0.0061],
+                        [0.45353, -0.0075, -0.8912],
+                        [-0.8913, 0.0014, -0.4535]])  # Rotation matrix for TIAD_2017_seq1
+
+    M_c_l_t = np.array([[0.6872, 0.2081, -2.2312]])
+    M_c_l = np.vstack((np.hstack((M_c_l_r, M_c_l_t.transpose())), np.array([0, 0, 0, 1])))
+    M_l_c = np.linalg.inv(M_c_l)
+    print('Yc', Yc)
+    [Xl, Yl, Zl, _] = np.matmul(M_l_c, np.array([Xc, Yc, Zc, 1]))
+    print('Yl', Yl)
+    return abs(Yl)
+
+
+# This function converts Z to Y using fisheye model  + pinhole camera geometry
+def ZtoY(image_Z=None):
+    class Params:
+        def __init__(self):
+            # camera params
+
+            self.cam_info_K = np.array([311.8333, 0.0000, 640.0000,
+                                        0.0000, 311.8333, 360.0000,
+                                        0.0000, 0.0000, 1.0000])
+            self.fisheye_r_start = 0.0
+            self.fisheye_r_step = 0.5
+
+    params = Params()
+    image_Y = image_Z
+    height, width = image_Z.shape
+
+    for row in range(height):
+        for col in range(width):
+            image_Y[row, col] = ZtoAbsYForaPixel(Zc=image_Z[row, col], y_f=row, x_f=col, params=params)
+    return image_Y
diff --git a/modules/pytorch_jacinto_ai/xnn/utils/function_utils.py b/modules/pytorch_jacinto_ai/xnn/utils/function_utils.py
new file mode 100644 (file)
index 0000000..d7c1228
--- /dev/null
@@ -0,0 +1,82 @@
+import os
+import numpy as np
+import cv2
+import torch
+from .. import layers as xtensor_layers
+from .image_utils import *
+
+##################################################
+# a utility function used for argument parsing
+def str2bool(v):
+  if isinstance(v, (str)):
+      if v.lower() in ("yes", "true", "t", "1"):
+          return True
+      elif v.lower() in ("no", "false", "f", "0"):
+          return False
+      else:
+          return v
+      #
+  else:
+      return v
+
+def splitstr2bool(v):
+  v = v.split(',')
+  for index, args in enumerate(v):
+      v[index] = str2bool(args)
+  return v
+
+
+#########################################################################
+def make_divisible(value, factor, min_value=None):
+    """
+    Inspired by https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet.py
+    """
+    min_value = factor if min_value is None else min_value
+    round_factor = factor/2
+    quotient = int(value + round_factor) // factor
+    value_multiple = max(quotient * factor, min_value)
+    # make sure that the change is contained
+    if value_multiple < 0.9*value:
+        value_multiple = value_multiple + factor
+    #
+    return int(value_multiple)
+
+
+def make_divisible_by8(v):
+    return make_divisible(v, 8)
+
+
+#########################################################################
+def recursive_glob(rootdir='.', suffix=''):
+    """Performs recursive glob with given suffix and rootdir
+        :param rootdir is the root directory
+        :param suffix is the suffix to be searched
+    """
+    return [os.path.join(looproot, filename)
+            for looproot, _, filenames in os.walk(rootdir)
+            for filename in filenames if filename.endswith(suffix)]
+
+
+###############################################################
+def get_shape_with_stride(in_shape, stride):
+    shape_s = [in_shape[0],in_shape[1],in_shape[2]//stride,in_shape[3]//stride]
+    if (int(in_shape[2]) % 2) == 1:
+        shape_s[2] += 1
+    if (int(in_shape[3]) % 2) == 1:
+        shape_s[3] += 1
+    return shape_s
+
+
+def get_blob_from_list(x_list, search_shape, start_dim=None):
+    x_ret = None
+    start_dim = start_dim if start_dim is not None else 0
+    for x in x_list:
+        if isinstance(x, list):
+            x = torch.cat(x,dim=1)
+        #
+        if (x.shape[start_dim:] == torch.Size(search_shape[start_dim:])):
+            x_ret = x
+        #
+    return x_ret
+
+
diff --git a/modules/pytorch_jacinto_ai/xnn/utils/hist_utils.py b/modules/pytorch_jacinto_ai/xnn/utils/hist_utils.py
new file mode 100644 (file)
index 0000000..ce17062
--- /dev/null
@@ -0,0 +1,121 @@
+import sys
+import numpy as np
+import cv2
+import torch
+import os
+import matplotlib.pyplot as plt
+
+#from .. import layers as xtensor_layers
+
+#study histogram for 3D tensor. The 1st dim belongs to ch
+#study 2D histogram of each channel
+def comp_hist_tensor3d(x=[], name='tensor', en=True, dir = 'dir_name', log = False, ch_dim=2):
+    if en == False:
+        return
+    root = os.getcwd()
+    path = root + '/checkpoints/debug/'+ dir
+    if not os.path.exists(path):
+        os.makedirs(path)
+
+    if ch_dim == 2:
+        for ch in range(x.shape[ch_dim]):
+            x_ch = x[:,:,ch]
+            print('min={}, max={}, std={}, mean={}'.format(x_ch.min(), x_ch.max(), x_ch.std(), x_ch.mean()))
+            plt.hist(x_ch.flatten(), bins=256, log=log)  # arguments are passed to np.histogram
+            #plt.title("Histogram with 'auto' bins")
+            #plt.show()
+            plt.savefig('{}/{}_{:03d}.jpg'.format(path, name, ch))
+            plt.close()
+    elif ch_dim == 0:
+        for ch in range(x.shape[ch_dim]):
+            x_ch = x[ch,:,:]
+            print('min={}, max={}, std={}, mean={}'.format(x_ch.min(), x_ch.max(), x_ch.std(), x_ch.mean()))
+            plt.hist(x_ch.flatten(), bins=256, log=log)  # arguments are passed to np.histogram
+            #plt.title("Histogram with 'auto' bins")
+            #plt.show()
+            plt.savefig('{}/{}_{:03d}.jpg'.format(path, name, ch))
+            plt.close()
+
+def hist_tensor2D(x_ch=[], dir = 'tensor_dir', name='tensor', en=True, log=False, ch=0):
+    if en == False:
+        return
+    root = os.getcwd()
+    path = root + '/checkpoints/debug/'+ dir
+    if not os.path.exists(path):
+        os.makedirs(path)
+    
+    print('min={:.3f}, max={:.3f}, std={:.3f}, mean={:.3f}'.format(x_ch.min(), x_ch.max(), x_ch.std(), x_ch.mean()))
+    hist_ary = plt.hist(x_ch.flatten(), bins=256, log=log)  # arguments are passed to np.histogram
+    
+
+    #plt.title("Histogram with 'auto' bins")
+    #plt.show()
+    plt.savefig('{}/{}_{:03d}.jpg'.format(path, name,ch))
+    plt.close()
+    return hist_ary
+
+def analyze_model(model):
+    num_dead_ch = 0
+    for n, m in model.named_modules():
+        if isinstance(m, torch.nn.Conv2d):
+            if m.weight.shape[1] == 1:
+                for ch in range(m.weight.shape[0]):
+                    cur_ch_wt = m.weight[ch][0][...]
+                    mn = cur_ch_wt.min()
+                    mx = cur_ch_wt.max()
+                    mn = mn.cpu().detach().numpy()
+                    mx = mx.cpu().detach().numpy()
+                    print(n, 'dws weight ch mn mx', ch, mn, mx)
+                    #print(n, 'dws weight ch', ch, cur_ch_wt)
+                    if max(abs(mn), abs(mx)) <= 1E-40:
+                        num_dead_ch += 1
+            else:
+                print(n, 'weight', 'shape', m.weight.shape, m.weight.min(), m.weight.max())
+                if m.bias is not None:
+                    print(n, 'bias', m.bias.min(), m.bias.max())        
+
+    print("num_dead_ch: ", num_dead_ch)                
+
+def study_wts(self, modules):
+    for key, value in modules.items():
+        print(key, value)
+        for key2, value2 in value._modules.items():
+            print(key2, value2)
+            print(value2.weight.shape)
+
+
+def comp_hist(self, x=[], ch_idx=0, name='tensor'):
+    #hist_pred = torch.histc(x.cpu(), bins=256)
+    root = os.getcwd()
+    path = root + '/checkpoints/debug/'+ name
+    if not os.path.exists(path):
+        os.makedirs(path)
+
+    #print(hist_pred)
+    for ch in range(x.shape[1]):
+        x_ch = x[0][ch]
+        plt.hist(x_ch.view(-1).cpu().numpy(), bins=256)  # arguments are passed to np.histogram
+        print('min={}, max={}, std={}, mean={}'.format(x_ch.min(), x_ch.max(), x_ch.std(), x_ch.mean()))
+        #plt.title("Histogram with 'auto' bins")
+        #plt.show()
+        plt.savefig('{}/{}_{:03d}.jpg'.format(path, name, ch))
+        plt.close()
+
+
+def store_layer_op(en=False, tensor= [], name='tensor_name'):
+    if en == False:
+        return
+
+    # write tensor
+    tensor = tensor.astype(np.int16)
+    print("writing tensor {} : {} : {} : {} : {}".format(name, tensor.shape, tensor.dtype, tensor.min(), tensor.max()))
+
+    root = os.getcwd()
+    tensor_dir = root + '/checkpoints/debug/' + name
+
+    if not os.path.exists(tensor_dir):
+        os.makedirs(tensor_dir)
+
+    tensor_name = tensor_dir + "{}.npy".format(name)
+    np.save(tensor_name, tensor)
+    comp_hist_tensor3d(x=tensor, name=name, en=True, dir=name, log=True, ch_dim=0)
\ No newline at end of file
index 5a933682a6f3b17edb60b7a5c9dc2b8a7895a115..9208f2a483816fbe22231b1d2b9f67fd89e4b437 100644 (file)
@@ -7,7 +7,7 @@ import torch
 import copy
 from collections import OrderedDict
 from . import print_utils
 import copy
 from collections import OrderedDict
 from . import print_utils
-from . import utils_data
+from . import data_utils
 
 ######################################################
 # the method used in vision/models
 
 ######################################################
 # the method used in vision/models
@@ -20,7 +20,8 @@ except ImportError:
 ######################################################
 # our custom load function with more features
 def load_weights(model, pretrained, change_names_dict=None, keep_original_names=False, width_mult=1.0,
 ######################################################
 # our custom load function with more features
 def load_weights(model, pretrained, change_names_dict=None, keep_original_names=False, width_mult=1.0,
-                       ignore_size=True, verbose=False, num_batches_tracked = None, download_root=None, **kwargs):
+                        ignore_size=True, verbose=False, num_batches_tracked = None, download_root=None,
+                        state_dict_name='state_dict', **kwargs):
     download_root = './' if (download_root is None) else download_root
     if pretrained is None or pretrained is False:
         print_utils.print_yellow(f'=> weights could not be loaded. pretrained data given is {pretrained}')
     download_root = './' if (download_root is None) else download_root
     if pretrained is None or pretrained is False:
         print_utils.print_yellow(f'=> weights could not be loaded. pretrained data given is {pretrained}')
@@ -28,7 +29,7 @@ def load_weights(model, pretrained, change_names_dict=None, keep_original_names=
 
     if isinstance(pretrained, str):
         if pretrained.startswith('http://') or pretrained.startswith('https://'):
 
     if isinstance(pretrained, str):
         if pretrained.startswith('http://') or pretrained.startswith('https://'):
-            pretrained_file = utils_data.download_url(pretrained, root=download_root)
+            pretrained_file = data_utils.download_url(pretrained, root=download_root)
         else:
             pretrained_file = pretrained
         #
         else:
             pretrained_file = pretrained
         #
@@ -38,7 +39,9 @@ def load_weights(model, pretrained, change_names_dict=None, keep_original_names=
     #
 
     load_error = False
     #
 
     load_error = False
-    data = data['state_dict'] if ((data is not None) and 'state_dict' in data) else data
+    state_dict_names = state_dict_name if isinstance(state_dict_name, (list,tuple)) else [state_dict_name]
+    for s_name in state_dict_names:
+        data = data[s_name] if ((data is not None) and s_name in data) else data
 
     if width_mult != 1.0:
         data = widen_model_data(data, factor=width_mult)
 
     if width_mult != 1.0:
         data = widen_model_data(data, factor=width_mult)
diff --git a/modules/pytorch_jacinto_ai/xnn/utils/range_utils.py b/modules/pytorch_jacinto_ai/xnn/utils/range_utils.py
new file mode 100644 (file)
index 0000000..3e4f4ca
--- /dev/null
@@ -0,0 +1,97 @@
+import random
+import torch
+from ..layers import functional
+
+
+########################################################################
+# pytorch implementation of a single tensor range
+########################################################################
+
+def extrema_fast(src, range_shrink_percentile=0.0, sigma=0.0, fast_mode=True):
+    return extrema(src, range_shrink_percentile, sigma, fast_mode)
+
+
+def extrema(src, range_shrink_percentile=0.0, sigma=0.0, fast_mode=False):
+    if range_shrink_percentile == 0 and sigma == 0:
+        mn = src.min()
+        mx = src.max()
+        return (mn, mx)
+    elif range_shrink_percentile:
+        # downsample for fast_mode
+        hist_array, mn, mx, mult_factor, offset = tensor_histogram(src, fast_mode=fast_mode)
+
+        new_mn_scaled, new_mx_scaled = extrema_hist_search(hist_array, range_shrink_percentile)
+        new_mn = (new_mn_scaled / mult_factor) + offset
+        new_mx = (new_mx_scaled / mult_factor) + offset
+
+        # take care of floating point inaccuracies that can
+        # increase the range (in rare cases) beyond the actual range.
+        new_mn = max(mn, new_mn)
+        new_mx = min(mx, new_mx)
+        return new_mn, new_mx
+    elif sigma:
+        mean = torch.mean(src)
+        std = torch.std(src)
+        mn = mean - sigma*std
+        mx = mean + sigma*std
+        return mn, mx
+    else:
+        assert False, 'unknown extrema computation mode'
+
+
+def tensor_histogram(src, fast_mode=False):
+    # downsample for fast_mode
+    fast_stride = 2
+    fast_stride2 = fast_stride * 2
+    if fast_mode and len(src.size()) == 4 and (src.size(2) > fast_stride2) and (src.size(3) > fast_stride2):
+        r_start = random.randint(0, fast_stride - 1)
+        c_start = random.randint(0, fast_stride - 1)
+        src = src[..., r_start::fast_stride, c_start::fast_stride]
+    #
+    mn = src.min()
+    mx = src.max()
+    if mn == 0 and mx == 0:
+        return mn, mx
+    #
+
+    # compute range_shrink_percentile based min/max
+    # frequency - bincount can only operate on unsigned
+    num_bins = 255.0
+    cum_freq = float(100.0)
+    offset = mn
+    range_val = torch.abs(mx - mn)
+    mult_factor = (num_bins / range_val)
+    tensor_int = (src.contiguous().view(-1) - offset) * mult_factor
+    tensor_int = functional.round_g(tensor_int).int()
+
+    # numpy version
+    # hist = np.bincount(tensor_int.cpu().numpy())
+    # hist_sum = np.sum(hist)
+    # hist_array = hist.astype(np.float32) * cum_freq / float(hist_sum)
+
+    # torch version
+    hist = torch.bincount(tensor_int)
+    hist_sum = torch.sum(hist)
+    hist = hist.float() * cum_freq / hist_sum.float()
+    hist_array = hist.cpu().numpy()
+    return hist_array, mn, mx, mult_factor, offset
+
+
+# this code is not parallelizable. better to pass a numpy array
+def extrema_hist_search(hist_array, range_shrink_percentile):
+    new_mn_scaled = 0
+    new_mx_scaled = len(hist_array) - 1
+    hist_sum_left = 0.0
+    hist_sum_right = 0.0
+    for h_idx in range(len(hist_array)):
+        r_idx = len(hist_array) - 1 - h_idx
+        hist_sum_left += hist_array[h_idx]
+        hist_sum_right += hist_array[r_idx]
+        if hist_sum_left < range_shrink_percentile:
+            new_mn_scaled = h_idx
+        if hist_sum_right < range_shrink_percentile:
+            new_mx_scaled = r_idx
+        #
+    #
+    return new_mn_scaled, new_mx_scaled
+
index 70c90c54118edeeb00f45b13a9c347d05cc23933..e171397b7b90578393d8d5abe6bddf1ae3d14c4b 100644 (file)
@@ -31,90 +31,6 @@ def signed_pow(x, base):
     return y
 
 
     return y
 
 
-###############################################################
-def extrema_fast(src, percentile_range_shrink=0.0, sigma=0.0, fast_mode=True):
-    return extrema(src, percentile_range_shrink, sigma, fast_mode)
-
-
-def extrema(src, percentile_range_shrink=0.0, sigma=0.0, fast_mode=False):
-    if percentile_range_shrink == 0 and sigma == 0:
-        mn = src.min()
-        mx = src.max()
-        return (mn, mx)
-    elif percentile_range_shrink:
-        # downsample for fast_mode
-        fast_stride = 2
-        fast_stride2 = fast_stride*2
-        if fast_mode and len(src.size())==4 and (src.size(2)>fast_stride2) and (src.size(3)>fast_stride2):
-            r_start = random.randint(0, fast_stride-1)
-            c_start = random.randint(0, fast_stride-1)
-            src = src[..., r_start::fast_stride, c_start::fast_stride]
-        #
-        mn = src.min()
-        mx = src.max()
-        if mn ==0 and mx == 0:
-            return mn, mx
-        #
-
-        # compute percentile_range_shrink based min/max
-        # frequency - bincount can only operate on unsigned
-        num_bins = 255.0
-        cum_freq = float(100.0)
-        offset = mn
-        range_val = torch.abs(mx - mn)
-        mult_factor = (num_bins / range_val)
-        tensor_int = (src.contiguous().view(-1) - offset) * mult_factor
-        tensor_int = functional.round_g(tensor_int).int()
-
-        # numpy version
-        #hist = np.bincount(tensor_int.cpu().numpy())
-        #hist_sum = np.sum(hist)
-        #hist_array = hist.astype(np.float32) * cum_freq / float(hist_sum)
-
-        # torch version
-        hist = torch.bincount(tensor_int)
-        hist_sum = torch.sum(hist)
-        hist = hist.float() * cum_freq / hist_sum.float()
-        hist_array = hist.cpu().numpy()
-
-        new_mn_scaled, new_mx_scaled = extrema_hist_search(hist_array, percentile_range_shrink)
-        new_mn = (new_mn_scaled / mult_factor) + offset
-        new_mx = (new_mx_scaled / mult_factor) + offset
-
-        # take care of floating point inaccuracies that can
-        # increase the range (in rare cases) beyond the actual range.
-        new_mn = max(mn, new_mn)
-        new_mx = min(mx, new_mx)
-        return new_mn, new_mx
-    elif sigma:
-        mean = torch.mean(src)
-        std = torch.std(src)
-        mn = mean - sigma*std
-        mx = mean + sigma*std
-        return mn, mx
-    else:
-        assert False, 'unknown extrema computation mode'
-
-
-# this code is not parallelizable. better to pass a numpy array
-def extrema_hist_search(hist_array, percentile_range_shrink):
-    new_mn_scaled = 0
-    new_mx_scaled = len(hist_array) - 1
-    hist_sum_left = 0.0
-    hist_sum_right = 0.0
-    for h_idx in range(len(hist_array)):
-        r_idx = len(hist_array) - 1 - h_idx
-        hist_sum_left += hist_array[h_idx]
-        hist_sum_right += hist_array[r_idx]
-        if hist_sum_left < percentile_range_shrink:
-            new_mn_scaled = h_idx
-        if hist_sum_right < percentile_range_shrink:
-            new_mx_scaled = r_idx
-        #
-    #
-    return new_mn_scaled, new_mx_scaled
-
-
 ##################################################################
 def check_sizes(input, input_name, expected):
     condition = [input.ndimension() == len(expected)]
 ##################################################################
 def check_sizes(input, input_name, expected):
     condition = [input.ndimension() == len(expected)]
index af1f6ccff959ce5827b0b2830eb2e0fa13c1ebb9..71491599bfabf9c41f7f5cd39af5c4dda1a38b19 100644 (file)
@@ -1,9 +1,8 @@
 from .. import mobilenetv2
 from .. import mobilenetv1
 from .. import resnet
 from .. import mobilenetv2
 from .. import mobilenetv1
 from .. import resnet
-from ..shufflenetv2 import shufflenet_v2_x1_0
-from ..shufflenetv2 import shufflenet_v2_x1_5
-from ..shufflenetv2 import shufflenet_v2_x2_0
+from ..shufflenetv2 import shufflenet_v2_x1_0, shufflenet_v2_x1_5, shufflenet_v2_x2_0
+from .. import regnet
 
 try: from .. import mobilenetv2_internal
 except: pass
 
 try: from .. import mobilenetv2_internal
 except: pass
@@ -28,6 +27,7 @@ from .... import xnn
 
 __all__ = ['mobilenetv1_x1', 'mobilenetv2_tv_x1', 'mobilenetv2_x1', 'mobilenetv2_tv_x2_t2',
            'resnet50_x1', 'resnet50_xp5', 'resnet18_x1',
 
 __all__ = ['mobilenetv1_x1', 'mobilenetv2_tv_x1', 'mobilenetv2_x1', 'mobilenetv2_tv_x2_t2',
            'resnet50_x1', 'resnet50_xp5', 'resnet18_x1',
+           'regnetx800mf_x1',
            # experimental
            'mobilenetv2_ericsun_x1', 'mobilenetv2_shicai_x1',
            'mobilenetv2_tv_gws_x1', 'flownetslite_base_x1', 'mobilenetv1_multi_label_x1',
            # experimental
            'mobilenetv2_ericsun_x1', 'mobilenetv2_shicai_x1',
            'mobilenetv2_tv_gws_x1', 'flownetslite_base_x1', 'mobilenetv1_multi_label_x1',
@@ -35,14 +35,16 @@ __all__ = ['mobilenetv1_x1', 'mobilenetv2_tv_x1', 'mobilenetv2_x1', 'mobilenetv2
 
 
 #####################################################################
 
 
 #####################################################################
-def resnet50_x1(model_config=None, pretrained=None):
+def resnet50_x1(model_config=None, pretrained=None, change_names_dict=None):
     model_config = resnet.get_config().merge_from(model_config)
     model = resnet.resnet50_with_model_config(model_config)
     # the pretrained model provided by torchvision and what is defined here differs slightly
     # note: that this change_names_dict  will take effect only if the direct load fails
     model_config = resnet.get_config().merge_from(model_config)
     model = resnet.resnet50_with_model_config(model_config)
     # the pretrained model provided by torchvision and what is defined here differs slightly
     # note: that this change_names_dict  will take effect only if the direct load fails
-    change_names_dict = {'^conv1.': 'features.conv1.', '^bn1.': 'features.bn1.',
-                         '^relu.': 'features.relu.', '^maxpool.': 'features.maxpool.',
-                         '^layer': 'features.layer', '^fc.': 'classifier.'}
+    if change_names_dict is None:
+        change_names_dict = {'^conv1.': 'features.conv1.', '^bn1.': 'features.bn1.',
+                             '^relu.': 'features.relu.', '^maxpool.': 'features.maxpool.',
+                             '^layer': 'features.layer', '^fc.': 'classifier.'}
+    #
     if pretrained:
         model = xnn.utils.load_weights(model, pretrained, change_names_dict=change_names_dict)
     return model, change_names_dict
     if pretrained:
         model = xnn.utils.load_weights(model, pretrained, change_names_dict=change_names_dict)
     return model, change_names_dict
@@ -54,14 +56,16 @@ def resnet50_xp5(model_config=None, pretrained=None):
 
 
 #####################################################################
 
 
 #####################################################################
-def resnet18_x1(model_config=None, pretrained=None):
+def resnet18_x1(model_config=None, pretrained=None, change_names_dict=None):
     model_config = resnet.get_config().merge_from(model_config)
     model = resnet.resnet18_with_model_config(model_config)
     # the pretrained model provided by torchvision and what is defined here differs slightly
     # note: that this change_names_dict  will take effect only if the direct load fails
     model_config = resnet.get_config().merge_from(model_config)
     model = resnet.resnet18_with_model_config(model_config)
     # the pretrained model provided by torchvision and what is defined here differs slightly
     # note: that this change_names_dict  will take effect only if the direct load fails
-    change_names_dict = {'^conv1.': 'features.conv1.', '^bn1.': 'features.bn1.',
-                         '^relu.': 'features.relu.', '^maxpool.': 'features.maxpool.',
-                         '^layer': 'features.layer', '^fc.': 'classifier.'}
+    if change_names_dict is None:
+        change_names_dict = {'^conv1.': 'features.conv1.', '^bn1.': 'features.bn1.',
+                             '^relu.': 'features.relu.', '^maxpool.': 'features.maxpool.',
+                             '^layer': 'features.layer', '^fc.': 'classifier.'}
+    #
     if pretrained:
         model = xnn.utils.load_weights(model, pretrained, change_names_dict=change_names_dict)
     return model, change_names_dict
     if pretrained:
         model = xnn.utils.load_weights(model, pretrained, change_names_dict=change_names_dict)
     return model, change_names_dict
@@ -145,4 +149,24 @@ def flownetslite_base_x1(model_config, pretrained=None):
         model = xnn.utils.load_weights(model, pretrained)
     return model
 
         model = xnn.utils.load_weights(model, pretrained)
     return model
 
-    #####################################################################
+#####################################################################
+# here this is nothing specific about bgr in this model
+# but is just a reminder that regnet models are typically trained with bgr input
+def regnetx800mf_x1(model_config=None, pretrained=None, change_names_dict=None):
+    model_config = resnet.get_config().merge_from(model_config)
+    model = regnet.regnetx800mf_with_model_config(model_config)
+    # the pretrained model provided by torchvision and what is defined here differs slightly
+    # note: that this change_names_dict  will take effect only if the direct load fails
+    if change_names_dict is None:
+        change_names_dict = {'^stem.': 'features.stem.',
+                             '^s1': 'features.s1',
+                             '^s2': 'features.s2',
+                             '^s3': 'features.s3',
+                             '^s4': 'features.s4'}
+    #
+    if pretrained:
+        model = xnn.utils.load_weights(model, pretrained, change_names_dict=change_names_dict,
+                                       state_dict_name=['state_dict','model_state'])
+    return model, change_names_dict
+
+
index b411c0d2509132beddd468222fd7afad7746dbd7..3d9193f12600f6e51245802ffb797290231e5b61 100644 (file)
@@ -3,6 +3,7 @@ import torch
 from ... import xnn
 from .mobilenetv2 import MobileNetV2TV
 from .resnet import resnet50_with_model_config
 from ... import xnn
 from .mobilenetv2 import MobileNetV2TV
 from .resnet import resnet50_with_model_config
+from .regnet import regnetx800mf_with_model_config
 
 try: from .mobilenetv2_ericsun_internal import *
 except: pass
 
 try: from .mobilenetv2_ericsun_internal import *
 except: pass
@@ -239,3 +240,11 @@ class ResNet50MI4(MultiInputNet):
         model_config.fuse_channels = 64
         super().__init__(resnet50_with_model_config, model_config)
 
         model_config.fuse_channels = 64
         super().__init__(resnet50_with_model_config, model_config)
 
+
+###################################################
+# thes are multi input blocks, but their num_input_blocks is set to 0
+class RegNetX800MFMI4(MultiInputNet):
+    def __init__(self, model_config):
+        model_config.num_input_blocks = 2
+        model_config.fuse_channels = 64
+        super().__init__(regnetx800mf_with_model_config, model_config)
\ No newline at end of file
index dda072e2383353af5455e9d90325df435e5cc4fa..88e3ff1891e994b3d63c2fbf6d831da3eae3b487 100644 (file)
@@ -7,13 +7,15 @@ from .pixel2pixelnet import *
 try: from .pixel2pixelnet_internal import *
 except: pass
 
 try: from .pixel2pixelnet_internal import *
 except: pass
 
-from ..multi_input_net import MobileNetV2TVMI4, MobileNetV2EricsunMI4, ResNet50MI4
+from ..multi_input_net import MobileNetV2TVMI4, MobileNetV2EricsunMI4, \
+                              ResNet50MI4, RegNetX800MFMI4
 
 ###########################################
 __all__ = ['DeepLabV3Lite', 'DeepLabV3LiteDecoder',
            'deeplabv3lite_mobilenetv2_tv', 'deeplabv3lite_mobilenetv2_tv_fd',
            'deeplabv3lite_mobilenetv2_ericsun',
 
 ###########################################
 __all__ = ['DeepLabV3Lite', 'DeepLabV3LiteDecoder',
            'deeplabv3lite_mobilenetv2_tv', 'deeplabv3lite_mobilenetv2_tv_fd',
            'deeplabv3lite_mobilenetv2_ericsun',
-           'deeplabv3lite_resnet50', 'deeplabv3lite_resnet50_p5', 'deeplabv3lite_resnet50_p5_fd']
+           'deeplabv3lite_resnet50', 'deeplabv3lite_resnet50_p5', 'deeplabv3lite_resnet50_p5_fd',
+           'deeplabv3lite_regnetx800mf']
 
 
 ###########################################
 
 
 ###########################################
@@ -28,9 +30,11 @@ class DeepLabV3LiteDecoder(torch.nn.Module):
         aspp_channels = round(model_config.aspp_chan*model_config.decoder_factor)
 
         if model_config.use_aspp:
         aspp_channels = round(model_config.aspp_chan*model_config.decoder_factor)
 
         if model_config.use_aspp:
+            group_size_dw = model_config.group_size_dw if hasattr(model_config, 'group_size_dw') else None
             ASPPBlock = xnn.layers.GWASPPLiteBlock if model_config.groupwise_sep else xnn.layers.DWASPPLiteBlock
             self.aspp = ASPPBlock(current_channels, aspp_channels, decoder_channels, dilation=model_config.aspp_dil,
             ASPPBlock = xnn.layers.GWASPPLiteBlock if model_config.groupwise_sep else xnn.layers.DWASPPLiteBlock
             self.aspp = ASPPBlock(current_channels, aspp_channels, decoder_channels, dilation=model_config.aspp_dil,
-                                              activation=model_config.activation, linear_dw=model_config.linear_dw)
+                                  activation=model_config.activation, linear_dw=model_config.linear_dw,
+                                  group_size_dw=group_size_dw)
         else:
             self.aspp = None
 
         else:
             self.aspp = None
 
@@ -90,7 +94,8 @@ class DeepLabV3LiteDecoder(torch.nn.Module):
             if (not self.training) and (self.model_config.output_type == 'segmentation'):
                 x = torch.argmax(x, dim=1, keepdim=True)
 
             if (not self.training) and (self.model_config.output_type == 'segmentation'):
                 x = torch.argmax(x, dim=1, keepdim=True)
 
-            assert int(in_shape[2]) == int(x.shape[2]*self.model_config.target_input_ratio) and int(in_shape[3]) == int(x.shape[3]*self.model_config.target_input_ratio), 'incorrect output shape'
+            assert int(in_shape[2]) == int(x.shape[2]*self.model_config.target_input_ratio) and \
+                   int(in_shape[3]) == int(x.shape[3]*self.model_config.target_input_ratio), 'incorrect output shape'
 
         if self.model_config.freeze_decoder:
             x = x.detach()
 
         if self.model_config.freeze_decoder:
             x = x.detach()
@@ -268,4 +273,58 @@ def deeplabv3lite_resnet50_p5_fd(model_config, pretrained=None):
     model_config.fastdown = True
     model_config.shortcut_channels = (128,1024)
     model_config.shortcut_strides = (8,64)
     model_config.fastdown = True
     model_config.shortcut_channels = (128,1024)
     model_config.shortcut_strides = (8,64)
-    return deeplabv3lite_resnet50(model_config, pretrained=pretrained)
\ No newline at end of file
+    return deeplabv3lite_resnet50(model_config, pretrained=pretrained)
+
+
+###########################################
+# config settings for mobilenetv2 backbone
+def get_config_deeplav3lite_regnetx800mf():
+    # only the delta compared to the one defined for mobilenetv2
+    model_config = get_config_deeplav3lite_mnv2()
+    model_config.shortcut_channels = (64,672)
+    model_config.group_size_dw = 16
+    return model_config
+
+
+# here this is nothing specific about bgr in this model
+# but is just a reminder that regnet models are typically trained with bgr input
+def deeplabv3lite_regnetx800mf(model_config, pretrained=None):
+    model_config = get_config_deeplav3lite_regnetx800mf().merge_from(model_config)
+    # encoder setup
+    model_config_e = model_config.clone()
+    base_model = RegNetX800MFMI4(model_config_e)
+    # decoder setup
+    model = DeepLabV3Lite(base_model, model_config)
+
+    # the pretrained model provided by torchvision and what is defined here differs slightly
+    # note: that this change_names_dict  will take effect only if the direct load fails
+    # finally take care of the change for deeplabv3lite (features->encoder.features)
+    num_inputs = len(model_config.input_channels)
+    num_decoders = len(model_config.output_channels) if (model_config.num_decoders is None) else model_config.num_decoders
+    if num_inputs > 1:
+        change_names_dict = {'^features.': ['encoder.features.stream{}.'.format(stream) for stream in range(num_inputs)],
+                             '^encoder.features.': ['encoder.features.stream{}.'.format(stream) for stream in range(num_inputs)],
+                             '^decoders.0.': ['decoders.{}.'.format(d) for d in range(num_decoders)]}
+    else:
+        change_names_dict = {'^stem.': 'encoder.features.stem.',
+                             '^s1': 'encoder.features.s1', '^s2': 'encoder.features.s2',
+                             '^s3': 'encoder.features.s3', '^s4': 'encoder.features.s4',
+                             '^features.': 'encoder.features.',
+                             '^decoders.0.': ['decoders.{}.'.format(d) for d in range(num_decoders)]}
+    #
+
+    if pretrained:
+        model = xnn.utils.load_weights(model, pretrained, change_names_dict, ignore_size=True, verbose=True,
+                                       state_dict_name=['state_dict','model_state'])
+    else:
+        # need to use state_dict_name as the checkpoint uses a different name for state_dict
+        # provide a custom load_weighs for the model
+        def load_weights_func(pretrained, change_names_dict, ignore_size=True, verbose=True,
+                                       state_dict_name=['state_dict','model_state']):
+            xnn.utils.load_weights(model, pretrained, change_names_dict=change_names_dict, ignore_size=ignore_size, verbose=verbose,
+                                           state_dict_name=state_dict_name)
+        #
+        model.load_weights = load_weights_func
+
+    return model, change_names_dict
+
index 986b46643beeaaf9038a9d85f7b14f1638ca846f..293325262ce25dbdb4754e9052648183d136ed99 100644 (file)
@@ -3,7 +3,7 @@ import numpy as np
 from .... import xnn
 
 from .pixel2pixelnet import *
 from .... import xnn
 
 from .pixel2pixelnet import *
-from ..multi_input_net import MobileNetV2TVMI4, ResNet50MI4
+from ..multi_input_net import MobileNetV2TVMI4, ResNet50MI4, RegNetX800MFMI4
 
 
 __all__ = ['FPNLitePixel2PixelASPP', 'FPNLitePixel2PixelDecoder',
 
 
 __all__ = ['FPNLitePixel2PixelASPP', 'FPNLitePixel2PixelDecoder',
@@ -12,6 +12,7 @@ __all__ = ['FPNLitePixel2PixelASPP', 'FPNLitePixel2PixelDecoder',
            'fpnlite_pixel2pixel_mobilenetv2_tv', 'fpnlite_pixel2pixel_mobilenetv2_tv_fd',
            # resnet models
            'fpnlite_pixel2pixel_aspp_resnet50', 'fpnlite_pixel2pixel_aspp_resnet50_fd',
            'fpnlite_pixel2pixel_mobilenetv2_tv', 'fpnlite_pixel2pixel_mobilenetv2_tv_fd',
            # resnet models
            'fpnlite_pixel2pixel_aspp_resnet50', 'fpnlite_pixel2pixel_aspp_resnet50_fd',
+           'fpnlite_pixel2pixel_aspp_regnetx800mf'
            ]
 
 # config settings for mobilenetv2 backbone
            ]
 
 # config settings for mobilenetv2 backbone
@@ -59,7 +60,9 @@ def get_config_fpnlitep2p_mnv2():
 
 ###########################################
 class FPNLitePyramid(torch.nn.Module):
 
 ###########################################
 class FPNLitePyramid(torch.nn.Module):
-    def __init__(self, current_channels, decoder_channels, shortcut_strides, shortcut_channels, activation, kernel_size_smooth, interpolation_type, interpolation_mode, inloop_fpn=False, all_outputs=False):
+    def __init__(self, current_channels, decoder_channels, shortcut_strides, shortcut_channels, activation,
+                 kernel_size_smooth, interpolation_type, interpolation_mode, inloop_fpn=False, all_outputs=False,
+                 group_size_dw=None):
         super().__init__()
         self.inloop_fpn = inloop_fpn
         self.shortcut_strides = shortcut_strides
         super().__init__()
         self.inloop_fpn = inloop_fpn
         self.shortcut_strides = shortcut_strides
@@ -68,10 +71,11 @@ class FPNLitePyramid(torch.nn.Module):
         self.shortcuts = torch.nn.ModuleList()
         self.upsamples = torch.nn.ModuleList()
 
         self.shortcuts = torch.nn.ModuleList()
         self.upsamples = torch.nn.ModuleList()
 
-        shortcut0 = self.create_shortcut(current_channels, decoder_channels, activation) if (current_channels != decoder_channels) else None
+        shortcut0 = self.create_shortcut(current_channels, decoder_channels, activation) \
+            if (current_channels != decoder_channels) else None
         self.shortcuts.append(shortcut0)
 
         self.shortcuts.append(shortcut0)
 
-        smooth_conv0 = None #xnn.layers.ConvDWSepNormAct2d(decoder_channels, decoder_channels, kernel_size=kernel_size_smooth, activation=(activation, activation)) if all_outputs else None
+        smooth_conv0 = None
         self.smooth_convs.append(smooth_conv0)
 
         upstride = 2
         self.smooth_convs.append(smooth_conv0)
 
         upstride = 2
@@ -79,10 +83,16 @@ class FPNLitePyramid(torch.nn.Module):
             shortcut = self.create_shortcut(feat_chan, decoder_channels, activation)
             self.shortcuts.append(shortcut)
             is_last = (idx == len(shortcut_channels)-1)
             shortcut = self.create_shortcut(feat_chan, decoder_channels, activation)
             self.shortcuts.append(shortcut)
             is_last = (idx == len(shortcut_channels)-1)
-            smooth_conv = xnn.layers.ConvDWSepNormAct2d(decoder_channels, decoder_channels, kernel_size=kernel_size_smooth, activation=(activation,activation)) \
-                        if (inloop_fpn or all_outputs or is_last) else None
+            if (inloop_fpn or all_outputs or is_last):
+                smooth_conv = xnn.layers.ConvDWSepNormAct2d(decoder_channels, decoder_channels,
+                                    kernel_size=kernel_size_smooth, activation=(activation,activation),
+                                    group_size_dw=group_size_dw)
+            else:
+                smooth_conv = None
+            #
             self.smooth_convs.append(smooth_conv)
             self.smooth_convs.append(smooth_conv)
-            upsample = xnn.layers.UpsampleWith(decoder_channels, decoder_channels, upstride, interpolation_type, interpolation_mode)
+            upsample = xnn.layers.UpsampleWith(decoder_channels, decoder_channels, upstride, interpolation_type,
+                                               interpolation_mode)
             self.upsamples.append(upsample)
         #
     #
             self.upsamples.append(upsample)
         #
     #
@@ -102,7 +112,8 @@ class FPNLitePyramid(torch.nn.Module):
         x = y if self.inloop_fpn else x
         outputs.append(y)
 
         x = y if self.inloop_fpn else x
         outputs.append(y)
 
-        for idx, (shortcut, smooth_conv, s_stride, short_chan, upsample) in enumerate(zip(self.shortcuts[1:], self.smooth_convs[1:], self.shortcut_strides, self.shortcut_channels, self.upsamples)):
+        for idx, (shortcut, smooth_conv, s_stride, short_chan, upsample) in \
+                enumerate(zip(self.shortcuts[1:], self.smooth_convs[1:], self.shortcut_strides, self.shortcut_channels, self.upsamples)):
             # get the feature of lower stride
             shape_s = xnn.utils.get_shape_with_stride(in_shape, s_stride)
             shape_s[1] = short_chan
             # get the feature of lower stride
             shape_s = xnn.utils.get_shape_with_stride(in_shape, s_stride)
             shape_s[1] = short_chan
@@ -122,8 +133,12 @@ class FPNLitePyramid(torch.nn.Module):
 
 
 class InLoopFPNLitePyramid(FPNLitePyramid):
 
 
 class InLoopFPNLitePyramid(FPNLitePyramid):
-    def __init__(self, input_channels, decoder_channels, shortcut_strides, shortcut_channels, activation, kernel_size_smooth, interpolation_type, interpolation_mode, inloop_fpn=True, all_outputs=False):
-        super().__init__(input_channels, decoder_channels, shortcut_strides, shortcut_channels, activation, kernel_size_smooth, interpolation_type, interpolation_mode, inloop_fpn=inloop_fpn, all_outputs=all_outputs)
+    def __init__(self, input_channels, decoder_channels, shortcut_strides, shortcut_channels, activation,
+                 kernel_size_smooth, interpolation_type, interpolation_mode, inloop_fpn=True,
+                 all_outputs=False, group_size_dw=None):
+        super().__init__(input_channels, decoder_channels, shortcut_strides, shortcut_channels, activation,
+                 kernel_size_smooth, interpolation_type, interpolation_mode, inloop_fpn=inloop_fpn,
+                 all_outputs=all_outputs, group_size_dw=group_size_dw)
 
 
 ###########################################
 
 
 ###########################################
@@ -134,18 +149,24 @@ class FPNLitePixel2PixelDecoder(torch.nn.Module):
         activation = self.model_config.activation
         self.output_type = model_config.output_type
         self.decoder_channels = decoder_channels = round(self.model_config.decoder_chan*self.model_config.decoder_factor)
         activation = self.model_config.activation
         self.output_type = model_config.output_type
         self.decoder_channels = decoder_channels = round(self.model_config.decoder_chan*self.model_config.decoder_factor)
+        group_size_dw = model_config.group_size_dw if hasattr(model_config, 'group_size_dw') else None
 
         self.rfblock = None
         if self.model_config.use_aspp:
             current_channels = self.model_config.shortcut_channels[-1]
             aspp_channels = round(self.model_config.aspp_chan * self.model_config.decoder_factor)
 
         self.rfblock = None
         if self.model_config.use_aspp:
             current_channels = self.model_config.shortcut_channels[-1]
             aspp_channels = round(self.model_config.aspp_chan * self.model_config.decoder_factor)
-            self.rfblock = xnn.layers.DWASPPLiteBlock(current_channels, aspp_channels, decoder_channels, dilation=self.model_config.aspp_dil, avg_pool=False, activation=activation)
+            self.rfblock = xnn.layers.DWASPPLiteBlock(current_channels, aspp_channels, decoder_channels,
+                                dilation=self.model_config.aspp_dil, avg_pool=False, activation=activation,
+                                group_size_dw=group_size_dw)
             current_channels = decoder_channels
         elif self.model_config.use_extra_strides:
             # a low complexity pyramid
             current_channels = self.model_config.shortcut_channels[-3]
             current_channels = decoder_channels
         elif self.model_config.use_extra_strides:
             # a low complexity pyramid
             current_channels = self.model_config.shortcut_channels[-3]
-            self.rfblock = torch.nn.Sequential(xnn.layers.ConvDWSepNormAct2d(current_channels, current_channels, kernel_size=3, stride=2, activation=(activation, activation)),
-                                               xnn.layers.ConvDWSepNormAct2d(current_channels, decoder_channels, kernel_size=3, stride=2, activation=(activation, activation)))
+            self.rfblock = torch.nn.Sequential(
+                xnn.layers.ConvDWSepNormAct2d(current_channels, current_channels, kernel_size=3,
+                        stride=2, activation=(activation, activation), group_size_dw=group_size_dw),
+                xnn.layers.ConvDWSepNormAct2d(current_channels, decoder_channels, kernel_size=3,
+                        stride=2, activation=(activation, activation), group_size_dw=group_size_dw))
             current_channels = decoder_channels
         else:
             current_channels = self.model_config.shortcut_channels[-1]
             current_channels = decoder_channels
         else:
             current_channels = self.model_config.shortcut_channels[-1]
@@ -156,8 +177,10 @@ class FPNLitePixel2PixelDecoder(torch.nn.Module):
         shortcut_strides = self.model_config.shortcut_strides[::-1][1:]
         shortcut_channels = self.model_config.shortcut_channels[::-1][1:]
         FPNType = InLoopFPNLitePyramid if model_config.inloop_fpn else FPNLitePyramid
         shortcut_strides = self.model_config.shortcut_strides[::-1][1:]
         shortcut_channels = self.model_config.shortcut_channels[::-1][1:]
         FPNType = InLoopFPNLitePyramid if model_config.inloop_fpn else FPNLitePyramid
-        self.fpn = FPNType(current_channels, decoder_channels, shortcut_strides, shortcut_channels, self.model_config.activation, self.model_config.kernel_size_smooth,
-                           self.model_config.interpolation_type, self.model_config.interpolation_mode)
+        self.fpn = FPNType(current_channels, decoder_channels, shortcut_strides, shortcut_channels,
+                        self.model_config.activation, self.model_config.kernel_size_smooth,
+                        self.model_config.interpolation_type, self.model_config.interpolation_mode,
+                        group_size_dw=group_size_dw)
 
         # add prediction & upsample modules
         if self.model_config.final_prediction:
 
         # add prediction & upsample modules
         if self.model_config.final_prediction:
@@ -337,4 +360,56 @@ def fpnlite_pixel2pixel_aspp_resnet50_fd(model_config, pretrained=None):
     model_config.shortcut_channels = (256,512,1024,2048) #(64,256,512,1024,2048)
     model_config.decoder_chan = 256 #128
     model_config.aspp_chan = 256 #128
     model_config.shortcut_channels = (256,512,1024,2048) #(64,256,512,1024,2048)
     model_config.decoder_chan = 256 #128
     model_config.aspp_chan = 256 #128
-    return fpnlite_pixel2pixel_aspp_resnet50(model_config, pretrained=pretrained)
\ No newline at end of file
+    return fpnlite_pixel2pixel_aspp_resnet50(model_config, pretrained=pretrained)
+
+
+###########################################
+# config settings for mobilenetv2 backbone
+def get_config_fpnlite_regnetx800mf():
+    # only the delta compared to the one defined for mobilenetv2
+    model_config = get_config_fpnlitep2p_mnv2()
+    model_config.shortcut_channels = (64,128,288,672)
+    return model_config
+
+
+# here this is nothing specific about bgr in this model
+# but is just a reminder that regnet models are typically trained with bgr input
+def fpnlite_pixel2pixel_aspp_regnetx800mf(model_config, pretrained=None):
+    model_config = get_config_fpnlite_regnetx800mf().merge_from(model_config)
+    # encoder setup
+    model_config_e = model_config.clone()
+    base_model = RegNetX800MFMI4(model_config_e)
+    # decoder setup
+    model = FPNLitePixel2PixelASPP(base_model, model_config)
+
+    # the pretrained model provided by torchvision and what is defined here differs slightly
+    # note: that this change_names_dict  will take effect only if the direct load fails
+    # finally take care of the change for deeplabv3lite (features->encoder.features)
+    num_inputs = len(model_config.input_channels)
+    num_decoders = len(model_config.output_channels) if (model_config.num_decoders is None) else model_config.num_decoders
+    if num_inputs > 1:
+        change_names_dict = {'^features.': ['encoder.features.stream{}.'.format(stream) for stream in range(num_inputs)],
+                             '^encoder.features.': ['encoder.features.stream{}.'.format(stream) for stream in range(num_inputs)],
+                             '^decoders.0.': ['decoders.{}.'.format(d) for d in range(num_decoders)]}
+    else:
+        change_names_dict = {'^stem.': 'encoder.features.stem.',
+                             '^s1': 'encoder.features.s1', '^s2': 'encoder.features.s2',
+                             '^s3': 'encoder.features.s3', '^s4': 'encoder.features.s4',
+                             '^features.': 'encoder.features.',
+                             '^decoders.0.': ['decoders.{}.'.format(d) for d in range(num_decoders)]}
+    #
+
+    if pretrained:
+        model = xnn.utils.load_weights(model, pretrained, change_names_dict, ignore_size=True, verbose=True,
+                                       state_dict_name=['state_dict','model_state'])
+    else:
+        # need to use state_dict_name as the checkpoint uses a different name for state_dict
+        # provide a custom load_weighs for the model
+        def load_weights_func(pretrained, change_names_dict, ignore_size=True, verbose=True,
+                                       state_dict_name=['state_dict','model_state']):
+            xnn.utils.load_weights(model, pretrained, change_names_dict=change_names_dict, ignore_size=ignore_size, verbose=verbose,
+                                           state_dict_name=state_dict_name)
+        #
+        model.load_weights = load_weights_func
+
+    return model, change_names_dict
\ No newline at end of file
index dce04718d023ccbe968ee2ea06f48292a4d1af29..c1af102d056abc6b19e5eca4c811b40367a0c2ba 100644 (file)
@@ -28,9 +28,10 @@ def add_lite_prediction_modules(self, model_config, current_channels, module_nam
         else:
             # prediction followed by conventional interpolation
             ConvXWSepBlock = xnn.layers.ConvGWSepNormAct2d if model_config.groupwise_sep else xnn.layers.ConvDWSepNormAct2d
         else:
             # prediction followed by conventional interpolation
             ConvXWSepBlock = xnn.layers.ConvGWSepNormAct2d if model_config.groupwise_sep else xnn.layers.ConvDWSepNormAct2d
+            group_size_dw = model_config.group_size_dw if hasattr(model_config, 'group_size_dw') else None
             pred = ConvXWSepBlock(current_channels, model_config.output_channels, kernel_size=3,
                                        normalization=((not model_config.linear_dw),False),
             pred = ConvXWSepBlock(current_channels, model_config.output_channels, kernel_size=3,
                                        normalization=((not model_config.linear_dw),False),
-                                       activation=(False,final_activation), groups=1)
+                                       activation=(False,final_activation), groups=1, group_size_dw=group_size_dw)
             setattr(self, module_names[0], pred)
 
             if self.model_config.final_upsample:
             setattr(self, module_names[0], pred)
 
             if self.model_config.final_upsample:
index 0d6c5b6da449d782fa00e7e724f2e939999ae099..a936c18e8d3501248bda1987827090805f38da54 100644 (file)
@@ -3,12 +3,13 @@ import numpy as np
 from .... import xnn
 
 from .pixel2pixelnet import *
 from .... import xnn
 
 from .pixel2pixelnet import *
-from ..multi_input_net import MobileNetV2TVMI4, ResNet50MI4
+from ..multi_input_net import MobileNetV2TVMI4, ResNet50MI4, RegNetX800MFMI4
 
 
 __all__ = ['UNetLitePixel2PixelASPP', 'UNetLitePixel2PixelDecoder',
            'unetlite_pixel2pixel_aspp_mobilenetv2_tv', 'unetlite_pixel2pixel_aspp_mobilenetv2_tv_fd',
            'unetlite_pixel2pixel_aspp_resnet50', 'unetlite_pixel2pixel_aspp_resnet50_fd',
 
 
 __all__ = ['UNetLitePixel2PixelASPP', 'UNetLitePixel2PixelDecoder',
            'unetlite_pixel2pixel_aspp_mobilenetv2_tv', 'unetlite_pixel2pixel_aspp_mobilenetv2_tv_fd',
            'unetlite_pixel2pixel_aspp_resnet50', 'unetlite_pixel2pixel_aspp_resnet50_fd',
+           'unetlite_pixel2pixel_aspp_regnetx800mf'
            ]
 
 # config settings for mobilenetv2 backbone
            ]
 
 # config settings for mobilenetv2 backbone
@@ -54,7 +55,8 @@ def get_config_unetlitep2p_mnv2():
 
 ###########################################
 class UNetLitePyramid(torch.nn.Module):
 
 ###########################################
 class UNetLitePyramid(torch.nn.Module):
-    def __init__(self, current_channels, minimum_channels, shortcut_strides, shortcut_channels, activation, kernel_size_smooth, interpolation_type, interpolation_mode):
+    def __init__(self, current_channels, minimum_channels, shortcut_strides, shortcut_channels, activation,
+                 kernel_size_smooth, interpolation_type, interpolation_mode, group_size_dw=None):
         super().__init__()
         self.shortcut_strides = shortcut_strides
         self.shortcut_channels = shortcut_channels
         super().__init__()
         self.shortcut_strides = shortcut_strides
         self.shortcut_channels = shortcut_channels
@@ -68,10 +70,12 @@ class UNetLitePyramid(torch.nn.Module):
         upstride = 2
         activation2 = (activation, activation)
         for idx, (s_stride, feat_chan) in enumerate(zip(shortcut_strides, shortcut_channels)):
         upstride = 2
         activation2 = (activation, activation)
         for idx, (s_stride, feat_chan) in enumerate(zip(shortcut_strides, shortcut_channels)):
-            self.upsamples.append(xnn.layers.UpsampleWith(current_channels, current_channels, upstride, interpolation_type, interpolation_mode))
+            self.upsamples.append(xnn.layers.UpsampleWith(current_channels, current_channels, upstride,
+                                                          interpolation_type, interpolation_mode))
             self.concats.append(xnn.layers.CatBlock())
             smooth_channels = max(minimum_channels, feat_chan)
             self.concats.append(xnn.layers.CatBlock())
             smooth_channels = max(minimum_channels, feat_chan)
-            self.smooth_convs.append( xnn.layers.ConvDWSepNormAct2d(current_channels+feat_chan, smooth_channels, kernel_size=kernel_size_smooth, activation=activation2))
+            self.smooth_convs.append(xnn.layers.ConvDWSepNormAct2d(current_channels+feat_chan, smooth_channels,
+                                kernel_size=kernel_size_smooth, activation=activation2, group_size_dw=group_size_dw))
             current_channels = smooth_channels
         #
     #
             current_channels = smooth_channels
         #
     #
@@ -111,18 +115,24 @@ class UNetLitePixel2PixelDecoder(torch.nn.Module):
         activation = self.model_config.activation
         self.output_type = model_config.output_type
         self.decoder_channels = decoder_channels = round(self.model_config.decoder_chan*self.model_config.decoder_factor)
         activation = self.model_config.activation
         self.output_type = model_config.output_type
         self.decoder_channels = decoder_channels = round(self.model_config.decoder_chan*self.model_config.decoder_factor)
+        group_size_dw = model_config.group_size_dw if hasattr(model_config, 'group_size_dw') else None
 
         self.rfblock = None
         if self.model_config.use_aspp:
             current_channels = self.model_config.shortcut_channels[-1]
             aspp_channels = round(self.model_config.aspp_chan * self.model_config.decoder_factor)
 
         self.rfblock = None
         if self.model_config.use_aspp:
             current_channels = self.model_config.shortcut_channels[-1]
             aspp_channels = round(self.model_config.aspp_chan * self.model_config.decoder_factor)
-            self.rfblock = xnn.layers.DWASPPLiteBlock(current_channels, aspp_channels, decoder_channels, dilation=self.model_config.aspp_dil, avg_pool=False, activation=activation)
+            self.rfblock = xnn.layers.DWASPPLiteBlock(current_channels, aspp_channels, decoder_channels,
+                                    dilation=self.model_config.aspp_dil, avg_pool=False, activation=activation,
+                                    group_size_dw=group_size_dw)
             current_channels = decoder_channels
         elif self.model_config.use_extra_strides:
             # a low complexity pyramid
             current_channels = self.model_config.shortcut_channels[-3]
             current_channels = decoder_channels
         elif self.model_config.use_extra_strides:
             # a low complexity pyramid
             current_channels = self.model_config.shortcut_channels[-3]
-            self.rfblock = torch.nn.Sequential(xnn.layers.ConvDWSepNormAct2d(current_channels, current_channels, kernel_size=3, stride=2, activation=(activation, activation)),
-                                               xnn.layers.ConvDWSepNormAct2d(current_channels, decoder_channels, kernel_size=3, stride=2, activation=(activation, activation)))
+            self.rfblock = torch.nn.Sequential(
+                xnn.layers.ConvDWSepNormAct2d(current_channels, current_channels, kernel_size=3, stride=2,
+                                              activation=(activation, activation), group_size_dw=group_size_dw),
+                xnn.layers.ConvDWSepNormAct2d(current_channels, decoder_channels, kernel_size=3, stride=2,
+                                              activation=(activation, activation), group_size_dw=group_size_dw))
             current_channels = decoder_channels
         else:
             current_channels = self.model_config.shortcut_channels[-1]
             current_channels = decoder_channels
         else:
             current_channels = self.model_config.shortcut_channels[-1]
@@ -133,8 +143,10 @@ class UNetLitePixel2PixelDecoder(torch.nn.Module):
         minimum_channels = max(self.model_config.output_channels*2, 32)
         shortcut_strides = self.model_config.shortcut_strides[::-1][1:]
         shortcut_channels = self.model_config.shortcut_channels[::-1][1:]
         minimum_channels = max(self.model_config.output_channels*2, 32)
         shortcut_strides = self.model_config.shortcut_strides[::-1][1:]
         shortcut_channels = self.model_config.shortcut_channels[::-1][1:]
-        self.unet = UNetLitePyramid(current_channels, minimum_channels, shortcut_strides, shortcut_channels, self.model_config.activation, self.model_config.kernel_size_smooth,
-                           self.model_config.interpolation_type, self.model_config.interpolation_mode)
+        self.unet = UNetLitePyramid(current_channels, minimum_channels, shortcut_strides, shortcut_channels,
+                           self.model_config.activation, self.model_config.kernel_size_smooth,
+                           self.model_config.interpolation_type, self.model_config.interpolation_mode,
+                           group_size_dw=group_size_dw)
         current_channels = max(minimum_channels, shortcut_channels[-1])
 
         # add prediction & upsample modules
         current_channels = max(minimum_channels, shortcut_channels[-1])
 
         # add prediction & upsample modules
@@ -173,7 +185,8 @@ class UNetLitePixel2PixelDecoder(torch.nn.Module):
             if (not self.training) and (self.output_type == 'segmentation'):
                 x = torch.argmax(x, dim=1, keepdim=True)
 
             if (not self.training) and (self.output_type == 'segmentation'):
                 x = torch.argmax(x, dim=1, keepdim=True)
 
-            assert int(in_shape[2]) == int(x.shape[2]*self.model_config.target_input_ratio) and int(in_shape[3]) == int(x.shape[3]*self.model_config.target_input_ratio), 'incorrect output shape'
+            assert int(in_shape[2]) == int(x.shape[2]*self.model_config.target_input_ratio) and \
+                   int(in_shape[3]) == int(x.shape[3]*self.model_config.target_input_ratio), 'incorrect output shape'
 
         return x
 
 
         return x
 
@@ -279,4 +292,57 @@ def unetlite_pixel2pixel_aspp_resnet50_fd(model_config, pretrained=None):
     model_config.shortcut_channels = (64,64,256,512,1024,2048) #(64,256,512,1024,2048)
     model_config.decoder_chan = 256 #128
     model_config.aspp_chan = 256 #128
     model_config.shortcut_channels = (64,64,256,512,1024,2048) #(64,256,512,1024,2048)
     model_config.decoder_chan = 256 #128
     model_config.aspp_chan = 256 #128
-    return unetlite_pixel2pixel_aspp_resnet50(model_config, pretrained=pretrained)
\ No newline at end of file
+    return unetlite_pixel2pixel_aspp_resnet50(model_config, pretrained=pretrained)
+
+
+###########################################
+# config settings for mobilenetv2 backbone
+def get_config_unetlite_regnetx800mf():
+    # only the delta compared to the one defined for mobilenetv2
+    model_config = get_config_unetlitep2p_mnv2()
+    model_config.shortcut_strides = (2,4,8,16,32)
+    model_config.shortcut_channels = (32,64,128,288,672)
+    return model_config
+
+
+# here this is nothing specific about bgr in this model
+# but is just a reminder that regnet models are typically trained with bgr input
+def unetlite_pixel2pixel_aspp_regnetx800mf(model_config, pretrained=None):
+    model_config = get_config_unetlite_regnetx800mf().merge_from(model_config)
+    # encoder setup
+    model_config_e = model_config.clone()
+    base_model = RegNetX800MFMI4(model_config_e)
+    # decoder setup
+    model = UNetLitePixel2PixelASPP(base_model, model_config)
+
+    # the pretrained model provided by torchvision and what is defined here differs slightly
+    # note: that this change_names_dict  will take effect only if the direct load fails
+    # finally take care of the change for deeplabv3lite (features->encoder.features)
+    num_inputs = len(model_config.input_channels)
+    num_decoders = len(model_config.output_channels) if (model_config.num_decoders is None) else model_config.num_decoders
+    if num_inputs > 1:
+        change_names_dict = {'^features.': ['encoder.features.stream{}.'.format(stream) for stream in range(num_inputs)],
+                             '^encoder.features.': ['encoder.features.stream{}.'.format(stream) for stream in range(num_inputs)],
+                             '^decoders.0.': ['decoders.{}.'.format(d) for d in range(num_decoders)]}
+    else:
+        change_names_dict = {'^stem.': 'encoder.features.stem.',
+                             '^s1': 'encoder.features.s1', '^s2': 'encoder.features.s2',
+                             '^s3': 'encoder.features.s3', '^s4': 'encoder.features.s4',
+                             '^features.': 'encoder.features.',
+                             '^decoders.0.': ['decoders.{}.'.format(d) for d in range(num_decoders)]}
+    #
+
+    if pretrained:
+        model = xnn.utils.load_weights(model, pretrained, change_names_dict, ignore_size=True, verbose=True,
+                                       state_dict_name=['state_dict','model_state'])
+    else:
+        # need to use state_dict_name as the checkpoint uses a different name for state_dict
+        # provide a custom load_weighs for the model
+        def load_weights_func(pretrained, change_names_dict, ignore_size=True, verbose=True,
+                                       state_dict_name=['state_dict','model_state']):
+            xnn.utils.load_weights(model, pretrained, change_names_dict=change_names_dict, ignore_size=ignore_size,
+                                       verbose=verbose, state_dict_name=state_dict_name)
+        #
+        model.load_weights = load_weights_func
+    #
+    return model, change_names_dict
\ No newline at end of file
diff --git a/modules/pytorch_jacinto_ai/xvision/models/regnet.py b/modules/pytorch_jacinto_ai/xvision/models/regnet.py
new file mode 100644 (file)
index 0000000..814a146
--- /dev/null
@@ -0,0 +1,257 @@
+'''
+An independent implementation of RegNet:
+Designing Network Design Spaces, Ilija Radosavovic Raj Prateek Kosaraju Ross Girshick Kaiming He Piotr DollarĀ“,
+Facebook AI Research (FAIR),
+https://arxiv.org/pdf/2003.13678.pdf, https://github.com/facebookresearch/pycls
+This implementation re-uses functions and classes from resnet.py
+'''
+
+
+import torch
+import torch.nn as nn
+import collections
+from .utils import load_state_dict_from_url
+from ... import xnn
+from .resnet import conv1x1, conv3x3
+
+
+__all__ = ['RegNet', 'regnetx800mf', 'regnetx800mf_with_model_config']
+
+
+model_urls = {
+    'regnetx800mf': 'https://dl.fbaipublicfiles.com/pycls/dds_baselines/160906036/RegNetX-800MF_dds_8gpu.pyth',
+}
+
+
+class RegNetBottleneck(nn.Module):
+    expansion = 1
+
+    def __init__(self, inplanes, planes, stride=1, downsample=None, group_width=None,
+                 dilation=1, norm_layer=None):
+        super(RegNetBottleneck, self).__init__()
+        if norm_layer is None:
+            norm_layer = nn.BatchNorm2d
+        #
+        groups = int(planes//group_width) if (group_width is not None) else 1
+        width = int(planes//groups) * groups
+        # Both self.conv2 and self.downsample layers downsample the input when stride != 1
+        a = conv1x1(inplanes, width)
+        a_bn = norm_layer(width)
+        a_relu = nn.ReLU(inplace=True)
+        b = conv3x3(width, width, stride, groups, dilation)
+        b_bn = norm_layer(width)
+        b_relu = nn.ReLU(inplace=True)
+        c = conv1x1(width, planes * self.expansion)
+        c_bn = norm_layer(planes * self.expansion)
+        self.f = torch.nn.Sequential(
+            collections.OrderedDict([('a',a),('a_bn',a_bn),('a_relu',a_relu),
+                 ('b',b),('b_bn',b_bn),('b_relu',b_relu),
+                 ('c',c),('c_bn',c_bn)]))
+
+        if downsample is not None:
+            self.proj = downsample[0]
+            self.bn = downsample[1]
+            self.do_downsample = True
+        else:
+            self.do_downsample = False
+        #
+        self.add = xnn.layers.AddBlock()
+        self.relu = nn.ReLU(inplace=True)
+        self.stride = stride
+
+    def forward(self, x):
+        identity = x
+
+        out = self.f(x)
+        if self.do_downsample:
+            identity = self.bn(self.proj(x))
+
+        out = self.add((out,identity))
+        out = self.relu(out)
+
+        return out
+
+
+class RegNet(nn.Module):
+
+    def __init__(self, block, layers, num_classes=1000, zero_init_residual=False,
+                 channels=(64,64,128,256,512), group_width=None, replace_stride_with_dilation=None,
+                 norm_layer=None, input_channels=3, strides=None,
+                 width_mult=1.0, fastdown=False):
+        super(RegNet, self).__init__()
+        if norm_layer is None:
+            norm_layer = nn.BatchNorm2d
+        #
+        self.group_width = group_width
+        self._norm_layer = norm_layer
+        self.num_classes = num_classes
+        self.simple_stem = True
+
+        self.inplanes = int(channels[0]*width_mult)
+        self.dilation = 1
+        if replace_stride_with_dilation is None:
+            # each element in the tuple indicates if we should replace
+            # the 2x2 stride with a dilated convolution instead
+            replace_stride_with_dilation = [False, False, False]
+        if len(replace_stride_with_dilation) != 3:
+            raise ValueError(f"replace_stride_with_dilation should be None or a 3-element tuple, got {replace_stride_with_dilation}")
+
+        # strides of various layers
+        strides = strides if (strides is not None) else (2,2,2,2,2)
+        s0 = strides[0]
+        s1 = strides[1]
+        sf = 2 if (fastdown or self.simple_stem) else 1 # additional stride if fast down is true
+        s2 = strides[2]
+        s3 = strides[3]
+        s4 = strides[4]
+
+        if self.simple_stem:
+            conv1 = nn.Conv2d(input_channels, self.inplanes, kernel_size=3, stride=s0, padding=1, bias=False)
+            bn1 = norm_layer(self.inplanes)
+            relu = nn.ReLU(inplace=True)
+            stem = [('conv',conv1), ('bn',bn1), ('relu1',relu)]
+            if fastdown:
+                maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
+                stem += [('maxpool', maxpool)]
+            #
+        else:
+            conv1 = nn.Conv2d(input_channels, self.inplanes, kernel_size=7, stride=s0, padding=3, bias=False)
+            bn1 = norm_layer(self.inplanes)
+            relu = nn.ReLU(inplace=True)
+            maxpool = nn.MaxPool2d(kernel_size=3, stride=s1, padding=1)
+            stem = [('conv', conv1), ('bn', bn1), ('relu1', relu), ('maxpool', maxpool)]
+        #
+        stem = torch.nn.Sequential(collections.OrderedDict(stem))
+        features = [('stem',stem)]
+
+        layer1 = self._make_layer(block, int(channels[1]*width_mult), layers[0], stride=sf)
+        layer2 = self._make_layer(block, int(channels[2]*width_mult), layers[1], stride=s2,
+                                       dilate=replace_stride_with_dilation[0])
+        layer3 = self._make_layer(block, int(channels[3]*width_mult), layers[2], stride=s3,
+                                       dilate=replace_stride_with_dilation[1])
+        layer4 = self._make_layer(block, int(channels[4]*width_mult), layers[3], stride=s4,
+                                       dilate=replace_stride_with_dilation[2])
+
+        features.append(('s1',layer1))
+        features.append(('s2',layer2))
+        features.append(('s3',layer3))
+        features.append(('s4',layer4))
+        self.features = torch.nn.Sequential(collections.OrderedDict(features))
+
+        if self.num_classes:
+            avgpool = nn.AdaptiveAvgPool2d((1, 1))
+            flatten = torch.nn.Flatten(start_dim=1)
+            fc = nn.Linear(int(channels[4]*width_mult) * block.expansion, num_classes)
+            self.head = torch.nn.Sequential(collections.OrderedDict(
+                [('avgpool',avgpool),('flatten',flatten),('fc',fc)]))
+
+        xnn.utils.module_weights_init(self)
+
+        # Zero-initialize the last BN in each residual branch,
+        # so that the residual branch starts with zeros, and each residual block behaves like an identity.
+        # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677
+        if zero_init_residual:
+            for m in self.modules():
+                if isinstance(m, RegNetBottleneck):
+                    nn.init.constant_(m.bn3.weight, 0)
+                elif isinstance(m, RegNetBottleneck):
+                    nn.init.constant_(m.bn2.weight, 0)
+
+    def _make_layer(self, block, planes, blocks, stride=1, dilate=False):
+        norm_layer = self._norm_layer
+        downsample = None
+        previous_dilation = self.dilation
+        if dilate:
+            self.dilation *= stride
+            stride = 1
+        if stride != 1 or self.inplanes != planes * block.expansion:
+            downsample = (
+                conv1x1(self.inplanes, planes * block.expansion, stride),
+                norm_layer(planes * block.expansion),
+            )
+
+        layers = []
+        bx = block(self.inplanes, planes, stride, downsample, group_width=self.group_width,
+                            dilation=previous_dilation, norm_layer=norm_layer)
+        layers.append(('b1', bx))
+        self.inplanes = planes * block.expansion
+        for i in range(1, blocks):
+            bx = block(self.inplanes, planes, group_width=self.group_width,
+                                dilation=self.dilation, norm_layer=norm_layer)
+            layers.append((f'b{i+1}', bx))
+        #
+        layers = torch.nn.Sequential(collections.OrderedDict(layers))
+        return layers
+
+    def forward(self, x):
+        x = self.features(x)
+        if self.num_classes:
+            x = self.head(x)
+
+        return x
+
+
+    # define a load weights fuinction in the module since the module is changed w.r.t. to torchvision
+    # since we want to be able to laod the existing torchvision pretrained weights
+    def load_weights(self, pretrained, change_names_dict=None, download_root=None):
+        if change_names_dict is None:
+            # the pretrained model provided by pycls and what is defined here differs slightly
+            # note: that this change_names_dict  will take effect only if the direct load fails
+            change_names_dict = {'^stem.': 'features.stem.',
+                                 '^s1': 'features.s1',
+                                 '^s2': 'features.s2',
+                                 '^s3': 'features.s3',
+                                 '^s4': 'features.s4'}
+        #
+        if pretrained is not None:
+            xnn.utils.load_weights(self, pretrained, change_names_dict=change_names_dict,
+                                   download_root=download_root, state_dict_name=['state_dict','model_state'])
+        return self, change_names_dict
+
+
+def _regnet(arch, block, layers, pretrained, progress, **kwargs):
+    model = RegNet(block, layers, **kwargs)
+    if pretrained is True:
+        change_names_dict = kwargs.get('change_names_dict', None)
+        state_dict = load_state_dict_from_url(model_urls[arch],
+                                              progress=progress)
+        model.load_weights(state_dict, change_names_dict=change_names_dict)
+    elif pretrained:
+        change_names_dict = kwargs.get('change_names_dict', None)
+        download_root = kwargs.get('download_root', None)
+        model.load_weights(pretrained, change_names_dict=change_names_dict, download_root=download_root)
+    return model
+
+
+def regnetx800mf(pretrained=False, progress=True, **kwargs):
+    r"""ResNeXt-50 32x4d model from
+    `"Aggregated Residual Transformation for Deep Neural Networks" <https://arxiv.org/pdf/1611.05431.pdf>`_
+
+    Args:
+        pretrained (bool): If True, returns a model pre-trained on ImageNet
+        progress (bool): If True, displays a progress bar of the download to stderr
+    """
+    kwargs['channels'] = (32,64,128,288,672)
+    kwargs['group_width'] = 16
+    return _regnet('regnetx800mf', RegNetBottleneck, [1, 3, 7, 5],
+                   pretrained, progress, **kwargs)
+
+
+###################################################
+def get_config():
+    model_config = xnn.utils.ConfigNode()
+    model_config.input_channels = 3
+    model_config.num_classes = 1000
+    model_config.width_mult = 1.0
+    model_config.strides = None
+    model_config.fastdown = False
+    return model_config
+
+
+def regnetx800mf_with_model_config(model_config, pretrained=None):
+    model_config = get_config().merge_from(model_config)
+    model = regnetx800mf(input_channels=model_config.input_channels, strides=model_config.strides,
+                     num_classes=model_config.num_classes, pretrained=pretrained,
+                     width_mult=model_config.width_mult, fastdown=model_config.fastdown)
+    return model
+
index c4a60c6a8a6748352b3256bbd3443bffafafae47..98d9e5040f15d42f5d104330e00316cb1370f1df 100644 (file)
@@ -3,6 +3,7 @@ import numpy as np
 import cv2
 import torch
 import types
 import cv2
 import torch
 import types
+import PIL
 
 class Compose(object):
     """ Composes several co_transforms together.
 
 class Compose(object):
     """ Composes several co_transforms together.
@@ -144,6 +145,15 @@ class ImageTransformUtils(object):
         img = cv2.warpAffine(img, rmat2x3, (w,h), flags=interpolation)
         return img
 
         img = cv2.warpAffine(img, rmat2x3, (w,h), flags=interpolation)
         return img
 
+    @staticmethod
+    def reverse_channels(img):
+        if isinstance(img, np.ndarray):
+            return img[:,:,::-1]
+        elif isinstance(img, PIL.Image):
+            return PIL.Image.fromarray(np.array(img)[:,:,::-1])
+        else:
+            assert False, 'unrecognized image type'
+
     @staticmethod
     def array_to_tensor(array):
         """Converts a numpy.ndarray (H x W x C) to a torch.FloatTensor of shape (C x H x W)."""
     @staticmethod
     def array_to_tensor(array):
         """Converts a numpy.ndarray (H x W x C) to a torch.FloatTensor of shape (C x H x W)."""
index fbf9ac53908070fa0b25f62c25e2a645ab8105a5..df29b433cb630081b4f3b94e7c8b734ab3775bf3 100644 (file)
@@ -45,6 +45,20 @@ class AlignImages(object):
         return images, targets
 
 
         return images, targets
 
 
+class ReverseImageChannels(object):
+    """Reverse the channels fo the tensor. eg. RGB to BGR
+    """
+    def __call__(self, images, targets):
+        def func(imgs, img_idx):
+            imgs = [ImageTransformUtils.reverse_channels(img_plane) for img_plane in imgs] \
+                if isinstance(imgs, list) else ImageTransformUtils.reverse_channels(imgs)
+            return imgs
+
+        images = ImageTransformUtils.apply_to_list(func, images)
+        # do not apply to targets
+        return images, targets
+
+
 class ConvertToTensor(object):
     """Converts a numpy.ndarray (H x W x C) to a torch.FloatTensor of shape (C x H x W)."""
     def __call__(self, images, targets):
 class ConvertToTensor(object):
     """Converts a numpy.ndarray (H x W x C) to a torch.FloatTensor of shape (C x H x W)."""
     def __call__(self, images, targets):
index 6f09a1c52bbc68901db11ddcef0b98f498e968b4..24c2cbad34b9bd7e6795ed77a43d29a44fa305b2 100644 (file)
@@ -30,7 +30,8 @@ __all__ = ["Compose", "ToTensor", "ToPILImage", "ToFloat", "Normalize", "Resize"
            "Lambda", "RandomApply", "RandomChoice", "RandomOrder", "RandomCrop", "RandomHorizontalFlip",
            "RandomVerticalFlip", "RandomResizedCrop", "RandomSizedCrop", "FiveCrop", "TenCrop", "LinearTransformation",
            "ColorJitter", "RandomRotation", "RandomAffine", "Grayscale", "RandomGrayscale",
            "Lambda", "RandomApply", "RandomChoice", "RandomOrder", "RandomCrop", "RandomHorizontalFlip",
            "RandomVerticalFlip", "RandomResizedCrop", "RandomSizedCrop", "FiveCrop", "TenCrop", "LinearTransformation",
            "ColorJitter", "RandomRotation", "RandomAffine", "Grayscale", "RandomGrayscale",
-           "RandomPerspective", "RandomErasing", "MultiColor", "Bypass", "NormalizeMeanScale"]
+           "RandomPerspective", "RandomErasing", "MultiColor", "Bypass", "NormalizeMeanScale",
+           "ReverseChannels"]
 
 _pil_interpolation_to_str = {
     Image.NEAREST: 'PIL.Image.NEAREST',
 
 _pil_interpolation_to_str = {
     Image.NEAREST: 'PIL.Image.NEAREST',
@@ -102,6 +103,23 @@ class ToTensor(object):
         return self.__class__.__name__ + '()'
 
 
         return self.__class__.__name__ + '()'
 
 
+class ReverseChannels(object):
+    """Reverse the channels fo the tensor. eg. RGB to BGR
+    """
+    def __call__(self, pic):
+        """
+        Args:
+            image (PIL.Image)
+
+        Returns:
+            image: Converted image.
+        """
+        return Image.fromarray(np.array(pic)[:,:,::-1])
+
+    def __repr__(self):
+        return self.__class__.__name__ + '()'
+
+
 class ToPILImage(object):
     """Convert a tensor or an ndarray to PIL Image.
 
 class ToPILImage(object):
     """Convert a tensor or an ndarray to PIL Image.
 
index 3b5ac5e0809a2e2fbd542f180c31bd545716b9e8..67fba0f17beb8aa61fed3b494ace0a4b0f95385f 100644 (file)
@@ -12,4 +12,5 @@ torchvision>=0.6
 tensorboard
 onnx
 packaging
 tensorboard
 onnx
 packaging
+fast-histogram
 
 
index 782fe1cd9a0c06203573f6f01bd9770e4c7262ba..3315d85accc649f5f6c8da00465b693b3bd24644 100755 (executable)
 #### For the datasets in sections marked as "Automatic Download", dataset will be downloaded automatically downloaded before training begins. For "Manual Download", it is expected that it is manually downloaded and kept in the folder specified agaianst the --data_path option.
 
 ## =====================================================================================
 #### For the datasets in sections marked as "Automatic Download", dataset will be downloaded automatically downloaded before training begins. For "Manual Download", it is expected that it is manually downloaded and kept in the folder specified agaianst the --data_path option.
 
 ## =====================================================================================
-## Training
+## Cifar Training (Dataset will be automatically downloaded)
 ## =====================================================================================
 ## Cifar100 Classification (Automatic Download)
 #### Training with MobileNetV2
 ## =====================================================================================
 ## Cifar100 Classification (Automatic Download)
 #### Training with MobileNetV2
-#python ./scripts/train_classification_main.py --dataset_name cifar100_classification --model_name mobilenetv2_tv_x1 --data_path ./data/datasets/cifar100_classification --img_resize 32 --img_crop 32 --rand_scale 0.5 1.0 --strides 1 1 1 2 2
+#python ./scripts/train_classification_main.py --dataset_name cifar100_classification --model_name mobilenetv2_tv_x1
+#--data_path ./data/datasets/cifar100_classification --img_resize 32 --img_crop 32 --rand_scale 0.5 1.0 --strides 1 1 1 2 2
 
 ## Cifar10 Classification (Automatic Download)
 #### Training with MobileNetV2
 
 ## Cifar10 Classification (Automatic Download)
 #### Training with MobileNetV2
-#python ./scripts/train_classification_main.py --dataset_name cifar10_classification --model_name mobilenetv2_tv_x1 --data_path ./data/datasets/cifar10_classification --img_resize 32 --img_crop 32 --rand_scale 0.5 1.0 --strides 1 1 1 2 2
+#python ./scripts/train_classification_main.py --dataset_name cifar10_classification --model_name mobilenetv2_tv_x1
+#--data_path ./data/datasets/cifar10_classification --img_resize 32 --img_crop 32 --rand_scale 0.5 1.0 --strides 1 1 1 2 2
 
 
-## ImageNet Classification (Automatic Download)
-#### Training with MobileNetV2
-#python ./scripts/train_classification_main.py --dataset_name imagenet_classification --model_name mobilenetv2_tv_x1 --data_path ./data/datasets/imagenet_classification
 
 
-## ImageNet Classification (Manual Download)
+## =====================================================================================
+## ImageNet Training (Assuming ImageNet data is already Manually Downloaded)
+## =====================================================================================
+#MobileNetV2 based Models
+#------------------------
 #### Training with MobileNetV2
 #### Training with MobileNetV2
-#python ./scripts/train_classification_main.py --dataset_name image_folder_classification --model_name mobilenetv2_tv_x1 --data_path ./data/datasets/image_folder_classification
+#python ./scripts/train_classification_main.py --dataset_name image_folder_classification --model_name mobilenetv2_tv_x1
+#--data_path ./data/datasets/image_folder_classification
+
 #### Training with MobileNetV2 - Small Resolution
 #### Training with MobileNetV2 - Small Resolution
-#python ./scripts/train_classification_main.py --dataset_name image_folder_classification --model_name mobilenetv2_tv_x1 --data_path ./data/datasets/image_folder_classification --img_resize 146 --img_crop 128 --batch_size 1024 --lr 0.2 --workers 16
+#python ./scripts/train_classification_main.py --dataset_name image_folder_classification --model_name mobilenetv2_tv_x1
+#--data_path ./data/datasets/image_folder_classification --img_resize 146 --img_crop 128 --batch_size 1024 --lr 0.2 --workers 16
+
 #### Training with MobileNetV2 with 2x channels and expansion factor of 2
 #### Training with MobileNetV2 with 2x channels and expansion factor of 2
-#python ./scripts/train_classification_main.py --dataset_name image_folder_classification --model_name mobilenetv2_tv_x2_t2 --data_path ./data/datasets/image_folder_classification --batch_size 256
+#python ./scripts/train_classification_main.py --dataset_name image_folder_classification --model_name mobilenetv2_tv_x2_t2
+#--data_path ./data/datasets/image_folder_classification --batch_size 256
+
 
 
+#ResNet50 based Models
+#------------------------
 ### Training with ResNet50
 ### Training with ResNet50
-#python ./scripts/train_classification_main.py --dataset_name image_folder_classification --model_name resnet50_x1 --data_path ./data/datasets/image_folder_classification
+#python ./scripts/train_classification_main.py --dataset_name image_folder_classification --model_name resnet50_x1
+#--data_path ./data/datasets/image_folder_classification
+
 ### Training with ResNet50 - with half the number of channels - so roughly 1/4 the complexity
 ### Training with ResNet50 - with half the number of channels - so roughly 1/4 the complexity
-#python ./scripts/train_classification_main.py --dataset_name image_folder_classification --model_name resnet50_xp5 --data_path ./data/datasets/image_folder_classification
+#python ./scripts/train_classification_main.py --dataset_name image_folder_classification --model_name resnet50_xp5
+#--data_path ./data/datasets/image_folder_classification
+
+
+#RegNetX based Models
+#------------------------
+### Training with ResgNetX800MF with BGR input
+#Note: to use BGR input, set: --input_channel_reverse True, for RGB input ommit this argument or set it to False.
+#python ./scripts/train_classification_main.py --dataset_name image_folder_classification --model_name regnetx800mf_x1
+#--data_path ./data/datasets/image_folder_classification \
+#--input_channel_reverse True --image_mean 103.53 116.28 123.675 --image_scale 0.017429 0.017507 0.017125
+
 
 ## =====================================================================================
 ## Validation
 ## =====================================================================================
 #### cifar100 Validation - populate the pretrained model path below in ??
 
 ## =====================================================================================
 ## Validation
 ## =====================================================================================
 #### cifar100 Validation - populate the pretrained model path below in ??
-#python ./scripts/train_classification_main.py --phase validation --dataset_name cifar100_classification --model_name mobilenetv2_tv_x1 --data_path ./data/datasets/cifar100_classification --img_resize 32 --img_crop 32 \
+#python ./scripts/train_classification_main.py --phase validation --dataset_name cifar100_classification --model_name mobilenetv2_tv_x1
+#--data_path ./data/datasets/cifar100_classification --img_resize 32 --img_crop 32 \
 #--pretrained=???
 
 #### cifar10 Validation - populate the pretrained model path below in ??
 #--pretrained=???
 
 #### cifar10 Validation - populate the pretrained model path below in ??
-#python ./scripts/train_classification_main.py --phase validation --dataset_name cifar10_classification --model_name mobilenetv2_tv_x1 --data_path ./data/datasets/cifar10_classification --img_resize 32 --img_crop 32 \
+#python ./scripts/train_classification_main.py --phase validation --dataset_name cifar10_classification --model_name mobilenetv2_tv_x1
+#--data_path ./data/datasets/cifar10_classification --img_resize 32 --img_crop 32 \
 #--pretrained=???
 
 #--pretrained=???
 
-#### Validation - populate the pretrained model path below in ?? or use https://download.pytorch.org/models/resnet50-19c8e357.pth for resnet50_x1
-#python ./scripts/train_classification_main.py --phase validation --dataset_name image_folder_classification --model_name resnet50_x1 --data_path ./data/datasets/image_folder_classification \
-#--pretrained https://download.pytorch.org/models/resnet50-19c8e357.pth
 
 
-#### Validation - populate the pretrained model path below in ?? or use https://download.pytorch.org/models/mobilenet_v2-b0353104.pth for mobilenetv2_tv_x1
-#python ./scripts/train_classification_main.py --phase validation --dataset_name image_folder_classification --model_name mobilenetv2_tv_x1 --data_path ./data/datasets/image_folder_classification \
+#MobileNetV2 based Models
+#------------------------
+#### Validation - ImageNet - populate the pretrained model path below in ?? or use https://download.pytorch.org/models/mobilenet_v2-b0353104.pth for mobilenetv2_tv_x1
+#python ./scripts/train_classification_main.py --phase validation --dataset_name image_folder_classification --model_name mobilenetv2_tv_x1
+#--data_path ./data/datasets/image_folder_classification \
 #--pretrained https://download.pytorch.org/models/mobilenet_v2-b0353104.pth
 
 #--pretrained https://download.pytorch.org/models/mobilenet_v2-b0353104.pth
 
-#### Validation - populate the pretrained model path below in ?? for resnet50_xp5
-#python ./scripts/train_classification_main.py --phase validation --dataset_name image_folder_classification --model_name resnet50_xp5 --data_path ./data/datasets/image_folder_classification \
+#### Validation - ImageNet - populate the pretrained model path below in ?? or use https://download.pytorch.org/models/resnet50-19c8e357.pth for resnet50_x1
+#python ./scripts/train_classification_main.py --phase validation --dataset_name image_folder_classification --model_name resnet50_x1
+#--data_path ./data/datasets/image_folder_classification \
+#--pretrained https://download.pytorch.org/models/resnet50-19c8e357.pth
+
+
+#ResNet50 based Models
+#------------------------
+#### Validation - ImageNet - populate the pretrained model path below in ?? for resnet50_xp5
+#python ./scripts/train_classification_main.py --phase validation --dataset_name image_folder_classification --model_name resnet50_xp5
+#--data_path ./data/datasets/image_folder_classification \
 #--pretrained ./data/modelzoo/pytorch/image_classification/imagenet1k/jacinto_ai/resnet50-0.5_2018-07-23_12-10-23.pth
 #--pretrained ./data/modelzoo/pytorch/image_classification/imagenet1k/jacinto_ai/resnet50-0.5_2018-07-23_12-10-23.pth
+
+
+#RegNetX based Models
+#------------------------
+#### Validation - ImageNet regnetx800mf_x1 with BGR input
+#Note: to use BGR input, set: --input_channel_reverse True, for RGB input ommit this argument or set it to False.
+#python ./scripts/train_classification_main.py --phase validation --dataset_name image_folder_classification --model_name resnet50_x1
+#--data_path ./data/datasets/image_folder_classification \
+#--input_channel_reverse True --image_mean 103.53 116.28 123.675 --image_scale 0.017429 0.017507 0.017125 --model_name regnetx800mf_x1 \
+#--pretrained https://dl.fbaipublicfiles.com/pycls/dds_baselines/160906036/RegNetX-800MF_dds_8gpu.pyth
+
+
+
+## =====================================================================================
+#Training with ImageNet data download - download may take too much time - we have not tested this.
+## =====================================================================================
+#### Training with MobileNetV2
+#python ./scripts/train_classification_main.py --dataset_name imagenet_classification --model_name mobilenetv2_tv_x1
+#--data_path ./data/datasets/imagenet_classification
index 3a288e7b59c7d1103712192843a7b103e7911f7b..3b48ccd5818061e21855e9c4ae1ae16095fba42c 100755 (executable)
@@ -5,43 +5,43 @@
 ## =====================================================================================
 #
 #### Image Classification - Post Training Calibration & Quantization - ResNet50
 ## =====================================================================================
 #
 #### Image Classification - Post Training Calibration & Quantization - ResNet50
-#python ./scripts/train_classification_main.py --phase calibration --dataset_name image_folder_classification --model_name resnet50_x1 --data_path ./data/datasets/image_folder_classification --gpus 0 \
+#python ./scripts/train_classification_main.py --phase calibration --dataset_name image_folder_classification --model_name resnet50_x1 \
+#--data_path ./data/datasets/image_folder_classification --gpus 0 \
 #--pretrained https://download.pytorch.org/models/resnet50-19c8e357.pth \
 #--batch_size 64 --quantize True --epochs 1 --epoch_size 0.1 --evaluate_start False
 #
 #
 #### Image Classification - Post Training Calibration & Quantization - ResNet18
 #--pretrained https://download.pytorch.org/models/resnet50-19c8e357.pth \
 #--batch_size 64 --quantize True --epochs 1 --epoch_size 0.1 --evaluate_start False
 #
 #
 #### Image Classification - Post Training Calibration & Quantization - ResNet18
-#python ./scripts/train_classification_main.py --phase calibration --dataset_name image_folder_classification --model_name resnet18_x1 --data_path ./data/datasets/image_folder_classification --gpus 0 \
+#python ./scripts/train_classification_main.py --phase calibration --dataset_name image_folder_classification --model_name resnet18_x1 \
+#--data_path ./data/datasets/image_folder_classification --gpus 0 \
 #--pretrained https://download.pytorch.org/models/resnet18-5c106cde.pth \
 #--batch_size 64 --quantize True --epochs 1 --epoch_size 0.1 --evaluate_start False
 #
 #
 #### Image Classification - Post Training Calibration & Quantization - MobileNetV2
 #--pretrained https://download.pytorch.org/models/resnet18-5c106cde.pth \
 #--batch_size 64 --quantize True --epochs 1 --epoch_size 0.1 --evaluate_start False
 #
 #
 #### Image Classification - Post Training Calibration & Quantization - MobileNetV2
-#python ./scripts/train_classification_main.py --phase calibration --dataset_name image_folder_classification --model_name mobilenetv2_tv_x1 --data_path ./data/datasets/image_folder_classification --gpus 0 \
+#python ./scripts/train_classification_main.py --phase calibration --dataset_name image_folder_classification --model_name mobilenetv2_tv_x1 \
+#--data_path ./data/datasets/image_folder_classification --gpus 0 \
 #--pretrained https://download.pytorch.org/models/mobilenet_v2-b0353104.pth \
 #--batch_size 64 --quantize True --epochs 1 --epoch_size 0.1 --evaluate_start False
 #
 #
 #### Image Classification - Post Training Calibration & Quantization for a TOUGH MobileNetV2 pretrained model
 #--pretrained https://download.pytorch.org/models/mobilenet_v2-b0353104.pth \
 #--batch_size 64 --quantize True --epochs 1 --epoch_size 0.1 --evaluate_start False
 #
 #
 #### Image Classification - Post Training Calibration & Quantization for a TOUGH MobileNetV2 pretrained model
-#python ./scripts/train_classification_main.py --phase calibration --dataset_name image_folder_classification --model_name mobilenetv2_shicai_x1 --data_path ./data/datasets/image_folder_classification --gpus 0 \
+#python ./scripts/train_classification_main.py --phase calibration --dataset_name image_folder_classification --model_name mobilenetv2_shicai_x1 \
+#--data_path ./data/datasets/image_folder_classification --gpus 0 \
 #--pretrained ./data/modelzoo/pytorch/image_classification/imagenet1k/shicai/mobilenetv2_shicai_rgb.pth \
 #--batch_size 64 --quantize True --epochs 1 --epoch_size 0.1 --evaluate_start False
 #
 #
 #--pretrained ./data/modelzoo/pytorch/image_classification/imagenet1k/shicai/mobilenetv2_shicai_rgb.pth \
 #--batch_size 64 --quantize True --epochs 1 --epoch_size 0.1 --evaluate_start False
 #
 #
-#### Image Classification - Post Training Calibration & Quantization - ONNX Model Import
-#python ./scripts/train_classification_main.py --phase calibration --dataset_name image_folder_classification --gpus 0 \
-#--model_name resnet18-v1-7 --model /data/tensorlabdata1/modelzoo/pytorch/image_classification/imagenet1k/onnx-model-zoo/resnet18-v1-7.onnx \
-#--data_path ./data/datasets/image_folder_classification --batch_size 64 --quantize True --epochs 1 --epoch_size 0.1 --evaluate_start False
-#
-#
 ### Semantic Segmentation - Post Training Calibration &  Quantization for MobileNetV2+DeeplabV3Lite
 ### Semantic Segmentation - Post Training Calibration &  Quantization for MobileNetV2+DeeplabV3Lite
-#python ./scripts/train_segmentation_main.py --phase calibration --dataset_name cityscapes_segmentation --model_name deeplabv3lite_mobilenetv2_tv --data_path ./data/datasets/cityscapes/data --img_resize 384 768 --output_size 1024 2048 --gpus 0 \
+#python ./scripts/train_segmentation_main.py --phase calibration --dataset_name cityscapes_segmentation --model_name deeplabv3lite_mobilenetv2_tv \
+#--data_path ./data/datasets/cityscapes/data --img_resize 384 768 --output_size 1024 2048 --gpus 0 \
 #--pretrained ./data/modelzoo/pytorch/semantic_seg/cityscapes/jacinto_ai/deeplabv3lite_mobilenetv2_tv_768x384_best.pth \
 #--batch_size 6 --quantize True --epochs 1 --evaluate_start False
 #
 #
 ### Semantic Segmentation - Post Training Calibration &  Quantization for MobileNetV2+UNetLite
 #--pretrained ./data/modelzoo/pytorch/semantic_seg/cityscapes/jacinto_ai/deeplabv3lite_mobilenetv2_tv_768x384_best.pth \
 #--batch_size 6 --quantize True --epochs 1 --evaluate_start False
 #
 #
 ### Semantic Segmentation - Post Training Calibration &  Quantization for MobileNetV2+UNetLite
-#python ./scripts/train_segmentation_main.py --phase calibration --dataset_name cityscapes_segmentation --model_name unetlite_pixel2pixel_aspp_mobilenetv2_tv --data_path ./data/datasets/cityscapes/data --img_resize 384 768 --output_size 1024 2048 --gpus 0 \
+#python ./scripts/train_segmentation_main.py --phase calibration --dataset_name cityscapes_segmentation --model_name unetlite_pixel2pixel_aspp_mobilenetv2_tv \
+#--data_path ./data/datasets/cityscapes/data --img_resize 384 768 --output_size 1024 2048 --gpus 0 \
 #--pretrained ./data/modelzoo/pytorch/semantic_seg/cityscapes/jacinto_ai/unet_aspp_mobilenetv2_tv_768x384_best.pth \
 #--batch_size 6 --quantize True --epochs 1 --evaluate_start False
 #
 #--pretrained ./data/modelzoo/pytorch/semantic_seg/cityscapes/jacinto_ai/unet_aspp_mobilenetv2_tv_768x384_best.pth \
 #--batch_size 6 --quantize True --epochs 1 --evaluate_start False
 #
 ## =====================================================================================
 #
 #### Image Classification - Quantization Aware Training - MobileNetV2
 ## =====================================================================================
 #
 #### Image Classification - Quantization Aware Training - MobileNetV2
-#python ./scripts/train_classification_main.py --dataset_name image_folder_classification --model_name mobilenetv2_tv_x1 --data_path ./data/datasets/image_folder_classification \
+#python ./scripts/train_classification_main.py --dataset_name image_folder_classification --model_name mobilenetv2_tv_x1 \
+#--data_path ./data/datasets/image_folder_classification \
 #--pretrained https://download.pytorch.org/models/mobilenet_v2-b0353104.pth \
 #--pretrained https://download.pytorch.org/models/mobilenet_v2-b0353104.pth \
-#--batch_size 64 --quantize True --epochs 50 --epoch_size 0.1 --lr 1e-5 --evaluate_start False
+#--batch_size 64 --quantize True --epoch_size 0.1 --epochs 50 --lr 1e-5 --evaluate_start False
 #
 #
 #### Image Classification - Quantization Aware Training - MobileNetV2(Shicai) - a TOUGH MobileNetV2 pretrained model
 #
 #
 #### Image Classification - Quantization Aware Training - MobileNetV2(Shicai) - a TOUGH MobileNetV2 pretrained model
-#python ./scripts/train_classification_main.py --dataset_name image_folder_classification --model_name mobilenetv2_shicai_x1 --data_path ./data/datasets/image_folder_classification \
+#python ./scripts/train_classification_main.py --dataset_name image_folder_classification --model_name mobilenetv2_shicai_x1 \
+#--data_path ./data/datasets/image_folder_classification \
 #--pretrained ./data/modelzoo/pytorch/image_classification/imagenet1k/shicai/mobilenetv2_shicai_rgb.pth \
 #--pretrained ./data/modelzoo/pytorch/image_classification/imagenet1k/shicai/mobilenetv2_shicai_rgb.pth \
-#--batch_size 64 --quantize True --epochs 50 --epoch_size 0.1 --lr 1e-5 --evaluate_start False
+#--batch_size 64 --quantize True --epoch_size 0.1 --epochs 50 --lr 1e-5 --evaluate_start False
 #
 #
 #### Semantic Segmentation - Quantization Aware Training for MobileNetV2+DeeplabV3Lite
 #
 #
 #### Semantic Segmentation - Quantization Aware Training for MobileNetV2+DeeplabV3Lite
-#python ./scripts/train_segmentation_main.py --dataset_name cityscapes_segmentation --model_name deeplabv3lite_mobilenetv2_tv --data_path ./data/datasets/cityscapes/data --img_resize 384 768 --output_size 1024 2048 --gpus 0 1 \
+#python ./scripts/train_segmentation_main.py --dataset_name cityscapes_segmentation --model_name deeplabv3lite_mobilenetv2_tv \
+#--data_path ./data/datasets/cityscapes/data --img_resize 384 768 --output_size 1024 2048 --gpus 0 1 \
 #--pretrained ./data/modelzoo/pytorch/semantic_seg/cityscapes/jacinto_ai/deeplabv3lite_mobilenetv2_tv_768x384_best.pth \
 #--pretrained ./data/modelzoo/pytorch/semantic_seg/cityscapes/jacinto_ai/deeplabv3lite_mobilenetv2_tv_768x384_best.pth \
-#--batch_size 12 --quantize True --epochs 50 --lr 1e-5 --evaluate_start False
+#--batch_size 12 --quantize True --epochs 50  --lr 1e-5 --evaluate_start False
 #
 #
 #### Semantic Segmentation - Quantization Aware Training for MobileNetV2+UNetLite
 #
 #
 #### Semantic Segmentation - Quantization Aware Training for MobileNetV2+UNetLite
-#python ./scripts/train_segmentation_main.py --dataset_name cityscapes_segmentation --model_name unetlite_pixel2pixel_aspp_mobilenetv2_tv --data_path ./data/datasets/cityscapes/data --img_resize 384 768 --output_size 1024 2048 --gpus 0 1 \
+#python ./scripts/train_segmentation_main.py --dataset_name cityscapes_segmentation --model_name unetlite_pixel2pixel_aspp_mobilenetv2_tv \
+#--data_path ./data/datasets/cityscapes/data --img_resize 384 768 --output_size 1024 2048 --gpus 0 1 \
 #--pretrained ./data/modelzoo/pytorch/semantic_seg/cityscapes/jacinto_ai/unet_aspp_mobilenetv2_tv_768x384_best.pth \
 #--batch_size 12 --quantize True --epochs 50 --lr 1e-5 --evaluate_start False
 #
 #
 #--pretrained ./data/modelzoo/pytorch/semantic_seg/cityscapes/jacinto_ai/unet_aspp_mobilenetv2_tv_768x384_best.pth \
 #--batch_size 12 --quantize True --epochs 50 --lr 1e-5 --evaluate_start False
 #
 #
-### Depth Estimation - Post Training Calibration &  Quantization for MobileNetV2+DeeplabV3Lite
+### Depth Estimation - Quantization Aware Training for MobileNetV2+DeeplabV3Lite
 #python ./scripts/train_depth_main.py --dataset_name kitti_depth --model_name deeplabv3lite_mobilenetv2_tv --img_resize 384 768 --output_size 1024 2048 \
 #--pretrained ./data/modelzoo/pytorch/monocular_depth/kitti_depth/jacinto_ai/deeplabv3lite_mobilenetv2_tv_768x384_best.pth \
 #--batch_size 32 --quantize True --epochs 50 --lr 1e-5 --evaluate_start False
 #python ./scripts/train_depth_main.py --dataset_name kitti_depth --model_name deeplabv3lite_mobilenetv2_tv --img_resize 384 768 --output_size 1024 2048 \
 #--pretrained ./data/modelzoo/pytorch/monocular_depth/kitti_depth/jacinto_ai/deeplabv3lite_mobilenetv2_tv_768x384_best.pth \
 #--batch_size 32 --quantize True --epochs 50 --lr 1e-5 --evaluate_start False
 ## =====================================================================================
 #
 #### Image Classification - Accuracy Estimation with Post Training Quantization - MobileNetV2
 ## =====================================================================================
 #
 #### Image Classification - Accuracy Estimation with Post Training Quantization - MobileNetV2
-#python ./scripts/train_classification_main.py --phase validation --dataset_name image_folder_classification --model_name mobilenetv2_tv_x1 --data_path ./data/datasets/image_folder_classification \
+#python ./scripts/train_classification_main.py --phase validation --dataset_name image_folder_classification --model_name mobilenetv2_tv_x1 \
+#--data_path ./data/datasets/image_folder_classification \
 #--pretrained https://download.pytorch.org/models/mobilenet_v2-b0353104.pth \
 #--batch_size 64 --quantize True
 #
 #### Image Classification - Accuracy Estimation with Post Training Quantization - ResNet50
 #--pretrained https://download.pytorch.org/models/mobilenet_v2-b0353104.pth \
 #--batch_size 64 --quantize True
 #
 #### Image Classification - Accuracy Estimation with Post Training Quantization - ResNet50
-#python ./scripts/train_classification_main.py --phase validation --dataset_name image_folder_classification --model_name resnet50_x1 --data_path ./data/datasets/image_folder_classification \
+#python ./scripts/train_classification_main.py --phase validation --dataset_name image_folder_classification --model_name resnet50_x1 \
+#--data_path ./data/datasets/image_folder_classification \
 #--pretrained https://download.pytorch.org/models/resnet50-19c8e357.pth \
 #--batch_size 64 --quantize True
 #
 #### Image Classification - Accuracy Estimation with Post Training Quantization - ResNet18
 #--pretrained https://download.pytorch.org/models/resnet50-19c8e357.pth \
 #--batch_size 64 --quantize True
 #
 #### Image Classification - Accuracy Estimation with Post Training Quantization - ResNet18
-#python ./scripts/train_classification_main.py --phase validation --dataset_name image_folder_classification --model_name resnet18_x1 --data_path ./data/datasets/image_folder_classification \
+#python ./scripts/train_classification_main.py --phase validation --dataset_name image_folder_classification --model_name resnet18_x1 \
+#--data_path ./data/datasets/image_folder_classification \
 #--pretrained https://download.pytorch.org/models/resnet18-5c106cde.pth \
 #--batch_size 64 --quantize True
 #
 #### Image Classification - Accuracy Estimation with Post Training Quantization - A TOUGH MobileNetV2 pretrained model
 #--pretrained https://download.pytorch.org/models/resnet18-5c106cde.pth \
 #--batch_size 64 --quantize True
 #
 #### Image Classification - Accuracy Estimation with Post Training Quantization - A TOUGH MobileNetV2 pretrained model
-#python ./scripts/train_classification_main.py --phase validation --dataset_name image_folder_classification --model_name mobilenetv2_shicai_x1 --data_path ./data/datasets/image_folder_classification \
+#python ./scripts/train_classification_main.py --phase validation --dataset_name image_folder_classification --model_name mobilenetv2_shicai_x1 \
+#--data_path ./data/datasets/image_folder_classification \
 #--pretrained ./data/modelzoo/pytorch/image_classification/imagenet1k/shicai/mobilenetv2_shicai_rgb.pth \
 #--batch_size 64 --quantize True
 #
 #### Semantic Segmentation - Accuracy Estimation with Post Training Quantization - MobileNetV2+DeeplabV3Lite
 #--pretrained ./data/modelzoo/pytorch/image_classification/imagenet1k/shicai/mobilenetv2_shicai_rgb.pth \
 #--batch_size 64 --quantize True
 #
 #### Semantic Segmentation - Accuracy Estimation with Post Training Quantization - MobileNetV2+DeeplabV3Lite
-#python ./scripts/train_segmentation_main.py --phase validation --dataset_name cityscapes_segmentation --model_name deeplabv3lite_mobilenetv2_tv --data_path ./data/datasets/cityscapes/data --img_resize 384 768 --output_size 1024 2048 --gpus 0 1 \
+#python ./scripts/train_segmentation_main.py --phase validation --dataset_name cityscapes_segmentation --model_name deeplabv3lite_mobilenetv2_tv \
+#--data_path ./data/datasets/cityscapes/data --img_resize 384 768 --output_size 1024 2048 --gpus 0 1 \
 #--pretrained './data/modelzoo/pytorch/semantic_seg/cityscapes/jacinto_ai/deeplabv3lite_mobilenetv2_tv_768x384_best.pth' \
 #--batch_size 1 --quantize True
 #
 #### Semantic Segmentation - Accuracy Estimation with Post Training Quantization - MobileNetV2+UNetLite
 #--pretrained './data/modelzoo/pytorch/semantic_seg/cityscapes/jacinto_ai/deeplabv3lite_mobilenetv2_tv_768x384_best.pth' \
 #--batch_size 1 --quantize True
 #
 #### Semantic Segmentation - Accuracy Estimation with Post Training Quantization - MobileNetV2+UNetLite
-#python ./scripts/train_segmentation_main.py --phase validation --dataset_name cityscapes_segmentation --model_name unetlite_pixel2pixel_aspp_mobilenetv2_tv --data_path ./data/datasets/cityscapes/data --img_resize 384 768 --output_size 1024 2048 --gpus 0 1 \
+#python ./scripts/train_segmentation_main.py --phase validation --dataset_name cityscapes_segmentation --model_name unetlite_pixel2pixel_aspp_mobilenetv2_tv \
+#--data_path ./data/datasets/cityscapes/data --img_resize 384 768 --output_size 1024 2048 --gpus 0 1 \
 #--pretrained ./data/modelzoo/pytorch/semantic_seg/cityscapes/jacinto_ai/unet_aspp_mobilenetv2_tv_768x384_best.pth \
 #--pretrained ./data/modelzoo/pytorch/semantic_seg/cityscapes/jacinto_ai/unet_aspp_mobilenetv2_tv_768x384_best.pth \
-#--batch_size 1 --quantize True
\ No newline at end of file
+#--batch_size 1 --quantize True
+
+
+## =====================================================================================
+# Not completely supported feature - ONNX Model Import and PTQ
+## =====================================================================================
+#### Image Classification - Post Training Calibration & Quantization
+#python ./scripts/train_classification_main.py --phase calibration --dataset_name image_folder_classification --gpus 0 \
+#--model_name resnet18-v1-7 --model /data/tensorlabdata1/modelzoo/pytorch/image_classification/imagenet1k/onnx-model-zoo/resnet18-v1-7.onnx \
+#--data_path ./data/datasets/image_folder_classification --batch_size 64 --quantize True --epochs 1 --epoch_size 0.1 --evaluate_start False
\ No newline at end of file
index d25a5d8e1971f7b7807cb75494e07fc7776bd06b..11616e43e03bad07b5a5198ee3f06423028668a3 100755 (executable)
 # unetlite_pixel2pixel_aspp_mobilenetv2_tv_fd: low complexity model with fast downsampling strategy
 #
 # deeplabv3lite_resnet50: uses resnet50 encoder
 # unetlite_pixel2pixel_aspp_mobilenetv2_tv_fd: low complexity model with fast downsampling strategy
 #
 # deeplabv3lite_resnet50: uses resnet50 encoder
-# deeplabv3lite_resnet50_p5: low complexity model - uses resnet50 encoder with half the number of channels (1/4 the complexity). note this need specially trained resnet50 pretrained weights
+# deeplabv3lite_resnet50_p5: low complexity model - uses resnet50 encoder with half the number of channels (1/4 the complexity).
+# note this need specially trained resnet50 pretrained weights
 # fpnlite_pixel2pixel_aspp_resnet50_fd: low complexity model - with fast downsampling strategy
 
 # fpnlite_pixel2pixel_aspp_resnet50_fd: low complexity model - with fast downsampling strategy
 
+# unetlite_pixel2pixel_aspp_regnetx800mf: uses regnetx800mf encoder and group_width according to that (even in the decoder)
+# fpnlite_pixel2pixel_aspp_regnetx800mf: uses regnetx800mf encoder and group_width according to that (even in the decoder)
+# deeplabv3lite_regnetx800mf: uses regnetx800mf encoder and group_width according to that (even in the decoder)
+
 
 ## =====================================================================================
 ## Training
 ## =====================================================================================
 
 ## =====================================================================================
 ## Training
 ## =====================================================================================
+#MobileNetV2 based Models
+#------------------------
 #### Cityscapes Semantic Segmentation - Training with MobileNetV2+DeeplabV3Lite
 #### Cityscapes Semantic Segmentation - Training with MobileNetV2+DeeplabV3Lite
-#python ./scripts/train_segmentation_main.py --dataset_name cityscapes_segmentation --model_name deeplabv3lite_mobilenetv2_tv --data_path ./data/datasets/cityscapes/data --img_resize 384 768 --output_size 1024 2048 --gpus 0 1 \
+#python ./scripts/train_segmentation_main.py --dataset_name cityscapes_segmentation --model_name deeplabv3lite_mobilenetv2_tv \
+#--data_path ./data/datasets/cityscapes/data --img_resize 384 768 --output_size 1024 2048 --gpus 0 1 \
 #--pretrained https://download.pytorch.org/models/mobilenet_v2-b0353104.pth
 
 #### Cityscapes Semantic Segmentation - Training with MobileNetV2+DeeplabV3Lite, Higher Resolution
 #--pretrained https://download.pytorch.org/models/mobilenet_v2-b0353104.pth
 
 #### Cityscapes Semantic Segmentation - Training with MobileNetV2+DeeplabV3Lite, Higher Resolution
-#python ./scripts/train_segmentation_main.py --dataset_name cityscapes_segmentation --model_name deeplabv3lite_mobilenetv2_tv --data_path ./data/datasets/cityscapes/data --img_resize 768 1536 --rand_crop 512 1024 --output_size 1024 2048 --gpus 0 1 \
+#python ./scripts/train_segmentation_main.py --dataset_name cityscapes_segmentation --model_name deeplabv3lite_mobilenetv2_tv \
+#--data_path ./data/datasets/cityscapes/data --img_resize 768 1536 --rand_crop 512 1024 --output_size 1024 2048 --gpus 0 1 \
 #--pretrained https://download.pytorch.org/models/mobilenet_v2-b0353104.pth
 
 #### Cityscapes Semantic Segmentation - original fpn - no aspp model, stride 64 model, Higher Resolution - Low Complexity Model
 #--pretrained https://download.pytorch.org/models/mobilenet_v2-b0353104.pth
 
 #### Cityscapes Semantic Segmentation - original fpn - no aspp model, stride 64 model, Higher Resolution - Low Complexity Model
-#python ./scripts/train_segmentation_main.py --dataset_name cityscapes_segmentation --model_name deeplabv3lite_mobilenetv2_tv_fd --data_path ./data/datasets/cityscapes/data --img_resize 768 1536 --rand_crop 512 1024 --output_size 1024 2048 --gpus 0 1 \
+#python ./scripts/train_segmentation_main.py --dataset_name cityscapes_segmentation --model_name deeplabv3lite_mobilenetv2_tv_fd \
+#--data_path ./data/datasets/cityscapes/data --img_resize 768 1536 --rand_crop 512 1024 --output_size 1024 2048 --gpus 0 1 \
 #--pretrained https://download.pytorch.org/models/mobilenet_v2-b0353104.pth
 
 
 
 #--pretrained https://download.pytorch.org/models/mobilenet_v2-b0353104.pth
 
 
 
-##--ResNet50 + deeplabv3lite
+#ResNet50 based Models
+#------------------------
 #### Cityscapes Semantic Segmentation - Training with ResNet50+DeeplabV3Lite
 #### Cityscapes Semantic Segmentation - Training with ResNet50+DeeplabV3Lite
-#python ./scripts/train_segmentation_main.py --dataset_name cityscapes_segmentation --model_name deeplabv3lite_resnet50 --data_path ./data/datasets/cityscapes/data --img_resize 384 768 --output_size 1024 2048 --gpus 0 1 \
+#python ./scripts/train_segmentation_main.py --dataset_name cityscapes_segmentation --model_name deeplabv3lite_resnet50 \
+#--data_path ./data/datasets/cityscapes/data --img_resize 384 768 --output_size 1024 2048 --gpus 0 1 \
 #--pretrained https://download.pytorch.org/models/resnet50-19c8e357.pth
 
 #### Cityscapes Semantic Segmentation - Training with FD-ResNet50+FPN - High Resolution - Low Complexity Model
 #--pretrained https://download.pytorch.org/models/resnet50-19c8e357.pth
 
 #### Cityscapes Semantic Segmentation - Training with FD-ResNet50+FPN - High Resolution - Low Complexity Model
-#python ./scripts/train_segmentation_main.py --dataset_name cityscapes_segmentation --model_name deeplabv3lite_resnet50_fd --data_path ./data/datasets/cityscapes/data --img_resize 768 1536 --rand_crop 512 1024 --output_size 1024 2048 --gpus 0 1 \
+#python ./scripts/train_segmentation_main.py --dataset_name cityscapes_segmentation --model_name deeplabv3lite_resnet50_fd \
+#--data_path ./data/datasets/cityscapes/data --img_resize 768 1536 --rand_crop 512 1024 --output_size 1024 2048 --gpus 0 1 \
 #--pretrained https://download.pytorch.org/models/resnet50-19c8e357.pth
 
 #--pretrained https://download.pytorch.org/models/resnet50-19c8e357.pth
 
-#--ResNet50 encoder with half the channels + deeplabv3lite
-#### Cityscapes Semantic Segmentation - Training with ResNet50_p5+DeeplabV3Lite (ResNet50 encoder with half the channels): deeplabv3lite_resnet50_p5 & deeplabv3lite_resnet50_p5_fd
-#python ./scripts/train_segmentation_main.py --dataset_name cityscapes_segmentation --model_name deeplabv3lite_resnet50_p5 --data_path ./data/datasets/cityscapes/data --img_resize 384 768 --output_size 1024 2048 --gpus 0 1 \
+#### Cityscapes Semantic Segmentation - Training with ResNet50_p5+DeeplabV3Lite (ResNet50 encoder with half the channels):
+#python ./scripts/train_segmentation_main.py --dataset_name cityscapes_segmentation --model_name deeplabv3lite_resnet50_p5 \
+#--data_path ./data/datasets/cityscapes/data --img_resize 384 768 --output_size 1024 2048 --gpus 0 1 \
 #--pretrained "./data/modelzoo/pretrained/pytorch/imagenet_classification/jacinto_ai/resnet50-0.5_2018-07-23_12-10-23.pth"
 
 
 #--pretrained "./data/modelzoo/pretrained/pytorch/imagenet_classification/jacinto_ai/resnet50-0.5_2018-07-23_12-10-23.pth"
 
 
+#RegNetX based Models
+#------------------------
+### Cityscapes Semantic Segmentation - Training with RegNetX800MF+DeeplabV3Lite
+#Note: to use BGR input, set: --input_channel_reverse True, for RGB input ommit this argument or set it to False.
+#python ./scripts/train_segmentation_main.py --dataset_name cityscapes_segmentation --model_name fpnlite_pixel2pixel_aspp_regnetx800mf \
+#--data_path ./data/datasets/cityscapes/data --img_resize 384 768 --output_size 1024 2048 --gpus 0 1 \
+#--pretrained https://dl.fbaipublicfiles.com/pycls/dds_baselines/160906036/RegNetX-800MF_dds_8gpu.pyth
+
+
+
+
 #-- VOC Segmentation
 #### VOC Segmentation - Training with MobileNetV2+DeeplabV3Lite
 #-- VOC Segmentation
 #### VOC Segmentation - Training with MobileNetV2+DeeplabV3Lite
-#python ./scripts/train_segmentation_main.py --dataset_name voc_segmentation --model_name deeplabv3lite_mobilenetv2_tv --data_path ./data/datasets/voc --img_resize 512 512 --output_size 512 512 --gpus 0 1 \
+#python ./scripts/train_segmentation_main.py --dataset_name voc_segmentation --model_name deeplabv3lite_mobilenetv2_tv \
+#--data_path ./data/datasets/voc --img_resize 512 512 --output_size 512 512 --gpus 0 1 \
 #--pretrained https://download.pytorch.org/models/mobilenet_v2-b0353104.pth
 
 
 #--pretrained https://download.pytorch.org/models/mobilenet_v2-b0353104.pth
 
 
 ## Validation
 ## =====================================================================================
 #### Validation - Cityscapes Semantic Segmentation - Validation with MobileNetV2+DeeplabV3Lite - populate the pretrained filename in ??
 ## Validation
 ## =====================================================================================
 #### Validation - Cityscapes Semantic Segmentation - Validation with MobileNetV2+DeeplabV3Lite - populate the pretrained filename in ??
-#python ./scripts/train_segmentation_main.py --phase validation --dataset_name cityscapes_segmentation --model_name deeplabv3lite_mobilenetv2_tv --data_path ./data/datasets/cityscapes/data --img_resize 384 768 --output_size 1024 2048 --gpus 0 1 \
+#python ./scripts/train_segmentation_main.py --phase validation --dataset_name cityscapes_segmentation --model_name deeplabv3lite_mobilenetv2_tv \
+#--data_path ./data/datasets/cityscapes/data --img_resize 384 768 --output_size 1024 2048 --gpus 0 1 \
 #--pretrained ??
 
 #### Inference - Cityscapes Semantic Segmentation - Inference with MobileNetV2+DeeplabV3Lite - populate the pretrained filename in ??
 #--pretrained ??
 
 #### Inference - Cityscapes Semantic Segmentation - Inference with MobileNetV2+DeeplabV3Lite - populate the pretrained filename in ??
-#python ./scripts/infer_segmentation_main.py --dataset_name cityscapes_segmentation_measure --model_name deeplabv3lite_mobilenetv2_tv --data_path ./data/datasets/cityscapes/data --img_resize 384 768 --output_size 1024 2048 --gpus 0 1 \
+#python ./scripts/infer_segmentation_main.py --dataset_name cityscapes_segmentation_measure --model_name deeplabv3lite_mobilenetv2_tv \
+#--data_path ./data/datasets/cityscapes/data --img_resize 384 768 --output_size 1024 2048 --gpus 0 1 \
 #--pretrained ???
 
 #### Validation - VOC Segmentation - Validation with MobileNetV2+DeeplabV3Lite - populate the pretrained filename in ??
 #--pretrained ???
 
 #### Validation - VOC Segmentation - Validation with MobileNetV2+DeeplabV3Lite - populate the pretrained filename in ??
-#python ./scripts/train_segmentation_main.py --phase validation --dataset_name voc_segmentation_measure --model_name deeplabv3lite_mobilenetv2_tv --data_path ./data/datasets/voc --img_resize 512 512 --output_size 512 512 --gpus 0 1
+#python ./scripts/train_segmentation_main.py --phase validation --dataset_name voc_segmentation_measure --model_name deeplabv3lite_mobilenetv2_tv \
+#--data_path ./data/datasets/voc --img_resize 512 512 --output_size 512 512 --gpus 0 1
 #--phase validation --pretrained ???
 
 #--phase validation --pretrained ???
 
-
-
-
index ebd9b02f1089011e7dfcfa7b3644a853692d09b3..a6bd657e05b543eddb192925ff55928869ea789f 100755 (executable)
@@ -10,6 +10,9 @@ import numpy as np
 ################################
 from pytorch_jacinto_ai.xnn.utils import str2bool
 parser = argparse.ArgumentParser()
 ################################
 from pytorch_jacinto_ai.xnn.utils import str2bool
 parser = argparse.ArgumentParser()
+parser.add_argument('--image_mean', type=float, nargs='*', default=None, help='image_mean')
+parser.add_argument('--image_scale', type=float, nargs='*', default=None, help='image_scale')
+parser.add_argument('--input_channel_reverse', type=str2bool, default=None, help='input_channel_reverse, for example rgb to bgr')
 parser.add_argument('--save_path', type=str, default=None, help='checkpoint save folder')
 parser.add_argument('--gpus', type=int, nargs='*', default=None, help='Base learning rate')
 parser.add_argument('--batch_size', type=int, default=None, help='Batch size')
 parser.add_argument('--save_path', type=str, default=None, help='checkpoint save folder')
 parser.add_argument('--gpus', type=int, nargs='*', default=None, help='Base learning rate')
 parser.add_argument('--batch_size', type=int, default=None, help='Batch size')
index 2fb707818086bd2dd310832d63585f68296de585..9fdfe61fb52cf441b41ac2b512933713e6a3f21a 100755 (executable)
@@ -10,6 +10,9 @@ import numpy as np
 ################################
 from pytorch_jacinto_ai.xnn.utils import str2bool
 parser = argparse.ArgumentParser()
 ################################
 from pytorch_jacinto_ai.xnn.utils import str2bool
 parser = argparse.ArgumentParser()
+parser.add_argument('--image_mean', type=float, nargs='*', default=None, help='image_mean')
+parser.add_argument('--image_scale', type=float, nargs='*', default=None, help='image_scale')
+parser.add_argument('--input_channel_reverse', type=str2bool, default=None, help='input_channel_reverse, for example rgb to bgr')
 parser.add_argument('--save_path', type=str, default=None, help='checkpoint save folder')
 parser.add_argument('--gpus', type=int, nargs='*', default=None, help='Base learning rate')
 parser.add_argument('--batch_size', type=int, default=None, help='Batch size')
 parser.add_argument('--save_path', type=str, default=None, help='checkpoint save folder')
 parser.add_argument('--gpus', type=int, nargs='*', default=None, help='Base learning rate')
 parser.add_argument('--batch_size', type=int, default=None, help='Batch size')