release commit
authorManu Mathew <a0393608@ti.com>
Tue, 17 Dec 2019 05:26:37 +0000 (10:56 +0530)
committerManu Mathew <a0393608@ti.com>
Tue, 17 Dec 2019 05:26:37 +0000 (10:56 +0530)
181 files changed:
.gitignore [new file with mode: 0644]
.gitmodules [new file with mode: 0644]
LICENSE [new file with mode: 0644]
README.md [new file with mode: 0644]
data/checkpoints/readme.txt [new file with mode: 0644]
data/datasets/readme.txt [new file with mode: 0644]
data/downloads/readme.txt [new file with mode: 0644]
docs/Depth_Estimation.md [new file with mode: 0644]
docs/Image_Classification.md [new file with mode: 0644]
docs/Keypoint_Estimation.md [new file with mode: 0644]
docs/Motion_Segmentation.md [new file with mode: 0644]
docs/Multi_Task_Learning.md [new file with mode: 0644]
docs/Object_Detection.md [new file with mode: 0644]
docs/Quantization.md [new file with mode: 0644]
docs/Semantic_Segmentation.md [new file with mode: 0644]
docs/motion_segmentation/motion_segmentation_network.PNG [new file with mode: 0755]
docs/multi_task_learning/multi_task_network.PNG [new file with mode: 0755]
examples/write_onnx_model_example.py [new file with mode: 0644]
modules/__init__.py [new file with mode: 0644]
modules/pytorch_jacinto_ai/__init__.py [new file with mode: 0644]
modules/pytorch_jacinto_ai/engine/__init__.py [new file with mode: 0644]
modules/pytorch_jacinto_ai/engine/evaluate_pixel2pixel.py [new file with mode: 0644]
modules/pytorch_jacinto_ai/engine/infer_pixel2pixel.py [new file with mode: 0644]
modules/pytorch_jacinto_ai/engine/test_classification.py [new file with mode: 0644]
modules/pytorch_jacinto_ai/engine/test_pixel2pixel_onnx.py [new file with mode: 0644]
modules/pytorch_jacinto_ai/engine/train_classification.py [new file with mode: 0644]
modules/pytorch_jacinto_ai/engine/train_pixel2pixel.py [new file with mode: 0644]
modules/pytorch_jacinto_ai/vision/__init__.py [new file with mode: 0644]
modules/pytorch_jacinto_ai/vision/datasets/__init__.py [new file with mode: 0644]
modules/pytorch_jacinto_ai/vision/datasets/caltech.py [new file with mode: 0644]
modules/pytorch_jacinto_ai/vision/datasets/celeba.py [new file with mode: 0644]
modules/pytorch_jacinto_ai/vision/datasets/cifar.py [new file with mode: 0644]
modules/pytorch_jacinto_ai/vision/datasets/cityscapes.py [new file with mode: 0644]
modules/pytorch_jacinto_ai/vision/datasets/classification/__init__.py [new file with mode: 0644]
modules/pytorch_jacinto_ai/vision/datasets/coco.py [new file with mode: 0644]
modules/pytorch_jacinto_ai/vision/datasets/fakedata.py [new file with mode: 0644]
modules/pytorch_jacinto_ai/vision/datasets/flickr.py [new file with mode: 0644]
modules/pytorch_jacinto_ai/vision/datasets/folder.py [new file with mode: 0644]
modules/pytorch_jacinto_ai/vision/datasets/hmdb51.py [new file with mode: 0644]
modules/pytorch_jacinto_ai/vision/datasets/imagenet.py [new file with mode: 0644]
modules/pytorch_jacinto_ai/vision/datasets/kinetics.py [new file with mode: 0644]
modules/pytorch_jacinto_ai/vision/datasets/lsun.py [new file with mode: 0644]
modules/pytorch_jacinto_ai/vision/datasets/mnist.py [new file with mode: 0644]
modules/pytorch_jacinto_ai/vision/datasets/omniglot.py [new file with mode: 0644]
modules/pytorch_jacinto_ai/vision/datasets/phototour.py [new file with mode: 0644]
modules/pytorch_jacinto_ai/vision/datasets/pixel2pixel/__init__.py [new file with mode: 0644]
modules/pytorch_jacinto_ai/vision/datasets/pixel2pixel/calculate_class_weight.py [new file with mode: 0644]
modules/pytorch_jacinto_ai/vision/datasets/pixel2pixel/cityscapes_plus.py [new file with mode: 0755]
modules/pytorch_jacinto_ai/vision/datasets/pixel2pixel/dataset_utils.py [new file with mode: 0644]
modules/pytorch_jacinto_ai/vision/datasets/pixel2pixel/flyingchairs.py [new file with mode: 0644]
modules/pytorch_jacinto_ai/vision/datasets/pixel2pixel/kitti_depth.py [new file with mode: 0644]
modules/pytorch_jacinto_ai/vision/datasets/pixel2pixel/kitti_sceneflow.py [new file with mode: 0644]
modules/pytorch_jacinto_ai/vision/datasets/pixel2pixel/mpisintel.py [new file with mode: 0644]
modules/pytorch_jacinto_ai/vision/datasets/pixel2pixel/segmentation.py [new file with mode: 0644]
modules/pytorch_jacinto_ai/vision/datasets/sbd.py [new file with mode: 0644]
modules/pytorch_jacinto_ai/vision/datasets/sbu.py [new file with mode: 0644]
modules/pytorch_jacinto_ai/vision/datasets/semeion.py [new file with mode: 0644]
modules/pytorch_jacinto_ai/vision/datasets/stl10.py [new file with mode: 0644]
modules/pytorch_jacinto_ai/vision/datasets/svhn.py [new file with mode: 0644]
modules/pytorch_jacinto_ai/vision/datasets/ucf101.py [new file with mode: 0644]
modules/pytorch_jacinto_ai/vision/datasets/usps.py [new file with mode: 0644]
modules/pytorch_jacinto_ai/vision/datasets/utils.py [new file with mode: 0644]
modules/pytorch_jacinto_ai/vision/datasets/video_utils.py [new file with mode: 0644]
modules/pytorch_jacinto_ai/vision/datasets/vision.py [new file with mode: 0644]
modules/pytorch_jacinto_ai/vision/datasets/voc.py [new file with mode: 0644]
modules/pytorch_jacinto_ai/vision/extension.py [new file with mode: 0644]
modules/pytorch_jacinto_ai/vision/io/__init__.py [new file with mode: 0644]
modules/pytorch_jacinto_ai/vision/io/video.py [new file with mode: 0644]
modules/pytorch_jacinto_ai/vision/losses/__init__.py [new file with mode: 0644]
modules/pytorch_jacinto_ai/vision/losses/basic_loss.py [new file with mode: 0644]
modules/pytorch_jacinto_ai/vision/losses/flow_loss.py [new file with mode: 0644]
modules/pytorch_jacinto_ai/vision/losses/interest_pt_loss.py [new file with mode: 0644]
modules/pytorch_jacinto_ai/vision/losses/loss_utils.py [new file with mode: 0644]
modules/pytorch_jacinto_ai/vision/losses/norm_loss.py [new file with mode: 0644]
modules/pytorch_jacinto_ai/vision/losses/scale_loss.py [new file with mode: 0644]
modules/pytorch_jacinto_ai/vision/losses/segmentation_loss.py [new file with mode: 0755]
modules/pytorch_jacinto_ai/vision/models/__init__.py [new file with mode: 0644]
modules/pytorch_jacinto_ai/vision/models/_utils.py [new file with mode: 0644]
modules/pytorch_jacinto_ai/vision/models/alexnet.py [new file with mode: 0644]
modules/pytorch_jacinto_ai/vision/models/classification/__init__.py [new file with mode: 0644]
modules/pytorch_jacinto_ai/vision/models/densenet.py [new file with mode: 0644]
modules/pytorch_jacinto_ai/vision/models/detection/__init__.py [new file with mode: 0644]
modules/pytorch_jacinto_ai/vision/models/detection/_utils.py [new file with mode: 0644]
modules/pytorch_jacinto_ai/vision/models/detection/backbone_utils.py [new file with mode: 0644]
modules/pytorch_jacinto_ai/vision/models/detection/faster_rcnn.py [new file with mode: 0644]
modules/pytorch_jacinto_ai/vision/models/detection/generalized_rcnn.py [new file with mode: 0644]
modules/pytorch_jacinto_ai/vision/models/detection/image_list.py [new file with mode: 0644]
modules/pytorch_jacinto_ai/vision/models/detection/keypoint_rcnn.py [new file with mode: 0644]
modules/pytorch_jacinto_ai/vision/models/detection/mask_rcnn.py [new file with mode: 0644]
modules/pytorch_jacinto_ai/vision/models/detection/roi_heads.py [new file with mode: 0644]
modules/pytorch_jacinto_ai/vision/models/detection/rpn.py [new file with mode: 0644]
modules/pytorch_jacinto_ai/vision/models/detection/transform.py [new file with mode: 0644]
modules/pytorch_jacinto_ai/vision/models/mnasnet.py [new file with mode: 0644]
modules/pytorch_jacinto_ai/vision/models/mobilenetv1.py [new file with mode: 0644]
modules/pytorch_jacinto_ai/vision/models/mobilenetv2.py [new file with mode: 0644]
modules/pytorch_jacinto_ai/vision/models/multi_input_net.py [new file with mode: 0644]
modules/pytorch_jacinto_ai/vision/models/pixel2pixel/__init__.py [new file with mode: 0644]
modules/pytorch_jacinto_ai/vision/models/pixel2pixel/deeplabv3lite.py [new file with mode: 0644]
modules/pytorch_jacinto_ai/vision/models/pixel2pixel/fpn_pixel2pixel.py [new file with mode: 0644]
modules/pytorch_jacinto_ai/vision/models/pixel2pixel/pixel2pixelnet.py [new file with mode: 0644]
modules/pytorch_jacinto_ai/vision/models/resnet.py [new file with mode: 0644]
modules/pytorch_jacinto_ai/vision/models/segmentation/__init__.py [new file with mode: 0644]
modules/pytorch_jacinto_ai/vision/models/segmentation/_utils.py [new file with mode: 0644]
modules/pytorch_jacinto_ai/vision/models/segmentation/deeplabv3.py [new file with mode: 0644]
modules/pytorch_jacinto_ai/vision/models/segmentation/fcn.py [new file with mode: 0644]
modules/pytorch_jacinto_ai/vision/models/segmentation/segmentation.py [new file with mode: 0644]
modules/pytorch_jacinto_ai/vision/models/shufflenetv2.py [new file with mode: 0644]
modules/pytorch_jacinto_ai/vision/models/squeezenet.py [new file with mode: 0644]
modules/pytorch_jacinto_ai/vision/models/utils.py [new file with mode: 0644]
modules/pytorch_jacinto_ai/vision/models/vgg.py [new file with mode: 0644]
modules/pytorch_jacinto_ai/vision/models/video/__init__.py [new file with mode: 0644]
modules/pytorch_jacinto_ai/vision/models/video/resnet.py [new file with mode: 0644]
modules/pytorch_jacinto_ai/vision/ops/__init__.py [new file with mode: 0644]
modules/pytorch_jacinto_ai/vision/ops/_utils.py [new file with mode: 0644]
modules/pytorch_jacinto_ai/vision/ops/boxes.py [new file with mode: 0644]
modules/pytorch_jacinto_ai/vision/ops/feature_pyramid_network.py [new file with mode: 0644]
modules/pytorch_jacinto_ai/vision/ops/misc.py [new file with mode: 0644]
modules/pytorch_jacinto_ai/vision/ops/poolers.py [new file with mode: 0644]
modules/pytorch_jacinto_ai/vision/ops/roi_align.py [new file with mode: 0644]
modules/pytorch_jacinto_ai/vision/ops/roi_pool.py [new file with mode: 0644]
modules/pytorch_jacinto_ai/vision/transforms/__init__.py [new file with mode: 0644]
modules/pytorch_jacinto_ai/vision/transforms/functional.py [new file with mode: 0644]
modules/pytorch_jacinto_ai/vision/transforms/image_transform_utils.py [new file with mode: 0644]
modules/pytorch_jacinto_ai/vision/transforms/image_transforms.py [new file with mode: 0644]
modules/pytorch_jacinto_ai/vision/transforms/transforms.py [new file with mode: 0644]
modules/pytorch_jacinto_ai/vision/utils.py [new file with mode: 0644]
modules/pytorch_jacinto_ai/xnn/__init__.py [new file with mode: 0644]
modules/pytorch_jacinto_ai/xnn/layers/__init__.py [new file with mode: 0644]
modules/pytorch_jacinto_ai/xnn/layers/activation.py [new file with mode: 0644]
modules/pytorch_jacinto_ai/xnn/layers/common_blocks.py [new file with mode: 0644]
modules/pytorch_jacinto_ai/xnn/layers/conv_blocks.py [new file with mode: 0644]
modules/pytorch_jacinto_ai/xnn/layers/deconv_blocks.py [new file with mode: 0644]
modules/pytorch_jacinto_ai/xnn/layers/function.py [new file with mode: 0644]
modules/pytorch_jacinto_ai/xnn/layers/functional.py [new file with mode: 0644]
modules/pytorch_jacinto_ai/xnn/layers/layer_config.py [new file with mode: 0644]
modules/pytorch_jacinto_ai/xnn/layers/model_utils.py [new file with mode: 0644]
modules/pytorch_jacinto_ai/xnn/layers/multi_task.py [new file with mode: 0644]
modules/pytorch_jacinto_ai/xnn/layers/normalization.py [new file with mode: 0644]
modules/pytorch_jacinto_ai/xnn/layers/rf_blocks.py [new file with mode: 0644]
modules/pytorch_jacinto_ai/xnn/optim/__init__.py [new file with mode: 0644]
modules/pytorch_jacinto_ai/xnn/optim/lr_scheduler.py [new file with mode: 0644]
modules/pytorch_jacinto_ai/xnn/quantize/__init__.py [new file with mode: 0644]
modules/pytorch_jacinto_ai/xnn/quantize/hooked_module.py [new file with mode: 0644]
modules/pytorch_jacinto_ai/xnn/quantize/quant_base_module.py [new file with mode: 0644]
modules/pytorch_jacinto_ai/xnn/quantize/quant_graph_module.py [new file with mode: 0644]
modules/pytorch_jacinto_ai/xnn/quantize/quant_test_module.py [new file with mode: 0644]
modules/pytorch_jacinto_ai/xnn/quantize/quant_train_module.py [new file with mode: 0644]
modules/pytorch_jacinto_ai/xnn/quantize/quant_train_utils.py [new file with mode: 0644]
modules/pytorch_jacinto_ai/xnn/quantize/quant_utils.py [new file with mode: 0644]
modules/pytorch_jacinto_ai/xnn/readme.txt [new file with mode: 0644]
modules/pytorch_jacinto_ai/xnn/utils/__init__.py [new file with mode: 0644]
modules/pytorch_jacinto_ai/xnn/utils/attr_dict.py [new file with mode: 0644]
modules/pytorch_jacinto_ai/xnn/utils/bn_utils.py [new file with mode: 0644]
modules/pytorch_jacinto_ai/xnn/utils/count_flops.py [new file with mode: 0644]
modules/pytorch_jacinto_ai/xnn/utils/image_utils.py [new file with mode: 0644]
modules/pytorch_jacinto_ai/xnn/utils/load_weights.py [new file with mode: 0644]
modules/pytorch_jacinto_ai/xnn/utils/logger.py [new file with mode: 0644]
modules/pytorch_jacinto_ai/xnn/utils/module_utils.py [new file with mode: 0644]
modules/pytorch_jacinto_ai/xnn/utils/print_utils.py [new file with mode: 0644]
modules/pytorch_jacinto_ai/xnn/utils/tensor_utils.py [new file with mode: 0644]
modules/pytorch_jacinto_ai/xnn/utils/util_functions.py [new file with mode: 0644]
modules/pytorch_jacinto_ai/xnn/utils/utils_depth.py [new file with mode: 0644]
modules/pytorch_jacinto_ai/xnn/utils/utils_hist.py [new file with mode: 0644]
modules/pytorch_jacinto_ai/xnn/utils/weights_init.py [new file with mode: 0644]
requirements.txt [new file with mode: 0644]
run_classification.sh [new file with mode: 0755]
run_depth.sh [new file with mode: 0755]
run_quantization.sh [new file with mode: 0755]
run_segmentation.sh [new file with mode: 0755]
scripts/evaluate_segmentation_main.py [new file with mode: 0755]
scripts/infer_segmentation_main.py [new file with mode: 0755]
scripts/infer_segmentation_onnx_main.py [new file with mode: 0755]
scripts/test_classification_main.py [new file with mode: 0755]
scripts/train_classification_main.py [new file with mode: 0755]
scripts/train_depth_main.py [new file with mode: 0755]
scripts/train_motion_segmentation_main.py [new file with mode: 0755]
scripts/train_pixel2pixel_multitask_main.py [new file with mode: 0755]
scripts/train_segmentation_main.py [new file with mode: 0755]
setup.py [new file with mode: 0644]
setup.sh [new file with mode: 0755]
version.py [new file with mode: 0644]

diff --git a/.gitignore b/.gitignore
new file mode 100644 (file)
index 0000000..4257a7e
--- /dev/null
@@ -0,0 +1,16 @@
+__pycache__
+*.pyc
+.eggs
+.idea
+.vscode
+*.egg-info
+
+data/*
+data/modelzoo/*
+data/checkpoints/*
+!data/checkpoints/readme.txt
+data/datasets/*
+!data/datasets/readme.txt
+data/downloads/*
+!data/downloads/readme.txt
+
diff --git a/.gitmodules b/.gitmodules
new file mode 100644 (file)
index 0000000..e69de29
diff --git a/LICENSE b/LICENSE
new file mode 100644 (file)
index 0000000..455a0f8
--- /dev/null
+++ b/LICENSE
@@ -0,0 +1,149 @@
+Texas Instruments (C) 2018-2019 
+All Rights Reserved
+
+Redistribution and use in source and binary forms, with or without
+modification, are permitted provided that the following conditions are met:
+
+* Redistributions of source code must retain the above copyright notice, this
+  list of conditions and the following disclaimer.
+
+* Redistributions in binary form must reproduce the above copyright notice,
+  this list of conditions and the following disclaimer in the documentation
+  and/or other materials provided with the distribution.
+
+* Neither the name of the copyright holder nor the names of its
+  contributors may be used to endorse or promote products derived from
+  this software without specific prior written permission.
+
+THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
+DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
+FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
+DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
+SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
+CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
+OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+
+==============================================================================
+Some parts of the code are borrowed from: https://github.com/pytorch/vision
+with the following license:
+
+BSD 3-Clause License
+
+Copyright (c) Soumith Chintala 2016,
+All rights reserved.
+
+Redistribution and use in source and binary forms, with or without
+modification, are permitted provided that the following conditions are met:
+
+* Redistributions of source code must retain the above copyright notice, this
+  list of conditions and the following disclaimer.
+
+* Redistributions in binary form must reproduce the above copyright notice,
+  this list of conditions and the following disclaimer in the documentation
+  and/or other materials provided with the distribution.
+
+* Neither the name of the copyright holder nor the names of its
+  contributors may be used to endorse or promote products derived from
+  this software without specific prior written permission.
+
+THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
+DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
+FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
+DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
+SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
+CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
+OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+
+==============================================================================
+Some parts of the code are borrowed from: https://github.com/pytorch/examples
+with the following license:
+
+BSD 3-Clause License
+
+Copyright (c) 2017,
+All rights reserved.
+
+Redistribution and use in source and binary forms, with or without
+modification, are permitted provided that the following conditions are met:
+
+* Redistributions of source code must retain the above copyright notice, this
+  list of conditions and the following disclaimer.
+
+* Redistributions in binary form must reproduce the above copyright notice,
+  this list of conditions and the following disclaimer in the documentation
+  and/or other materials provided with the distribution.
+
+* Neither the name of the copyright holder nor the names of its
+  contributors may be used to endorse or promote products derived from
+  this software without specific prior written permission.
+
+THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
+DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
+FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
+DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
+SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
+CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
+OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+
+==============================================================================
+Some parts of the code are borrowed from: https://github.com/ClementPinard/FlowNetPytorch
+with the following license:
+
+MIT License
+
+Copyright (c) 2017 Clément Pinard
+
+Permission is hereby granted, free of charge, to any person obtaining a copy
+of this software and associated documentation files (the "Software"), to deal
+in the Software without restriction, including without limitation the rights
+to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+copies of the Software, and to permit persons to whom the Software is
+furnished to do so, subject to the following conditions:
+
+The above copyright notice and this permission notice shall be included in all
+copies or substantial portions of the Software.
+
+THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+SOFTWARE.
+
+==============================================================================
+Some parts of the code are borrowed from: https://github.com/ansleliu/LightNet
+with the following license:
+
+MIT License
+
+Copyright (c) 2018 Huijun Liu
+
+Permission is hereby granted, free of charge, to any person obtaining a copy
+of this software and associated documentation files (the "Software"), to deal
+in the Software without restriction, including without limitation the rights
+to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+copies of the Software, and to permit persons to whom the Software is
+furnished to do so, subject to the following conditions:
+
+The above copyright notice and this permission notice shall be included in all
+copies or substantial portions of the Software.
+
+THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+SOFTWARE.
+
+
diff --git a/README.md b/README.md
new file mode 100644 (file)
index 0000000..510711c
--- /dev/null
+++ b/README.md
@@ -0,0 +1,64 @@
+# Jacinto-AI-DevKit (PyTorch)
+
+### Deep Learning Models / Training / Calibration & Quantization - Using PyTorch<br>
+Internal URL: https://bitbucket.itg.ti.com/projects/jacinto-ai-devkit/repos/pytorch-jacinto-ai-devkit<br>
+External URL: https://git.ti.com/jacinto-ai-devkit/pytorch-jacinto-ai-devkit<br>
+
+We provide a set of low complexity deep learning examples and models for low power embedded systems. Low power embedded systems often requires balancing of complexity and accuracy. This is a tough task and requires significant amount of expertise and experimentation. We call this process **complexity optimization**. In addition we would like to bridge the gap between Deep Learning training frameworks and real-time embedded inference by providing ready to use examples and enable **ease of use**.
+
+We have added several complexity optimized Deep Learning examples for commonly used vision tasks. We provide training scripts, accuracy, complexity and in some cases, the trained models as well. Our expectation is that these Deep Learning examples and models will find application in a variety of problems, and you will be able to build upon the **building blocks** that we have provided. 
+
+We also have a **Calibration tool for Quantization** that can output an 8-bit Quantization friendly model using a few calibration images - this tool can be used to improve the quantized accuracy and bring it near floating point accuracy. This tools adjusts weights and biases and also collects the ranges of activations to make the model quantization friendly. For more details, please refer to the section on Quantization.
+
+**Several of these models have been verified to work on [TI's Jacinto Automotive Processors](http://www.ti.com/processors/automotive-processors/tdax-adas-socs/overview.html).** 
+
+## Installation Instructions
+- These instructions are for installation on **Ubuntu 18.04**. 
+- Install Anaconda with Python 3.7 or higher from https://www.anaconda.com/distribution/ <br>
+- After installation, make sure that your python is indeed Anaconda Python 3.7 or higher by typing:<br>
+    ```
+    python --version
+    ```
+- Clone this repository into your local folder
+- Execute the following shell script to install the dependencies:<br>
+    ```
+    ./setup.sh
+    ```
+
+## Examples
+The following examples are currently available. Click on each of the links below to go into the full description of the example. 
+- Image Classification
+    - [**Image Classification**](docs/Image_Classification.md)
+- Pixel2Pixel prediction
+    - [**Semantic Segmentation**](docs/Semantic_Segmentation.md)
+    - [Depth Estimation](docs/Depth_Estimation.md)
+    - [Motion Segmentation](docs/Motion_Segmentation.md)
+    - [**Multi Task Estimation**](docs/Multi_Task_Learning.md)
+- Object Detection
+    - [**Object Detection**](docs/Object_Detection.md)
+    - [Object Keypoint Estimation](docs/Keypoint_Estimation.md)
+ - [**Quantization**](docs/Quantization.md)<br>
+
+
+We have written down some of the common training and validation commands in the shell scripts (.sh files) provided in the root folder.
+
+## Model Zoo
+Sample models are uploaded in our [modelzoo](./data/modelzoo). Some of our scripts use the pretrained models from this modelzoo.
+
+## Additional Information
+For information on other similar devkits, please visit:<br> 
+http://git.ti.com/jacinto-ai-devkit
+
+## Acknowledgements
+
+Our source code uses parts of the following open source projects. We would like to sincerely thank their authors for making their code bases publicly available.
+
+|Module/Functionality              |Parts of the code borrowed/modified from                                             |
+|----------------------------------|-------------------------------------------------------------------------------------|
+|Datasets, Models                  |https://github.com/pytorch/vision, https://github.com/ansleliu/LightNet              |
+|Training, Validation Engine/Loops |https://github.com/pytorch/examples, https://github.com/ClementPinard/FlowNetPytorch |
+|Object Detection                  |https://github.com/open-mmlab/mmdetection                                            |
+
+## License
+
+Please see the [LICENSE](./LICENSE) file for more information about the license under which this code is made available.
\ No newline at end of file
diff --git a/data/checkpoints/readme.txt b/data/checkpoints/readme.txt
new file mode 100644 (file)
index 0000000..e69de29
diff --git a/data/datasets/readme.txt b/data/datasets/readme.txt
new file mode 100644 (file)
index 0000000..e69de29
diff --git a/data/downloads/readme.txt b/data/downloads/readme.txt
new file mode 100644 (file)
index 0000000..3b94f91
--- /dev/null
@@ -0,0 +1 @@
+Placeholder
diff --git a/docs/Depth_Estimation.md b/docs/Depth_Estimation.md
new file mode 100644 (file)
index 0000000..12f7d05
--- /dev/null
@@ -0,0 +1,2 @@
+# Work in progress
+
diff --git a/docs/Image_Classification.md b/docs/Image_Classification.md
new file mode 100644 (file)
index 0000000..a05cc5d
--- /dev/null
@@ -0,0 +1,126 @@
+# Training for Image Classification
+ Image Classification is a fundamental task in Deep Learning and Computer Vision. Here we show couple of examples of training CNN / Deep Learning models for Image Classification. For this example, we  use MobileNetV2 as the model for training, but other models can also be used.
+
+ Commonly used Traning/Validation commands are listed in the file [run_classification.sh](../run_classification.sh). Uncommend one and run the file to start the run. 
+
+ ## Cifar Dataset 
+ Cifar10 and Cifar100 are popular Datasets used for training CNNs. Since these datasets are small, the training can be finished in a a short time and can give an indication of how good a particular CNN model is. The images in these datasets are small (32x32).
+
+### Cifar100 Dataset
+ * Since the dataset is small, the training script itself can download the dataset before training.
+
+ * Training can be started by the following command:<br>
+    ```
+    python ./scripts/train_classification_main.py --dataset_name cifar100_classification --model_name mobilenetv2_tv_x1 --data_path ./data/datasets/cifar100_classification --img_resize 32 --img_crop 32 --rand_scale 0.5 1.0
+   ```
+ * In the script, note that there are some special settings for the cifar datasets. The most important one is the 'strides' settings. Since the input images in Cifar are small we do not want to have as many strides as in a large sized image. See the argument args.model_config.strides being set in the script.
+
+ * During the training, **validation** accuracy will also be printed. But if you want to explicitly check the accuracy again with **validation** set, it can be done:<br>
+    ```
+    python ./scripts/train_classification_main.py --evaluate True --dataset_name cifar100_classification --model_name mobilenetv2_tv_x1 --data_path ./data/datasets/cifar100_classification --img_resize 32 --img_crop 32
+   ```
+
+ ### Cifar10 Dataset
+  * Training can be started by the following command:<br>
+    ```
+    python ./scripts/train_classification_main.py --dataset_name cifar10_classification --model_name mobilenetv2_tv_x1 --data_path ./data/datasets/cifar10_classification --img_resize 32 --img_crop 32 --rand_scale 0.5 1.0
+    ```
+
+  ## ImageNet Dataset
+  * It is difficult to reproduce the accuracies reported in the original papers for certain models (especially true for MobileNetV1 and MobileNetV2 models) due to the need for careful hyper-parameter tuning. In our examples, we have incorporated hyper-parameters required to train high accuracy classification models.
+
+  * Important note: ImageNet dataset is huge and download may take long time. Attempt this only if you have a good internet connection. Also the training takes a long time. In our case, using four GTX 1080 Ti, it takes nearly two days to train.
+
+  * Training can be started by the following command:<br>
+    ```
+    python ./scripts/train_classification_main.py --dataset_name imagenet_classification --model_name mobilenetv2_tv_x1 --data_path ./data/datasets/imagenet_classification
+    ```
+
+* Training with ResNet50:<br>
+    ```
+    python ./scripts/train_classification_main.py --dataset_name imagenet_classification --model_name resnet50_x1 --data_path ./data/datasets/imagenet_classification
+  ```
+  
+  * After the training, the **validation** accuracy using (make sure that  args.dataset_name and args.pretrained are correctly set)<br>
+    ```
+    python ./scripts/train_classification_main.py --evaluate True --dataset_name imagenet_classification --model_name mobilenetv2_tv_x1 --data_path ./data/datasets/imagenet_classification --pretrained <checkpoint_path>
+    ```
+
+## ImageNet or any other classification dataset - manual download
+  * In this case, the images of the dataset is assumed to be arranges in folders. 'train'  and 'validation' are two separate folders and underneath that, each class should have a different folder.
+
+  * Assume that that folder './data/datasets/image_folder_classification' has the  the classification dataset. This folder should contain folders and images as follows: 
+  image_folder_classification<br>
+  &nbsp;&nbsp;train<br>
+  &nbsp;&nbsp;&nbsp;&nbsp;class1<br>
+  &nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;image files here<br>
+  &nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;....<br>
+  &nbsp;&nbsp;&nbsp;&nbsp;class2<br>
+  &nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;image files here<br>
+  &nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;....<br>
+  &nbsp;&nbsp;validation<br>
+  &nbsp;&nbsp;&nbsp;&nbsp;class1<br>
+  &nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;image files here<br>
+  &nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;....<br>
+  &nbsp;&nbsp;&nbsp;&nbsp;class2<br>
+  &nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;image files here<br>
+  &nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;....<br>
+
+* Note 'class1', 'class2' etc are examples and they stand for the names of the classes that we are trying to classify.
+
+* Here we use ImageNEt dataset as an example, but it could ne any other image classification dataset arranged in folders. 
+
+* The download links for ImageNet are given in ../modules/pytorch_jacinto_ai/vision/datasets/imagenet.py. 
+    ```
+    cd ./data/datasets/image_folder_classification
+    wget http://www.image-net.org/challenges/LSVRC/2012/nnoupb/ILSVRC2012_img_train.tar
+    wget http://www.image-net.org/challenges/LSVRC/2012/nnoupb/ILSVRC2012_img_val.tar
+    mkdir train
+    mkdir validation
+    tar -C train -xvf ILSVRC2012_img_train.tar
+    tar -C validation -xvf ILSVRC2012_img_val.tar
+    ```
+* After downloading and extracting, use this script to arrange the validation folder into folders of classes: 
+    ```
+    cd validation
+    wget https://github.com/soumith/imagenetloader.torch/blob/master/valprep.sh
+    ./valprep.sh
+    rm ./valprep.sh
+    ```
+
+* Training can be started by the following command from the base folder of the repository:<br>
+    ```
+    python ./scripts/train_classification_main.py --dataset_name image_folder_classification --model_name mobilenetv2_tv_x1 --data_path ./data/datasets/image_folder_classification
+    ```
+
+* Training with ResNet50:<br>
+    ```
+    python ./scripts/train_classification_main.py --dataset_name image_folder_classification --model_name resnet50_x1 --data_path ./data/datasets/image_folder_classification
+    ```
+
+* If the dataset is in a different location, it can be specified by the --data_path option, but dataset_name must be *image_folder_classification* for folder based classification.
+
+
+## Results
+
+### ImageNet (ILSVRC2012) Classification (1000 class)
+
+* ImageNet classification results are as follows:
+
+|Dataset  |Mode Name     |Resize Resolution|Crop Resolution|Complexity (GigaMACS)|MeanIoU% |
+|---------|----------    |-----------      |----------     |--------             |-------- |
+|ImageNet |MobileNetV1   |256x256          |224x224        |0.568                |**71.83**|
+|ImageNet |MobileNetV2   |256x256          |224x224        |0.296                |**71.89**|
+|ImageNet |ResNet50      |256x256          |224x224        |                     |         |
+|.
+|ImageNet |MobileNetV1[1]|256x256          |224x224        |0.569                |70.60    |
+|ImageNet |MobileNetV2[2]|256x256          |224x224        |0.300                |72.00    |
+
+
+## Referrences
+
+[1] MobileNets: Efficient Convolutional Neural Networks for Mobile Vision Applications, Howard AG, Zhu M, Chen B, Kalenichenko D, Wang W, Weyand T, Andreetto M, Adam H, arXiv:1704.04861, 2017
+
+[2] MobileNetV2: Inverted Residuals and Linear Bottlenecks, Sandler M, Howard A, Zhu M, Zhmoginov A, Chen LC. arXiv preprint. arXiv:1801.04381, 2018.
\ No newline at end of file
diff --git a/docs/Keypoint_Estimation.md b/docs/Keypoint_Estimation.md
new file mode 100644 (file)
index 0000000..12f7d05
--- /dev/null
@@ -0,0 +1,2 @@
+# Work in progress
+
diff --git a/docs/Motion_Segmentation.md b/docs/Motion_Segmentation.md
new file mode 100644 (file)
index 0000000..5d11ecb
--- /dev/null
@@ -0,0 +1,67 @@
+# Motion Segmentation
+Motion segmentation network predicts the state of motion for each pixel. 
+
+The model used for this task uses two stream architecture as shown in the figure below . Two parallel encoders extract appearance and temporal features separately and fuse them at stride of 4. 
+
+<p float="left">
+  <img src="motion_segmentation/motion_segmentation_network.PNG" width="555" hspace="5"/>
+</p>
+
+## Models
+We provide  scripts for training models with three different input combinations. They are image pair and (optical flow, current image) and (opticalflow with confidence, current image) respectively. Optical flow is generated using current frame and previous frame.
+
+The following model is used for training:
+
+**deeplabv3lite_mobilenetv2_tv**: This is same as deeplabv3lite_mobilenetv2_tv as described [`here`](Semantic_Segmentation.md). The sole difference being it takes two inputs for the training and fuses the feature maps after a stride of 4. This increases the complexity compared to a single stream model by roughly 20%.
+
+## Datasets: Cityscapes Motion Datset
+**Dataset preparation:**  Dataset used for this training is cityscapes dataset with motion annotation. This training requires either the previous frame or optical flow generated from (current frame, previous frame). Given below are the details to download these extra files.
+
+* Clone this [`repository`](https://bitbucket.itg.ti.com/projects/ALGO-DEVKIT/repos/cityscapes_motion_dataset/browse) for all the tools required to proceed further.
+
+* **Current frame:**: This is can be downloaded from https://www.cityscapes-dataset.com/. Download the zip file leftImg8bit_trainvaltest.zip. keep the directory leftimg8bit in ./data/datatsets/cityscapes/data/.
+
+* **Previous frame:** Previous frames can be downloaded from https://www.cityscapes-dataset.com/ as well. The zip file is cityscapes_leftImg8bit_sequence_trainvaltest.zip. The previous frame corresponds to the 18th frame in each non-overlapping 30 frame snippet. Once you have downloaded the entire sequence, run the script `filter_cityscape_previous_frame.py` from ths [`repository`](https://bitbucket.itg.ti.com/projects/ALGO-DEVKIT/repos/cityscapes_motion_dataset/browse)  to extract the previous frame from the sequence. This script will keep the previous frame in a directory named leftImg8bit_previous. Move leftImg8bitPrevious to ./data/datatsets/cityscapes/data/.
+* **Optical flow:** Optical flow can be generated using current frame and previous frame. This repository contains optical flow generated from (current frame, previous frame). Run **this** script to generate the flow if you have downloaded the previous frame. Otherwise this repository will have two folders named `leftimg8bit_flow_farneback` and `leftimg8bit_flow_farneback_confidence` . Move them to location ./data/datatsets/cityscapes/data/. Optical flow has been generated using opencv farneback flow with some parameter tuning whereas confidence is computed using forward-backward consistency.
+
+* **Motion annotation:**  The same repository as above contains motion annotation as well. They are inside `gtFine`. Move the gtFine directory into ./data/datatsets/cityscapes/data/. Here, moving pixels correspond to 255 whereas static pixels are marked with 0.
+
+Now, the final directory structure must look like this:
+```
+./data/datatsets/cityscapes/data/
+    leftimg8bit                                                        -  Current frame from cityscapes dataset
+    leftImg8bit_previous                                    -  Previous frame from cityscapes dataset  
+    leftimg8bit_flow_farneback                     -  Optical flow generated from (Curr_frame, Previous Frame) and stored in the format(u',v',128)
+    leftimg8bit_flow_farneback_confidence  -  Optical flow, confidence generated from (Curr_frame, Previous Frame) and stored in the format(u',v',confidence)
+    gtFine                                                     -  Ground truth with motion annotation
+```
+
+ Following are the  commands for training networks for `various inputs`:
+<br>
+
+**(Previous frame, Current frame):** 
+```
+python ./scripts/train_motion_segmentation_main.py --image_folders  leftImg8bit_previous leftImg8bit --is_flow 0,0 0 --data_path ./data/datasets/cityscapes/data --img_resize 384 768 --output_size 1024 2048 --gpus 0 1
+```
+**(Optical flow, Current frame):**
+``` 
+python ./scripts/train_motion_segmentation_main.py --image_folders leftImg8bit_flow_farneback leftImg8bit --is_flow 1,0 0 --data_path ./data/datasets/cityscapes/data --img_resize 384 768 --output_size 1024 2048 --gpus 0 1
+```
+**(Optical flow with confidence, Current frame):**
+```  
+python ./scripts/train_motion_segmentation_main.py --image_folders leftImg8bit_flow_farneback_confidence leftImg8bit --is_flow 1,0 0 --data_path ./data/datasets/cityscapes/data --img_resize 384 768 --output_size 1024 2048 --gpus 0 1
+```
+
+## Results:
+## Cityscapes Motion segmentation
+| Inputs                        | mIOU (static class, moving class) |
+|-----------------------------------------------|-------------------|
+| Previous frame, Current frame                  | 83.1 (99.7,66.5) |
+| Optical flow, Current frame                    | 84.8 (99.7,69.9) |
+| Optical flow with confidence, Current frame    | 85.5 (99.7,71.3) |
+
+Using (optical flow, curr frame) , we get around ~3.4% improvement for moving class over image pair. Using confidence, we get an overall improvement of 4.8% over image_pair baseline. The results above show that, we can achieve significant improvement for motion segmentation using optical flow and it further improves with confidence measure for flow.
+## References
+[1]The Cityscapes Dataset for Semantic Urban Scene Understanding, Marius Cordts, Mohamed Omran, Sebastian Ramos, Timo Rehfeld, Markus Enzweiler, Rodrigo Benenson, Uwe Franke, Stefan Roth, Bernt Schiele, CVPR 2016, https://www.cityscapes-dataset.com/
diff --git a/docs/Multi_Task_Learning.md b/docs/Multi_Task_Learning.md
new file mode 100644 (file)
index 0000000..ef4e005
--- /dev/null
@@ -0,0 +1,65 @@
+# Multi task Network
+
+Multi task networks are trained to perform multiple tasks on a single inference of the network. Thus it saves on compute and hence suitable for real-time application. Theoretically, for correlated tasks, it further improve performance compared to each individual tasks. 
+
+## Model
+We will discuss an example network that uses two inputs (Optical flow, Current frame) and performs three tasks namely (depth, semantic and motion) for each input pixel. The model takes two inputs and three output as shown in the figure below:
+
+<p float="centre">
+  <img src="multi_task_learning/multi_task_network.PNG" width="555" hspace="5"/>
+</p> 
+
+Two parallel encoders extract appearance and flow feature separately and fuse them at stride of 4. We have three separate decoders for each task whereas the encoder is common across all the tasks.
+
+
+## Datasets: Cityscapes Multitask Datset
+**Inputs:** The network takes (Optical flow, Current frame) as input. 
+* **Optical flow** For optical flow input, copy the directory **leftimg8bit_flow_farneback_confidence** from this [**repository**](https://bitbucket.itg.ti.com/projects/ALGO-DEVKIT/repos/cityscapes_motion_dataset/browse) into ./data/datatsets/cityscapes/data/.
+* **Current frame:**: This is can be downloaded from https://www.cityscapes-dataset.com/. Download the zip file leftImg8bit_trainvaltest.zip. keep the directory leftimg8bit in ./data/datatsets/cityscapes/data/. 
+
+**Ground truth**
+Since we are training  network to infer depth, semantic and motion together, we need to have the ground truth for all these tasks for common input.  
+* **Depth:**  This is available from https://www.cityscapes-dataset.com/ . This folder named disparity must be kept in  ./data/datasets/cityscapes/data.
+* **Semantic:** This is available from https://www.cityscapes-dataset.com/ as well. Keep the gtFine directory in ./data/datasets/cityscapes/data. 
+* **Motion:** This [repository](https://bitbucket.itg.ti.com/projects/ALGO-DEVKIT/repos/cityscapes_motion_dataset/browse)contains motion annotation inside **gtFine**. Move the gtFine directory into ./data/datatsets/cityscapes/data/.
+Finally depth annotation must reside inside ./data/datasets/cityscapes/data whereas both the semantic and motion annotations must go inside ./data/datatsets/cityscapes/data/gtFine.
+
+Now, the final directory structure must look like this:
+```
+./data/datatsets/cityscapes/data/
+    leftimg8bit                                                       -  Current frame from cityscapes dataset 
+    leftimg8bit_flow_farneback_confidence  -  Optical flow, confidence generated from (Curr_frame, Previous Frame) and stored in the format(u',v',confidence)
+    gtFine                                                     -  Ground truth motion and semantic annotation
+    depth                                                              -  Ground truth depth annotation
+```
+
+## Multi task learning: Learning task-specific weights
+There are many practical challenges in training a network for multiple tasks without significant deterioration in performance. The first and foremost being the weight given to each task during training. Easier task may have a faster convergence and hence require lesser weightage. There have been significant advances on adaptively finding the optimum task-specific weight. We will discuss a couple of them here along with the vanilla multitask learning.
+
+**Vanilla multi-task learning**: Here the weights are not learnt.They are all set to 1.0. To train a vanilla multi-task model run the following command, <br> 
+``` 
+python ./scripts/train_pixel2pixel_main.py --dataset_name cityscapes_image_dof_conf --data_path ./data/datasets/cityscapes/data --img_resize 384 768 --output_size 1024 2048 --gpus 0 1 
+```
+**Uncertainty bassed learning**: In this case, the weights are updated based on homscedastic uncertainty. To train a uncertainty based multi-task model run the following command: <br>
+```
+python ./scripts/train_pixel2pixel_main.py --dataset_name cityscapes_image_dof_conf_uncerainty --data_path ./data/datasets/cityscapes/data --img_resize 384 768 --output_size 1024 2048 --gpus 0 1 <br>
+```
+**Gradient nornalization** : The weights are updated based on the norm of the backward gradient of the last common along with the rate of learning for each task. To train a model based on gradient normalization run the following command: <br>
+```
+python ./scripts/train_pixel2pixel_main.py --dataset_name cityscapes_image_dof_conf_unceraintys --data_path ./data/datasets/cityscapes/data --img_resize 384 768 --output_size 1024 2048 --gpus 0 1 <br>
+```
+## Results:
+## Cityscapes Multi Task Learning(Depth, Semantic, Motion)
+| Training Modality       | Depth(ARD), Semantic(mIOU), Motion(mIOU) |
+|-------------------------------------------------|-------------------|
+| Single Task Training                            | ----- , -----, -----|
+| Vanilla Multi Task Training                     |12.31, 82.32, 80.52|
+| Uncertainty based Multi Task Training           | ----  , ---- , ----|
+| gradient-norm based Multi Task Learning         | 12.64, 85.53, 84.75|
+
+## References
+[1]The Cityscapes Dataset for Semantic Urban Scene Understanding, Marius Cordts, Mohamed Omran, Sebastian Ramos, Timo Rehfeld, Markus Enzweiler, Rodrigo Benenson, Uwe Franke, Stefan Roth, Bernt Schiele, CVPR 2016, https://www.cityscapes-dataset.com/
+
+[2]A. Kendall, Y. Gal, and R. Cipolla. Multi-task learning using uncertainty to weigh losses for scene geometry and semantics. In CVPR, 2018.
+
+[3]Z. Chen, V. Badrinarayanan, C. Lee, and A. Rabinovich. GradNorm: Gradient normalization for adaptive loss balancing in deep multitask networks. In ICML, 2018.
diff --git a/docs/Object_Detection.md b/docs/Object_Detection.md
new file mode 100644 (file)
index 0000000..34fa1a9
--- /dev/null
@@ -0,0 +1,122 @@
+# Object Detection From Images
+
+- Object detection training should be done with mmdetection.
+- We provide a custom fork of mmdetection that can export onnx model as well as the heads.
+
+### TODO: The following instructions have changed. To be updated.
+
+Object detections from images is one of the most important tasks in computer vision and Deep Learning. Here we shw a fe examples of training object detection models.
+
+Our scripts use a modified version of mmdetection [3] code base for object detection. We have only done minimal modifications fot mmdetection to be able to integrate into our devkit. We sincerely thank the authors of mmdetection for integrating so many features and detectors into the code base.  
+
+## Pascal VOC Dataset
+
+In this training we will using the VOC2007 and VOC2012 'trainval' sets for training and the VOC2007 'test' set for validation.
+Execute the following bash commands:
+### Download the data
+* Download the data using the following:
+    ```
+    mkdir /data/datasets/voc
+    cd /data/datasets/voc
+    wget http://host.robots.ox.ac.uk/pascal/VOC/voc2012/VOCtrainval_11-May-2012.tar
+    wget http://host.robots.ox.ac.uk/pascal/VOC/voc2007/VOCtrainval_06-Nov-2007.tar
+    wget http://host.robots.ox.ac.uk/pascal/VOC/voc2007/VOCtest_06-Nov-2007.tar
+    ```
+### Extract the data.
+* Extract the data as follows:
+    ```
+    tar -xvf VOCtrainval_11-May-2012.tar
+    tar -xvf VOCtrainval_06-Nov-2007.tar
+    tar -xvf VOCtest_06-Nov-2007.tar
+    ```
+
+### Training
+* Open the file ./scripts/train_object_detection_config.py and make sure 'config_script' is set to a 'voc' config file and not a 'coco' config file.
+
+* To start the training. run the bash command from the base folder of the repository:<br>
+    ```
+    python ./scripts/train_object_detection_config.py
+    ```
+
+### Evaluation
+* To evaluate the detection accuracy, open the file ./scripts/test_object_detection_config.py and the 'config_script' is same as the one used for training. Also set 'checkpoint_file' to the latest check point obtained by training. Evaluation can be done by<br>
+    ```
+    python ./scripts/test_object_detection_config.py
+    ```
+
+## COCO 2017 Dataset
+### Download
+* Download the COCO 2017 dataset from http://cocodataset.org/#home
+    ```
+    mkdir /data/datasets/coco
+    cd /data/datasets/coco
+    wget http://images.cocodataset.org/zips/train2017.zip
+    wget http://images.cocodataset.org/zips/val2017.zip
+    wget http://images.cocodataset.org/annotations/annotations_trainval2017.zip
+    ```
+
+### Extract the files 
+Extract the files to the folder /data/datasets/coco. This folder should have subfolders train2017, val2017, annotations
+    ```
+    unzip "*.zip"<br>
+    ```
+
+### Training & Evaluation
+* Open the file ./scripts/train_object_detection_config.py and make sure 'config_script' is set to a 'coco' and not a 'voc' config file.
+
+* To start the training. run the bash command from base folder of the repository<br>
+   ```
+  python ./scripts/train_object_detection_config.py
+  ```
+
+* Evaluation can be done by<br>*python ./scripts/test_object_detection_config.py*
+
+## Cityscapes Dataset
+
+### Download
+Cityscapes dataset can be obtained from https://www.cityscapes-dataset.com/. You may need to register and get permission to download some files.
+
+### Convert 
+mmdetection supports only COCO and VOC formats for object detection. In order to use Cityscapses dataset it has to be conveted into the COCO format. 
+
+A tool for this conversion is given along with the maskrcnn-benchmark source code: https://github.com/facebookresearch/maskrcnn-benchmark/tree/master/tools/cityscapes. The scripts that they have provided can be placed into https://github.com/mcordts/cityscapesScripts for this conversion. 
+
+We have skipped some details, but if you have expertise in python, it should be easy to figure things out from the error messages that you get while running.
+
+
+## Results
+
+
+|Dataset    |Mode Name       |Backbone Model|Backbone Stride| Resolution|Complexity (GigaMACS)|MeanAP% |
+|---------  |----------      |-----------   |-------------- |---------- |--------             |-------- |
+|VOC2007*   |SSD             |MobileNetV1   |32             |512x512    |                     |**74.9**|
+|.
+|VOC2007*   |SSD[2]          |MobileNetV1   |32             |512x512    |                     |72.7     |
+|VOC2007*   |SSD[1]          |VGG16         |               |512x512    |                     |76.9     |
+
+*VOC2007+VOC2012 Train, VOC2007 validation<br><br>
+
+
+|Dataset    |Mode Name       |Backbone Model|Backbone Stride| Resolution|Complexity (GigaMACS)|MeanAP[0.5:0.95]% |
+|---------  |----------      |-----------   |-------------- |---------- |--------             |-------- |
+|COCO2017** |SSD             |MobileNetV1   |32             |512x512    |                     |**22.2**|
+|.
+|COCO2017** |SSD[2]          |MobileNetV1   |32             |512x512    |                     |18.5     |
+|COCO2017** |SSD[3]          |VGG16         |               |512x512    |                     |29.3     |
+
+**COCO2017 validation set was used for validation
+
+## References
+[1] SSD: Single Shot MultiBox Detector, Wei Liu, Dragomir Anguelov, Dumitru Erhan, Christian Szegedy, Scott Reed, Cheng-Yang Fu, Alexander C. Berg, arXiv preprint arXiv:1512.02325
+
+[2] Repository for Single Shot MultiBox Detector and its variants, implemented with pytorch, python3, https://github.com/ShuangXieIrene/ssds.pytorch
+
+[3] Open MMLab Detection Toolbox and Benchmark, Chen, Kai and Wang, Jiaqi and Pang, Jiangmiao and Cao, Yuhang and Xiong, Yu and Li, Xiaoxiao and Sun, Shuyang and Feng, Wansen and Liu, Ziwei and Xu, Jiarui and Zhang, Zheng and Cheng, Dazhi and Zhu, Chenchen and Cheng, Tianheng and Zhao, Qijie and Li, Buyu and Lu, Xin and Zhu, Rui and Wu, Yue and Dai, Jifeng and Wang, Jingdong and Shi, Jianping and Ouyang, Wanli and Loy, Chen Change and Lin, Dahua, arXiv preprint arXiv:1906.07155, https://github.com/open-mmlab/mmdetection
+
+[4] The PASCAL Visual Object Classes (VOC) Challenge, Everingham, M., Van Gool, L., Williams, C. K. I., Winn, J. and Zisserman, A. International Journal of Computer Vision, 88(2), 303-338, 2010, http://host.robots.ox.ac.uk/pascal/VOC/
+
+[5] Microsoft COCO: Common Objects in Context, Tsung-Yi Lin, Michael Maire, Serge Belongie, Lubomir Bourdev, Ross Girshick, James Hays, Pietro Perona, Deva Ramanan, C. Lawrence Zitnick, Piotr Dollár, arXiv preprint arXiv:1405.0312
+
+[6]The Cityscapes Dataset for Semantic Urban Scene Understanding, Marius Cordts, Mohamed Omran, Sebastian Ramos, Timo Rehfeld, Markus Enzweiler, Rodrigo Benenson, Uwe Franke, Stefan Roth, Bernt Schiele, CVPR 2016, https://www.cityscapes-dataset.com/
+
+
diff --git a/docs/Quantization.md b/docs/Quantization.md
new file mode 100644 (file)
index 0000000..14c4881
--- /dev/null
@@ -0,0 +1,166 @@
+# Quantization
+
+As we know Quantization is the process of converting floating point data & operations to fixed point (integer). CNNs can be quantized to 8-bits integer data/operations without significant accuracy loss. This includes quantization of weights, feature maps and all operations (including convolution of weights). **We use power-of-2, symmetric quantization for both weights and activations**.
+
+There are two primary methods of quantization - Post Training Quantization and Trained Quantization. 
+
+## Post Training Calibration & Quantization
+
+Post Training Calibration & Quantization can take a model trained in floating point and with a few steps convert it to a model that is friendly for quantized inference. Compared to the alternative (Trained Quantization), the advantages of this method are:
+- Calibration is fast - a typical calibration finishes in a few minutes.  
+- Ground truth is not required - just input images are sufficient.
+- Loss function or backward (back-propagation) are not required. 
+
+Thus, this is the preferred method of quantization from an ease of use point of view.  As explained earlier, in this method, the training happens entirely in floating point. The inference (possibly in an embedded device) happens in fixed point. In between training and fixed point inference, the model goes through the step called Calibration with some sample images. The Calibration happens in PC and Quantized Inference happens in the embedded device. Calibration basically tries to make the quantized output similar to the floating point output - by choosing appropriate activation ranges, weights and biases. The step by step process is as follows:
+
+#### Model preparation:
+- Replace all the ReLU, ReLU6 layers in the model by PACT2. Insert PACT2 after Convolution+BatchNorm if a ReLU is missing after that.  Insert PACT2 anywhere else required - where activation range clipping and range collection is required. For example it can ne after the Fully Connected Layer. We use forward post hooks of PyTorch nn.Modules to call these extra activation functions. Thus we are able to add these extra activations without disturbing the loading of existing pre-trained weights.
+- Clip the weights to an appropriate range if the weight range is very high.
+
+#### Forward iterations:
+- For each iteration perform a forward in floating point using the original weights and biases. During this pass PACT2 layers will collect output ranges using histogram and running average.
+- In addition, perform Convolution+BatchNorm merging and quantization of the resulting weights. These quantized and de-quantized weights are used in a forward pass. Ranges collected by PACT2 is used for activation quantization (and de-quantization) to generate quantized output.
+- The floating point output and quantized output are compared using statistic measures. Using such statistic measures, we can adjust the weights and biases of Convolutions and Batch Normalization layers - so that the quantized output becomes closer to the floating point output.
+- Within a few iterations, we could get reasonable quantization accuracy for several models that we tried this method on.
+
+Depending on how the activation range is collected and Quantization is done, we have a few variants of this basic scheme.  
+- Simple Calib: Calibration includes PACT2 for activation clipping, running average and range collection. In this method we use min-max for activation range collection (no histogram).
+- **Advanced Calib**: Calibration includes PACT2 with histogram based ranges, Weight clipping, Bias correction. 
+- Advanced DW Calib: Calibration includes Per-Channel Quantization of Weights for Depthwise layers, PACT2 with histogram based ranges, Weight clipping, Bias correction. One of the earliest papers that clearly explained the benefits of Per-Channel Quantization for weights only (while the activations are quantized as Per-Tensor) is [6] 
+- Advanced Per-Chan Calib: Calibration includes Per-Channel Quantization for all layers, PACT2 with histogram based ranges, Weight clipping, Bias correction.
+
+Out of these methods, **Advanced Calib** is our recommended Calibration method as of now, as it has the best trade-off between the Accuracy and the features required during fixed point inference. All the Calibration scripts that we have in this page uses "Advanced Calib" by default. Other Calibration methods described here are for information only. 
+
+In order to do Calibration easily we have a developed a wrapper module called QuantCalibrateModule, which is located in pytorch_jacinto_ai.xnn.quantize.QuantCalibrateModule. We make use of a kind of Parametric Activation called **PACT2** in order to store the calibrated ranges of activations. PACT2 is a improved form of PACT [1]. **PACT2 uses power of 2 activation ranges** for activation clipping. PACT2 can learn ranges very quickly (using a statistic method) without back propagation - this feature makes it quite attractive for Calibration. Our wrapper module replaces all the ReLUs in the model with PACT2. It also inserts PACT2 in other places where activation ranges need to be collected.  Statistical range clipping in PACT2 improves the Quantized Accuracy over simple min-max range clipping. 
+
+As explained, our method of **Calibration does not need ground truth, loss function or back propagation.** However in our script, we make use of ground truth to measure the loss/accuracy even in the Calibration stage - although that is not necessary. 
+
+#### How to use  QuantCalibrateModule
+The section briefly explains how to make use of our helper/wrapper module to do the calibration of your model. For further details, please see pytorch_jacinto_ai.engine.train_classification.py and pytorch_jacinto_ai.engine.train_pixel2pixel.py.
+
+```
+# create your model here:
+model = ...
+
+# create a dummy input - this is required to analyze the model - fill in the input image size expected by your model.
+dummy_input = torch.rand((1,3,384,768))
+
+#wrap your model in QuantCalibrateModule. Once it is wrapped, the actual model is in model.module
+model = pytorch_jacinto_ai.xnn.quantize.QuantCalibrateModule(model, dummy_input=dummy_input)
+
+# load your pretrained weights here into model.module
+pretrained_data = torch.load(pretrained_path)
+model.module.load_state_dict(pretrained_data)
+
+# create your dataset here - the ground-truth/target that you provide in the dataset can be dummy and does not affect calibration.
+my_dataset_train, my_dataset_val = ...
+
+# do one epoch of calibration - in practice about 1000 iterations are sufficient.
+for images, target in my_dataset_train:
+    output = model(images)
+
+# save the model - the calibrated module is in model.module
+torch.save(model.module.state_dict(), os.path.join(save_path,'model.pth'))
+torch.onnx.export(model.module, dummy_input, os.path.join(save_path,'model.onnx'), export_params=True, verbose=False)
+
+```
+
+Few examples of calibration are provided below. These commands are also listed in the file **run_quantization.sh** for convenience.<br>
+
+#### Calibration of ImageNet Classification MobileNetV2 model 
+```
+python ./scripts/train_classification_main.py --phase calibration --dataset_name image_folder_classification --model_name mobilenetv2_tv_x1 --data_path ./data/datasets/image_folder_classification --pretrained https://download.pytorch.org/models/mobilenet_v2-b0353104.pth --batch_size 64 --quantize True --epochs 1 --epoch_size 100
+```
+
+#### Calibration of ImageNet Classification ResNet50 model 
+```
+python ./scripts/train_classification_main.py --phase calibration --dataset_name image_folder_classification --model_name mobilenetv2_tv_x1 --data_path ./data/datasets/image_folder_classification --pretrained https://download.pytorch.org/models/resnet50-19c8e357.pth --batch_size 64 --quantize True --epochs 1 --epoch_size 100
+```
+
+#### Calibration of Cityscapes Semantic Segmentation model 
+```
+python ./scripts/train_segmentation_main.py --phase calibration --dataset_name cityscapes_segmentation --model_name deeplabv3lite_mobilenetv2_tv --data_path ./data/datasets/cityscapes/data --img_resize 384 768 --output_size 1024 2048 --gpus 0 1 
+--pretrained ./data/modelzoo/semantic_segmentation/cityscapes/deeplabv3lite-mobilenetv2/cityscapes_segmentation_deeplabv3lite-mobilenetv2_2019-06-26-08-59-32.pth 
+--batch_size 12 --quantize True --epochs 1 --epoch_size 100
+```
+
+## Trained Quantization
+
+As explained in the previous section, Calibration is our preferred method of making a model quantization friendly. However, in exceptional cases, it is possible that the drop in accuracy during calibration is more than acceptable. In this case, Trained Quantization can be used. 
+
+Unlike Calibration, Trained Quantization involves ground truth, loss function and back propagation. The most popular method of trained quantization is [4]. It takes care of merging Convolution layers with the adjascent Batch Normalization layers (on-the-fly) during the quantized training (if this merging is not correctly done, quantized training may not improve the accuracy). In addition, we use Straight-Through Estimation (STE) [2,3] to improve the gradient flow in back-propagation. Also, the statistical range clipping in PACT2 improves the Quantized Accuracy over simple min-max range clipping. 
+
+Note: Instead of STE and statistical ranges for PACT2, we also tried out approximate gradients for scale and trained quantization thresholds proposed in [5] (We did not use the gradient nomralization and log-domain training mentioned in the paper). We found that method to be able to learn the clipping thresholds for initial few epochs, but became unstable after a few epochs and loss became high. Compared to that learned thresholds method, our statistical PACT2 ranges/thresholds combined with STE is simple and stable. 
+
+In order to enable quantized training, we have developed the wrapper class pytorch_jacinto_ai.xnn.quantize.QuantTrainModule. The usage of this module can be seen in pytorch_jacinto_ai.engine.train_classification.py and pytorch_jacinto_ai.engine.train_pixel2pixel.py. 
+```
+model = pytorch_jacinto_ai.xnn.quantize.QuantTrainModule(model, dummy_input=dummy_input)
+```
+The resultant model can then be used for training as usual and it will take care of quantization constraints during the training forward and backward passes.
+
+One word of caution is that our current implementation of Trained Quantization is a bit slow. The reason for this slowdown is that our implementation is using the top-level python layer of PyTorch and not the underlying C++ layer. But with PyTorch natively supporting the functionality required for quantization under the hood - we hope that this speed issue can be resolved in a future update. 
+
+Example commands for trained quantization: 
+```
+python ./scripts/train_classification_main.py --dataset_name image_folder_classification --model_name mobilenetv2_tv_x1 --data_path ./data/datasets/image_folder_classification --pretrained https://download.pytorch.org/models/mobilenet_v2-b0353104.pth --batch_size 64 --quantize True --epochs 150 --epoch_size 1000 --lr 5e-5 --evaluate_start False
+```
+
+```
+python ./scripts/train_segmentation_main.py --dataset_name cityscapes_segmentation --model_name deeplabv3lite_mobilenetv2_tv --data_path ./data/datasets/cityscapes/data --img_resize 384 768 --output_size 1024 2048 --gpus 0 1 --pretrained ./data/modelzoo/semantic_segmentation/cityscapes/deeplabv3lite-mobilenetv2/cityscapes_segmentation_deeplabv3lite-mobilenetv2_2019-06-26-08-59-32.pth --batch_size 8 --quantize True --epochs 150 --lr 5e-5 --evaluate_start False
+```
+
+## Important Notes
+**Multi-GPU training/calibration/validation with DataParallel is not yet working with our quantization modules** QuantCalibrateModule, QuantTrainModule. 
+- For now, we recommend not to wrap the modules in DataParallel if you are training/calibrating with quantization - i.e. if your model is wrapped in QuantCalibrateModule/QuantTrainModule/QuantTestModule. 
+- This may not be such a problem as calibration and quantization may not take as much time as the original training. 
+- If your calibration/training crashes with insufficient GPU memory, reduce the batch size and try again.
+- The original training (without quantization) can use Multi-GPU as usual and we do not have any restrictions on that.
+- Tools for Calibration and Trained Quantization have started appearing in mainstream Deep Learning training frameworks [7,8]. Using the tools natively provided by these frameworks may be faster compared to an implementation in the Python layer of these frameworks (like we have done) - but they may not be mature currently. 
+- In order that the activation range estimation is correct, **the same module should not be re-used multiple times within the module**. Unfortunately, in the torchvision ResNet models, the ReLU module in the BasicBlock and BottleneckBlock are re-used multiple times. We have corrected this by defining separate ReLU modules. This change is minor and **does not** affect the loading of existing pretrained weights. See the [our modified ResNet model definition here](./modules/pytorch_jacinto_ai/vision/models/resnet.py).
+- Use Modules instead of functions as much as possible (we make use of modules to decide whether to do activation range clipping or not). For example use torch.nn.AdaptiveAvgPool2d() instead of torch.nn.functional.adaptive_avg_pool2d(), torch.nn.Flatten() instead of torch.nn.functional.flatten() etc. If you are using functions in your model and is giving poor quantized accuracy, then consider replacing those functions by the corresponding modules.
+
+
+## Results
+
+The table below shows the Quantized Accuracy with various Calibration and methods and also Trained Quantization. Some of the commands used to generate these results are summarized in the file **run_quantization.sh** for convenience.
+
+###### Dataset: ImageNet Classification (Image Classification)
+
+|Mode Name               |Backbone   |Stride|Resolution|Float Acc%|Simple Calib Acc%|Advanced Calib Acc%|Advanced DW Calib Acc%|Advanced Per-Chan Calib Acc%|Trained Quant Acc%|
+|----------              |-----------|------|----------|--------- |---              |---                |---                   |---                         |---               |
+|ResNet50(TorchVision)   |ResNet50   |32    |224x224   |**76.15** |75.56            |**75.56**          |75.56                 |75.39                       |                  |
+|MobileNetV2(TorchVision)|MobileNetV2|32    |224x224   |**71.89** |67.77            |**68.39**          |69.34                 |69.46                       |**70.55**         |
+|MobileNetV2(Shicai)     |MobileNetV2|32    |224x224   |**71.44** |45.60            |**68.81**          |70.65                 |70.75                       |                  |
+
+Notes:
+- For Image Classification, the accuracy measure used is % Top-1 Classification Accuracy. 'Top-1 Classification Accuracy' is abbreviated by Acc in the above table.<br>
+- MobileNetV2(Shicai) model is from https://github.com/shicai/MobileNet-Caffe (converted from caffe to PyTorch) - this is a tough model for quantization.<br><br>
+
+###### Dataset: Cityscapes Segmentation (Semantic Segmentation)
+
+|Mode Name               |Backbone   |Stride|Resolution|Float Acc%|Simple Calib Acc%|Advanced Calib Acc%|Advanced DW Calib Acc%|Advanced Per-Chan Calib Acc%|Trained Quant Acc%|
+|----------              |-----------|------|----------|----------|---              |---                |---                   |---                         |---               |
+|DeepLabV3Lite           |MobileNetV2|16    |768x384   |**69.13** |61.71            |**67.95**          |68.47                 |68.56                       |**68.26**         |
+
+Notes: 
+ - For Semantic Segmentation, the accuracy measure used in MeanIoU Accuracy. 'MeanIoU Accuracy' is abbreviated by Acc in the above table.
+<br>
+
+
+## References 
+[1] PACT: Parameterized Clipping Activation for Quantized Neural Networks, Jungwook Choi, Zhuo Wang, Swagath Venkataramani, Pierce I-Jen Chuang, Vijayalakshmi Srinivasan, Kailash Gopalakrishnan, arXiv preprint, arXiv:1805.06085, 2018
+
+[2] Estimating or propagating gradients through stochastic neurons for conditional computation. Y. Bengio, N. Léonard, and A. Courville. arXiv preprint arXiv:1308.3432, 2013.
+
+[3] Understanding Straight-Through Estimator in training activation quantized neural nets, Penghang Yin, Jiancheng Lyu, Shuai Zhang, Stanley Osher, Yingyong Qi, Jack Xin, ICLR 2019
+
+[4] Quantization and Training of Neural Networks for Efficient Integer-Arithmetic-Only Inference, Benoit Jacob Skirmantas Kligys Bo Chen Menglong Zhu, Matthew Tang Andrew Howard Hartwig Adam Dmitry Kalenichenko, arXiv preprint, arXiv:1712.05877
+
+[5] Trained quantization thresholds for accurate and efficient fixed-point inference of Deep Learning Neural Networks, Sambhav R. Jain, Albert Gural, Michael Wu, Chris H. Dick, arXiv preprint, arXiv:1903.08066 
+
+[6] Quantizing deep convolutional networks for efficient inference: A whitepaper, Raghuraman Krishnamoorthi, arXiv preprint, arXiv:1806.08342
+
+[7] TensorFlow / Learn / For Mobile & IoT / Guide / Post-training quantization, https://www.tensorflow.org/lite/performance/post_training_quantization
+
+[8] QUANTIZATION / Introduction to Quantization, https://pytorch.org/docs/stable/quantization.html
+
diff --git a/docs/Semantic_Segmentation.md b/docs/Semantic_Segmentation.md
new file mode 100644 (file)
index 0000000..320c878
--- /dev/null
@@ -0,0 +1,112 @@
+# Semantic Segmentation
+
+Semantic segmentation assigns a class to each pixel of the image. It is useful for tasks such as lane detection, road segmentation etc. 
+
+Commonly used Training/Validation commands are listed in the file [run_segmentation.sh](../run_segmentation.sh). Un-commend one and run the file to start the run. 
+
+## Models
+We have defined a set of example models for all pixel to pixel tasks. 
+
+These models can support multiple inputs (for example image and optical flow) as well as support multiple decoders for multi-task prediction (for example semantic segmentation + depth estimation + motion segmentation). 
+
+Whether to use multiple inputs or how many decoders to use are fully configurable. Our framework is also flexible to add different model architectures or backbone networks if you wish to do. 
+
+The models that we support of the shelf are:<br>
+
+* **deeplabv3lite_mobilenetv2_tv**: (default) This model is mostly similar to the DeepLabV3+ model [3] using MobileNetV2 backbone. The difference with DeepLabV3+ is that we removed the convolutions after the shortcut and kep one set of depthwise separable convolutions to generate the prediction. The ASPP module that we used is a lite-weight variant with depthwise separable convolutions (DWASPP). We found that this reduces complexity without sacrificing accuracy. Due to this we call this model DeepLabV3+(Lite) or simply  DeepLabV3Lite. (Note: The suffix "_tv" is used to indicate that our backbone model is from torchvision)<br> 
+
+* **fpn_pixel2pixel_aspp_mobilenetv2_tv**: This is similar to Feature Pyramid Network [4], but adapted for pixel2pixel tasks. We stop the decoder at a stride of 4 and then upsample to the final resolution from there. We also use DWASPP module to improve the receptive field. We call this model FPNPixel2Pixel. 
+
+* **fpn_pixel2pixel_aspp_resnet50**: Feature Pyramid Network (FPN) based pixel2pixel using ResNet50 backbone with DWASPP.
+
+## Datasets: Cityscapes Dataset [1]
+
+* Download the cityscapes dataset from https://www.cityscapes-dataset.com/. You will need need to register before the data can be downloaded. Unzip the data into the folder ./data/datasets/cityscapes/data. This folder should contain leftImg8bit and gtFine folders of cityscapes. 
+
+* These examples use two gpus because we use slightly higher accuracy when we restricted the number of GPUs used. 
+
+* Training can be done as follows:<br>
+    ```
+    python ./scripts/train_segmentation_main.py --dataset_name cityscapes_segmentation --data_path ./data/datasets/cityscapes/data --img_resize 384 768 --output_size 1024 2048 --gpus 0 1
+  ```
+
+ * During the training, **validation** accuracy will also be printed. But if you want to explicitly check the accuracy again with **validation** set, it can be done:<br>
+    ```
+    python ./scripts/train_segmentation_main.py --evaluate True --dataset_name cityscapes_segmentation --data_path ./data/datasets/cityscapes/data --img_resize 384 768 --output_size 1024 2048 --gpus 0 1
+    ```
+  
+  * It is possible to use a different image size. For example, we trained for 1536x768 resolution by the following. (We used a smaller crop size compared to the image resize resolution to reduce GPU memory usage). <br>
+    ```
+    python ./scripts/train_segmentation_main.py --dataset_name cityscapes_segmentation --data_path ./data/datasets/cityscapes/data --img_resize 768 1536 --rand_crop 512 1024 --output_size 1024 2048 --gpus 0 1
+    ```
+
+* Train FPNPixel2Pixel model at 1536x768 resolution (use 1024x512 crop to reduce memory usage):<br>
+    ```
+    python ./scripts/train_segmentation_main.py --model_name fpn_pixel2pixel_aspp_mobilenetv2_tv --dataset_name cityscapes_segmentation --data_path ./data/datasets/cityscapes/data --img_resize 768 1536 --rand_crop 512 1024 --output_size 1024 2048 --gpus 0 1
+    ```
+
+## Datasets: VOC Dataset [2]
+### Download the data
+* The dataset can be downloaded using the following:<br>
+    ```
+    mkdir ./data/datasets/voc
+    cd /data/datasets/voc
+    wget http://host.robots.ox.ac.uk/pascal/VOC/voc2012/VOCtrainval_11-May-2012.tar
+    wget http://host.robots.ox.ac.uk/pascal/VOC/voc2007/VOCtrainval_06-Nov-2007.tar
+    wget http://host.robots.ox.ac.uk/pascal/VOC/voc2007/VOCtest_06-Nov-2007.tar
+    ```
+### Extract the data.
+* Extact the dataset files into ./data/datasets/voc/VOCdevkit
+    ```
+    tar -xvf VOCtrainval_11-May-2012.tar
+    tar -xvf VOCtrainval_06-Nov-2007.tar
+    tar -xvf VOCtest_06-Nov-2007.tar
+    ```
+* Download Extra annotations: Download the augumented annotations as explained here: https://github.com/DrSleep/tensorflow-deeplab-resnet. For this, using a browser, download the zip file SegmentationClassAug.zip from: https://www.dropbox.com/s/oeu149j8qtbs1x0/SegmentationClassAug.zip?dl=0
+* Unzip SegmentationClassAug.zip and place the images in the folder ./data/datasets/voc/VOCdevkit/VOC2012/SegmentationClassAug
+* Create a list of those images in the ImageSets folder using the following:<br>
+    ```
+    cd VOCdevkit/VOC2012/ImageSets/Segmentation
+    ls -1 SegmentationClassAug | sed s/.png// > trainaug.txt
+    ```
+
+* training can be done as follows from the base folder of the repository:<br>
+    ```
+    python ./scripts/train_segmentation_main.py --dataset_name voc_segmentation --data_path ./data/datasets/voc --img_resize 512 512 --output_size 512 512 --gpus 0 1
+  ```
+
+## Results
+
+### Cityscapes Segmentation
+
+|Dataset    |Mode Architecture|Backbone Model|Model Name                         |Backbone Stride|Resolution |Complexity (GigaMACS)|MeanIoU%  |
+|---------  |----------       |-----------   |---------------------------------  |-------------- |-----------|--------             |----------|
+|Cityscapes |DeepLabV3Lite    |MobileNetV2   |deeplabv3lite_mobilenetv2_tv       |16             |768x384    |3.54                 |**69.13** |
+|Cityscapes |FPNPixel2Pixel   |MobileNetV2   |fpn_pixel2pixel_mobilenetv2_tv     |32             |768x384    |3.84                 |**70.39** |
+|Cityscapes |FPNPixel2Pixel   |MobileNetV2   |fpn_pixel2pixel_mobilenetv2_tv_es64|64             |1536x768   |3.96                 |**71.28** |
+|Cityscapes |FPNPixel2Pixel   |MobileNetV2   |fpn_pixel2pixel_mobilenetv2_tv_es64|64             |2048x1024  |7.03                 |          |
+|Cityscapes |DeepLabV3Lite    |MobileNetV2   |deeplabv3lite_mobilenetv2_tv       |16             |1536x768   |14.48                |**73.59** |
+|Cityscapes |FPNPixel2Pixel   |MobileNetV2   |fpn_pixel2pixel_mobilenetv2_tv     |32             |1536x768   |15.37                |**74.98** |
+|-
+|Cityscapes |ERFNet[5]        |              |                                   |               |1024x512   |27.705               |69.7      |
+|Cityscapes |SwiftNetMNV2[6]  |MobileNetV2   |                                   |               |2048x1024  |41.0                 |75.3      |
+|Cityscapes |DeepLabV3Plus[3] |MobileNetV2   |                                   |16             |           |21.27                |70.71     |
+|Cityscapes |DeepLabV3Plus[3] |Xception65    |                                   |16             |           |418.64               |78.79     |
+
+
+## References
+[1]The Cityscapes Dataset for Semantic Urban Scene Understanding, Marius Cordts, Mohamed Omran, Sebastian Ramos, Timo Rehfeld, Markus Enzweiler, Rodrigo Benenson, Uwe Franke, Stefan Roth, Bernt Schiele, CVPR 2016, https://www.cityscapes-dataset.com/
+
+[2] The PASCAL Visual Object Classes (VOC) Challenge
+Everingham, M., Van Gool, L., Williams, C. K. I., Winn, J. and Zisserman, A.
+International Journal of Computer Vision, 88(2), 303-338, 2010, http://host.robots.ox.ac.uk/pascal/VOC/
+
+[3] Encoder-Decoder with Atrous Separable Convolution for Semantic Image Segmentation, Liang-Chieh Chen, Yukun Zhu, George Papandreou, Florian Schroff, and Hartwig Adam, CVPR 2018, https://github.com/tensorflow/models/blob/master/research/deeplab/g3doc/model_zoo.md
+
+[4] Feature Pyramid Networks for Object Detection, Tsung-Yi Lin, Piotr Dollar, Ross Girshick, Kaiming He, Bharath Hariharan, Serge Belongie, CVPR 2017
+
+[5] ERFNet: Efficient Residual Factorized ConvNet for Real-time Semantic Segmentation, E. Romera, J. M. Alvarez, L. M. Bergasa and R. Arroyo, Transactions on Intelligent Transportation Systems (T-ITS), 2017
+
+[6] In Defense of Pre-trained ImageNet Architectures for Real-time Semantic Segmentation of Road-driving Images, Marin Orsic, Ivan Kreso, Petra Bevandic, Sinisa Segvic, CVPR 2019.
+
+
diff --git a/docs/motion_segmentation/motion_segmentation_network.PNG b/docs/motion_segmentation/motion_segmentation_network.PNG
new file mode 100755 (executable)
index 0000000..9e6cbf1
Binary files /dev/null and b/docs/motion_segmentation/motion_segmentation_network.PNG differ
diff --git a/docs/multi_task_learning/multi_task_network.PNG b/docs/multi_task_learning/multi_task_network.PNG
new file mode 100755 (executable)
index 0000000..1b46611
Binary files /dev/null and b/docs/multi_task_learning/multi_task_network.PNG differ
diff --git a/examples/write_onnx_model_example.py b/examples/write_onnx_model_example.py
new file mode 100644 (file)
index 0000000..64f2f8e
--- /dev/null
@@ -0,0 +1,31 @@
+import os
+import torch
+import torchvision
+import datetime
+
+# dependencies
+# Anaconda Python 3.7 for Linux - download and install from: https://www.anaconda.com/distribution/
+# pytorch, torchvision - install using: 
+# conda install pytorch torchvision -c pytorch
+
+# some parameters - modify as required
+date = datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
+dataset_name = 'image_folder_classification'
+model_name = 'resnet50'
+img_resize = (256,256)
+rand_crop = (224,224)
+
+# the saving path - you can choose any path
+save_path = './data/checkpoints'
+save_path = os.path.join(save_path, dataset_name, date + '_' + dataset_name + '_' + model_name)
+save_path += '_resize{}x{}_traincrop{}x{}'.format(img_resize[1], img_resize[0], rand_crop[1], rand_crop[0])
+os.makedirs(save_path, exist_ok=True)
+
+# create the model - replace with your model
+model = torchvision.models.resnet50(pretrained=True)
+
+# create a rand input
+rand_input = torch.rand(1, 3, rand_crop[0], rand_crop[1])
+
+# write the onnx model
+torch.onnx.export(model, rand_input, os.path.join(save_path, 'model.onnx'), export_params=True, verbose=False)
diff --git a/modules/__init__.py b/modules/__init__.py
new file mode 100644 (file)
index 0000000..e69de29
diff --git a/modules/pytorch_jacinto_ai/__init__.py b/modules/pytorch_jacinto_ai/__init__.py
new file mode 100644 (file)
index 0000000..e69de29
diff --git a/modules/pytorch_jacinto_ai/engine/__init__.py b/modules/pytorch_jacinto_ai/engine/__init__.py
new file mode 100644 (file)
index 0000000..e69de29
diff --git a/modules/pytorch_jacinto_ai/engine/evaluate_pixel2pixel.py b/modules/pytorch_jacinto_ai/engine/evaluate_pixel2pixel.py
new file mode 100644 (file)
index 0000000..5b10a02
--- /dev/null
@@ -0,0 +1,272 @@
+import sys
+import torch
+
+import numpy as np
+import cv2
+
+from .. import xnn
+
+
+##################################################
+# to avoid hangs in data loader with multi threads
+# this was observed after using cv2 image processing functions
+# https://github.com/pytorch/pytorch/issues/1355
+cv2.setNumThreads(0)
+
+##################################################
+np.set_printoptions(precision=3)
+
+##################################################
+def get_config():
+    args = xnn.utils.ConfigNode()
+
+    args.data_path = './'                   # path to dataset
+    args.num_classes = None                 # number of classes (for segmentation)
+
+    args.save_path = None        # checkpoints save path
+
+    args.img_resize = None                  # image size to be resized to
+
+    args.output_size = None                 # target output size to be resized to
+
+    args.upsample_mode = 'bilinear'         # 'upsample mode to use. choices=['nearest','bilinear']
+
+    args.eval_semantic = False              # 'set for 19 class segmentation
+
+    args.eval_semantic_five_class = False   # set for five class segmentation
+
+    args.eval_motion = False                # set for motion segmentation
+
+    args.eval_depth = False                 # set for motion segmentation
+
+    args.scale_factor = 1.0                 # scale_factor used by Deepak to improve depth accuracy
+    args.verbose = True                     # whether to print scores for all frames or the final result
+
+    args.inf_suffix = ''                    # suffix for diffrent job sem/mot/depth
+    args.frame_IOU = False                  # Check for framewise IOU for segmentation
+    args.phase = 'validation'
+    args.date = None
+    return args
+
+
+# ################################################
+def main(args):
+
+    assert not hasattr(args, 'model_surgery'), 'the argument model_surgery is deprecated, it is not needed now - remove it'
+
+    ################################
+    # print everything for log
+    print('=> args: ', args)
+
+    #################################################
+    if args.eval_semantic:
+        args.void_classes = [0, 1, 2, 3, 4, 5, 6, 9, 10, 14, 15, 16, 18, 29, 30, -1]
+        args.valid_classes = [7, 8, 11, 12, 13, 17, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 31, 32, 33]
+        args.class_map = dict(zip(args.valid_classes, range(19)))
+        args.gt_suffix = '_labelIds.png'
+        if args.inf_suffix == '':
+          args.inf_suffix = '.png'
+        args.output_channels = 19
+    elif args.eval_semantic_five_class:
+        args.void_classes = [-1, 255]
+        args.valid_classes = [0, 1, 2, 3, 4]
+        args.class_map = dict(zip(args.valid_classes, range(5)))
+        args.gt_suffix = '_labelTrainIds.png'
+        if args.inf_suffix == '':        
+          args.inf_suffix = '1.png'
+        args.output_channels = 5
+    elif args.eval_motion:
+        args. void_classes = []
+        args.valid_classes = [0, 255]
+        args.class_map = dict(zip(args.valid_classes, range(2)))
+        args.gt_suffix = '_labelTrainIds_motion.png'
+        if args.inf_suffix == '':        
+          args.inf_suffix = '2.png'
+        args.output_channels = 2
+    elif args.eval_depth:
+        args. void_classes = []
+        args.valid_classes = [0, 255]
+        args.gt_suffix = '.png'
+        if args.inf_suffix == '':        
+          args.inf_suffix = '0.png'
+        args.max_depth = 20
+
+    print(args)
+    print("=> fetching gt labels in '{}'".format(args.label_path))
+    label_files = xnn.utils.recursive_glob(rootdir=args.label_path, suffix=args.gt_suffix)
+    label_files.sort()
+    print('=> {} gt label samples found'.format(len(label_files)))
+
+    print("=> fetching inferred images in '{}'".format(args.infer_path))
+    infer_files = xnn.utils.recursive_glob(rootdir=args.infer_path, suffix=args.inf_suffix)
+    infer_files.sort()
+    print('=> {} inferred samples found'.format(len(infer_files)))
+
+    assert len(label_files) == len(infer_files), 'Number of label files and inference file must be same'
+
+    #################################################
+    if not args.eval_depth:
+        validate(args, label_files, infer_files)
+    else:
+        validate_depth(args, label_files, infer_files)
+
+
+def validate(args, label_files, infer_files):
+
+    confusion_matrix = np.zeros((args.output_channels, args.output_channels+1))
+
+    for iter, (label_file, infer_file) in enumerate(zip(label_files, infer_files)):
+        if args.frame_IOU:
+            confusion_matrix = np.zeros((args.output_channels, args.output_channels + 1))
+        gt = encode_segmap(args, cv2.imread(label_file, 0))
+        inference = cv2.imread(infer_file)
+        inference = inference[:,:,-1]
+        if inference.shape != gt.shape:
+            inference = np.expand_dims(np.expand_dims(inference, 0),0)
+            inference = torch.tensor(inference).float()
+            scale_factor = torch.tensor(args.scale_factor).float()
+            inference = inference/scale_factor
+            inference = torch.nn.functional.interpolate(inference, (gt.shape[0], gt.shape[1]), mode='nearest')
+            inference = np.ndarray.astype(np.squeeze(inference.numpy()), dtype=np.uint8)
+        confusion_matrix = eval_output(args, inference, gt, confusion_matrix, args.output_channels)
+        accuracy, mean_iou, iou, f1_score= compute_accuracy(args, confusion_matrix, args.output_channels)
+        if args.verbose:
+            print('{}/{}inferred image {} mIOU {}'.format(iter, len(label_files), label_file.split('/')[-1], mean_iou))
+            if iter % 100 ==0:
+                print('\npixel_accuracy={},\nmean_iou={},\niou={},\nf1_score = {}'.format(accuracy, mean_iou, iou, f1_score))
+                sys.stdout.flush()
+    print('\npixel_accuracy={},\nmean_iou={},\niou={},\nf1_score = {}'.format(accuracy, mean_iou, iou, f1_score))
+    #
+
+def encode_segmap(args, mask):
+    ignore_index = 255
+    for _voidc in args.void_classes:
+        mask[mask == _voidc] = ignore_index
+    for _validc in args.valid_classes:
+        mask[mask == _validc] = args.class_map[_validc]
+    return mask
+
+
+def eval_output(args, output, label, confusion_matrix, n_classes):
+    if len(label.shape)>2:
+        label = label[:,:,0]
+    gt_labels = label.ravel()
+    det_labels = output.ravel().clip(0,n_classes)
+    gt_labels_valid_ind = np.where(gt_labels != 255)
+    gt_labels_valid = gt_labels[gt_labels_valid_ind]
+    det_labels_valid = det_labels[gt_labels_valid_ind]
+    for r in range(confusion_matrix.shape[0]):
+        for c in range(confusion_matrix.shape[1]):
+            confusion_matrix[r,c] += np.sum((gt_labels_valid==r) & (det_labels_valid==c))
+
+    return confusion_matrix
+
+
+def compute_accuracy(args, confusion_matrix, n_classes):
+    num_selected_classes = n_classes
+    tp = np.zeros(n_classes)
+    population = np.zeros(n_classes)
+    det = np.zeros(n_classes)
+    iou = np.zeros(n_classes)
+    
+    for r in range(n_classes):
+      for c in range(n_classes):
+        population[r] += confusion_matrix[r][c]
+        det[c] += confusion_matrix[r][c]   
+        if r == c:
+          tp[r] += confusion_matrix[r][c]
+
+    for cls in range(num_selected_classes):
+      intersection = tp[cls]
+      union = population[cls] + det[cls] - tp[cls]
+      iou[cls] = (intersection / union) if union else 0     # For caffe jacinto script
+      #iou[cls] = (intersection / (union + np.finfo(np.float32).eps))  # For pytorch-jacinto script
+
+    num_nonempty_classes = 0
+    for pop in population:
+      if pop>0:
+        num_nonempty_classes += 1
+          
+    mean_iou = np.sum(iou) / num_nonempty_classes if num_nonempty_classes else 0
+    accuracy = np.sum(tp) / np.sum(population) if np.sum(population) else 0
+    
+    #F1 score calculation
+    fp = np.zeros(n_classes)
+    fn = np.zeros(n_classes)
+    precision = np.zeros(n_classes)
+    recall = np.zeros(n_classes)
+    f1_score = np.zeros(n_classes)
+
+    for cls in range(num_selected_classes):
+        fp[cls] = det[cls] - tp[cls]
+        fn[cls] = population[cls] - tp[cls]
+        precision[cls] = tp[cls] / (det[cls] + 1e-10)
+        recall[cls] = tp[cls] / (tp[cls] + fn[cls] + 1e-10)        
+        f1_score[cls] = 2 * precision[cls]*recall[cls] / (precision[cls] + recall[cls] + 1e-10)
+
+    return accuracy, mean_iou, iou, f1_score
+
+def validate_depth(args, label_files, infer_files):
+    max_depth = args.max_depth
+    print("Max depth set to {} meters".format(max_depth))
+    ard_err = None
+    for iter, (label_file, infer_file) in enumerate(zip(label_files, infer_files)):
+        gt = cv2.imread(label_file, cv2.IMREAD_UNCHANGED)
+        inference = cv2.imread(infer_file, cv2.IMREAD_UNCHANGED)
+
+        if inference.shape != gt.shape:
+            inference = np.expand_dims(np.expand_dims(inference, 0),0)
+            inference = torch.tensor(inference).float()
+            scale_factor = torch.tensor(args.scale_factor).float()
+            inference = (inference/scale_factor).int().float()
+            inference = torch.nn.functional.interpolate(inference, (gt.shape[0], gt.shape[1]), mode='nearest')
+            inference = torch.squeeze(inference)
+
+        gt[gt==255] = 0
+        gt[gt>max_depth]=max_depth
+        gt = torch.tensor(gt).float()
+        inference[inference > max_depth] = max_depth
+
+        valid = (gt!=0)
+        gt = gt[valid]
+        inference = inference[valid]
+        if len(gt)>2:
+            if ard_err is None:
+                ard_err = [absreldiff_rng3to80(inference, gt).mean()]
+            else:
+                ard_err.append(absreldiff_rng3to80(inference, gt).mean())
+        elif len(gt) < 2:
+            if ard_err is None:
+                ard_err = [0.0]
+            else:
+                ard_err.append(0.0)
+
+        if args.verbose:
+            print('{}/{} Inferred Frame {} ARD {}'.format(iter+1, len(label_files), label_file.split('/')[-1], float(ard_err[-1])))
+
+    print('ARD_final {}'.format(float(torch.tensor(ard_err).mean())))
+
+def absreldiff(x, y, eps = 0.0, max_val=None):
+    assert x.size() == y.size(), 'tensor dimension mismatch'
+    if max_val is not None:
+        x = torch.clamp(x, -max_val, max_val)
+        y = torch.clamp(y, -max_val, max_val)
+    #
+
+    diff = torch.abs(x - y)
+    y = torch.abs(y)
+
+    den_valid = (y == 0).float()
+    eps_arr = (den_valid * (1e-6))   # Just to avoid divide by zero
+
+    large_arr = (y > eps).float()    # ARD is not a good measure for small ref values. Avoid them.
+    out = (diff / (y + eps_arr)) * large_arr
+    return out
+
+def absreldiff_rng3to80(x, y):
+    return absreldiff(x, y, eps = 3.0, max_val=80.0)
+
+
+if __name__ == '__main__':
+    train_args = get_config()
+    main(train_args)
diff --git a/modules/pytorch_jacinto_ai/engine/infer_pixel2pixel.py b/modules/pytorch_jacinto_ai/engine/infer_pixel2pixel.py
new file mode 100644 (file)
index 0000000..585ff3e
--- /dev/null
@@ -0,0 +1,946 @@
+import os
+import time
+import sys
+import math
+import copy
+
+import torch
+import torch.nn.parallel
+import torch.backends.cudnn as cudnn
+import torch.optim
+import torch.utils.data
+import datetime
+import numpy as np
+import random
+import cv2
+import matplotlib.pyplot as plt
+
+from .. import xnn
+from .. import vision
+
+#sys.path.insert(0, '../devkit-datasets/TI/')
+#from fisheye_calib import r_fish_to_theta_rect
+
+# ################################################
+def get_config():
+    args = xnn.utils.ConfigNode()
+
+    args.dataset_config = xnn.utils.ConfigNode()
+    args.dataset_config.split_name = 'val'
+    args.dataset_config.max_depth_bfr_scaling = 80
+    args.dataset_config.depth_scale = 1
+    args.dataset_config.train_depth_log = 1
+    args.use_semseg_for_depth = False
+
+    args.model_config = xnn.utils.ConfigNode()
+    args.model_name = 'deeplabv2lite_mobilenetv2'       # model architecture, overwritten if pretrained is specified
+    args.dataset_name = 'flying_chairs'              # dataset type
+
+    args.data_path = './data/datasets'                       # path to dataset
+    args.save_path = None            # checkpoints save path
+    args.pretrained = None
+
+    args.model_config.output_type = ['flow']                # the network is used to predict flow or depth or sceneflow')
+    args.model_config.output_channels = None                 # number of output channels
+    args.model_config.input_channels = None                  # number of input channels
+    args.model_config.num_classes = None                       # number of classes (for segmentation)
+
+    args.model_config.num_decoders = None               # number of decoders to use. [options: 0, 1, None]
+    args.sky_dir = False
+
+    args.logger = None                          # logger stream to output into
+
+    args.split_file = None                      # train_val split file
+    args.split_files = None                     # split list files. eg: train.txt val.txt
+    args.split_value = 0.8                      # test_val split proportion (between 0 (only test) and 1 (only train))
+
+    args.workers = 8                            # number of data loading workers
+
+    args.epoch_size = 0                         # manual epoch size (will match dataset size if not specified)
+    args.epoch_size_val = 0                     # manual epoch size (will match dataset size if not specified)
+    args.batch_size = 8                         # mini_batch_size
+    args.total_batch_size = None                # accumulated batch size. total_batch_size = batch_size*iter_size
+    args.iter_size = 1                          # iteration size. total_batch_size = batch_size*iter_size
+
+    args.tensorboard_num_imgs = 5               # number of imgs to display in tensorboard
+    args.phase = 'validation'                        # evaluate model on validation set
+    args.pretrained = None                      # path to pre_trained model
+    args.date = None                            # don\'t append date timestamp to folder
+    args.print_freq = 10                        # print frequency (default: 100)
+
+    args.div_flow = 1.0                         # value by which flow will be divided. Original value is 20 but 1 with batchNorm gives good results
+    args.losses = ['supervised_loss']           # loss functions to minimize
+    args.metrics = ['supervised_error']         # metric/measurement/error functions for train/validation
+    args.class_weights = None                   # class weights
+
+    args.multistep_gamma = 0.5                  # steps for step scheduler
+    args.polystep_power = 1.0                   # power for polynomial scheduler
+    args.train_fwbw = False                     # do forward backward step while training
+
+    args.rand_seed = 1                          # random seed
+    args.img_border_crop = None                 # image border crop rectangle. can be relative or absolute
+    args.target_mask = None                      # mask rectangle. can be relative or absolute. last value is the mask value
+    args.img_resize = None                      # image size to be resized to
+    args.rand_scale = (1,1.25)                  # random scale range for training
+    args.rand_crop = None                       # image size to be cropped to')
+    args.output_size = None                     # target output size to be resized to')
+
+    args.count_flops = True                     # count flops and report
+
+    args.shuffle = True                         # shuffle or not
+    args.is_flow = None                         # whether entries in images and targets lists are optical flow or not
+
+    args.multi_decoder = True                   # whether to use multiple decoders or unified decoder
+
+    args.create_video = False                   # whether to create video out of the inferred images
+
+    args.input_tensor_name = ['0']              # list of input tensore names
+
+    args.upsample_mode = 'nearest'              # upsample mode to use., choices=['nearest','bilinear']
+
+    args.image_prenorm = True                   # whether normalization is done before all other the transforms
+    args.image_mean = [128.0]                   # image mean for input image normalization
+    args.image_scale = [1.0/(0.25*256)]         # image scaling/mult for input iamge normalization
+    args.quantize = False                       # apply quantized inference or not
+    #args.model_surgery = None                   # replace activations with PAct2 activation module. Helpful in quantized training.
+    args.bitwidth_weights = 8                   # bitwidth for weights
+    args.bitwidth_activations = 8               # bitwidth for activations
+    args.histogram_range = True                 # histogram range for calibration
+    args.per_channel_q = False                  # apply separate quantizion factor for each channel in depthwise or not
+    args.bias_calibration = False                # apply bias correction during quantized inference calibration
+
+    args.frame_IOU = False                      # Print mIOU for each frame
+    args.make_score_zero_mean = False           #to make score and desc zero mean
+    args.learn_scaled_values_interest_pt = True
+    args.save_mod_files = False                 # saves modified files after last commit. Also  stores commit id.
+    args.gpu_mode = True                        #False will make inference run on CPU
+    args.write_layer_ip_op= False               #True will make it tap inputs outputs for layers
+    args.file_format = 'none'                   #Ip/Op tapped points for each layer: None : it will not be written but print will still appear
+    args.generate_onnx = True
+    args.remove_ignore_lbls_in_pred = False     #True: if in the pred where GT has ignore label do not visualize for GT visualization
+    args.do_pred_cordi_f2r = False              #true: Do f2r operation on detected location for interet point task
+    args.depth_cmap_plasma = False      
+    args.visualize_gt = False                   #to vis pred or GT
+    args.viz_depth_color_type = 'plasma'       #color type for dpeth visualization
+    args.depth = [False]
+    return args
+
+
+# ################################################
+# to avoid hangs in data loader with multi threads
+# this was observed after using cv2 image processing functions
+# https://github.com/pytorch/pytorch/issues/1355
+cv2.setNumThreads(0)
+
+##################################################
+np.set_printoptions(precision=3)
+
+
+#################################################
+def shape_as_string(shape=[]):
+    shape_str = ''
+    for dim in shape:
+        shape_str += '_' + str(dim)
+    return shape_str
+
+def write_tensor_int(m = [], tensor = [], suffix='op', bitwidth = 8, power2_scaling = True, file_format='bin', rnd_type = 'rnd_sym'):
+    mn = tensor.min()
+    mx = tensor.max()
+
+    print('{:6}, {:32}, {:10}, {:7.2f}, {:7.2f}'.format(suffix, m.name, m.__class__.__name__, tensor.min(), tensor.max()),  end =" ")
+
+    [tensor_scale, clamp_limits] = xnn.quantize.compute_tensor_scale(tensor, mn, mx, bitwidth, power2_scaling)
+    print("{:30} : {:15} : {:8.2f}".format(str(tensor.shape), str(tensor.dtype), tensor_scale), end =" ")
+    
+    print_weight_bias = False
+    if rnd_type == 'rnd_sym':
+      #use best rounding for offline quantities
+      if suffix == 'weight' and print_weight_bias:
+          no_idx = 0
+          torch.set_printoptions(precision=32)
+          print("tensor_scale: ", tensor_scale)
+          print(tensor[no_idx])
+      if tensor.dtype != torch.int64:
+          tensor = xnn.quantize.symmetric_round_tensor(tensor * tensor_scale)
+      if suffix == 'weight'  and print_weight_bias:
+          print(tensor[no_idx])
+    else:  
+      #for activation use HW friendly rounding  
+      if tensor.dtype != torch.int64:
+          tensor = xnn.quantize.upward_round_tensor(tensor * tensor_scale)
+    tensor = tensor.clamp(clamp_limits[0], clamp_limits[1]).float()
+
+    if bitwidth == 8:
+      data_type = np.int8
+    elif bitwidth == 16:
+      data_type = np.int16
+    elif bitwidth == 32:
+      data_type = np.int32
+    else:
+       exit("Bit width other 8,16,32 not supported for writing layer level op")
+
+    tensor = tensor.cpu().numpy().astype(data_type)
+
+    print("{:7} : {:7d} : {:7d}".format(str(tensor.dtype), tensor.min(), tensor.max()))
+
+    root = os.getcwd()
+    tensor_dir = root + '/checkpoints/debug/test_vecs/' + '{}_{}_{}_scale_{:010.4f}'.format(m.name, m.__class__.__name__, suffix,  tensor_scale)
+
+    if not os.path.exists(tensor_dir):
+        os.makedirs(tensor_dir)
+
+    if file_format == 'bin':
+        tensor_name = tensor_dir + "/{}_shape{}.bin".format(m.name, shape_as_string(shape=tensor.shape))
+        tensor.tofile(tensor_name)
+    elif file_format == 'npy':
+        tensor_name = tensor_dir + "/{}_shape{}.npy".format(m.name, shape_as_string(shape=tensor.shape))
+        np.save(tensor_name, tensor)
+
+    #utils_hist.comp_hist_tensor3d(x=tensor, name=m.name, en=True, dir=m.name, log=True, ch_dim=0)
+
+
+def write_tensor_float(m = [], tensor = [], suffix='op'):
+    
+    mn = tensor.min()
+    mx = tensor.max()
+
+    print('{:6}, {:32}, {:10}, {:7.2f}, {:7.2f}'.format(suffix, m.name, m.__class__.__name__, tensor.min(), tensor.max()))
+    root = os.getcwd()
+    tensor_dir = root + '/checkpoints/debug/test_vecs/' + '{}_{}_{}'.format(m.name, m.__class__.__name__, suffix)
+
+    if not os.path.exists(tensor_dir):
+        os.makedirs(tensor_dir)
+
+    tensor_name = tensor_dir + "/{}_shape{}.npy".format(m.name, shape_as_string(shape=tensor.shape))
+    np.save(tensor_name, tensor.data)
+
+def write_tensor(data_type = 'int', m = [], tensor = [], suffix='op', bitwidth = 8, power2_scaling = True, file_format='bin', 
+    rnd_type = 'rnd_sym'):
+    
+    if data_type == 'int':
+        write_tensor_int(m=m, tensor=tensor, suffix=suffix, rnd_type =rnd_type, file_format=file_format)
+    elif  data_type == 'float':
+        write_tensor_float(m=m, tensor=tensor, suffix=suffix)
+
+enable_hook_function = True
+def write_tensor_hook_function(m, inp, out, file_format='none'):
+    if not enable_hook_function:
+        return
+
+    #Output
+    if isinstance(out, (torch.Tensor)):
+        write_tensor(m=m, tensor=out, suffix='op', rnd_type ='rnd_up', file_format=file_format)
+
+    #Input(s)
+    if type(inp) is tuple:
+        #if there are more than 1 inputs
+        for index, sub_ip in enumerate(inp[0]):
+            if isinstance(sub_ip, (torch.Tensor)):
+                write_tensor(m=m, tensor=sub_ip, suffix='ip_{}'.format(index), rnd_type ='rnd_up', file_format=file_format)
+    elif isinstance(inp, (torch.Tensor)):
+         write_tensor(m=m, tensor=inp, suffix='ip', rnd_type ='rnd_up', file_format=file_format)
+
+    #weights
+    if hasattr(m, 'weight'):
+        if isinstance(m.weight,torch.Tensor):
+            write_tensor(m=m, tensor=m.weight, suffix='weight', rnd_type ='rnd_sym', file_format=file_format)
+
+    #bias
+    if hasattr(m, 'bias'):
+        if m.bias is not None:
+            write_tensor(m=m, tensor=m.bias, suffix='bias', rnd_type ='rnd_sym', file_format=file_format)
+
+# ################################################
+def main(args):
+
+    assert not hasattr(args, 'model_surgery'), 'the argument model_surgery is deprecated, it is not needed now - remove it'
+
+    #################################################
+    # global settings. rand seeds for repeatability
+    random.seed(args.rand_seed)
+    np.random.seed(args.rand_seed)
+    torch.manual_seed(args.rand_seed)
+    torch.backends.cudnn.deterministic = True
+    torch.backends.cudnn.benchmark = True
+
+    ################################
+    # args check and config
+    assert args.iter_size == 1 or args.total_batch_size is None, "only one of --iter_size or --total_batch_size must be set"
+    if args.total_batch_size is not None:
+        args.iter_size = args.total_batch_size//args.batch_size
+    else:
+        args.total_batch_size = args.batch_size*args.iter_size
+    #
+
+    assert args.pretrained is not None, 'pretrained path must be provided'
+
+    # onnx generation is filing for post quantized module
+    args.generate_onnx = False if (args.quantize) else args.generate_onnx
+
+    #################################################
+    # set some global flags and initializations
+    # keep it in args for now - although they don't belong here strictly
+    # using pin_memory is seen to cause issues, especially when when lot of memory is used.
+    args.use_pinned_memory = False
+    args.n_iter = 0
+    args.best_metric = -1
+
+    #################################################
+    if args.save_path is None:
+        save_path = get_save_path(args)
+    else:
+        save_path = args.save_path
+    #
+    print('=> will save everything to {}'.format(save_path))
+    if not os.path.exists(save_path):
+        os.makedirs(save_path)
+
+    #################################################
+    if args.logger is None:
+        log_file = os.path.splitext(os.path.basename(__file__))[0] + '.log'
+        args.logger = xnn.utils.TeeLogger(filename=os.path.join(save_path,log_file))
+
+    ################################
+    # print everything for log
+    print('=> args: ', args)
+
+    if args.save_mod_files:
+        #store all the files after the last commit.
+        mod_files_path = save_path+'/mod_files'
+        os.makedirs(mod_files_path)
+        
+        cmd = "git ls-files --modified | xargs -i cp {} {}".format("{}", mod_files_path)
+        print("cmd:", cmd)    
+        os.system(cmd)
+
+        #stoe last commit id. 
+        cmd = "git log -n 1  >> {}".format(mod_files_path + '/commit_id.txt')
+        print("cmd:", cmd)    
+        os.system(cmd)
+
+    transforms = get_transforms(args)
+
+    print("=> fetching img pairs in '{}'".format(args.data_path))
+    split_arg = args.split_file if args.split_file else (args.split_files if args.split_files else args.split_value)
+
+    val_dataset = vision.datasets.pixel2pixel.__dict__[args.dataset_name](args.dataset_config, args.data_path, split=split_arg, transforms=transforms)
+
+    print('=> {} val samples found'.format(len(val_dataset)))
+    val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=args.batch_size,
+        num_workers=args.workers, pin_memory=args.use_pinned_memory, shuffle=args.shuffle)
+
+    #################################################
+    if (args.model_config.input_channels is None):
+        args.model_config.input_channels = (3,)
+        print("=> input channels is not given - setting to {}".format(args.model_config.input_channels))
+
+    if (args.model_config.output_channels is None):
+        if ('num_classes' in dir(val_dataset)):
+            args.model_config.output_channels = val_dataset.num_classes()
+        else:
+            args.model_config.output_channels = (2 if args.model_config.output_type == 'flow' else args.model_config.output_channels)
+            xnn.utils.print_yellow("=> output channels is not given - setting to {} - not sure to work".format(args.model_config.output_channels))
+        #
+        if not isinstance(args.model_config.output_channels,(list,tuple)):
+            args.model_config.output_channels = [args.model_config.output_channels]
+
+    #################################################
+    pretrained_data = None
+    model_surgery_quantize = False
+    if args.pretrained and args.pretrained != "None":
+        if args.pretrained.startswith('http://') or args.pretrained.startswith('https://'):
+            pretrained_file = vision.datasets.utils.download_url(args.pretrained, './data/downloads')
+        else:
+            pretrained_file = args.pretrained
+        #
+        print(f'=> using pre-trained weights from: {args.pretrained}')
+        pretrained_data = torch.load(pretrained_file)
+        model_surgery_quantize = pretrained_data['quantize'] if 'quantize' in pretrained_data else False
+    #
+
+    #################################################
+    # the portion before comma is used as the model name
+    # string after comma (if present is used as decoder names) in the decoder ModuleDict()
+    model = vision.models.pixel2pixel.__dict__[args.model_name](args.model_config)
+
+    # check if we got the model as well as parameters to change the names in pretrained
+    model, change_names_dict = model if isinstance(model, (list,tuple)) else (model,None)
+
+    #################################################
+    if args.quantize:
+        # dummy input is used by quantized models to analyze graph
+        is_cuda = next(model.parameters()).is_cuda
+        dummy_input = create_rand_inputs(args, is_cuda=is_cuda)
+        # Note: bias_calibration is not enabled in test
+        model = xnn.quantize.QuantTestModule(model, per_channel_q=args.per_channel_q,
+                        bitwidth_weights=args.bitwidth_weights, bitwidth_activations=args.bitwidth_activations,
+                        histogram_range=args.histogram_range, model_surgery_quantize=model_surgery_quantize,
+                        dummy_input=dummy_input)
+
+    # load pretrained weights
+    xnn.utils.load_weights_check(get_model_orig(model), pretrained=pretrained_data, change_names_dict=change_names_dict)
+
+    #################################################
+    # multi gpu mode is not yet supported with quantization in evaluate
+    if args.gpu_mode and (args.phase=='training'):
+        model = torch.nn.DataParallel(model)
+
+    #################################################
+    model = model.cuda()
+
+    #################################################
+    if args.write_layer_ip_op:
+        # for dumping module outputs
+        for name, module in model.named_modules():
+            module.name = name
+            print(name)
+            #if 'module.encoder.features.0.' in name:
+            module.register_forward_hook(write_tensor_hook_function)
+        print('{:7} {:33} {:12} {:8} {:6} {:30} : {:17} : {:4} : {:11} : {:7} : {:7}'.format("type",  "name", "layer", "min", "max", "tensor_shape", "dtype", "scale", "dtype", "min", "max"))
+
+    #################################################
+    args.loss_modules = copy.deepcopy(args.losses)
+    for task_dx, task_losses in enumerate(args.losses):
+        for loss_idx, loss_fn in enumerate(task_losses):
+            kw_args = {}
+            loss_args = vision.losses.__dict__[loss_fn].args()
+            for arg in loss_args:
+                #if arg == 'weight':
+                #    kw_args.update({arg:args.class_weights[task_dx]})
+                if arg == 'num_classes':
+                    kw_args.update({arg:args.model_config.output_channels[task_dx]})
+                elif arg == 'sparse':
+                    kw_args.update({arg:args.sparse})
+                #
+            #
+            loss_fn = vision.losses.__dict__[loss_fn](**kw_args)
+            loss_fn = loss_fn.cuda()
+            args.loss_modules[task_dx][loss_idx] = loss_fn
+
+    args.metric_modules = copy.deepcopy(args.metrics)
+    for task_dx, task_metrics in enumerate(args.metrics):
+        for midx, metric_fn in enumerate(task_metrics):
+            kw_args = {}
+            loss_args = vision.losses.__dict__[metric_fn].args()
+            for arg in loss_args:
+                if arg == 'weight':
+                    kw_args.update({arg:args.class_weights[task_dx]})
+                elif arg == 'num_classes':
+                    kw_args.update({arg:args.model_config.output_channels[task_dx]})
+                elif arg == 'sparse':
+                    kw_args.update({arg:args.sparse})
+                #
+            #
+            metric_fn = vision.losses.__dict__[metric_fn](**kw_args)
+            metric_fn = metric_fn.cuda()
+            args.metric_modules[task_dx][midx] = metric_fn
+
+    #################################################
+    if args.palette:
+        print('Creating palette')
+        args.palette = val_dataset.create_palette()
+        for i, p in enumerate(args.palette):
+            args.palette[i] = np.array(p, dtype = np.uint8)
+            args.palette[i] = args.palette[i][..., ::-1]  # RGB->BGR, since palette is expected to be given in RGB format
+
+    infer_path = []
+    for i, p in enumerate(args.model_config.output_channels):
+        infer_path.append(os.path.join(save_path, 'Task{}'.format(i)))
+        if not os.path.exists(infer_path[i]):
+            os.makedirs(infer_path[i])
+
+    #################################################
+    with torch.no_grad():
+        validate(args, val_dataset, val_loader, model, 0, infer_path)
+
+    if args.create_video:
+        create_video(args, infer_path=infer_path)
+
+
+def validate(args, val_dataset, val_loader, model, epoch, infer_path):
+    data_time = xnn.utils.AverageMeter()
+    avg_metric = [xnn.utils.AverageMeter(print_avg=(not task_metric[0].info()['is_avg'])) for task_metric in args.metric_modules]
+
+    # switch to evaluate mode
+    model.eval()
+    metric_name = "Metric"
+    end_time = time.time()
+    writer_idx = 0
+    last_update_iter = -1
+    metric_ctx = [None] * len(args.metric_modules)
+
+    confusion_matrix = []
+    for n_cls in args.model_config.output_channels:
+        confusion_matrix.append(np.zeros((n_cls, n_cls+1)))
+    metric_txt = []
+    ard_err = None
+    for iter, (input_list, target_list, input_path, target_path) in enumerate(val_loader):
+        file_name =  input_path[-1][0]
+        print("started inference of file_name:", file_name)
+        data_time.update(time.time() - end_time)
+        if args.gpu_mode:
+            input_list = [img.cuda() for img in input_list]
+        outputs = model(input_list)
+        if args.output_size is not None and target_list:
+           target_sizes = [tgt.shape for tgt in target_list]
+           outputs = upsample_tensors(outputs, target_sizes, args.upsample_mode)
+        elif args.output_size is not None and not target_list:
+           target_sizes = [args.output_size for _ in range(len(outputs))]
+           outputs = upsample_tensors(outputs, target_sizes, args.upsample_mode)
+        outputs = [out.cpu() for out in outputs]
+
+        for task_index in range(len(outputs)):
+            output = outputs[task_index]
+            gt_target = target_list[task_index] if target_list else None
+            if args.visualize_gt and target_list:
+                if args.model_config.output_type[task_index] is 'depth':
+                    output = gt_target
+                else:
+                    output = gt_target.to(dtype=torch.int8)
+                
+            if args.remove_ignore_lbls_in_pred and not (args.model_config.output_type[task_index] is 'depth') and target_list :
+                output[gt_target == 255] = args.palette[task_index-1].shape[0]-1
+            for index in range(output.shape[0]):
+                if args.frame_IOU:
+                    confusion_matrix[task_index] = np.zeros((args.model_config.output_channels[task_index], args.model_config.output_channels[task_index] + 1))
+                prediction = np.array(output[index])
+                if output.shape[1]>1:
+                    prediction = np.argmax(prediction, axis=0)
+                #
+                prediction = np.squeeze(prediction)
+
+                if target_list:
+                    label = np.squeeze(np.array(target_list[task_index][index]))
+                    if not args.model_config.output_type[task_index] is 'depth':
+                        confusion_matrix[task_index] = eval_output(args, prediction, label, confusion_matrix[task_index], args.model_config.output_channels[task_index])
+                        accuracy, mean_iou, iou, f1_score= compute_accuracy(args, confusion_matrix[task_index], args.model_config.output_channels[task_index])
+                        temp_txt = []
+                        temp_txt.append(input_path[-1][index])
+                        temp_txt.extend(iou)
+                        metric_txt.append(temp_txt)
+                        print('{}/{} Inferred Frame {} mean_iou={},'.format((args.batch_size*iter+index+1), len(val_dataset), input_path[-1][index], mean_iou))
+                        if index == output.shape[0]-1:
+                            print('Task={},\npixel_accuracy={},\nmean_iou={},\niou={},\nf1_score = {}'.format(task_index, accuracy, mean_iou, iou, f1_score))
+                            sys.stdout.flush()
+                    elif args.model_config.output_type[task_index] is 'depth':
+                        valid = (label != 0)
+                        gt = torch.tensor(label[valid]).float()
+                        inference = torch.tensor(prediction[valid]).float()
+                        if len(gt) > 2:
+                            if ard_err is None:
+                                ard_err = [absreldiff_rng3to80(inference, gt).mean()]
+                            else:
+                                ard_err.append(absreldiff_rng3to80(inference, gt).mean())
+                        elif len(gt) < 2:
+                            if ard_err is None:
+                                ard_err = [0.0]
+                            else:
+                                ard_err.append(0.0)
+
+                        print('{}/{} ARD: {}'.format((args.batch_size * iter + index), len(val_dataset),torch.tensor(ard_err).mean()))
+
+                seq = input_path[-1][index].split('/')[-4]
+                base_file = os.path.basename(input_path[-1][index])
+
+                if args.label_infer:
+                    output_image = prediction
+                    output_name = os.path.join(infer_path[task_index], seq + '_' + input_path[-1][index].split('/')[-3] + '_' + base_file)
+                    cv2.imwrite(output_name, output_image)
+                    print('{}/{}'.format((args.batch_size*iter+index), len(val_dataset)))
+
+                if hasattr(args, 'interest_pt') and args.interest_pt[task_index]:
+                    print('{}/{}'.format((args.batch_size * iter + index), len(val_dataset)))
+                    output_name = os.path.join(infer_path[task_index], seq + '_' + input_path[-1][index].split('/')[-3] + '_' + base_file)
+                    output_name_short = os.path.join(infer_path[task_index], os.path.basename(input_path[-1][index]))
+                    wrapper_write_desc(args=args, task_index=task_index, outputs=outputs, index=index, output_name=output_name, output_name_short=output_name_short)
+                    
+                if args.model_config.output_type[task_index] is 'depth':
+                    output_name = os.path.join(infer_path[task_index], seq + '_' + input_path[-1][index].split('/')[-3] + '_' + base_file)
+                    viz_depth(prediction = prediction, args=args, output_name = output_name, input_name=input_path[-1][task_index])
+                    print('{}/{}'.format((args.batch_size * iter + index), len(val_dataset)))
+
+                if args.blend[task_index]:
+                    prediction_size = (prediction.shape[0], prediction.shape[1], 3)
+                    output_image = args.palette[task_index-1][prediction.ravel()].reshape(prediction_size)
+                    input_bgr = cv2.imread(input_path[-1][index]) #Read the actual RGB image
+                    input_bgr = cv2.resize(input_bgr, dsize=(prediction.shape[1],prediction.shape[0]))
+                    output_image = xnn.utils.chroma_blend(input_bgr, output_image)
+                    output_name = os.path.join(infer_path[task_index], input_path[-1][index].split('/')[-4] + '_' + input_path[-1][index].split('/')[-3] + '_' +os.path.basename(input_path[-1][index]))
+                    cv2.imwrite(output_name, output_image)
+                    print('{}/{}'.format((args.batch_size*iter+index), len(val_dataset)))
+                #
+
+                if args.car_mask:   # generating car_mask (required for localization)
+                    car_mask = np.logical_or(prediction == 13, prediction == 14, prediction == 16, prediction == 17)
+                    prediction[car_mask] = 255
+                    prediction[np.invert(car_mask)] = 0
+                    output_image = prediction
+                    output_name = os.path.join(infer_path[task_index], os.path.basename(input_path[-1][index]))
+                    cv2.imwrite(output_name, output_image)
+    np.savetxt('metric.txt', metric_txt, fmt='%s')
+
+
+
+
+###############################################################
+def get_save_path(args, phase=None):
+    date = args.date if args.date else datetime.datetime.now().strftime("%Y-%m-%d-%H-%M-%S")
+    save_path = os.path.join('./data/checkpoints', args.dataset_name, date + '_' + args.dataset_name + '_' + args.model_name)
+    save_path += '_resize{}x{}'.format(args.img_resize[1], args.img_resize[0])
+    if args.rand_crop:
+        save_path += '_crop{}x{}'.format(args.rand_crop[1], args.rand_crop[0])
+    #
+    phase = phase if (phase is not None) else args.phase
+    save_path = os.path.join(save_path, phase)
+    return save_path
+
+
+def get_model_orig(model):
+    is_parallel_model = isinstance(model, (torch.nn.DataParallel, torch.nn.parallel.DistributedDataParallel))
+    model_orig = (model.module if is_parallel_model else model)
+    model_orig = (model_orig.module if isinstance(model_orig, (xnn.quantize.QuantBaseModule)) else model_orig)
+    return model_orig
+
+
+def create_rand_inputs(args, is_cuda):
+    dummy_input = []
+    for i_ch in args.model_config.input_channels:
+        x = torch.rand((1, i_ch, args.img_resize[0], args.img_resize[1]))
+        x = x.cuda() if is_cuda else x
+        dummy_input.append(x)
+    #
+    return dummy_input
+
+
+# FIX_ME:SN move to utils
+def store_desc(args=[], output_name=[], write_dense=False, desc_tensor=[], prediction=[],
+               scale_to_write_kp_loc_to_orig_res=[1.0, 1.0],
+               learn_scaled_values=True):
+    sys.path.insert(0, './scripts/')
+    import write_desc as write_desc
+
+    if args.write_desc_type != 'NONE':
+        txt_file_name = output_name.replace(".png", ".txt")
+        if write_dense:
+            # write desc
+            desc_tensor = desc_tensor.astype(np.int16)
+            print("writing dense desc(64 ch) op: {} : {} : {} : {}".format(desc_tensor.shape, desc_tensor.dtype,
+                                                                           desc_tensor.min(), desc_tensor.max()))
+            desc_tensor_name = output_name.replace(".png", "_desc.npy")
+            np.save(desc_tensor_name, desc_tensor)
+
+            # utils_hist.comp_hist_tensor3d(x=desc_tensor, name='desc_64ch', en=True, dir='desc_64ch', log=True, ch_dim=0)
+
+            # write score channel
+            prediction = prediction.astype(np.int16)
+
+            print("writing dense score ch op: {} : {} : {} : {}".format(prediction.shape, prediction.dtype,
+                                                                        prediction.min(),
+                                                                        prediction.max()))
+            score_tensor_name = output_name.replace(".png", "_score.npy")
+            np.save(score_tensor_name, prediction)
+
+            # utils_hist.hist_tensor2D(x_ch = prediction, dir='score', name='score', en=True, log=True)
+        else:
+            prediction[prediction < 0.0] = 0.0
+
+            if learn_scaled_values:
+                img_interest_pt_cur = prediction.astype(np.uint16)
+                score_th = 127
+            else:
+                img_interest_pt_cur = prediction
+                score_th = 0.001
+
+            # at boundary pixles score/desc will be wrong so do not write pred qty near the borader.
+            guard_band = 32 if args.write_desc_type == 'PRED' else 0
+
+            write_desc.write_score_desc_as_text(desc_tensor_cur=desc_tensor, img_interest_pt_cur=img_interest_pt_cur,
+                                                txt_file_name=txt_file_name, score_th=score_th,
+                                                skip_fac_for_reading_desc=1, en_nms=args.en_nms,
+                                                scale_to_write_kp_loc_to_orig_res=scale_to_write_kp_loc_to_orig_res,
+                                                recursive_nms=True, learn_scaled_values=learn_scaled_values,
+                                                guard_band=guard_band, true_nms=True, f2r=args.do_pred_cordi_f2r)
+
+
+      #utils_hist.hist_tensor2D(x_ch = prediction, dir='score', name='score', en=True, log=True)
+    else:  
+      prediction[prediction < 0.0] = 0.0
+      
+      if learn_scaled_values:
+        img_interest_pt_cur = prediction.astype(np.uint16)
+        score_th = 127
+      else:  
+        img_interest_pt_cur = prediction
+        score_th = 0.001
+
+      # at boundary pixles score/desc will be wrong so do not write pred qty near the borader.
+      guard_band = 32 if args.write_desc_type == 'PRED' else 0
+
+      write_desc.write_score_desc_as_text(desc_tensor_cur = desc_tensor, img_interest_pt_cur = img_interest_pt_cur,
+        txt_file_name = txt_file_name, score_th = score_th, skip_fac_for_reading_desc = 1, en_nms=args.en_nms,
+        scale_to_write_kp_loc_to_orig_res = scale_to_write_kp_loc_to_orig_res,
+        recursive_nms=True, learn_scaled_values=learn_scaled_values, guard_band = guard_band, true_nms=True, f2r=args.do_pred_cordi_f2r)
+
+def viz_depth(prediction = [], args=[], output_name=[], input_name=[]):
+    max_value_depth = args.max_depth
+    output_image = torch.tensor(prediction)
+    if args.viz_depth_color_type == 'rainbow':
+        not_valid_indices = output_image == 0
+        output_image = 255*xnn.utils.tensor2array(output_image, max_value=max_value_depth, colormap='rainbow')[:,:,::-1]
+        output_image[not_valid_indices] = 0
+    elif args.viz_depth_color_type == 'rainbow_blend':
+        print(max_value_depth)
+        #scale_mul = 1 if args.visualize_gt else 255
+        print(output_image.min())
+        print(output_image.max())
+        not_valid_indices = output_image == 0
+        output_image = 255*xnn.utils.tensor2array(output_image, max_value=max_value_depth, colormap='rainbow')[:,:,::-1]
+        print(output_image.max())
+        #output_image[label == 1] = 0
+        input_bgr = cv2.imread(input_name)  # Read the actual RGB image
+        input_bgr = cv2.resize(input_bgr, dsize=(prediction.shape[1], prediction.shape[0]))
+        if args.sky_dir:
+            label_file = os.path.join(args.sky_dir, seq, seq + '_image_00_' + base_file)
+            label = cv2.imread(label_file)
+            label = cv2.resize(label, dsize=(prediction.shape[1], prediction.shape[0]),
+                                interpolation=cv2.INTER_NEAREST)
+            output_image[label == 1] = 0
+        output_image[not_valid_indices] = 0
+        output_image = xnn.utils.chroma_blend(input_bgr, output_image)  # chroma_blend(input_bgr, output_image)
+
+    elif args.viz_depth_color_type == 'bone':
+        output_image = 255 * xnn.utils.tensor2array(output_image, max_value=max_value_depth, colormap='bone')
+    elif args.viz_depth_color_type == 'raw_depth':
+        output_image = np.array(output_image)
+        output_image[output_image > max_value_depth] = max_value_depth
+        output_image[output_image < 0] = 0
+        scale = 2.0**16 - 1.0 #255
+        output_image = (output_image / max_value_depth) * scale
+        output_image = output_image.astype(np.uint16)
+        # output_image[(label[:,:,0]==1)|(label[:,:,0]==4)]=255
+    elif args.viz_depth_color_type == 'plasma':
+        plt.imsave(output_name, output_image, cmap='plasma', vmin=0, vmax=max_value_depth)
+    elif args.viz_depth_color_type == 'log_greys':        
+        plt.imsave(output_name, np.log10(output_image), cmap='Greys', vmin=0, vmax=np.log10(max_value_depth))
+        #plt.imsave(output_name, output_image, cmap='Greys', vmin=0, vmax=max_value_depth)
+    else:
+        print("undefined color type for visualization")
+        exit(0)
+
+    if args.viz_depth_color_type != 'plasma':
+        # plasma type will be handled by imsave
+        cv2.imwrite(output_name, output_image)
+
+
+def wrapper_write_desc(args=[], task_index=0, outputs=[], index=0, output_name=[], output_name_short=[]):
+    if args.write_desc_type == 'GT':
+        # write GT desc
+        tensor_to_write = target_list[task_index]
+    elif args.write_desc_type == 'PRED':
+        # write predicted desc
+        tensor_to_write = outputs[task_index]
+
+    interest_pt_score = np.array(tensor_to_write[index, 0, ...])
+
+    if args.make_score_zero_mean:
+        # visulization code assumes range [0,255]. Add 128 to make range the same in case of zero mean too.
+        interest_pt_score += 128.0
+
+    if args.write_desc_type == 'NONE':
+        # scale + clip score between 0-255 and convert score_array to image
+        # scale_range = 127.0/0.005
+        # scale_range = 255.0/np.max(interest_pt_score)
+        scale_range = 1.0
+        interest_pt_score = np.clip(interest_pt_score * scale_range, 0.0, 255.0)
+        interest_pt_score = np.asarray(interest_pt_score, 'uint8')
+
+    interest_pt_descriptor = np.array(tensor_to_write[index, 1:, ...])
+
+    # output_name = os.path.join(infer_path[task_index], seq + '_' + input_path[-1][index].split('/')[-3] + '_' + base_file)
+    cv2.imwrite(output_name, interest_pt_score)
+
+    # output_name_short = os.path.join(infer_path[task_index], os.path.basename(input_path[-1][index]))
+
+    scale_to_write_kp_loc_to_orig_res = args.scale_to_write_kp_loc_to_orig_res
+    if args.scale_to_write_kp_loc_to_orig_res[0] == -1:
+        scale_to_write_kp_loc_to_orig_res[0] = input_list[task_index].shape[2] / target_list[task_index].shape[2]
+        scale_to_write_kp_loc_to_orig_res[1] = scale_to_write_kp_loc_to_orig_res[0]
+
+    print("scale_to_write_kp_loc_to_orig_res: ", scale_to_write_kp_loc_to_orig_res)
+    store_desc(args=args, output_name=output_name_short, desc_tensor=interest_pt_descriptor,
+               prediction=interest_pt_score,
+               scale_to_write_kp_loc_to_orig_res=scale_to_write_kp_loc_to_orig_res,
+               learn_scaled_values=args.learn_scaled_values_interest_pt,
+               write_dense=False)
+
+
+def get_transforms(args):
+    # image normalization can be at the beginning of transforms or at the end
+    args.image_mean = np.array(args.image_mean, dtype=np.float32)
+    args.image_scale = np.array(args.image_scale, dtype=np.float32)
+    image_prenorm = vision.transforms.image_transforms.NormalizeMeanScale(mean=args.image_mean, scale=args.image_scale) if args.image_prenorm else None
+    image_postnorm = vision.transforms.image_transforms.NormalizeMeanScale(mean=args.image_mean, scale=args.image_scale) if (not image_prenorm) else None
+
+    #target size must be according to output_size. prediction will be resized to output_size before evaluation.
+    test_transform = vision.transforms.image_transforms.Compose([
+        image_prenorm,
+        vision.transforms.image_transforms.AlignImages(),
+        vision.transforms.image_transforms.MaskTarget(args.target_mask, 0),
+        vision.transforms.image_transforms.CropRect(args.img_border_crop),
+        vision.transforms.image_transforms.Scale(args.img_resize, target_size=args.output_size, is_flow=args.is_flow),
+        image_postnorm,
+        vision.transforms.image_transforms.ConvertToTensor()
+        ])
+
+    return test_transform
+
+
+def _upsample_impl(tensor, output_size, upsample_mode):
+    # upsample of long tensor is not supported currently. covert to float, just to avoid error.
+    # we can do thsi only in the case of nearest mode, otherwise output will have invalid values.
+    convert_to_float = False
+    if isinstance(tensor, (torch.LongTensor,torch.cuda.LongTensor)):
+        convert_to_float = True
+        tensor = tensor.float()
+        upsample_mode = 'nearest'
+    #
+
+    dim_added = False
+    if len(tensor.shape) < 4:
+        tensor = tensor[np.newaxis,...]
+        dim_added = True
+    #
+    if (tensor.size()[-2:] != output_size):
+        tensor = torch.nn.functional.interpolate(tensor, output_size, mode=upsample_mode)
+    # --
+    if dim_added:
+        tensor = tensor[0,...]
+    #
+
+    if convert_to_float:
+        tensor = tensor.long()
+    #
+    return tensor
+
+def upsample_tensors(tensors, output_sizes, upsample_mode):
+    if not output_sizes:
+        return tensors
+    #
+    if isinstance(tensors, (list,tuple)):
+        for tidx, tensor in enumerate(tensors):
+            tensors[tidx] = _upsample_impl(tensor, output_sizes[tidx][-2:], upsample_mode)
+        #
+    else:
+        tensors = _upsample_impl(tensors, output_sizes[0][-2:], upsample_mode)
+    return tensors
+
+
+
+
+def eval_output(args, output, label, confusion_matrix, n_classes):
+    if len(label.shape)>2:
+        label = label[:,:,0]
+    gt_labels = label.ravel()
+    det_labels = output.ravel().clip(0,n_classes)
+    gt_labels_valid_ind = np.where(gt_labels != 255)
+    gt_labels_valid = gt_labels[gt_labels_valid_ind]
+    det_labels_valid = det_labels[gt_labels_valid_ind]
+    for r in range(confusion_matrix.shape[0]):
+        for c in range(confusion_matrix.shape[1]):
+            confusion_matrix[r,c] += np.sum((gt_labels_valid==r) & (det_labels_valid==c))
+
+    return confusion_matrix
+    
+def compute_accuracy(args, confusion_matrix, n_classes):
+    num_selected_classes = n_classes
+    tp = np.zeros(n_classes)
+    population = np.zeros(n_classes)
+    det = np.zeros(n_classes)
+    iou = np.zeros(n_classes)
+    
+    for r in range(n_classes):
+      for c in range(n_classes):
+        population[r] += confusion_matrix[r][c]
+        det[c] += confusion_matrix[r][c]   
+        if r == c:
+          tp[r] += confusion_matrix[r][c]
+
+    for cls in range(num_selected_classes):
+      intersection = tp[cls]
+      union = population[cls] + det[cls] - tp[cls]
+      iou[cls] = (intersection / union) if union else 0     # For caffe jacinto script
+      #iou[cls] = (intersection / (union + np.finfo(np.float32).eps))  # For pytorch-jacinto script
+
+    num_nonempty_classes = 0
+    for pop in population:
+      if pop>0:
+        num_nonempty_classes += 1
+          
+    mean_iou = np.sum(iou) / num_nonempty_classes if num_nonempty_classes else 0
+    accuracy = np.sum(tp) / np.sum(population) if np.sum(population) else 0
+    
+    #F1 score calculation
+    fp = np.zeros(n_classes)
+    fn = np.zeros(n_classes)
+    precision = np.zeros(n_classes)
+    recall = np.zeros(n_classes)
+    f1_score = np.zeros(n_classes)
+
+    for cls in range(num_selected_classes):
+        fp[cls] = det[cls] - tp[cls]
+        fn[cls] = population[cls] - tp[cls]
+        precision[cls] = tp[cls] / (det[cls] + 1e-10)
+        recall[cls] = tp[cls] / (tp[cls] + fn[cls] + 1e-10)        
+        f1_score[cls] = 2 * precision[cls]*recall[cls] / (precision[cls] + recall[cls] + 1e-10)
+
+    return accuracy, mean_iou, iou, f1_score
+    
+        
+def infer_video(args, net):
+    videoIpHandle = imageio.get_reader(args.input, 'ffmpeg')
+    fps = math.ceil(videoIpHandle.get_meta_data()['fps'])
+    print(videoIpHandle.get_meta_data())
+    numFrames = min(len(videoIpHandle), args.num_images)
+    videoOpHandle = imageio.get_writer(args.output,'ffmpeg', fps=fps)
+    for num in range(numFrames):
+        print(num, end=' ')
+        sys.stdout.flush()
+        input_blob = videoIpHandle.get_data(num)
+        input_blob = input_blob[...,::-1]    #RGB->BGR
+        output_blob = infer_blob(args, net, input_blob)     
+        output_blob = output_blob[...,::-1]  #BGR->RGB            
+        videoOpHandle.append_data(output_blob)
+    videoOpHandle.close()
+    return
+
+
+def absreldiff(x, y, eps = 0.0, max_val=None):
+    assert x.size() == y.size(), 'tensor dimension mismatch'
+    if max_val is not None:
+        x = torch.clamp(x, -max_val, max_val)
+        y = torch.clamp(y, -max_val, max_val)
+    #
+
+    diff = torch.abs(x - y)
+    y = torch.abs(y)
+
+    den_valid = (y == 0).float()
+    eps_arr = (den_valid * (1e-6))   # Just to avoid divide by zero
+
+    large_arr = (y > eps).float()    # ARD is not a good measure for small ref values. Avoid them.
+    out = (diff / (y + eps_arr)) * large_arr
+    return out
+
+
+def absreldiff_rng3to80(x, y):
+    return absreldiff(x, y, eps = 3.0, max_val=80.0)
+
+
+
+def create_video(args, infer_path):
+    op_file_name = args.data_path.split('/')[-1]
+    os.system(' ffmpeg -framerate 30 -pattern_type glob -i "{}/*.png" -c:v libx264 -vb 50000k -qmax 20 -r 10 -vf scale=1024:512  -pix_fmt yuv420p {}.MP4'.format(infer_path, op_file_name))
+
+if __name__ == '__main__':
+    train_args = get_config()
+    main(train_args)
diff --git a/modules/pytorch_jacinto_ai/engine/test_classification.py b/modules/pytorch_jacinto_ai/engine/test_classification.py
new file mode 100644 (file)
index 0000000..554cc5e
--- /dev/null
@@ -0,0 +1,529 @@
+import os
+import sys
+import shutil
+import time
+import datetime
+
+import random
+import numpy as np
+from colorama import Fore
+import random
+import progiter
+import warnings
+
+import torch
+import torch.nn.parallel
+import torch.backends.cudnn as cudnn
+import torch.optim
+import torch.utils.data
+import torch.utils.data.distributed
+
+from .. import xnn
+from .. import vision
+
+
+# ################################################
+def get_config():
+    args = xnn.utils.ConfigNode()
+    args.model_config = xnn.utils.ConfigNode()
+    args.dataset_config = xnn.utils.ConfigNode()
+
+    args.model_name = 'mobilenet_v2_classification'     # model architecture'
+    args.dataset_name = 'imagenet_classification'       # image folder classification
+
+    args.data_path = './data/datasets/ilsvrc'           # path to dataset
+    args.save_path = None                               # checkpoints save path
+    args.pretrained = './data/modelzoo/pretrained/pytorch/imagenet_classification/ericsun99/MobileNet-V2-Pytorch/mobilenetv2_Top1_71.806_Top2_90.410.pth.tar' # path to pre_trained model
+
+    args.workers = 8                                    # number of data loading workers (default: 4)
+    args.batch_size = 256                               # mini_batch size (default: 256)
+    args.print_freq = 100                               # print frequency (default: 100)
+
+    args.img_resize = 256                               # image resize
+    args.img_crop = 224                                 # image crop
+
+    args.image_mean = (123.675, 116.28, 103.53)         # image mean for input image normalization')
+    args.image_scale = (0.017125, 0.017507, 0.017429)   # image scaling/mult for input iamge normalization')
+
+    args.logger = None                                  # logger stream to output into
+
+    args.data_augument = 'inception'                    # data augumentation method, choices=['inception','resize','adaptive_resize']
+    args.dataset_format = 'folder'                      # dataset format, choices=['folder','lmdb']
+    args.count_flops = True                             # count flops and report
+
+    args.lr_calib = 0.1                                 # lr for bias calibration
+
+    args.rand_seed = 1                                  # random seed
+    args.generate_onnx = False                          # apply quantized inference or not
+    args.print_model = False                            # print the model to text
+    args.run_soon = True                                # Set to false if only cfs files/onnx  modelsneeded but no training
+    args.parallel_model = True                          # parallel or not
+    args.shuffle = True                                 # shuffle or not
+    args.epoch_size = 0                                 # epoch size
+    args.rand_seed = 1                                  # random seed
+    args.date = None                                    # date to add to save path. if this is None, current date will be added.
+    args.write_layer_ip_op = False
+
+    args.quantize = False                               # apply quantized inference or not
+    #args.model_surgery = None                           # replace activations with PAct2 activation module. Helpful in quantized training.
+    args.bitwidth_weights = 8                           # bitwidth for weights
+    args.bitwidth_activations = 8                       # bitwidth for activations
+    args.histogram_range = True                         # histogram range for calibration
+    args.per_channel_q = False                          # apply separate quantizion factor for each channel in depthwise or not
+    args.bias_calibration = False                        # apply bias correction during quantized inference calibration
+    return args
+
+
+def main(args):
+    assert not hasattr(args, 'model_surgery'), 'the argument model_surgery is deprecated, it is not needed now - remove it'
+
+    if (args.phase == 'validation' and args.bias_calibration):
+        args.bias_calibration = False
+        warnings.warn('switching off bias calibration in validation')
+    #
+
+    #################################################
+    # onnx generation is filing for post quantized module
+    args.generate_onnx = False if (args.quantize) else args.generate_onnx
+
+    if args.save_path is None:
+        save_path = get_save_path(args)
+    else:
+        save_path = args.save_path
+    #
+
+    if not os.path.exists(save_path):
+        os.makedirs(save_path)
+    #
+
+    #################################################
+    if args.logger is None:
+        log_file = os.path.splitext(os.path.basename(__file__))[0] + '.log'
+        args.logger = xnn.utils.TeeLogger(filename=os.path.join(save_path,log_file))
+
+    #################################################
+    # global settings. rand seeds for repeatability
+    random.seed(args.rand_seed)
+    np.random.seed(args.rand_seed)
+    torch.manual_seed(args.rand_seed)
+    torch.backends.cudnn.deterministic = True
+    torch.backends.cudnn.benchmark = True
+
+    ################################
+    # print everything for log
+    # reset character color, in case it is different
+    print('{}'.format(Fore.RESET))
+    print("=> args: ", args)
+    print("=> resize resolution: {}".format(args.img_resize))
+    print("=> crop resolution  : {}".format(args.img_crop))
+    sys.stdout.flush()
+
+    #################################################
+    pretrained_data = None
+    model_surgery_quantize = False
+    if args.pretrained and args.pretrained != "None":
+        if args.pretrained.startswith('http://') or args.pretrained.startswith('https://'):
+            pretrained_file = vision.datasets.utils.download_url(args.pretrained, './data/downloads')
+        else:
+            pretrained_file = args.pretrained
+        #
+        print(f'=> using pre-trained weights from: {args.pretrained}')
+        pretrained_data = torch.load(pretrained_file)
+        model_surgery_quantize = pretrained_data['quantize'] if 'quantize' in pretrained_data else False
+    #
+
+    ################################
+    # create model
+    print("=> creating model '{}'".format(args.model_name))
+    model = vision.models.classification.__dict__[args.model_name](args.model_config)
+
+    # check if we got the model as well as parameters to change the names in pretrained
+    model, change_names_dict = model if isinstance(model, (list,tuple)) else (model,None)
+
+    #################################################
+    if args.quantize:
+        # dummy input is used by quantized models to analyze graph
+        is_cuda = next(model.parameters()).is_cuda
+        dummy_input = create_rand_inputs(args, is_cuda=is_cuda)
+        #
+        if args.phase == 'training':
+            model = xnn.quantize.QuantTrainModule(model, per_channel_q=args.per_channel_q,
+                        bitwidth_weights=args.bitwidth_weights, bitwidth_activations=args.bitwidth_activations,
+                        dummy_input=dummy_input)
+        elif args.phase == 'calibration':
+            model = xnn.quantize.QuantCalibrateModule(model, per_channel_q=args.per_channel_q,
+                        bitwidth_weights=args.bitwidth_weights, bitwidth_activations=args.bitwidth_activations,
+                        bias_calibration=args.bias_calibration, lr_calib=args.lr_calib,
+                        dummy_input=dummy_input)
+        elif args.phase == 'validation':
+            # Note: bias_calibration is not enabled in test
+            model = xnn.quantize.QuantTestModule(model, per_channel_q=args.per_channel_q,
+                        bitwidth_weights=args.bitwidth_weights, bitwidth_activations=args.bitwidth_activations,
+                        histogram_range=args.histogram_range, model_surgery_quantize=model_surgery_quantize,
+                        dummy_input=dummy_input)
+        else:
+            assert False, f'invalid phase {args.phase}'
+    #
+
+
+    # load pretrained
+    xnn.utils.load_weights_check(get_model_orig(model), pretrained=pretrained_data, change_names_dict=change_names_dict)
+
+    #################################################
+    if args.count_flops:
+        count_flops(args, model)
+
+    #################################################
+    if args.generate_onnx:
+        write_onnx_model(args, get_model_orig(model), save_path)
+    #
+
+    #################################################
+    if args.print_model:
+        print(model)
+    else:
+        args.logger.debug(str(model))
+
+    #################################################
+    if (not args.run_soon):
+        print("Training not needed for now")
+        exit()
+
+    #################################################
+    # multi gpu mode is not yet supported with quantization in evaluate
+    if args.parallel_model and (args.phase=='training'):
+        model = torch.nn.DataParallel(model)
+
+    #################################################
+    model = model.cuda()
+
+    #################################################
+    if args.write_layer_ip_op:
+        # for dumping module outputs
+        for name, module in model.named_modules():
+            module.name = name
+            print(name)
+            #if 'module.encoder.features.0.' in name:
+            module.register_forward_hook(write_tensor_hook_function)
+        print('{:7} {:33} {:12} {:8} {:6} {:30} : {:17} : {:4} : {:11} : {:7} : {:7}'.format("type",  "name", "layer", "min", "max", "tensor_shape", "dtype", "scale", "dtype", "min", "max"))
+
+
+    #################################################
+    # define loss function (criterion) and optimizer
+    criterion = torch.nn.CrossEntropyLoss().cuda()
+
+    val_loader = get_data_loaders(args)
+    validate(args, val_loader, model, criterion)
+
+
+
+def validate(args, val_loader, model, criterion):
+    # switch to evaluate mode
+    model.eval()
+
+    # change color to green
+    print('{}'.format(Fore.GREEN), end='')
+
+    with torch.no_grad():
+        batch_time = AverageMeter()
+        losses = AverageMeter()
+        top1 = AverageMeter()
+        top5 = AverageMeter()
+        use_progressbar = True
+        epoch_size = get_epoch_size(args, val_loader, args.epoch_size)
+
+        if use_progressbar:
+            progress_bar = progiter.ProgIter(np.arange(epoch_size), chunksize=1)
+            last_update_iter = -1
+
+        end = time.time()
+        for iteration, (input, target) in enumerate(val_loader):
+            target = target.cuda(non_blocking=True)
+            input = torch.cat([j.cuda() for j in input], dim=1) if (type(input) in (list,tuple)) else input.cuda()
+
+            # compute output
+            output = model(input)
+            if type(output) in (list, tuple):
+                output = output[0]
+            #
+
+            loss = criterion(output, target)
+
+            # measure accuracy and record loss
+            prec1, prec5 = accuracy(output, target, topk=(1, 5))
+            losses.update(loss.item(), input.size(0))
+            top1.update(prec1[0], input.size(0))
+            top5.update(prec5[0], input.size(0))
+
+            # measure elapsed time
+            batch_time.update(time.time() - end)
+            end = time.time()
+            final_iter = (iteration >= (epoch_size-1))
+
+            if ((iteration % args.print_freq) == 0) or final_iter:
+                status_str = 'Time {batch_time.val:.2f}({batch_time.avg:.2f}) LR {cur_lr:.4f} ' \
+                             'Loss {loss.val:.2f}({loss.avg:.2f}) Prec@1 {top1.val:.2f}({top1.avg:.2f}) Prec@5 {top5.val:.2f}({top5.avg:.2f})' \
+                             .format(batch_time=batch_time, cur_lr=0.0, loss=losses, top1=top1, top5=top5)
+                #
+                prefix = '**' if final_iter else '=>'
+                if use_progressbar:
+                    progress_bar.set_description('{} validation'.format(prefix))
+                    progress_bar.set_postfix(Epoch='{}'.format(status_str))
+                    progress_bar.update(iteration - last_update_iter)
+                    last_update_iter = iteration
+                else:
+                    iter_str = '{:6}/{:6}    : '.format(iteration+1, len(val_loader))
+                    status_str = prefix + ' ' + iter_str + status_str
+                    if final_iter:
+                        xnn.utils.print_color(status_str, color=Fore.GREEN)
+                    else:
+                        xnn.utils.print_color(status_str)
+
+            if final_iter:
+                break
+
+        if use_progressbar:
+            progress_bar.close()
+
+        # to print a new line - do not provide end=''
+        print('{}'.format(Fore.RESET), end='')
+
+    return top1.avg
+
+
+#######################################################################
+def get_save_path(args, phase=None):
+    date = args.date if args.date else datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
+    save_path_base = os.path.join('./data/checkpoints', args.dataset_name, date + '_' + args.dataset_name + '_' + args.model_name)
+    save_path = save_path_base + '_resize{}_crop{}'.format(args.img_resize, args.img_crop)
+    phase = phase if (phase is not None) else args.phase
+    save_path = os.path.join(save_path, phase)
+    return save_path
+
+
+def get_model_orig(model):
+    is_parallel_model = isinstance(model, (torch.nn.DataParallel, torch.nn.parallel.DistributedDataParallel))
+    model_orig = (model.module if is_parallel_model else model)
+    model_orig = (model_orig.module if isinstance(model_orig, (xnn.quantize.QuantBaseModule)) else model_orig)
+    return model_orig
+
+
+def create_rand_inputs(args, is_cuda=True):
+    x = torch.rand((1, args.model_config.input_channels, args.img_crop, args.img_crop))
+    x = x.cuda() if is_cuda else x
+    return x
+
+
+def count_flops(args, model):
+    is_cuda = next(model.parameters()).is_cuda
+    input_list = create_rand_inputs(args, is_cuda)
+    model.eval()
+    flops = xnn.utils.forward_count_flops(model, input_list)
+    gflops = flops/1e9
+    print('=> Resize = {}, Crop = {}, GFLOPs = {}, GMACs = {}'.format(args.img_resize, args.img_crop, gflops, gflops/2))
+
+
+def write_onnx_model(args, model, save_path, name='checkpoint.onnx'):
+    is_cuda = next(model.parameters()).is_cuda
+    dummy_input = create_rand_inputs(args, is_cuda)
+    #
+    model.eval()
+    torch.onnx.export(model, dummy_input, os.path.join(save_path,name), export_params=True, verbose=False)
+
+def save_checkpoint(state, is_best, filename='checkpoint.pth.tar'):
+    torch.save(state, filename)
+    if is_best:
+        shutil.copyfile(filename, 'model_best.pth.tar')
+
+
+class AverageMeter(object):
+    """Computes and stores the average and current value"""
+    def __init__(self):
+        self.reset()
+
+    def reset(self):
+        self.val = 0
+        self.avg = 0
+        self.sum = 0
+        self.count = 0
+
+    def update(self, val, n=1):
+        self.val = val
+        self.sum += val * n
+        self.count += n
+        self.avg = self.sum / self.count
+
+
+def accuracy(output, target, topk=(1,)):
+    """Computes the precision@k for the specified values of k"""
+    with torch.no_grad():
+        maxk = max(topk)
+        batch_size = target.size(0)
+
+        _, pred = output.topk(maxk, 1, True, True)
+        pred = pred.t()
+        correct = pred.eq(target.view(1, -1).expand_as(pred))
+
+        res = []
+        for k in topk:
+            correct_k = correct[:k].view(-1).float().sum(0, keepdim=True)
+            res.append(correct_k.mul_(100.0 / batch_size))
+        return res
+
+
+def get_epoch_size(args, loader, args_epoch_size):
+    if args_epoch_size == 0:
+        epoch_size = len(loader)
+    elif args_epoch_size < 1:
+        epoch_size = int(len(loader) * args_epoch_size)
+    else:
+        epoch_size = min(len(loader), int(args_epoch_size))
+    return epoch_size
+
+
+def get_data_loaders(args):
+    # Data loading code
+    normalize = vision.transforms.NormalizeMeanScale(mean=args.image_mean, scale=args.image_scale) \
+                        if (args.image_mean is not None and args.image_scale is not None) else None
+
+    # pass tuple to Resize() to resize to exact size without respecting aspect ratio (typical caffe style)
+    val_transform = vision.transforms.Compose([vision.transforms.Resize(size=args.img_resize),
+                                               vision.transforms.CenterCrop(size=args.img_crop),
+                                               vision.transforms.ToFloat(),
+                                               vision.transforms.ToTensor(),
+                                               normalize])
+
+    train_dataset, val_dataset = vision.datasets.classification.__dict__[args.dataset_name](args.dataset_config, args.data_path, transforms=(None,val_transform))
+
+    val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=args.batch_size, shuffle=args.shuffle, num_workers=args.workers,
+                                             pin_memory=True, drop_last=False)
+
+    return val_loader
+
+
+#################################################
+def shape_as_string(shape=[]):
+    shape_str = ''
+    for dim in shape:
+        shape_str += '_' + str(dim)
+    return shape_str
+
+
+def write_tensor_int(m=[], tensor=[], suffix='op', bitwidth=8, power2_scaling=True, file_format='bin',
+                     rnd_type='rnd_sym'):
+    mn = tensor.min()
+    mx = tensor.max()
+
+    print(
+        '{:6}, {:32}, {:10}, {:7.2f}, {:7.2f}'.format(suffix, m.name, m.__class__.__name__, tensor.min(), tensor.max()),
+        end=" ")
+
+    [tensor_scale, clamp_limits] = xnn.quantize.compute_tensor_scale(tensor, mn, mx, bitwidth, power2_scaling)
+    print("{:30} : {:15} : {:8.2f}".format(str(tensor.shape), str(tensor.dtype), tensor_scale), end=" ")
+
+    print_weight_bias = False
+    if rnd_type == 'rnd_sym':
+        # use best rounding for offline quantities
+        if suffix == 'weight' and print_weight_bias:
+            no_idx = 0
+            torch.set_printoptions(precision=32)
+            print("tensor_scale: ", tensor_scale)
+            print(tensor[no_idx])
+        if tensor.dtype != torch.int64:
+            tensor = xnn.quantize.symmetric_round_tensor(tensor * tensor_scale)
+        if suffix == 'weight' and print_weight_bias:
+            print(tensor[no_idx])
+    else:
+        # for activation use HW friendly rounding
+        if tensor.dtype != torch.int64:
+            tensor = xnn.quantize.upward_round_tensor(tensor * tensor_scale)
+    tensor = tensor.clamp(clamp_limits[0], clamp_limits[1]).float()
+
+    if bitwidth == 8:
+        data_type = np.int8
+    elif bitwidth == 16:
+        data_type = np.int16
+    elif bitwidth == 32:
+        data_type = np.int32
+    else:
+        exit("Bit width other 8,16,32 not supported for writing layer level op")
+
+    tensor = tensor.cpu().numpy().astype(data_type)
+
+    print("{:7} : {:7d} : {:7d}".format(str(tensor.dtype), tensor.min(), tensor.max()))
+
+    tensor_dir = './data/checkpoints/debug/test_vecs/' + '{}_{}_{}_scale_{:010.4f}'.format(m.name,
+                                                                                            m.__class__.__name__,
+                                                                                            suffix, tensor_scale)
+
+    if not os.path.exists(tensor_dir):
+        os.makedirs(tensor_dir)
+
+    if file_format == 'bin':
+        tensor_name = tensor_dir + "/{}_shape{}.bin".format(m.name, shape_as_string(shape=tensor.shape))
+        tensor.tofile(tensor_name)
+    elif file_format == 'npy':
+        tensor_name = tensor_dir + "/{}_shape{}.npy".format(m.name, shape_as_string(shape=tensor.shape))
+        np.save(tensor_name, tensor)
+    else:
+        warnings.warn('unknown file_format for write_tensor - no file written')
+    #
+
+    # utils_hist.comp_hist_tensor3d(x=tensor, name=m.name, en=True, dir=m.name, log=True, ch_dim=0)
+
+
+def write_tensor_float(m=[], tensor=[], suffix='op'):
+    mn = tensor.min()
+    mx = tensor.max()
+
+    print(
+        '{:6}, {:32}, {:10}, {:7.2f}, {:7.2f}'.format(suffix, m.name, m.__class__.__name__, tensor.min(), tensor.max()))
+    root = os.getcwd()
+    tensor_dir = root + '/checkpoints/debug/test_vecs/' + '{}_{}_{}'.format(m.name, m.__class__.__name__, suffix)
+
+    if not os.path.exists(tensor_dir):
+        os.makedirs(tensor_dir)
+
+    tensor_name = tensor_dir + "/{}_shape{}.npy".format(m.name, shape_as_string(shape=tensor.shape))
+    np.save(tensor_name, tensor.data)
+
+
+def write_tensor(data_type='int', m=[], tensor=[], suffix='op', bitwidth=8, power2_scaling=True, file_format='bin',
+                 rnd_type='rnd_sym'):
+    if data_type == 'int':
+        write_tensor_int(m=m, tensor=tensor, suffix=suffix, rnd_type=rnd_type, file_format=file_format)
+    elif data_type == 'float':
+        write_tensor_float(m=m, tensor=tensor, suffix=suffix)
+
+
+enable_hook_function = True
+def write_tensor_hook_function(m, inp, out, file_format='bin'):
+    if not enable_hook_function:
+        return
+
+    #Output
+    if isinstance(out, (torch.Tensor)):
+        write_tensor(m=m, tensor=out, suffix='op', rnd_type ='rnd_up', file_format=file_format)
+
+    #Input(s)
+    if type(inp) is tuple:
+        #if there are more than 1 inputs
+        for index, sub_ip in enumerate(inp[0]):
+            if isinstance(sub_ip, (torch.Tensor)):
+                write_tensor(m=m, tensor=sub_ip, suffix='ip_{}'.format(index), rnd_type ='rnd_up', file_format=file_format)
+    elif isinstance(inp, (torch.Tensor)):
+         write_tensor(m=m, tensor=inp, suffix='ip', rnd_type ='rnd_up', file_format=file_format)
+
+    #weights
+    if hasattr(m, 'weight'):
+        if isinstance(m.weight,torch.Tensor):
+            write_tensor(m=m, tensor=m.weight, suffix='weight', rnd_type ='rnd_sym', file_format=file_format)
+
+    #bias
+    if hasattr(m, 'bias'):
+        if m.bias is not None:
+            write_tensor(m=m, tensor=m.bias, suffix='bias', rnd_type ='rnd_sym', file_format=file_format)
+
+
+if __name__ == '__main__':
+    main()
diff --git a/modules/pytorch_jacinto_ai/engine/test_pixel2pixel_onnx.py b/modules/pytorch_jacinto_ai/engine/test_pixel2pixel_onnx.py
new file mode 100644 (file)
index 0000000..458d9ab
--- /dev/null
@@ -0,0 +1,476 @@
+import os
+import time
+import sys
+
+import torch
+import torch.nn.parallel
+import torch.optim
+import torch.utils.data
+import datetime
+import numpy as np
+import random
+import cv2
+import PIL
+import PIL.Image
+
+import onnx
+import caffe2
+import caffe2.python.onnx.backend
+
+from .. import xnn
+from .. import vision
+
+
+
+# ################################################
+def get_config():
+    args = xnn.utils.ConfigNode()
+    args.model_config = xnn.utils.ConfigNode()
+    args.dataset_config = xnn.utils.ConfigNode()
+
+    args.dataset_name = 'flying_chairs'              # dataset type
+    args.model_name = 'flownets'                # model architecture, overwritten if pretrained is specified: '
+
+    args.data_path = './data/datasets'                       # path to dataset
+    args.save_path = None            # checkpoints save path
+    args.pretrained = None
+
+    args.model_config.output_type = ['flow']                # the network is used to predict flow or depth or sceneflow')
+    args.model_config.output_channels = None                 # number of output channels
+    args.model_config.input_channels = None                  # number of input channels
+    args.n_classes = None                       # number of classes (for segmentation)
+
+    args.logger = None                          # logger stream to output into
+
+    args.prediction_type = 'flow'               # the network is used to predict flow or depth or sceneflow
+    args.split_file = None                      # train_val split file
+    args.split_files = None                     # split list files. eg: train.txt val.txt
+    args.split_value = 0.8                      # test_val split proportion (between 0 (only test) and 1 (only train))
+
+    args.workers = 8                            # number of data loading workers
+
+    args.epoch_size = 0                         # manual epoch size (will match dataset size if not specified)
+    args.epoch_size_val = 0                     # manual epoch size (will match dataset size if not specified)
+    args.batch_size = 8                         # mini_batch_size
+    args.total_batch_size = None                # accumulated batch size. total_batch_size = batch_size*iter_size
+    args.iter_size = 1                          # iteration size. total_batch_size = batch_size*iter_size
+
+    args.tensorboard_num_imgs = 5               # number of imgs to display in tensorboard
+    args.phase = 'validation'                   # evaluate model on validation set
+    args.pretrained = None                      # path to pre_trained model
+    args.date = None                            # don\'t append date timestamp to folder
+    args.print_freq = 10                        # print frequency (default: 100)
+
+    args.div_flow = 1.0                         # value by which flow will be divided. Original value is 20 but 1 with batchNorm gives good results
+    args.milestones = [100,150,200]             # epochs at which learning rate is divided by 2
+    args.losses = ['supervised_loss']           # loss functions to minimize
+    args.metrics = ['supervised_error']         # metric/measurement/error functions for train/validation
+    args.class_weights = None                   # class weights
+
+    args.multistep_gamma = 0.5                  # steps for step scheduler
+    args.polystep_power = 1.0                   # power for polynomial scheduler
+    args.train_fwbw = False                     # do forward backward step while training
+
+    args.rand_seed = 1                          # random seed
+    args.img_border_crop = None                 # image border crop rectangle. can be relative or absolute
+    args.target_mask = None                      # mask rectangle. can be relative or absolute. last value is the mask value
+    args.img_resize = None                      # image size to be resized to
+    args.rand_scale = (1,1.25)                  # random scale range for training
+    args.rand_crop = None                       # image size to be cropped to')
+    args.output_size = None                     # target output size to be resized to')
+
+    args.count_flops = True                     # count flops and report
+
+    args.shuffle = True                         # shuffle or not
+    args.is_flow = None                         # whether entries in images and targets lists are optical flow or not
+
+    args.multi_decoder = True                   # whether to use multiple decoders or unified decoder
+
+    args.create_video = False                   # whether to create video out of the inferred images
+
+    args.input_tensor_name = ['0']              # list of input tensore names
+
+    args.upsample_mode = 'nearest'              # upsample mode to use., choices=['nearest','bilinear']
+
+    args.image_prenorm = True                   # whether normalization is done before all other the transforms
+    args.image_mean = [128.0]                   # image mean for input image normalization
+    args.image_scale = [1.0/(0.25*256)]         # image scaling/mult for input iamge normalization
+    return args
+
+
+# ################################################
+# to avoid hangs in data loader with multi threads
+# this was observed after using cv2 image processing functions
+# https://github.com/pytorch/pytorch/issues/1355
+cv2.setNumThreads(0)
+
+
+# ################################################
+def main(args):
+
+    assert not hasattr(args, 'model_surgery'), 'the argument model_surgery is deprecated, it is not needed now - remove it'
+
+    #################################################
+    # global settings. rand seeds for repeatability
+    random.seed(args.rand_seed)
+    np.random.seed(args.rand_seed)
+    torch.manual_seed(args.rand_seed)
+
+    ################################
+    # print everything for log
+    print('=> args: ', args)
+
+    ################################
+    # args check and config
+    assert args.iter_size == 1 or args.total_batch_size is None, "only one of --iter_size or --total_batch_size must be set"
+    if args.total_batch_size is not None:
+        args.iter_size = args.total_batch_size//args.batch_size
+    else:
+        args.total_batch_size = args.batch_size*args.iter_size
+    #
+
+    assert args.pretrained is not None, 'pretrained onnx model path should be provided'
+
+    #################################################
+    # set some global flags and initializations
+    # keep it in args for now - although they don't belong here strictly
+    # using pin_memory is seen to cause issues, especially when when lot of memory is used.
+    args.use_pinned_memory = False
+    args.n_iter = 0
+    args.best_metric = -1
+
+    #################################################
+    if args.save_path is None:
+        save_path = get_save_path(args)
+    else:
+        save_path = args.save_path
+    #
+
+    print('=> will save everything to {}'.format(save_path))
+    
+    if not os.path.exists(save_path):
+        os.makedirs(save_path)
+
+    #################################################
+    if args.logger is None:
+        log_file = os.path.splitext(os.path.basename(__file__))[0] + '.log'
+        args.logger = xnn.utils.TeeLogger(filename=os.path.join(save_path,log_file))
+    transforms = get_transforms(args)
+
+    print("=> fetching img pairs in '{}'".format(args.data_path))
+    split_arg = args.split_file if args.split_file else (args.split_files if args.split_files else args.split_value)
+    val_dataset = vision.datasets.pixel2pixel.__dict__[args.dataset_name](args.dataset_config, args.data_path, split=split_arg, transforms=transforms)
+
+    print('=> {} val samples found'.format(len(val_dataset)))
+    
+    val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=args.batch_size,
+        num_workers=args.workers, pin_memory=args.use_pinned_memory, shuffle=args.shuffle)
+    #
+
+    #################################################
+    args.model_config.output_channels = val_dataset.num_classes() if (args.model_config.output_channels == None and 'num_classes' in dir(val_dataset)) else None
+    args.n_classes = args.model_config.output_channels[0]
+
+    #################################################
+    # create model
+    print("=> creating model '{}'".format(args.model_name))
+
+    model = onnx.load(args.pretrained)
+    # Run the ONNX model with Caffe2
+    onnx.checker.check_model(model)
+    model = caffe2.python.onnx.backend.prepare(model)
+
+
+    #################################################
+    if args.palette:
+        print('Creating palette')
+        eval_string = args.palette
+        palette = eval(eval_string)
+        args.palette = np.zeros((256,3))
+        for i, p in enumerate(palette):
+            args.palette[i,0] = p[0]
+            args.palette[i,1] = p[1]
+            args.palette[i,2] = p[2]
+        args.palette = args.palette[...,::-1] #RGB->BGR, since palette is expected to be given in RGB format
+
+    infer_path = os.path.join(save_path, 'inference')
+    if not os.path.exists(infer_path):
+        os.makedirs(infer_path)
+
+    #################################################
+    with torch.no_grad():
+        validate(args, val_dataset, val_loader, model, 0, infer_path)
+
+    if args.create_video:
+        create_video(args, infer_path=infer_path)
+
+def validate(args, val_dataset, val_loader, model, epoch, infer_path):
+    data_time = xnn.utils.AverageMeter()
+    avg_metric = xnn.utils.AverageMeter()
+
+    # switch to evaluate mode
+    #model.eval()
+    metric_name = "Metric"
+    end_time = time.time()
+    writer_idx = 0
+    last_update_iter = -1
+    metric_ctx = [None] * len(args.metric_modules)
+
+    if args.label:
+        confusion_matrix = np.zeros((args.n_classes, args.n_classes+1))
+        for iter, (input_list, target_list, input_path, target_path) in enumerate(val_loader):
+            data_time.update(time.time() - end_time)
+            target_sizes = [tgt.shape for tgt in target_list]
+
+            input_dict = {args.input_tensor_name[idx]: input_list[idx].numpy() for idx in range(len(input_list))}
+            output = model.run(input_dict)[0]
+
+            list_output = type(output) in (list, tuple)
+            output_pred = output[0] if list_output else output
+
+            if args.output_size is not None:
+                output_pred = upsample_tensors(output_pred, target_sizes, args.upsample_mode)
+            #
+            if args.blend:
+                for index in range(output_pred.shape[0]):
+                    prediction = np.squeeze(np.array(output_pred[index]))
+                    #prediction = np.argmax(prediction, axis = 0)
+                    prediction_size = (prediction.shape[0], prediction.shape[1], 3)
+                    output_image = args.palette[prediction.ravel()].reshape(prediction_size)
+                    input_bgr = cv2.imread(input_path[0][index]) #Read the actual RGB image
+                    input_bgr = cv2.resize(input_bgr, dsize=(prediction.shape[1],prediction.shape[0]))
+
+                    output_image = chroma_blend(input_bgr, output_image)
+                    output_name = os.path.join(infer_path, os.path.basename(input_path[0][index]))
+                    cv2.imwrite(output_name, output_image)
+
+            if args.label:
+                for index in range(output_pred.shape[0]): 
+                    prediction = np.array(output_pred[index])
+                    #prediction = np.argmax(prediction, axis = 0)
+                    label = np.squeeze(np.array(target_list[0][index]))
+                    confusion_matrix = eval_output(args, prediction, label, confusion_matrix)
+                    accuracy, mean_iou, iou, f1_score= compute_accuracy(args, confusion_matrix)
+                print('pixel_accuracy={}, mean_iou={}, iou={}, f1_score = {}'.format(accuracy, mean_iou, iou, f1_score))
+                sys.stdout.flush()
+    else:
+        for iter, (input_list, _ , input_path, _) in enumerate(val_loader):
+            data_time.update(time.time() - end_time)
+
+            input_dict = {args.input_tensor_name[idx]: input_list[idx].numpy() for idx in range(len(input_list))}
+            output = model.run(input_dict)[0]
+
+            list_output = type(output) in (list, tuple)
+            output_pred = output[0] if list_output else output
+            input_path = input_path[0]
+            
+            if args.output_size is not None:
+                target_sizes = [args.output_size]
+                output_pred = upsample_tensors(output_pred, target_sizes, args.upsample_mode)
+            #
+            if args.blend:
+                for index in range(output_pred.shape[0]):
+                    prediction = np.squeeze(np.array(output_pred[index])) #np.squeeze(np.array(output_pred[index].cpu()))
+                    prediction_size = (prediction.shape[0], prediction.shape[1], 3)
+                    output_image = args.palette[prediction.ravel()].reshape(prediction_size)
+                    input_bgr = cv2.imread(input_path[index]) #Read the actual RGB image
+                    input_bgr = cv2.resize(input_bgr, (args.img_resize[1], args.img_resize[0]), interpolation=cv2.INTER_LINEAR)
+                    output_image = chroma_blend(input_bgr, output_image)
+                    output_name = os.path.join(infer_path, os.path.basename(input_path[index]))
+                    cv2.imwrite(output_name, output_image)
+                    print('Inferred image {}'.format(input_path[index]))
+            if args.car_mask:   #generating car_mask (required for localization)
+                for index in range(output_pred.shape[0]):
+                    prediction = np.array(output_pred[index])
+                    prediction = np.argmax(prediction, axis = 0)
+                    car_mask = np.logical_or(prediction == 13, prediction == 14, prediction == 16, prediction, prediction == 17)
+                    prediction[car_mask] = 255
+                    prediction[np.invert(car_mask)] = 0
+                    output_image = prediction
+                    output_name = os.path.join(infer_path, os.path.basename(input_path[index]))
+                    cv2.imwrite(output_name, output_image)
+
+
+def get_save_path(args, phase=None):
+    date = args.date if args.date else datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
+    save_path = os.path.join('./data/checkpoints', args.dataset_name, '{}_{}_'.format(date, args.model_name))
+    save_path += 'b{}'.format(args.batch_size)
+    phase = phase if (phase is not None) else args.phase
+    save_path = os.path.join(save_path, phase)
+    return save_path
+
+
+def get_transforms(args):
+    # image normalization can be at the beginning of transforms or at the end
+    args.image_mean = np.array(args.image_mean, dtype=np.float32)
+    args.image_scale = np.array(args.image_scale, dtype=np.float32)
+    image_prenorm = vision.transforms.image_transforms.NormalizeMeanScale(mean=args.image_mean, scale=args.image_scale) if args.image_prenorm else None
+    image_postnorm = vision.transforms.image_transforms.NormalizeMeanScale(mean=args.image_mean, scale=args.image_scale) if (not image_prenorm) else None
+
+    #target size must be according to output_size. prediction will be resized to output_size before evaluation.
+    test_transform = vision.transforms.image_transforms.Compose([
+        image_prenorm,
+        vision.transforms.image_transforms.AlignImages(),
+        vision.transforms.image_transforms.MaskTarget(args.target_mask, 0),
+        vision.transforms.image_transforms.CropRect(args.img_border_crop),
+        vision.transforms.image_transforms.Scale(args.img_resize, target_size=args.output_size, is_flow=args.is_flow),
+        image_postnorm,
+        vision.transforms.image_transforms.ConvertToTensor()
+        ])
+
+    return test_transform
+
+
+def _upsample_impl(tensor, output_size, upsample_mode):
+    # upsample of long tensor is not supported currently. covert to float, just to avoid error.
+    # we can do this only in the case of nearest mode, otherwise output will have invalid values.
+    convert_tensor_to_float = False
+    convert_np_to_float = False
+    if isinstance(tensor, (torch.LongTensor,torch.cuda.LongTensor)):
+        convert_tensor_to_float = True
+        original_dtype = tensor.dtype
+        tensor = tensor.float()
+    elif isinstance(tensor, np.ndarray) and (np.dtype != np.float32):
+        convert_np_to_float = True
+        original_dtype = tensor.dtype
+        tensor = tensor.astype(np.float32)
+    #
+
+    dim_added = False
+    if len(tensor.shape) < 4:
+        tensor = tensor[np.newaxis,...]
+        dim_added = True
+    #
+    if (tensor.shape[-2:] != output_size):
+        assert tensor.shape[1] == 1, 'TODO: add code for multi channel resizing'
+        out_tensor = np.zeros((tensor.shape[0],tensor.shape[1],output_size[0],output_size[1]),dtype=np.float32)
+        for b_idx in range(tensor.shape[0]):
+            b_tensor = PIL.Image.fromarray(tensor[b_idx,0])
+            b_tensor = b_tensor.resize((output_size[1],output_size[0]), PIL.Image.NEAREST)
+            out_tensor[b_idx,0,...] = np.asarray(b_tensor)
+        #
+        tensor = out_tensor
+    #
+    if dim_added:
+        tensor = tensor[0]
+    #
+
+    if convert_tensor_to_float:
+        tensor = tensor.long()
+    elif convert_np_to_float:
+        tensor = tensor.astype(original_dtype)
+    #
+    return tensor
+
+def upsample_tensors(tensors, output_sizes, upsample_mode):
+    if isinstance(tensors, (list,tuple)):
+        for tidx, tensor in enumerate(tensors):
+            tensors[tidx] = _upsample_impl(tensor, output_sizes[tidx][-2:], upsample_mode)
+        #
+    else:
+        tensors = _upsample_impl(tensors, output_sizes[0][-2:], upsample_mode)
+    return tensors
+
+
+def chroma_blend(image, color):
+    image_yuv = cv2.cvtColor(image.astype(np.uint8), cv2.COLOR_BGR2YUV)
+    image_y,image_u,image_v = cv2.split(image_yuv)
+    color_yuv = cv2.cvtColor(color.astype(np.uint8), cv2.COLOR_BGR2YUV)
+    color_y,color_u,color_v = cv2.split(color_yuv)
+    image_y = np.uint8(image_y)
+    color_u = np.uint8(color_u)
+    color_v = np.uint8(color_v)
+    image_yuv = cv2.merge((image_y,color_u,color_v))
+    image = cv2.cvtColor(image_yuv.astype(np.uint8), cv2.COLOR_YUV2BGR)
+    return image    
+
+
+
+def eval_output(args, output, label, confusion_matrix):
+
+    if len(label.shape)>2:
+        label = label[:,:,0]
+    gt_labels = label.ravel()
+    det_labels = output.ravel().clip(0,args.n_classes)
+    gt_labels_valid_ind = np.where(gt_labels != 255)
+    gt_labels_valid = gt_labels[gt_labels_valid_ind]
+    det_labels_valid = det_labels[gt_labels_valid_ind]
+    for r in range(confusion_matrix.shape[0]):
+        for c in range(confusion_matrix.shape[1]):
+            confusion_matrix[r,c] += np.sum((gt_labels_valid==r) & (det_labels_valid==c))
+
+    return confusion_matrix
+    
+def compute_accuracy(args, confusion_matrix):
+
+    #pdb.set_trace()
+    num_selected_classes = args.n_classes
+    tp = np.zeros(args.n_classes)
+    population = np.zeros(args.n_classes)
+    det = np.zeros(args.n_classes)
+    iou = np.zeros(args.n_classes)
+    
+    for r in range(args.n_classes):
+      for c in range(args.n_classes):   
+        population[r] += confusion_matrix[r][c]
+        det[c] += confusion_matrix[r][c]   
+        if r == c:
+          tp[r] += confusion_matrix[r][c]
+
+    for cls in range(num_selected_classes):
+      intersection = tp[cls]
+      union = population[cls] + det[cls] - tp[cls]
+      iou[cls] = (intersection / union) if union else 0     # For caffe jacinto script
+      #iou[cls] = (intersection / (union + np.finfo(np.float32).eps))  # For pytorch-jacinto script
+
+    num_nonempty_classes = 0
+    for pop in population:
+      if pop>0:
+        num_nonempty_classes += 1
+          
+    mean_iou = np.sum(iou) / num_nonempty_classes
+    accuracy = np.sum(tp) / np.sum(population)
+    
+    #F1 score calculation
+    fp = np.zeros(args.n_classes)
+    fn = np.zeros(args.n_classes)
+    precision = np.zeros(args.n_classes)
+    recall = np.zeros(args.n_classes)
+    f1_score = np.zeros(args.n_classes)
+
+    for cls in range(num_selected_classes):
+        fp[cls] = det[cls] - tp[cls]
+        fn[cls] = population[cls] - tp[cls]
+        precision[cls] = tp[cls] / (det[cls] + 1e-10)
+        recall[cls] = tp[cls] / (tp[cls] + fn[cls] + 1e-10)        
+        f1_score[cls] = 2 * precision[cls]*recall[cls] / (precision[cls] + recall[cls] + 1e-10)
+
+    return accuracy, mean_iou, iou, f1_score
+    
+        
+def infer_video(args, net):
+    videoIpHandle = imageio.get_reader(args.input, 'ffmpeg')
+    fps = math.ceil(videoIpHandle.get_meta_data()['fps'])
+    print(videoIpHandle.get_meta_data())
+    numFrames = min(len(videoIpHandle), args.num_images)
+    videoOpHandle = imageio.get_writer(args.output,'ffmpeg', fps=fps)
+    for num in range(numFrames):
+        print(num, end=' ')
+        sys.stdout.flush()
+        input_blob = videoIpHandle.get_data(num)
+        input_blob = input_blob[...,::-1]    #RGB->BGR
+        output_blob = infer_blob(args, net, input_blob)     
+        output_blob = output_blob[...,::-1]  #BGR->RGB            
+        videoOpHandle.append_data(output_blob)
+    videoOpHandle.close()
+    return
+
+def create_video(args, infer_path):
+    op_file_name = args.data_path.split('/')[-1]
+    os.system(' ffmpeg -framerate 30 -pattern_type glob -i "{}/*.png" -c:v libx264 -vb 50000k -qmax 20 -r 10 -vf \
+                 scale=1024:512  -pix_fmt yuv420p {}.MP4'.format(infer_path, op_file_name))
+
+if __name__ == '__main__':
+    train_args = get_config()
+    train_args = parser.parse_args()
+    main(train_args)
diff --git a/modules/pytorch_jacinto_ai/engine/train_classification.py b/modules/pytorch_jacinto_ai/engine/train_classification.py
new file mode 100644 (file)
index 0000000..83a0e9d
--- /dev/null
@@ -0,0 +1,709 @@
+import os
+import shutil
+import time
+
+import random
+import numpy as np
+from colorama import Fore
+import math
+import progiter
+import warnings
+
+import torch
+import torch.nn.parallel
+import torch.backends.cudnn as cudnn
+import torch.distributed as dist
+import torch.optim
+import torch.utils.data
+import torch.utils.data.distributed
+
+import sys
+import datetime
+
+from .. import xnn
+from .. import vision
+
+
+#################################################
+def get_config():
+    args = xnn.utils.ConfigNode()
+    args.model_config = xnn.utils.ConfigNode()
+    args.dataset_config = xnn.utils.ConfigNode()
+
+    args.model_config.input_channels = 3                # num input channels
+
+    args.data_path = './data/datasets/ilsvrc'           # path to dataset
+    args.model_name = 'mobilenet_v2_classification'     # model architecture'
+    args.dataset_name = 'imagenet_classification'   # image folder classification
+    args.save_path = None                               # checkpoints save path
+    args.phase = 'training'                             # training/calibration/validation
+    args.date = None                                    # date to add to save path. if this is None, current date will be added.
+
+    args.workers = 8                                    # number of data loading workers (default: 8)
+    args.logger = None                                  # logger stream to output into
+
+    args.epochs = 90                                    # number of total epochs to run
+    args.warmup_epochs = None                           # number of epochs to warm up by linearly increasing lr
+
+    args.epoch_size = 0                                 # fraction of training epoch to use each time. 0 indicates full
+    args.start_epoch = 0                                # manual epoch number to start
+    args.stop_epoch = None                              # manual epoch number to stop
+    args.batch_size = 256                               # mini_batch size (default: 256)
+    args.total_batch_size = None                        # accumulated batch size. total_batch_size = batch_size*iter_size
+    args.iter_size = 1                                  # iteration size. total_batch_size = batch_size*iter_size
+
+    args.lr = 0.1                                       # initial learning rate
+    args.lr_clips = None                                 # use args.lr itself if it is None
+    args.lr_calib = 0.1                                 # lr for bias calibration
+    args.momentum = 0.9                                 # momentum
+    args.weight_decay = 1e-4                            # weight decay (default: 1e-4)
+    args.bias_decay = None                              # bias decay (default: 0.0)
+
+    args.rand_seed = 1                                  # random seed
+    args.print_freq = 100                               # print frequency (default: 100)
+    args.resume = None                                  # path to latest checkpoint (default: none)
+    args.evaluate_start = True                          # evaluate right at the begining of training or not
+    args.world_size = 1                                 # number of distributed processes
+    args.dist_url = 'tcp://224.66.41.62:23456'          # url used to set up distributed training
+    args.dist_backend = 'gloo'                          # distributed backend
+
+    args.optimizer = 'sgd'                              # solver algorithms, choices=['adam','sgd','sgd_nesterov','rmsprop']
+    args.scheduler = 'step'                             # help='scheduler algorithms, choices=['step','poly','exponential', 'cosine']
+    args.milestones = (30, 60, 90)                      # epochs at which learning rate is divided
+    args.multistep_gamma = 0.1                          # multi step gamma (default: 0.1)
+    args.polystep_power = 1.0                           # poly step gamma (default: 1.0)
+    args.step_size = 1,                                 # step size for exp lr decay
+
+    args.beta = 0.999                                   # beta parameter for adam
+    args.pretrained = None                              # path to pre_trained model
+    args.img_resize = 256                               # image resize
+    args.img_crop = 224                                 # image crop
+    args.rand_scale = (0.08,1.0)                        # random scale range for training
+    args.data_augument = 'inception'                    # data augumentation method, choices=['inception','resize','adaptive_resize']
+    args.count_flops = True                             # count flops and report
+
+    args.generate_onnx = True                           # apply quantized inference or not
+    args.print_model = False                            # print the model to text
+    args.run_soon = True                                # Set to false if only cfs files/onnx  modelsneeded but no training
+
+    args.multi_color_modes = None                       # input modes created with multi color transform
+    args.image_mean = (123.675, 116.28, 103.53)         # image mean for input image normalization')
+    args.image_scale = (0.017125, 0.017507, 0.017429)   # image scaling/mult for input iamge normalization')
+
+    args.parallel_model = True                          # Usedata parallel for model
+
+    args.quantize = False                               # apply quantized inference or not
+    #args.model_surgery = None                           # replace activations with PAct2 activation module. Helpful in quantized training.
+    args.bitwidth_weights = 8                           # bitwidth for weights
+    args.bitwidth_activations = 8                       # bitwidth for activations
+    args.histogram_range = True                         # histogram range for calibration
+    args.bias_calibration = True                        # apply bias correction during quantized inference calibration
+    args.per_channel_q = False                          # apply separate quantizion factor for each channel in depthwise or not
+
+    args.freeze_bn = False                              # freeze the statistics of bn
+
+    return args
+
+
+#################################################
+cudnn.benchmark = True
+#cudnn.enabled = False
+
+
+
+def main(args):
+    assert (not hasattr(args, 'evaluate')), 'args.evaluate is deprecated. use args.phase=training or calibration or validation'
+    assert args.phase in ('training', 'calibration', 'validation'), f'invalid phase {args.phase}'
+    assert not hasattr(args, 'model_surgery'), 'the argument model_surgery is deprecated, it is not needed now - remove it'
+
+    if (args.phase == 'validation' and args.bias_calibration):
+        args.bias_calibration = False
+        warnings.warn('switching off bias calibration in validation')
+    #
+
+    #################################################
+    if args.save_path is None:
+        save_path = get_save_path(args)
+    else:
+        save_path = args.save_path
+    #
+
+    args.best_prec1 = -1
+
+    # resume has higher priority
+    args.pretrained = None if (args.resume is not None) else args.pretrained
+
+    if not os.path.exists(save_path):
+        os.makedirs(save_path)
+
+    #################################################
+    if args.logger is None:
+        log_file = os.path.splitext(os.path.basename(__file__))[0] + '.log'
+        args.logger = xnn.utils.TeeLogger(filename=os.path.join(save_path,log_file))
+
+    ################################
+    args.pretrained = None if (args.pretrained == 'None') else args.pretrained
+    args.num_inputs = len(args.multi_color_modes) if (args.multi_color_modes is not None) else 1
+
+    assert args.iter_size == 1 or args.total_batch_size is None, "only one of --iter_size or --total_batch_size must be set"
+    if args.total_batch_size is not None:
+        args.iter_size = args.total_batch_size//args.batch_size
+    else:
+        args.total_batch_size = args.batch_size*args.iter_size
+
+    args.stop_epoch = args.stop_epoch if args.stop_epoch else args.epochs
+
+    args.distributed = args.world_size > 1
+
+    if args.distributed:
+        dist.init_process_group(backend=args.dist_backend, init_method=args.dist_url,
+                                world_size=args.world_size)
+
+    #################################################
+    # global settings. rand seeds for repeatability
+    random.seed(args.rand_seed)
+    np.random.seed(args.rand_seed)
+    torch.manual_seed(args.rand_seed)
+    torch.backends.cudnn.deterministic = True
+    torch.backends.cudnn.benchmark = True
+    # torch.autograd.set_detect_anomaly(True)
+
+    ################################
+    # print everything for log
+    # reset character color, in case it is different
+    print('{}'.format(Fore.RESET))
+    print("=> args: ", args)
+    print("=> resize resolution: {}".format(args.img_resize))
+    print("=> crop resolution  : {}".format(args.img_crop))
+    sys.stdout.flush()
+
+    #################################################
+    pretrained_data = None
+    model_surgery_quantize = False
+    if args.pretrained and args.pretrained != "None":
+        if args.pretrained.startswith('http://') or args.pretrained.startswith('https://'):
+            pretrained_file = vision.datasets.utils.download_url(args.pretrained, './data/downloads')
+        else:
+            pretrained_file = args.pretrained
+        #
+        print(f'=> using pre-trained weights from: {args.pretrained}')
+        pretrained_data = torch.load(pretrained_file)
+        model_surgery_quantize = pretrained_data['quantize'] if 'quantize' in pretrained_data else False
+    #
+
+    #################################################
+    # create model
+    print("=> creating model '{}'".format(args.model_name))
+    model = vision.models.classification.__dict__[args.model_name](args.model_config)
+
+    # check if we got the model as well as parameters to change the names in pretrained
+    model, change_names_dict = model if isinstance(model, (list,tuple)) else (model,None)
+
+    #################################################
+    if args.quantize:
+        # dummy input is used by quantized models to analyze graph
+        is_cuda = next(model.parameters()).is_cuda
+        dummy_input = create_rand_inputs(args, is_cuda=is_cuda)
+        #
+        if args.phase == 'training':
+            model = xnn.quantize.QuantTrainModule(model, per_channel_q=args.per_channel_q,
+                        histogram_range=args.histogram_range, bitwidth_weights=args.bitwidth_weights,
+                        bitwidth_activations=args.bitwidth_activations,
+                        dummy_input=dummy_input)
+        elif args.phase == 'calibration':
+            model = xnn.quantize.QuantCalibrateModule(model, per_channel_q=args.per_channel_q,
+                        bitwidth_weights=args.bitwidth_weights, bitwidth_activations=args.bitwidth_activations,
+                        histogram_range=args.histogram_range, bias_calibration=args.bias_calibration, lr_calib=args.lr_calib,
+                        dummy_input=dummy_input)
+        elif args.phase == 'validation':
+            # Note: bias_calibration is not used in test
+            model = xnn.quantize.QuantTestModule(model, per_channel_q=args.per_channel_q,
+                        bitwidth_weights=args.bitwidth_weights, bitwidth_activations=args.bitwidth_activations,
+                        histogram_range=args.histogram_range, model_surgery_quantize=model_surgery_quantize,
+                        dummy_input=dummy_input)
+        else:
+            assert False, f'invalid phase {args.phase}'
+    #
+
+    # load pretrained
+    xnn.utils.load_weights_check(get_model_orig(model), pretrained=pretrained_data, change_names_dict=change_names_dict)
+
+    #################################################
+    if args.count_flops:
+        count_flops(args, model)
+
+    #################################################
+    if args.generate_onnx and ((args.phase in ('training','calibration')) or (args.run_soon == False)):
+        write_onnx_model(args, get_model_orig(model), save_path)
+    #
+
+    #################################################
+    if args.print_model:
+        print(model)
+    else:
+        args.logger.debug(str(model))
+
+    #################################################
+    if (not args.run_soon):
+        print("Training not needed for now")
+        close(args)
+        exit()
+
+    #################################################
+    # multi gpu mode is not working for quantized model
+    if args.parallel_model and (not args.quantize):
+        if args.distributed:
+            model = torch.nn.parallel.DistributedDataParallel(model)
+        else:
+            model = torch.nn.DataParallel(model)
+    #
+
+    #################################################
+    model = model.cuda()
+
+    #################################################
+    # define loss function (criterion) and optimizer
+    criterion = torch.nn.CrossEntropyLoss().cuda()
+
+    model_module = model.module if hasattr(model, 'module') else model
+    if args.lr_clips is not None:
+        learning_rate_clips = args.lr_clips if args.phase == 'training' else 0.0
+        clips_decay = args.bias_decay if (args.bias_decay is not None and args.bias_decay != 0.0) else args.weight_decay
+        clips_params = [p for n,p in model_module.named_parameters() if 'clips' in n]
+        other_params = [p for n,p in model_module.named_parameters() if 'clips' not in n]
+        param_groups = [{'params': clips_params, 'weight_decay': clips_decay, 'lr': learning_rate_clips},
+                        {'params': other_params, 'weight_decay': args.weight_decay}]
+    else:
+        param_groups = [{'params': filter(lambda p: p.requires_grad, model.parameters()), 'weight_decay': args.weight_decay}]
+    #
+
+    print("=> args: ", args)          
+    print("=> optimizer type   : {}".format(args.optimizer))
+    print("=> learning rate    : {}".format(args.lr))
+    print("=> resize resolution: {}".format(args.img_resize))
+    print("=> crop resolution  : {}".format(args.img_crop))
+    print("=> batch size       : {}".format(args.batch_size))
+    print("=> total batch size : {}".format(args.total_batch_size))
+    print("=> epoch size       : {}".format(args.epoch_size))
+    print("=> data augument    : {}".format(args.data_augument))
+    print("=> epochs           : {}".format(args.epochs))
+    if args.scheduler == 'step':
+        print("=> milestones       : {}".format(args.milestones))
+
+    learning_rate = args.lr if (args.phase == 'training') else 0.0
+    if args.optimizer == 'adam':
+        optimizer = torch.optim.Adam(param_groups, learning_rate, betas=(args.momentum, args.beta))
+    elif args.optimizer == 'sgd':
+        optimizer = torch.optim.SGD(param_groups, learning_rate, momentum=args.momentum)
+    elif args.optimizer == 'sgd_nesterov':
+        optimizer = torch.optim.SGD(param_groups, learning_rate, momentum=args.momentum, nesterov=True)
+    elif args.optimizer == 'rmsprop':
+        optimizer = torch.optim.RMSprop(param_groups, learning_rate, momentum=args.momentum)
+    else:
+        raise ValueError('Unknown optimizer type{}'.format(args.optimizer))
+        
+    # optionally resume from a checkpoint
+    if args.resume:
+        if os.path.isfile(args.resume):
+            print("=> resuming from checkpoint '{}'".format(args.resume))
+            checkpoint = torch.load(args.resume)
+            if args.start_epoch == 0:
+                args.start_epoch = checkpoint['epoch'] + 1
+                
+            args.best_prec1 = checkpoint['best_prec1']
+            model = xnn.utils.load_weights_check(model, checkpoint)
+            optimizer.load_state_dict(checkpoint['optimizer'])
+            print("=> resuming from checkpoint '{}' (epoch {})".format(args.resume, checkpoint['epoch']))
+        else:
+            print("=> no checkpoint found at '{}'".format(args.resume))
+
+    train_loader, val_loader = get_data_loaders(args)
+
+    # number of train iterations per epoch
+    args.iters = get_epoch_size(train_loader, args.epoch_size)
+
+    args.cur_lr = adjust_learning_rate(args, optimizer, args.start_epoch)
+
+    if args.evaluate_start or args.phase=='validation':
+        validate(args, val_loader, model, criterion, args.start_epoch)
+
+    if args.phase == 'validation':
+        close(args)
+        return
+
+    for epoch in range(args.start_epoch, args.stop_epoch):
+        if args.distributed:
+            train_loader.sampler.set_epoch(epoch)
+
+        # train for one epoch
+        train(args, train_loader, model, criterion, optimizer, epoch)
+
+        # evaluate on validation set
+        prec1 = validate(args, val_loader, model, criterion, epoch)
+
+        # remember best prec@1 and save checkpoint
+        is_best = prec1 > args.best_prec1
+        args.best_prec1 = max(prec1, args.best_prec1)
+
+        model_orig = get_model_orig(model)
+
+        save_dict = {'epoch': epoch, 'arch': args.model_name, 'state_dict': model_orig.state_dict(),
+                     'best_prec1': args.best_prec1, 'optimizer' : optimizer.state_dict(),
+                     'quantize' : args.quantize}
+
+        save_checkpoint(args, save_path, model_orig, save_dict, is_best)
+    #
+
+    # for n, m in model.named_modules():
+    #     if hasattr(m, 'num_batches_tracked'):
+    #         print(f'name={n}, num_batches_tracked={m.num_batches_tracked}')
+    # #
+
+    # close and cleanup
+    close(args)
+#
+
+###################################################################
+def close(args):
+    if args.logger is not None:
+        del args.logger
+        args.logger = None
+    #
+    args.best_prec1 = -1
+#
+
+
+def get_save_path(args, phase=None):
+    date = args.date if args.date else datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
+    save_path_base = os.path.join('./data/checkpoints', args.dataset_name, date + '_' + args.dataset_name + '_' + args.model_name)
+    save_path = save_path_base + '_resize{}_crop{}'.format(args.img_resize, args.img_crop)
+    phase = phase if (phase is not None) else args.phase
+    save_path = os.path.join(save_path, phase)
+    return save_path
+
+
+def get_model_orig(model):
+    is_parallel_model = isinstance(model, (torch.nn.DataParallel, torch.nn.parallel.DistributedDataParallel))
+    model_orig = (model.module if is_parallel_model else model)
+    model_orig = (model_orig.module if isinstance(model_orig, (xnn.quantize.QuantBaseModule)) else model_orig)
+    return model_orig
+
+
+def create_rand_inputs(args, is_cuda):
+    dummy_input = torch.rand((1, args.model_config.input_channels, args.img_crop, args.img_crop))
+    dummy_input = dummy_input.cuda() if is_cuda else dummy_input
+    return dummy_input
+
+
+def count_flops(args, model):
+    is_cuda = next(model.parameters()).is_cuda
+    dummy_input = create_rand_inputs(args, is_cuda)
+    #
+    model.eval()
+    flops = xnn.utils.forward_count_flops(model, dummy_input)
+    gflops = flops/1e9
+    print('=> Resize = {}, Crop = {}, GFLOPs = {}, GMACs = {}'.format(args.img_resize, args.img_crop, gflops, gflops/2))
+
+
+def write_onnx_model(args, model, save_path, name='checkpoint.onnx'):
+    is_cuda = next(model.parameters()).is_cuda
+    dummy_input = create_rand_inputs(args, is_cuda)
+    #
+    model.eval()
+    torch.onnx.export(model, dummy_input, os.path.join(save_path,name), export_params=True, verbose=False)
+
+
+def train(args, train_loader, model, criterion, optimizer, epoch):
+    # actual training code
+    batch_time = AverageMeter()
+    data_time = AverageMeter()
+    losses = AverageMeter()
+    top1 = AverageMeter()
+    top5 = AverageMeter()
+
+    # switch to train mode
+    model.train()
+    if args.freeze_bn:
+        xnn.utils.freeze_bn(model)
+    #
+    
+    progress_bar = progiter.ProgIter(np.arange(args.iters), chunksize=1)
+    args.cur_lr = adjust_learning_rate(args, optimizer, epoch)
+
+    end = time.time()
+    train_iter = iter(train_loader)
+    last_update_iter = -1
+
+    progressbar_color = (Fore.YELLOW if args.phase=='calibration' else Fore.WHITE)
+    print('{}'.format(progressbar_color), end='')
+
+    for iteration in range(args.iters):
+        (input, target) = next(train_iter)
+        input = [inp.cuda() for inp in input] if xnn.utils.is_list(input) else input.cuda()
+        input_size = input[0].size() if xnn.utils.is_list(input) else input.size()
+        target = target.cuda(non_blocking=True)
+
+        data_time.update(time.time() - end)
+
+        # compute output
+        output = model(input)
+
+        # compute loss
+        loss = criterion(output, target) / args.iter_size
+
+        # measure accuracy and record loss
+        prec1, prec5 = accuracy(output, target, topk=(1, 5))
+        losses.update(loss.item(), input_size[0])
+        top1.update(prec1[0], input_size[0])
+        top5.update(prec5[0], input_size[0])
+
+        if args.phase == 'training':
+            # zero gradients so that we can accumulate gradients
+            if (iteration % args.iter_size) == 0:
+                optimizer.zero_grad()
+
+            loss.backward()
+
+            if ((iteration+1) % args.iter_size) == 0:
+                optimizer.step()
+        #
+
+        # measure elapsed time
+        batch_time.update(time.time() - end)
+        end = time.time()
+        final_iter = (iteration >= (args.iters-1))
+
+        if ((iteration % args.print_freq) == 0) or final_iter:
+            epoch_str = '{}/{}'.format(epoch + 1, args.epochs)
+            status_str = '{epoch} LR={cur_lr:.5f} Time={batch_time.avg:0.3f} DataTime={data_time.avg:0.3f} Loss={loss.avg:0.3f} Prec@1={top1.avg:0.3f} Prec@5={top5.avg:0.3f}' \
+                         .format(epoch=epoch_str, cur_lr=args.cur_lr, batch_time=batch_time, data_time=data_time, loss=losses, top1=top1, top5=top5)
+
+            progress_bar.set_description(f'=> {args.phase}  ')
+            progress_bar.set_postfix(Epoch='{}'.format(status_str))
+            progress_bar.update(iteration-last_update_iter)
+            last_update_iter = iteration
+
+    progress_bar.close()
+
+    # to print a new line - do not provide end=''
+    print('{}'.format(Fore.RESET), end='')
+
+    ##########################
+    if args.quantize:
+        def debug_format(v):
+            return ('{:.3f}'.format(v) if v is not None else 'None')
+        #
+        clips_act = [m.get_clips_act()[1] for n,m in model.named_modules() if isinstance(m,xnn.layers.PAct2)]
+        if len(clips_act) > 0:
+            args.logger.debug('\nclips_act : ' + ' '.join(map(debug_format, clips_act)))
+            args.logger.debug('')
+    #
+
+
+
+def validate(args, val_loader, model, criterion, epoch):
+    batch_time = AverageMeter()
+    losses = AverageMeter()
+    top1 = AverageMeter()
+    top5 = AverageMeter()
+
+    # switch to evaluate mode
+    model.eval()
+
+    progress_bar = progiter.ProgIter(np.arange(len(val_loader)), chunksize=1)
+    last_update_iter = -1
+
+    # change color to green
+    print('{}'.format(Fore.GREEN), end='')
+
+    with torch.no_grad():
+        end = time.time()
+        for iteration, (input, target) in enumerate(val_loader):
+            input = [inp.cuda() for inp in input] if xnn.utils.is_list(input) else input.cuda()
+            input_size = input[0].size() if xnn.utils.is_list(input) else input.size()
+            target = target.cuda(non_blocking=True)
+
+            # compute output
+            output = model(input)
+            loss = criterion(output, target)
+
+            # measure accuracy and record loss
+            prec1, prec5 = accuracy(output, target, topk=(1, 5))
+            losses.update(loss.item(), input_size[0])
+            top1.update(prec1[0], input_size[0])
+            top5.update(prec5[0], input_size[0])
+
+            # measure elapsed time
+            batch_time.update(time.time() - end)
+            end = time.time()
+            final_iter = (iteration >= (len(val_loader)-1))
+
+            if ((iteration % args.print_freq) == 0) or final_iter:
+                epoch_str = '{}/{}'.format(epoch+1,args.epochs)
+                status_str = '{epoch} LR={cur_lr:.5f} Time={batch_time.avg:0.3f} Loss={loss.avg:0.3f} Prec@1={top1.avg:0.3f} Prec@5={top5.avg:0.3f}' \
+                             .format(epoch=epoch_str, cur_lr=args.cur_lr, batch_time=batch_time, loss=losses, top1=top1, top5=top5)
+                           
+                prefix = '**' if final_iter else '=>'
+                progress_bar.set_description('{} {}'.format(prefix, 'validation'))
+                progress_bar.set_postfix(Epoch='{}'.format(status_str))
+                progress_bar.update(iteration - last_update_iter)
+                last_update_iter = iteration
+
+        progress_bar.close()
+
+        # to print a new line - do not provide end=''
+        print('{}'.format(Fore.RESET), end='')
+
+    return top1.avg
+
+
+def save_checkpoint(args, save_path, model, state, is_best, filename='checkpoint.pth.tar'):
+    filename = os.path.join(save_path, filename)
+    torch.save(state, filename)
+    if is_best:
+        bestname = os.path.join(save_path, 'model_best.pth.tar')
+        shutil.copyfile(filename, bestname)
+    #
+    if args.generate_onnx:
+        write_onnx_model(args, model, save_path, name='checkpoint.onnx')
+        if is_best:
+            write_onnx_model(args, model, save_path, name='model_best.onnx')
+
+#
+
+class AverageMeter(object):
+    """Computes and stores the average and current value"""
+    def __init__(self):
+        self.reset()
+
+    def reset(self):
+        self.val = 0
+        self.avg = 0
+        self.sum = 0
+        self.count = 0
+
+    def update(self, val, n=1):
+        self.val = val
+        self.sum += val * n
+        self.count += n
+        self.avg = self.sum / self.count
+
+
+def adjust_learning_rate(args, optimizer, epoch):
+    """Sets the learning rate to the initial LR decayed by 10 every 30 epochs"""
+    cur_lr = args.cur_lr if hasattr(args, 'cur_lr') else args.lr
+
+    if (args.warmup_epochs is not None) and (epoch < (args.warmup_epochs-1)):
+        cur_lr = (epoch + 1) * args.lr / args.warmup_epochs
+    elif args.scheduler == 'poly':
+        epoch_frac = (args.epochs - epoch) / args.epochs
+        epoch_frac = max(epoch_frac, 0)
+        cur_lr = args.lr * (epoch_frac ** args.polystep_power)
+        for param_group in optimizer.param_groups:
+            param_group['lr'] = cur_lr
+        #
+    elif args.scheduler == 'step':                                            # step
+        num_milestones = 0
+        for m in args.milestones:
+            num_milestones += (1 if epoch >= m else 0)
+        #
+        cur_lr = args.lr * (args.multistep_gamma ** num_milestones)
+    elif args.scheduler == 'exponential':                                   # exponential
+        cur_lr = args.lr * (args.multistep_gamma ** (epoch//args.step_size))
+    elif args.scheduler == 'cosine':                                        # cosine
+        if epoch == 0:
+            cur_lr = args.lr
+        else:
+            lr_min = 0
+            cur_lr = (args.lr - lr_min) * (1 + math.cos(math.pi * epoch / args.epochs)) / 2.0  + lr_min
+        #
+    else:
+        ValueError('Unknown scheduler {}'.format(args.scheduler))
+    #
+    for param_group in optimizer.param_groups:
+        param_group['lr'] = cur_lr
+    #
+    return cur_lr
+
+
+def accuracy(output, target, topk=(1,)):
+    """Computes the precision@k for the specified values of k"""
+    with torch.no_grad():
+        maxk = max(topk)
+        batch_size = target.size(0)
+
+        _, pred = output.topk(maxk, 1, True, True)
+        pred = pred.t()
+        correct = pred.eq(target.view(1, -1).expand_as(pred))
+
+        res = []
+        for k in topk:
+            correct_k = correct[:k].view(-1).float().sum(0, keepdim=True)
+            res.append(correct_k.mul_(100.0 / batch_size))
+        return res
+
+
+def get_epoch_size(loader, args_epoch_size):
+    if args_epoch_size == 0:
+        epoch_size = len(loader)
+    elif args_epoch_size < 1:
+        epoch_size = int(len(loader) * args_epoch_size)
+    else:
+        epoch_size = min(len(loader), int(args_epoch_size))
+    #
+    return epoch_size
+    
+
+def get_train_transform(args):
+    normalize = vision.transforms.NormalizeMeanScale(mean=args.image_mean, scale=args.image_scale) \
+        if (args.image_mean is not None and args.image_scale is not None) else None
+    multi_color_transform = vision.transforms.MultiColor(args.multi_color_modes) if (args.multi_color_modes is not None) else None
+
+    train_resize_crop_transform = vision.transforms.RandomResizedCrop(size=args.img_crop, scale=args.rand_scale) \
+        if args.rand_scale else vision.transforms.RandomCrop(size=args.img_crop)
+    train_transform = vision.transforms.Compose([train_resize_crop_transform,
+                                                 vision.transforms.RandomHorizontalFlip(),
+                                                 multi_color_transform,
+                                                 vision.transforms.ToFloat(),
+                                                 vision.transforms.ToTensor(),
+                                                 normalize])
+    return train_transform
+
+def get_validation_transform(args):
+    normalize = vision.transforms.NormalizeMeanScale(mean=args.image_mean, scale=args.image_scale) \
+        if (args.image_mean is not None and args.image_scale is not None) else None
+    multi_color_transform = vision.transforms.MultiColor(args.multi_color_modes) if (args.multi_color_modes is not None) else None
+
+    # pass tuple to Resize() to resize to exact size without respecting aspect ratio (typical caffe style)
+    val_resize_crop_transform = vision.transforms.Resize(size=args.img_resize) if args.img_resize else vision.transforms.Bypass()
+    val_transform = vision.transforms.Compose([val_resize_crop_transform,
+                                               vision.transforms.CenterCrop(size=args.img_crop),
+                                               multi_color_transform,
+                                               vision.transforms.ToFloat(),
+                                               vision.transforms.ToTensor(),
+                                               normalize])
+    return val_transform
+
+def get_transforms(args):
+    train_transform = get_train_transform(args)
+    val_transform = get_validation_transform(args)
+    return train_transform, val_transform
+
+def get_data_loaders(args):
+    train_transform, val_transform = get_transforms(args)
+
+    train_dataset, val_dataset = vision.datasets.classification.__dict__[args.dataset_name](args.dataset_config, args.data_path, transforms=(train_transform,val_transform))
+
+    train_shuffle = (not args.distributed)
+    train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset) if args.distributed else None
+    train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=args.batch_size, shuffle=train_shuffle, num_workers=args.workers,
+        pin_memory=True, sampler=train_sampler)
+
+    val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.workers,
+                                             pin_memory=True, drop_last=False)
+
+    return train_loader, val_loader
+
+
+if __name__ == '__main__':
+    main()
diff --git a/modules/pytorch_jacinto_ai/engine/train_pixel2pixel.py b/modules/pytorch_jacinto_ai/engine/train_pixel2pixel.py
new file mode 100644 (file)
index 0000000..4589a1a
--- /dev/null
@@ -0,0 +1,1089 @@
+import os
+import shutil
+import time
+import math
+import copy
+
+import torch
+import torch.nn.parallel
+import torch.backends.cudnn as cudnn
+import torch.optim
+import torch.utils.data
+import torch.onnx
+import onnx
+
+import datetime
+from tensorboardX import SummaryWriter
+import numpy as np
+import random
+import cv2
+from colorama import Fore
+import progiter
+from packaging import version
+import warnings
+
+from .. import xnn
+from .. import vision
+
+
+##################################################
+warnings.filterwarnings('ignore', category=torch.jit.TracerWarning)
+
+##################################################
+def get_config():
+    args = xnn.utils.ConfigNode()
+
+    args.dataset_config = xnn.utils.ConfigNode()
+    args.dataset_config.split_name = 'val'
+    args.dataset_config.max_depth_bfr_scaling = 80
+    args.dataset_config.depth_scale = 1
+    args.dataset_config.train_depth_log = 1
+    args.use_semseg_for_depth = False
+
+    # model config
+    args.model_config = xnn.utils.ConfigNode()
+    args.model_config.output_type = ['segmentation']   # the network is used to predict flow or depth or sceneflow
+    args.model_config.output_channels = None            # number of output channels
+    args.model_config.input_channels = None             # number of input channels
+    args.model_config.output_range = None               # max range of output
+    args.model_config.num_decoders = None               # number of decoders to use. [options: 0, 1, None]
+    args.model_config.freeze_encoder = False            # do not update encoder weights
+    args.model_config.freeze_decoder = False            # do not update decoder weights
+    args.model_config.multi_task_type = 'learned'       # find out loss multiplier by learning, choices=[None, 'learned', 'uncertainty', 'grad_norm', 'dwa_grad_norm']
+    
+    args.model_name = 'deeplabv2lite_mobilenetv2'       # model architecture, overwritten if pretrained is specified
+    args.dataset_name = 'cityscapes_segmentation'       # dataset type
+    args.data_path = './data/cityscapes'                # 'path to dataset'
+    args.save_path = None                               # checkpoints save path
+    args.phase = 'training'                             # training/calibration/validation
+    args.date = None                                    # date to add to save path. if this is None, current date will be added.
+
+    args.logger = None                                  # logger stream to output into
+    args.show_gpu_usage = False                         # Shows gpu usage at the begining of each training epoch
+
+    args.split_file = None                              # train_val split file
+    args.split_files = None                             # split list files. eg: train.txt val.txt
+    args.split_value = None                             # test_val split proportion (between 0 (only test) and 1 (only train))
+
+    args.solver = 'adam'                                # solver algorithms, choices=['adam','sgd']
+    args.scheduler = 'step'                             # scheduler algorithms, choices=['step','poly', 'cosine']
+    args.workers = 8                                    # number of data loading workers
+
+    args.epochs = 250                                   # number of total epochs to run
+    args.start_epoch = 0                                # manual epoch number (useful on restarts)
+
+    args.epoch_size = 0                                 # manual epoch size (will match dataset size if not specified)
+    args.epoch_size_val = 0                             # manual epoch size (will match dataset size if not specified)
+    args.batch_size = 12                                # mini_batch size
+    args.total_batch_size = None                        # accumulated batch size. total_batch_size = batch_size*iter_size
+    args.iter_size = 1                                  # iteration size. total_batch_size = batch_size*iter_size
+
+    args.lr = 1e-4                                      # initial learning rate
+    args.lr_clips = None                                 # use args.lr itself if it is None
+    args.lr_calib = 0.1                                 # lr for bias calibration
+    args.warmup_epochs = 5                              # number of epochs to warmup
+
+    args.momentum = 0.9                                 # momentum for sgd, alpha parameter for adam
+    args.beta = 0.999                                   # beta parameter for adam
+    args.weight_decay = 1e-4                            # weight decay
+    args.bias_decay = None                              # bias decay
+
+    args.sparse = True                                  # avoid invalid/ignored target pixels from loss computation, use NEAREST for interpolation
+
+    args.tensorboard_num_imgs = 5                       # number of imgs to display in tensorboard
+    args.pretrained = None                              # path to pre_trained model
+    args.resume = None                                  # path to latest checkpoint (default: none)
+    args.no_date = False                                # don\'t append date timestamp to folder
+    args.print_freq = 100                               # print frequency (default: 100)
+
+    args.milestones = (100, 200)                        # epochs at which learning rate is divided by 2
+
+    args.losses = ['segmentation_loss']                 # loss functions to mchoices=['step','poly', 'cosine'],loss multiplication factor')
+    args.metrics = ['segmentation_metrics']  # metric/measurement/error functions for train/validation
+    args.multi_task_factors = None                      # loss mult factors
+    args.class_weights = None                           # class weights
+
+    args.loss_mult_factors = None                       # fixed loss mult factors - per loss - not: this is different from multi_task_factors (which is per task)
+
+    args.multistep_gamma = 0.5                          # steps for step scheduler
+    args.polystep_power = 1.0                           # power for polynomial scheduler
+    args.train_fwbw = False                             # do forward backward step while training
+
+    args.rand_seed = 1                                  # random seed
+    args.img_border_crop = None                         # image border crop rectangle. can be relative or absolute
+    args.target_mask = None                              # mask rectangle. can be relative or absolute. last value is the mask value
+
+    args.rand_resize = None                             # random image size to be resized to during training
+    args.rand_output_size = None                        # output size to be resized to during training
+    args.rand_scale = (1.0, 2.0)                        # random scale range for training
+    args.rand_crop = None                               # image size to be cropped to
+
+    args.img_resize = None                              # image size to be resized to during evaluation
+    args.output_size = None                             # target output size to be resized to
+
+    args.count_flops = True                             # count flops and report
+
+    args.shuffle = True                                 # shuffle or not
+
+    args.transform_rotation = 0.                        # apply rotation augumentation. value is rotation in degrees. 0 indicates no rotation
+    args.is_flow = None                                 # whether entries in images and targets lists are optical flow or not
+
+    args.upsample_mode = 'bilinear'                     # upsample mode to use, choices=['nearest','bilinear']
+
+    args.image_prenorm = True                           # whether normalization is done before all other the transforms
+    args.image_mean = (128.0,)                          # image mean for input image normalization
+    args.image_scale = (1.0 / (0.25 * 256),)            # image scaling/mult for input iamge normalization
+
+    args.max_depth = 80                                 # maximum depth to be used for visualization
+
+    args.pivot_task_idx = 0                             # task id to select best model
+
+    args.parallel_model = True                          # Usedata parallel for model
+    args.parallel_criterion = True                      # Usedata parallel for loss and metric
+
+    args.evaluate_start = True                          # evaluate right at the begining of training or not
+    args.generate_onnx = True                           # apply quantized inference or not
+    args.print_model = False                            # print the model to text
+    args.run_soon = True                                # To start training after generating configs/models
+
+    args.quantize = False                               # apply quantized inference or not
+    #args.model_surgery = None                           # replace activations with PAct2 activation module. Helpful in quantized training.
+    args.bitwidth_weights = 8                           # bitwidth for weights
+    args.bitwidth_activations = 8                       # bitwidth for activations
+    args.histogram_range = True                         # histogram range for calibration
+    args.bias_calibration = True                        # apply bias correction during quantized inference calibration
+    args.per_channel_q = False                          # apply separate quantizion factor for each channel in depthwise or not
+
+    args.save_mod_files = False                         # saves modified files after last commit. Also  stores commit id.
+    args.make_score_zero_mean = False                   # make score zero mean while learning
+    args.no_q_for_dws_layer_idx = 0                     # no_q_for_dws_layer_idx
+
+    args.viz_colormap = 'rainbow'                       # colormap for tensorboard: 'rainbow', 'plasma', 'magma', 'bone'
+
+    args.freeze_bn = False                              # freeze the statistics of bn
+
+    return args
+
+
+# ################################################
+# to avoid hangs in data loader with multi threads
+# this was observed after using cv2 image processing functions
+# https://github.com/pytorch/pytorch/issues/1355
+cv2.setNumThreads(0)
+
+# ################################################
+def main(args):
+    # ensure pytorch version is 1.2 or higher
+    assert version.parse(torch.__version__) >= version.parse('1.1'), \
+        'torch version must be 1.1 or higher, due to the change in scheduler.step() and optimiser.step() call order'
+
+    assert (not hasattr(args, 'evaluate')), 'args.evaluate is deprecated. use args.phase=training or calibration or validation'
+    assert args.phase in ('training', 'calibration', 'validation'), f'invalid phase {args.phase}'
+    assert not hasattr(args, 'model_surgery'), 'the argument model_surgery is deprecated, it is not needed now - remove it'
+
+    if (args.phase == 'validation' and args.bias_calibration):
+        args.bias_calibration = False
+        warnings.warn('switching off bias calibration in validation')
+    #
+
+    #################################################
+    args.rand_resize = args.img_resize if args.rand_resize is None else args.rand_resize
+    args.rand_crop = args.img_resize if args.rand_crop is None else args.rand_crop
+    args.output_size = args.img_resize if args.output_size is None else args.output_size
+    # resume has higher priority
+    args.pretrained = None if (args.resume is not None) else args.pretrained
+
+    if args.save_path is None:
+        save_path = get_save_path(args)
+    else:
+        save_path = args.save_path
+    #
+    if not os.path.exists(save_path):
+        os.makedirs(save_path)
+
+    if args.save_mod_files:
+        #store all the files after the last commit.
+        mod_files_path = save_path+'/mod_files'
+        os.makedirs(mod_files_path)
+        
+        cmd = "git ls-files --modified | xargs -i cp {} {}".format("{}", mod_files_path)
+        print("cmd:", cmd)    
+        os.system(cmd)
+
+        #stoe last commit id. 
+        cmd = "git log -n 1  >> {}".format(mod_files_path + '/commit_id.txt')
+        print("cmd:", cmd)    
+        os.system(cmd)
+
+    #################################################
+    if args.logger is None:
+        log_file = os.path.splitext(os.path.basename(__file__))[0] + '.log'
+        args.logger = xnn.utils.TeeLogger(filename=os.path.join(save_path,log_file))
+
+    #################################################
+    # global settings. rand seeds for repeatability
+    random.seed(args.rand_seed)
+    np.random.seed(args.rand_seed)
+    torch.manual_seed(args.rand_seed)
+    torch.cuda.manual_seed(args.rand_seed)
+
+    ################################
+    # args check and config
+    assert args.iter_size == 1 or args.total_batch_size is None, "only one of --iter-size or --total-batch-size must be set"
+    if args.total_batch_size is not None:
+        args.iter_size = args.total_batch_size//args.batch_size
+    else:
+        args.total_batch_size = args.batch_size*args.iter_size
+
+    #################################################
+    # set some global flags and initializations
+    # keep it in args for now - although they don't belong here strictly
+    # using pin_memory is seen to cause issues, especially when when lot of memory is used.
+    args.use_pinned_memory = False
+    args.n_iter = 0
+    args.best_metric = -1
+    cudnn.benchmark = True
+    # torch.autograd.set_detect_anomaly(True)
+
+    ################################
+    # reset character color, in case it is different
+    print('{}'.format(Fore.RESET))
+    # print everything for log
+    print('=> args: {}'.format(args))
+    print('=> will save everything to {}'.format(save_path))
+
+    #################################################
+    train_writer = SummaryWriter(os.path.join(save_path,'train'))
+    val_writer = SummaryWriter(os.path.join(save_path,'val'))
+    transforms = get_transforms(args)
+
+    print("=> fetching images in '{}'".format(args.data_path))
+    split_arg = args.split_file if args.split_file else (args.split_files if args.split_files else args.split_value)
+    train_dataset, val_dataset = vision.datasets.pixel2pixel.__dict__[args.dataset_name](args.dataset_config, args.data_path, split=split_arg, transforms=transforms)
+
+    #################################################
+    train_sampler = None
+    val_sampler = None
+    print('=> {} samples found, {} train samples and {} test samples '.format(len(train_dataset)+len(val_dataset),
+        len(train_dataset), len(val_dataset)))
+    train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=args.batch_size,
+        num_workers=args.workers, pin_memory=args.use_pinned_memory, sampler=train_sampler, shuffle=args.shuffle)
+
+    val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=args.batch_size,
+        num_workers=args.workers, pin_memory=args.use_pinned_memory, sampler=val_sampler, shuffle=args.shuffle)
+
+    #################################################
+    if (args.model_config.input_channels is None):
+        args.model_config.input_channels = (3,)
+        print("=> input channels is not given - setting to {}".format(args.model_config.input_channels))
+
+    if (args.model_config.output_channels is None):
+        if ('num_classes' in dir(train_dataset)):
+            args.model_config.output_channels = train_dataset.num_classes()
+        else:
+            args.model_config.output_channels = (2 if args.model_config.output_type == 'flow' else args.model_config.output_channels)
+            xnn.utils.print_yellow("=> output channels is not given - setting to {} - not sure to work".format(args.model_config.output_channels))
+        #
+        if not isinstance(args.model_config.output_channels,(list,tuple)):
+            args.model_config.output_channels = [args.model_config.output_channels]
+
+    if (args.class_weights is None) and ('class_weights' in dir(train_dataset)):
+        args.class_weights = train_dataset.class_weights()
+        if not isinstance(args.class_weights, (list,tuple)):
+            args.class_weights = [args.class_weights]
+        #
+        print("=> class weights available for dataset: {}".format(args.class_weights))
+
+    #################################################
+    pretrained_data = None
+    model_surgery_quantize = False
+    if args.pretrained and args.pretrained != "None":
+        if args.pretrained.startswith('http://') or args.pretrained.startswith('https://'):
+            pretrained_file = vision.datasets.utils.download_url(args.pretrained, './data/downloads')
+        else:
+            pretrained_file = args.pretrained
+        #
+        print(f'=> using pre-trained weights from: {args.pretrained}')
+        pretrained_data = torch.load(pretrained_file)
+        model_surgery_quantize = pretrained_data['quantize'] if 'quantize' in pretrained_data else False
+    #
+
+    #################################################
+    # create model
+    xnn.utils.print_yellow("=> creating model '{}'".format(args.model_name))
+    model = vision.models.pixel2pixel.__dict__[args.model_name](args.model_config)
+
+    # check if we got the model as well as parameters to change the names in pretrained
+    model, change_names_dict = model if isinstance(model, (list,tuple)) else (model,None)
+
+    if args.quantize:
+        # dummy input is used by quantized models to analyze graph
+        is_cuda = next(model.parameters()).is_cuda
+        dummy_input = create_rand_inputs(args, is_cuda=is_cuda)
+        #
+        if args.phase == 'training':
+            model = xnn.quantize.QuantTrainModule(model, per_channel_q=args.per_channel_q,
+                        histogram_range=args.histogram_range, bitwidth_weights=args.bitwidth_weights, bitwidth_activations=args.bitwidth_activations, dummy_input=dummy_input)
+        elif args.phase == 'calibration':
+            model = xnn.quantize.QuantCalibrateModule(model, per_channel_q=args.per_channel_q,
+                        bitwidth_weights=args.bitwidth_weights, bitwidth_activations=args.bitwidth_activations,
+                        histogram_range=args.histogram_range, bias_calibration=args.bias_calibration, lr_calib=args.lr_calib, dummy_input=dummy_input)
+        elif args.phase == 'validation':
+            # Note: bias_calibration is not emabled
+            model = xnn.quantize.QuantTestModule(model, per_channel_q=args.per_channel_q,
+                        bitwidth_weights=args.bitwidth_weights, bitwidth_activations=args.bitwidth_activations,
+                        histogram_range=args.histogram_range, model_surgery_quantize=model_surgery_quantize,
+                        dummy_input=dummy_input)
+        else:
+            assert False, f'invalid phase {args.phase}'
+    #
+
+    # load pretrained model
+    xnn.utils.load_weights_check(get_model_orig(model), pretrained=pretrained_data, change_names_dict=change_names_dict)
+
+    #################################################
+    if args.count_flops:
+        count_flops(args, model)
+
+    #################################################
+    if args.generate_onnx and ((args.phase in ('training','calibration')) or (args.run_soon == False)):
+        write_onnx_model(args, get_model_orig(model), save_path)
+    #
+
+    #################################################
+    if args.print_model:
+        print(model)
+        print('\n')
+    else:
+        args.logger.debug(str(model))
+        args.logger.debug('\n')
+
+    #################################################
+    if (not args.run_soon):
+        print("Training not needed for now")
+        close(args)
+        exit()
+
+    #################################################
+    # multi gpu mode does not work for calibration/training for quantization
+    # so use it only when args.quantize is False
+    if args.parallel_model and ((not args.quantize)):
+        model = torch.nn.DataParallel(model)
+
+    #################################################
+    model = model.cuda()
+
+    #################################################
+    # for help in debug/print
+    for name, module in model.named_modules():
+        module.name = name
+
+    #################################################
+    args.loss_modules = copy.deepcopy(args.losses)
+    for task_dx, task_losses in enumerate(args.losses):
+        for loss_idx, loss_fn in enumerate(task_losses):
+            kw_args = {}
+            loss_args = vision.losses.__dict__[loss_fn].args()
+            for arg in loss_args:
+                if arg == 'weight' and (args.class_weights is not None):
+                    kw_args.update({arg:args.class_weights[task_dx]})
+                elif arg == 'num_classes':
+                    kw_args.update({arg:args.model_config.output_channels[task_dx]})
+                elif arg == 'sparse':
+                    kw_args.update({arg:args.sparse})
+                #
+            #
+            loss_fn_raw = vision.losses.__dict__[loss_fn](**kw_args)
+            if args.parallel_criterion:
+                loss_fn = torch.nn.DataParallel(loss_fn_raw).cuda() if args.parallel_criterion else loss_fn_raw.cuda()
+                loss_fn.info = loss_fn_raw.info
+                loss_fn.clear = loss_fn_raw.clear
+            else:
+                loss_fn = loss_fn_raw.cuda()
+            #
+            args.loss_modules[task_dx][loss_idx] = loss_fn
+    #
+
+    args.metric_modules = copy.deepcopy(args.metrics)
+    for task_dx, task_metrics in enumerate(args.metrics):
+        for midx, metric_fn in enumerate(task_metrics):
+            kw_args = {}
+            loss_args = vision.losses.__dict__[metric_fn].args()
+            for arg in loss_args:
+                if arg == 'weight':
+                    kw_args.update({arg:args.class_weights[task_dx]})
+                elif arg == 'num_classes':
+                    kw_args.update({arg:args.model_config.output_channels[task_dx]})
+                elif arg == 'sparse':
+                    kw_args.update({arg:args.sparse})
+
+            metric_fn_raw = vision.losses.__dict__[metric_fn](**kw_args)
+            if args.parallel_criterion:
+                metric_fn = torch.nn.DataParallel(metric_fn_raw).cuda()
+                metric_fn.info = metric_fn_raw.info
+                metric_fn.clear = metric_fn_raw.clear
+            else:
+                metric_fn = metric_fn_raw.cuda()
+            #
+            args.metric_modules[task_dx][midx] = metric_fn
+    #
+
+    #################################################
+    if args.phase=='validation':
+        with torch.no_grad():
+            validate(args, val_dataset, val_loader, model, 0, val_writer)
+        #
+        close(args)
+        return
+
+    #################################################
+    assert(args.solver in ['adam', 'sgd'])
+    print('=> setting {} solver'.format(args.solver))
+    if args.lr_clips is not None:
+        learning_rate_clips = args.lr_clips if args.phase == 'training' else 0.0
+        clips_decay = args.bias_decay if (args.bias_decay is not None and args.bias_decay != 0.0) else args.weight_decay
+        clips_params = [p for n,p in model.named_parameters() if 'clips' in n]
+        other_params = [p for n,p in model.named_parameters() if 'clips' not in n]
+        param_groups = [{'params': clips_params, 'weight_decay': clips_decay, 'lr': learning_rate_clips},
+                        {'params': other_params, 'weight_decay': args.weight_decay}]
+    else:
+        param_groups = [{'params': filter(lambda p: p.requires_grad, model.parameters()), 'weight_decay': args.weight_decay}]
+    #
+
+    learning_rate = args.lr if (args.phase == 'training') else 0.0
+    if args.solver == 'adam':
+        optimizer = torch.optim.Adam(param_groups, learning_rate, betas=(args.momentum, args.beta))
+    elif args.solver == 'sgd':
+        optimizer = torch.optim.SGD(param_groups, learning_rate, momentum=args.momentum)
+    else:
+        raise ValueError('Unknown optimizer type{}'.format(args.solver))
+    #
+
+    #################################################
+    epoch_size = get_epoch_size(args, train_loader, args.epoch_size)
+    max_iter = args.epochs * epoch_size
+    scheduler = xnn.optim.lr_scheduler.SchedulerWrapper(args.scheduler, optimizer, args.epochs, args.start_epoch, \
+                                                            args.warmup_epochs, max_iter=max_iter, polystep_power=args.polystep_power,
+                                                            milestones=args.milestones, multistep_gamma=args.multistep_gamma)
+
+    # optionally resume from a checkpoint
+    if args.resume:
+        if not os.path.isfile(args.resume):
+            print("=> no checkpoint found at '{}'".format(args.resume))        
+        else:
+            print("=> loading checkpoint '{}'".format(args.resume))
+
+        checkpoint = torch.load(args.resume)
+        model = xnn.utils.load_weights_check(model, checkpoint)
+            
+        if args.start_epoch == 0:
+            args.start_epoch = checkpoint['epoch']
+        
+        if 'best_metric' in list(checkpoint.keys()):    
+            args.best_metric = checkpoint['best_metric']
+
+        if 'optimizer' in list(checkpoint.keys()):  
+            optimizer.load_state_dict(checkpoint['optimizer'])
+
+        if 'scheduler' in list(checkpoint.keys()):
+            scheduler.load_state_dict(checkpoint['scheduler'])
+
+        if 'multi_task_factors' in list(checkpoint.keys()):
+            args.multi_task_factors = checkpoint['multi_task_factors']
+
+        print("=> loaded checkpoint '{}' (epoch {})".format(args.resume, checkpoint['epoch']))
+
+    #################################################
+    if args.evaluate_start:
+        with torch.no_grad():
+            validate(args, val_dataset, val_loader, model, args.start_epoch, val_writer)
+
+    for epoch in range(args.start_epoch, args.epochs):
+        if train_sampler:
+            train_sampler.set_epoch(epoch)
+        if val_sampler:
+            val_sampler.set_epoch(epoch)
+
+        # train for one epoch
+        train(args, train_dataset, train_loader, model, optimizer, epoch, train_writer, scheduler)
+
+        # evaluate on validation set
+        with torch.no_grad():
+            val_metric, metric_name = validate(args, val_dataset, val_loader, model, epoch, val_writer)
+
+        if args.best_metric < 0:
+            args.best_metric = val_metric
+
+        if "iou" in metric_name.lower() or "acc" in metric_name.lower():
+            is_best = val_metric >= args.best_metric
+            args.best_metric = max(val_metric, args.best_metric)
+        elif "error" in metric_name.lower() or "diff" in metric_name.lower() or "norm" in metric_name.lower() \
+                or "loss" in metric_name.lower() or "outlier" in metric_name.lower():
+            is_best = val_metric <= args.best_metric
+            args.best_metric = min(val_metric, args.best_metric)
+        else:
+            raise ValueError("Metric is not known. Best model could not be saved.")
+        #
+
+        checkpoint_dict = { 'epoch': epoch + 1, 'model_name': args.model_name,
+                            'state_dict': get_model_orig(model).state_dict(),
+                            'optimizer': optimizer.state_dict(),
+                            'scheduler': scheduler.state_dict(),
+                            'best_metric': args.best_metric,
+                            'multi_task_factors': args.multi_task_factors,
+                            'quantize' : args.quantize}
+
+        save_checkpoint(args, save_path, get_model_orig(model), checkpoint_dict, is_best)
+        
+        train_writer.file_writer.flush()
+        val_writer.file_writer.flush()
+
+        # adjust the learning rate using lr scheduler
+        if args.phase == 'training':
+            scheduler.step()
+        #
+    #
+
+    # close and cleanup
+    close(args)
+#
+
+
+
+###################################################################
+def train(args, train_dataset, train_loader, model, optimizer, epoch, train_writer, scheduler):
+    batch_time = xnn.utils.AverageMeter()
+    data_time = xnn.utils.AverageMeter()
+    # if the loss/ metric is already an average, no need to further average
+    avg_loss = [xnn.utils.AverageMeter(print_avg=(not task_loss[0].info()['is_avg'])) for task_loss in args.loss_modules]
+    avg_loss_orig = [xnn.utils.AverageMeter(print_avg=(not task_loss[0].info()['is_avg'])) for task_loss in args.loss_modules]
+    avg_metric = [xnn.utils.AverageMeter(print_avg=(not task_metric[0].info()['is_avg'])) for task_metric in args.metric_modules]
+    epoch_size = get_epoch_size(args, train_loader, args.epoch_size)
+
+    ##########################
+    # switch to train mode
+    model.train()
+    if args.freeze_bn:
+        xnn.utils.freeze_bn(model)
+    #
+
+    ##########################
+    for task_dx, task_losses in enumerate(args.loss_modules):
+        for loss_idx, loss_fn in enumerate(task_losses):
+            loss_fn.clear()
+    for task_dx, task_metrics in enumerate(args.metric_modules):
+        for midx, metric_fn in enumerate(task_metrics):
+            metric_fn.clear()
+
+    progress_bar = progiter.ProgIter(np.arange(epoch_size), chunksize=1)
+    metric_name = "Metric"
+    metric_ctx = [None] * len(args.metric_modules)
+    end_time = time.time()
+    writer_idx = 0
+    last_update_iter = -1
+
+    # change color to yellow for calibration
+    progressbar_color = (Fore.YELLOW if args.phase=='calibration' else Fore.WHITE)
+    print('{}'.format(progressbar_color), end='')
+
+    ##########################
+    for iter, (inputs, targets) in enumerate(train_loader):
+        # measure data loading time
+        data_time.update(time.time() - end_time)
+
+        lr = scheduler.get_lr()[0]
+
+        input_list = [img.cuda() for img in inputs]
+        target_list = [tgt.cuda(non_blocking=True) for tgt in targets]
+        target_sizes = [tgt.shape for tgt in target_list]
+        batch_size_cur = target_sizes[0][0]
+
+        ##########################
+        # compute output
+        task_outputs = model(input_list)
+        # upsample output to target resolution
+        task_outputs = upsample_tensors(task_outputs, target_sizes, args.upsample_mode)
+
+        if args.model_config.multi_task_type is not None and len(args.model_config.output_channels) > 1:
+            args.multi_task_factors, args.multi_task_offsets = xnn.layers.get_loss_scales(model)
+        else:
+            args.multi_task_factors = None
+            args.multi_task_offsets = None
+
+        loss_total, loss_list, loss_names, loss_types, loss_list_orig = \
+            compute_task_objectives(args, args.loss_modules, input_list, task_outputs, target_list,
+                         task_mults=args.multi_task_factors, task_offsets=args.multi_task_offsets,
+                         loss_mult_factors=args.loss_mult_factors)
+
+        metric_total, metric_list, metric_names, metric_types, _ = \
+            compute_task_objectives(args, args.metric_modules, input_list, task_outputs, target_list)
+
+        if args.model_config.multi_task_type is not None and len(args.model_config.output_channels) > 1:
+            xnn.layers.set_losses(model, loss_list_orig)
+
+        if args.phase == 'training':
+            # zero gradients so that we can accumulate gradients
+            if (iter % args.iter_size) == 0:
+                optimizer.zero_grad()
+
+            # accumulate gradients
+            loss_total.backward()
+            # optimization step
+            if ((iter+1) % args.iter_size) == 0:
+                optimizer.step()
+        #
+
+        # record loss.
+        for task_idx, task_losses in enumerate(args.loss_modules):
+            avg_loss[task_idx].update(float(loss_list[task_idx].cpu()), batch_size_cur)
+            avg_loss_orig[task_idx].update(float(loss_list_orig[task_idx].cpu()), batch_size_cur)
+            train_writer.add_scalar('Training/Task{}_{}_Loss_Iter'.format(task_idx,loss_names[task_idx]), float(loss_list[task_idx]), args.n_iter)
+            if args.model_config.multi_task_type is not None and len(args.model_config.output_channels) > 1:
+                train_writer.add_scalar('Training/multi_task_Factor_Task{}_{}'.format(task_idx,loss_names[task_idx]), float(args.multi_task_factors[task_idx]), args.n_iter)
+
+        # record error/accuracy.
+        for task_idx, task_metrics in enumerate(args.metric_modules):
+            avg_metric[task_idx].update(float(metric_list[task_idx].cpu()), batch_size_cur)
+
+        ##########################
+        write_output(args, 'Training_', epoch_size, iter, epoch, train_dataset, train_writer, input_list, task_outputs, target_list, metric_name, writer_idx)
+        if ((iter % args.print_freq) == 0) or (iter == (epoch_size-1)):
+            output_string = ''
+            for task_idx, task_metrics in enumerate(args.metric_modules):
+                output_string += '[{}={}]'.format(metric_names[task_idx], str(avg_metric[task_idx]))
+
+            epoch_str = '{}/{}'.format(epoch + 1, args.epochs)
+            progress_bar.set_description("{}=> {}  ".format(progressbar_color, args.phase))
+            multi_task_factors_print = ['{:.3f}'.format(float(lmf)) for lmf in args.multi_task_factors] if args.multi_task_factors is not None else None
+            progress_bar.set_postfix(Epoch=epoch_str, LR=lr, DataTime=str(data_time), LossMult=multi_task_factors_print, Loss=avg_loss, Output=output_string)
+            progress_bar.update(iter-last_update_iter)
+            last_update_iter = iter
+
+        args.n_iter += 1
+        end_time = time.time()
+        writer_idx = (writer_idx + 1) % args.tensorboard_num_imgs
+
+        # add onnx graph to tensorboard
+        # commenting out due to issues in transitioning to pytorch 0.4
+        # (bilinear mode in upsampling causes hang or crash - may be due to align_borders change, nearest is fine)
+        #if epoch == 0 and iter == 0:
+        #    input_zero = torch.zeros(input_var.shape)
+        #    train_writer.add_graph(model, input_zero)
+
+        torch.cuda.empty_cache()
+
+        if iter >= epoch_size:
+            break
+
+    progress_bar.close()
+
+    # to print a new line - do not provide end=''
+    print('{}'.format(Fore.RESET), end='')
+
+    for task_idx, task_losses in enumerate(args.loss_modules):
+        train_writer.add_scalar('Training/Task{}_{}_Loss_Epoch'.format(task_idx,loss_names[task_idx]), float(avg_loss[task_idx]), epoch)
+
+    for task_idx, task_metrics in enumerate(args.metric_modules):
+        train_writer.add_scalar('Training/Task{}_{}_Metric_Epoch'.format(task_idx,metric_names[task_idx]), float(avg_metric[task_idx]), epoch)
+
+    output_name = metric_names[args.pivot_task_idx]
+    output_metric = float(avg_metric[args.pivot_task_idx])
+
+    ##########################
+    if args.quantize:
+        def debug_format(v):
+            return ('{:.3f}'.format(v) if v is not None else 'None')
+        #
+        clips_act = [m.get_clips_act()[1] for n,m in model.named_modules() if isinstance(m,xnn.layers.PAct2)]
+        if len(clips_act) > 0:
+            args.logger.debug('\nclips_act : ' + ' '.join(map(debug_format, clips_act)))
+            args.logger.debug('')
+    #
+    return output_metric, output_name
+
+
+###################################################################
+def validate(args, val_dataset, val_loader, model, epoch, val_writer):
+    data_time = xnn.utils.AverageMeter()
+    # if the loss/ metric is already an average, no need to further average
+    avg_metric = [xnn.utils.AverageMeter(print_avg=(not task_metric[0].info()['is_avg'])) for task_metric in args.metric_modules]
+    epoch_size = get_epoch_size(args, val_loader, args.epoch_size_val)
+
+    ##########################
+    # switch to evaluate mode
+    model.eval()
+
+    ##########################
+    for task_dx, task_metrics in enumerate(args.metric_modules):
+        for midx, metric_fn in enumerate(task_metrics):
+            metric_fn.clear()
+
+    metric_name = "Metric"
+    end_time = time.time()
+    writer_idx = 0
+    last_update_iter = -1
+    metric_ctx = [None] * len(args.metric_modules)
+    progress_bar = progiter.ProgIter(np.arange(epoch_size), chunksize=1)
+
+    # change color to green
+    print('{}'.format(Fore.GREEN), end='')
+
+    ##########################
+    for iter, (inputs, targets) in enumerate(val_loader):
+        data_time.update(time.time() - end_time)
+        input_list = [j.cuda() for j in inputs]
+        target_list = [j.cuda(non_blocking=True) for j in targets]
+        target_sizes = [tgt.shape for tgt in target_list]
+        batch_size_cur = target_sizes[0][0]
+
+        # compute output
+        task_outputs = model(input_list)
+
+        task_outputs = upsample_tensors(task_outputs, target_sizes, args.upsample_mode)
+
+        metric_total, metric_list, metric_names, metric_types, _ = \
+            compute_task_objectives(args, args.metric_modules, input_list, task_outputs, target_list)
+
+        # record error/accuracy.
+        for task_idx, task_metrics in enumerate(args.metric_modules):
+            avg_metric[task_idx].update(float(metric_list[task_idx].cpu()), batch_size_cur)
+
+        write_output(args, 'Validation_', epoch_size, iter, epoch, val_dataset, val_writer, input_list, task_outputs, target_list, metric_names, writer_idx)
+
+        if ((iter % args.print_freq) == 0) or (iter == (epoch_size-1)):
+            output_string = ''
+            for task_idx, task_metrics in enumerate(args.metric_modules):
+                output_string += '[{}={}]'.format(metric_names[task_idx], str(avg_metric[task_idx]))
+
+            epoch_str = '{}/{}'.format(epoch + 1, args.epochs)
+            progress_bar.set_description("=> validation")
+            progress_bar.set_postfix(Epoch=epoch_str, DataTime=data_time, Output="{}".format(output_string))
+            progress_bar.update(iter-last_update_iter)
+            last_update_iter = iter
+
+        end_time = time.time()
+        writer_idx = (writer_idx + 1) % args.tensorboard_num_imgs
+
+        if iter >= epoch_size:
+            break
+
+    progress_bar.close()
+
+    # to print a new line - do not provide end=''
+    print('{}'.format(Fore.RESET), end='')
+
+    for task_idx, task_metrics in enumerate(args.metric_modules):
+        val_writer.add_scalar('Validation/Task{}_{}_Metric_Epoch'.format(task_idx,metric_names[task_idx]), float(avg_metric[task_idx]), epoch)
+
+    output_name = metric_names[args.pivot_task_idx]
+    output_metric = float(avg_metric[args.pivot_task_idx])
+    return output_metric, output_name
+
+
+###################################################################
+def close(args):
+    if args.logger is not None:
+        del args.logger
+        args.logger = None
+    #
+    args.best_metric = -1
+#
+
+
+def get_save_path(args, phase=None):
+    date = args.date if args.date else datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
+    save_path = os.path.join('./data/checkpoints', args.dataset_name, date + '_' + args.dataset_name + '_' + args.model_name)
+    save_path += '_resize{}x{}_traincrop{}x{}'.format(args.img_resize[1], args.img_resize[0], args.rand_crop[1], args.rand_crop[0])
+    phase = phase if (phase is not None) else args.phase
+    save_path = os.path.join(save_path, phase)
+    return save_path
+
+
+def get_model_orig(model):
+    is_parallel_model = isinstance(model, (torch.nn.DataParallel, torch.nn.parallel.DistributedDataParallel))
+    model_orig = (model.module if is_parallel_model else model)
+    model_orig = (model_orig.module if isinstance(model_orig, (xnn.quantize.QuantBaseModule)) else model_orig)
+    return model_orig
+
+
+def create_rand_inputs(args, is_cuda):
+    dummy_input = []
+    for i_ch in args.model_config.input_channels:
+        x = torch.rand((1, i_ch, args.img_resize[0], args.img_resize[1]))
+        x = x.cuda() if is_cuda else x
+        dummy_input.append(x)
+    #
+    return dummy_input
+
+
+def count_flops(args, model):
+    is_cuda = next(model.parameters()).is_cuda
+    dummy_input = create_rand_inputs(args, is_cuda)
+    #
+    model.eval()
+    flops = xnn.utils.forward_count_flops(model, dummy_input)
+    gflops = flops/1e9
+    print('=> Size = {}, GFLOPs = {}, GMACs = {}'.format(args.img_resize, gflops, gflops/2))
+
+
+def derive_node_name(input_name):
+    #take last entry of input names for deciding node name
+    #print("input_name[-1]: ", input_name[-1])
+    node_name = input_name[-1].rsplit('.', 1)[0]
+    #print("formed node_name: ", node_name)
+    return node_name
+
+
+#torch onnx export does not update names. Do it using onnx.save
+def add_node_names(onnx_model_name= []):
+    onnx_model = onnx.load(onnx_model_name)
+    for i in range(len(onnx_model.graph.node)):
+        for j in range(len(onnx_model.graph.node[i].input)):
+            #print('-'*60)
+            #print("name: ", onnx_model.graph.node[i].name)
+            #print("input: ", onnx_model.graph.node[i].input)
+            #print("output: ", onnx_model.graph.node[i].output)
+            onnx_model.graph.node[i].input[j] = onnx_model.graph.node[i].input[j].split(':')[0]
+            onnx_model.graph.node[i].name = derive_node_name(onnx_model.graph.node[i].input)
+        #
+    #
+    #update model inplace
+    onnx.save(onnx_model, onnx_model_name)
+
+def write_onnx_model(args, model, save_path, name='checkpoint.onnx'):
+    is_cuda = next(model.parameters()).is_cuda
+    input_list = create_rand_inputs(args, is_cuda=is_cuda)
+    #
+    model.eval()
+    torch.onnx.export(model, input_list, os.path.join(save_path, name), export_params=True, verbose=False)
+    #torch onnx export does not update names. Do it using onnx.save
+    add_node_names(onnx_model_name = os.path.join(save_path, name))
+
+
+###################################################################
+def write_output(args, prefix, val_epoch_size, iter, epoch, dataset, output_writer, input_images, task_outputs, task_targets, metric_names, writer_idx):
+    write_freq = (args.tensorboard_num_imgs / float(val_epoch_size))
+    write_prob = np.random.random()
+    if (write_prob > write_freq):
+        return
+
+    batch_size = input_images[0].shape[0]
+    b_index = random.randint(0, batch_size - 1)
+
+    input_image = None
+    for img_idx, img in enumerate(input_images):
+        input_image = input_images[img_idx][b_index].cpu().numpy().transpose((1, 2, 0))
+        # convert back to original input range (0-255)
+        input_image = input_image / args.image_scale + args.image_mean
+        if args.is_flow and args.is_flow[0][img_idx]:
+            #input corresponding to flow is assumed to have been generated by adding 128
+            flow = input_image - 128
+            flow_hsv = xnn.utils.flow2hsv(flow.transpose(2, 0, 1), confidence=False).transpose(2, 0, 1)
+            #flow_hsv = (flow_hsv / 255.0).clip(0, 1) #TODO: check this
+            output_writer.add_image(prefix +'Input{}/{}'.format(img_idx, writer_idx), flow_hsv, epoch)
+        else:
+            input_image = (input_image/255.0).clip(0,1) #.astype(np.uint8)
+            output_writer.add_image(prefix + 'Input{}/{}'.format(img_idx, writer_idx), input_image.transpose((2,0,1)), epoch)
+
+    # for sparse data, chroma blending does not look good
+    for task_idx, output_type in enumerate(args.model_config.output_type):
+        # metric_name = metric_names[task_idx]
+        output = task_outputs[task_idx]
+        target = task_targets[task_idx]
+        if (output_type == 'segmentation') and hasattr(dataset, 'decode_segmap'):
+            segmentation_target = dataset.decode_segmap(target[b_index,0].cpu().numpy())
+            segmentation_output = output.max(dim=1,keepdim=True)[1].data.cpu().numpy() if(output.shape[1]>1) else output.data.cpu().numpy()
+            segmentation_output = dataset.decode_segmap(segmentation_output[b_index,0])
+            segmentation_output_blend = xnn.utils.chroma_blend(input_image, segmentation_output)
+            #
+            output_writer.add_image(prefix+'Task{}_{}_GT/{}'.format(task_idx,output_type,writer_idx), segmentation_target.transpose(2,0,1), epoch)
+            if not args.sparse:
+                segmentation_target_blend = xnn.utils.chroma_blend(input_image, segmentation_target)
+                output_writer.add_image(prefix + 'Task{}_{}_GT_ColorBlend/{}'.format(task_idx, output_type, writer_idx), segmentation_target_blend.transpose(2, 0, 1), epoch)
+            #
+            output_writer.add_image(prefix+'Task{}_{}_Output/{}'.format(task_idx,output_type,writer_idx), segmentation_output.transpose(2,0,1), epoch)
+            output_writer.add_image(prefix+'Task{}_{}_Output_ColorBlend/{}'.format(task_idx,output_type,writer_idx), segmentation_output_blend.transpose(2,0,1), epoch)
+        elif (output_type in ('depth', 'disparity')):
+            depth_chanidx = 0
+            output_writer.add_image(prefix+'Task{}_{}_GT_Color_Visualization/{}'.format(task_idx,output_type,writer_idx), xnn.utils.tensor2array(target[b_index][depth_chanidx].cpu(), max_value=args.max_depth, colormap=args.viz_colormap).transpose(2,0,1), epoch)
+            if not args.sparse:
+                output_writer.add_image(prefix + 'Task{}_{}_GT_ColorBlend_Visualization/{}'.format(task_idx, output_type, writer_idx), xnn.utils.tensor2array(target[b_index][depth_chanidx].cpu(), max_value=args.max_depth, colormap=args.viz_colormap, input_blend=input_image).transpose(2, 0, 1), epoch)
+            #
+            output_writer.add_image(prefix+'Task{}_{}_Output_Color_Visualization/{}'.format(task_idx,output_type,writer_idx), xnn.utils.tensor2array(output.data[b_index][depth_chanidx].cpu(), max_value=args.max_depth, colormap=args.viz_colormap).transpose(2,0,1), epoch)
+            output_writer.add_image(prefix + 'Task{}_{}_Output_ColorBlend_Visualization/{}'.format(task_idx, output_type, writer_idx),xnn.utils.tensor2array(output.data[b_index][depth_chanidx].cpu(), max_value=args.max_depth, colormap=args.viz_colormap, input_blend=input_image).transpose(2, 0, 1), epoch)
+        elif (output_type == 'flow'):
+            max_value_flow = 10.0 # only for visualization
+            output_writer.add_image(prefix+'Task{}_{}_GT/{}'.format(task_idx,output_type,writer_idx), xnn.utils.flow2hsv(target[b_index][:2].cpu().numpy(), max_value=max_value_flow).transpose(2,0,1), epoch)
+            output_writer.add_image(prefix+'Task{}_{}_Output/{}'.format(task_idx,output_type,writer_idx), xnn.utils.flow2hsv(output.data[b_index][:2].cpu().numpy(), max_value=max_value_flow).transpose(2,0,1), epoch)
+        elif (output_type == 'interest_pt'):
+            score_chanidx = 0
+            target_score_to_write = target[b_index][score_chanidx].cpu()
+            output_score_to_write = output.data[b_index][score_chanidx].cpu()
+            
+            #if score is learnt as zero mean add offset to make it [0-255]
+            if args.make_score_zero_mean:
+                # target_score_to_write!=0 : value 0 indicates GT unavailble. Leave them to be 0.
+                target_score_to_write[target_score_to_write!=0] += 128.0
+                output_score_to_write += 128.0
+
+            max_value_score = float(torch.max(target_score_to_write)) #0.002
+            output_writer.add_image(prefix+'Task{}_{}_GT_Bone_Visualization/{}'.format(task_idx,output_type,writer_idx), xnn.utils.tensor2array(target_score_to_write, max_value=max_value_score, colormap='bone').transpose(2,0,1), epoch)
+            output_writer.add_image(prefix+'Task{}_{}_Output_Bone_Visualization/{}'.format(task_idx,output_type,writer_idx), xnn.utils.tensor2array(output_score_to_write, max_value=max_value_score, colormap='bone').transpose(2,0,1), epoch)
+        #
+
+
+def compute_task_objectives(args, objective_fns, input_var, task_outputs, task_targets, task_mults=None, task_offsets=None, loss_mult_factors=None):
+    ##########################
+    objective_total = torch.zeros_like(task_outputs[0].view(-1)[0])
+    objective_list = []
+    objective_list_orig = []
+    objective_names = []
+    objective_types = []
+    for task_idx, task_objectives in enumerate(objective_fns):
+        output_type = args.model_config.output_type[task_idx]
+        objective_sum_value = torch.zeros_like(task_outputs[task_idx].view(-1)[0])
+        objective_sum_name = ''
+        objective_sum_type = ''
+
+        task_mult = task_mults[task_idx] if task_mults is not None else 1.0
+        task_offset = task_offsets[task_idx] if task_offsets is not None else 0.0
+
+        for oidx, objective_fn in enumerate(task_objectives):
+            objective_batch = objective_fn(input_var, task_outputs[task_idx], task_targets[task_idx])
+            objective_batch = objective_batch.mean() if isinstance(objective_fn, torch.nn.DataParallel) else objective_batch
+            objective_name = objective_fn.info()['name']
+            objective_type = objective_fn.info()['is_avg']
+            loss_mult = loss_mult_factors[task_idx][oidx] if (loss_mult_factors is not None) else 1.0
+            # --
+            objective_batch_not_nan = (objective_batch if not torch.isnan(objective_batch) else 0.0)
+            objective_sum_value = objective_batch_not_nan*loss_mult + objective_sum_value
+            objective_sum_name += (objective_name if (objective_sum_name == '') else ('+' + objective_name))
+            assert (objective_sum_type == '' or objective_sum_type == objective_type), 'metric types (avg/val) for a given task should match'
+            objective_sum_type = objective_type
+
+        objective_list.append(objective_sum_value)
+        objective_list_orig.append(objective_sum_value)
+        objective_names.append(objective_sum_name)
+        objective_types.append(objective_sum_type)
+
+        objective_total = objective_sum_value*task_mult + task_offset + objective_total
+
+    return objective_total, objective_list, objective_names, objective_types, objective_list_orig
+
+
+
+def save_checkpoint(args, save_path, model, checkpoint_dict, is_best, filename='checkpoint.pth.tar'):
+    torch.save(checkpoint_dict, os.path.join(save_path,filename))
+    if is_best:
+        shutil.copyfile(os.path.join(save_path,filename), os.path.join(save_path,'model_best.pth.tar'))
+    #
+    if args.generate_onnx:
+        write_onnx_model(args, model, save_path, name='checkpoint.onnx')
+        if is_best:
+            write_onnx_model(args, model, save_path, name='model_best.onnx')
+    #
+
+
+def get_epoch_size(args, loader, args_epoch_size):
+    if args_epoch_size == 0:
+        epoch_size = len(loader)
+    elif args_epoch_size < 1:
+        epoch_size = int(len(loader) * args_epoch_size)
+    else:
+        epoch_size = min(len(loader), int(args_epoch_size))
+    return epoch_size
+
+
+def get_train_transform(args):
+    # image normalization can be at the beginning of transforms or at the end
+    image_mean = np.array(args.image_mean, dtype=np.float32)
+    image_scale = np.array(args.image_scale, dtype=np.float32)
+    image_prenorm = vision.transforms.image_transforms.NormalizeMeanScale(mean=image_mean, scale=image_scale) if args.image_prenorm else None
+    image_postnorm = vision.transforms.image_transforms.NormalizeMeanScale(mean=image_mean, scale=image_scale) if (not image_prenorm) else None
+
+    # crop size used only for training
+    image_train_output_scaling = vision.transforms.image_transforms.Scale(args.rand_resize, target_size=args.rand_output_size, is_flow=args.is_flow) \
+        if (args.rand_output_size and args.rand_output_size != args.rand_resize) else None
+    train_transform = vision.transforms.image_transforms.Compose([
+        image_prenorm,
+        vision.transforms.image_transforms.AlignImages(),
+        vision.transforms.image_transforms.MaskTarget(args.target_mask, 0),
+        vision.transforms.image_transforms.CropRect(args.img_border_crop),
+        vision.transforms.image_transforms.RandomRotate(args.transform_rotation, is_flow=args.is_flow) if args.transform_rotation else None,
+        vision.transforms.image_transforms.RandomScaleCrop(args.rand_resize, scale_range=args.rand_scale, is_flow=args.is_flow),
+        vision.transforms.image_transforms.RandomHorizontalFlip(is_flow=args.is_flow),
+        vision.transforms.image_transforms.RandomCrop(args.rand_crop),
+        vision.transforms.image_transforms.RandomColor2Gray(is_flow=args.is_flow, random_threshold=0.5) if 'tiad' in args.dataset_name else None,
+        image_train_output_scaling,
+        image_postnorm,
+        vision.transforms.image_transforms.ConvertToTensor()
+        ])
+    return train_transform
+
+
+def get_validation_transform(args):
+    # image normalization can be at the beginning of transforms or at the end
+    image_mean = np.array(args.image_mean, dtype=np.float32)
+    image_scale = np.array(args.image_scale, dtype=np.float32)
+    image_prenorm = vision.transforms.image_transforms.NormalizeMeanScale(mean=image_mean, scale=image_scale) if args.image_prenorm else None
+    image_postnorm = vision.transforms.image_transforms.NormalizeMeanScale(mean=image_mean, scale=image_scale) if (not image_prenorm) else None
+
+    # prediction is resized to output_size before evaluation.
+    val_transform = vision.transforms.image_transforms.Compose([
+        image_prenorm,
+        vision.transforms.image_transforms.AlignImages(),
+        vision.transforms.image_transforms.MaskTarget(args.target_mask, 0),
+        vision.transforms.image_transforms.CropRect(args.img_border_crop),
+        vision.transforms.image_transforms.Scale(args.img_resize, target_size=args.output_size, is_flow=args.is_flow),
+        image_postnorm,
+        vision.transforms.image_transforms.ConvertToTensor()
+        ])
+    return val_transform
+
+
+def get_transforms(args):
+    train_transform = get_train_transform(args)
+    val_transform = get_validation_transform(args)
+    return train_transform, val_transform
+
+
+def _upsample_impl(tensor, output_size, upsample_mode):
+    # upsample of long tensor is not supported currently. covert to float, just to avoid error.
+    # we can do thsi only in the case of nearest mode, otherwise output will have invalid values.
+    convert_to_float = False
+    if isinstance(tensor, (torch.LongTensor,torch.cuda.LongTensor)):
+        convert_to_float = True
+        original_dtype = tensor.dtype
+        tensor = tensor.float()
+        upsample_mode = 'nearest'
+
+    dim_added = False
+    if len(tensor.shape) < 4:
+        tensor = tensor[np.newaxis,...]
+        dim_added = True
+
+    if (tensor.size()[-2:] != output_size):
+        tensor = torch.nn.functional.interpolate(tensor, output_size, mode=upsample_mode)
+
+    if dim_added:
+        tensor = tensor[0,...]
+
+    if convert_to_float:
+        tensor = tensor.long() #tensor.astype(original_dtype)
+
+    return tensor
+
+
+def upsample_tensors(tensors, output_sizes, upsample_mode):
+    if isinstance(tensors, (list,tuple)):
+        for tidx, tensor in enumerate(tensors):
+            tensors[tidx] = _upsample_impl(tensor, output_sizes[tidx][-2:], upsample_mode)
+        #
+    else:
+        tensors = _upsample_impl(tensors, output_sizes[0][-2:], upsample_mode)
+    return tensors
+
+
+if __name__ == '__main__':
+    train_args = get_config()
+    main(train_args)
diff --git a/modules/pytorch_jacinto_ai/vision/__init__.py b/modules/pytorch_jacinto_ai/vision/__init__.py
new file mode 100644 (file)
index 0000000..9493a00
--- /dev/null
@@ -0,0 +1,36 @@
+from . import models
+from . import datasets
+from . import ops
+from . import transforms
+from . import utils
+from . import losses
+
+try:
+    from .version import __version__  # noqa: F401
+except ImportError:
+    pass
+
+_image_backend = 'PIL'
+
+
+def set_image_backend(backend):
+    """
+    Specifies the package used to load images.
+
+    Args:
+        backend (string): Name of the image backend. one of {'PIL', 'accimage'}.
+            The :mod:`accimage` package uses the Intel IPP library. It is
+            generally faster than PIL, but does not support as many operations.
+    """
+    global _image_backend
+    if backend not in ['PIL', 'accimage']:
+        raise ValueError("Invalid backend '{}'. Options are 'PIL' and 'accimage'"
+                         .format(backend))
+    _image_backend = backend
+
+
+def get_image_backend():
+    """
+    Gets the name of the package used to load images
+    """
+    return _image_backend
diff --git a/modules/pytorch_jacinto_ai/vision/datasets/__init__.py b/modules/pytorch_jacinto_ai/vision/datasets/__init__.py
new file mode 100644 (file)
index 0000000..59d08e5
--- /dev/null
@@ -0,0 +1,41 @@
+from .lsun import LSUN, LSUNClass
+from .folder import ImageFolder, DatasetFolder
+from .coco import CocoCaptions, CocoDetection
+from .cifar import CIFAR10, CIFAR100
+from .stl10 import STL10
+from .mnist import MNIST, EMNIST, FashionMNIST, KMNIST, QMNIST
+from .svhn import SVHN
+from .phototour import PhotoTour
+from .fakedata import FakeData
+from .semeion import SEMEION
+from .omniglot import Omniglot
+from .sbu import SBU
+from .flickr import Flickr8k, Flickr30k
+from .voc import VOCSegmentation, VOCDetection
+from .cityscapes import Cityscapes
+from .imagenet import ImageNet
+from .caltech import Caltech101, Caltech256
+from .celeba import CelebA
+from .sbd import SBDataset
+from .vision import VisionDataset
+from .usps import USPS
+from .kinetics import Kinetics400
+from .hmdb51 import HMDB51
+from .ucf101 import UCF101
+
+from . import classification
+from . import pixel2pixel
+
+# utils
+from . import utils
+from . import video_utils
+
+__all__ = ('LSUN', 'LSUNClass',
+           'ImageFolder', 'DatasetFolder', 'FakeData',
+           'CocoCaptions', 'CocoDetection',
+           'CIFAR10', 'CIFAR100', 'EMNIST', 'FashionMNIST', 'QMNIST',
+           'MNIST', 'KMNIST', 'STL10', 'SVHN', 'PhotoTour', 'SEMEION',
+           'Omniglot', 'SBU', 'Flickr8k', 'Flickr30k',
+           'VOCSegmentation', 'VOCDetection', 'Cityscapes', 'ImageNet',
+           'Caltech101', 'Caltech256', 'CelebA', 'SBDataset', 'VisionDataset',
+           'USPS', 'Kinetics400', 'HMDB51', 'UCF101')
diff --git a/modules/pytorch_jacinto_ai/vision/datasets/caltech.py b/modules/pytorch_jacinto_ai/vision/datasets/caltech.py
new file mode 100644 (file)
index 0000000..e18349d
--- /dev/null
@@ -0,0 +1,203 @@
+from __future__ import print_function
+from PIL import Image
+import os
+import os.path
+
+from .vision import VisionDataset
+from .utils import download_and_extract_archive, makedir_exist_ok, verify_str_arg
+
+
+class Caltech101(VisionDataset):
+    """`Caltech 101 <http://www.vision.caltech.edu/Image_Datasets/Caltech101/>`_ Dataset.
+
+    Args:
+        root (string): Root directory of dataset where directory
+            ``caltech101`` exists or will be saved to if download is set to True.
+        target_type (string or list, optional): Type of target to use, ``category`` or
+        ``annotation``. Can also be a list to output a tuple with all specified target types.
+        ``category`` represents the target class, and ``annotation`` is a list of points
+        from a hand-generated outline. Defaults to ``category``.
+        transform (callable, optional): A function/transform that takes in an PIL image
+            and returns a transformed version. E.g, ``transforms.RandomCrop``
+        target_transform (callable, optional): A function/transform that takes in the
+            target and transforms it.
+        download (bool, optional): If true, downloads the dataset from the internet and
+            puts it in root directory. If dataset is already downloaded, it is not
+            downloaded again.
+    """
+
+    def __init__(self, root, target_type="category", transform=None,
+                 target_transform=None, download=False):
+        super(Caltech101, self).__init__(os.path.join(root, 'caltech101'),
+                                         transform=transform,
+                                         target_transform=target_transform)
+        makedir_exist_ok(self.root)
+        if not isinstance(target_type, list):
+            target_type = [target_type]
+        self.target_type = [verify_str_arg(t, "target_type", ("category", "annotation"))
+                            for t in target_type]
+
+        if download:
+            self.download()
+
+        if not self._check_integrity():
+            raise RuntimeError('Dataset not found or corrupted.' +
+                               ' You can use download=True to download it')
+
+        self.categories = sorted(os.listdir(os.path.join(self.root, "101_ObjectCategories")))
+        self.categories.remove("BACKGROUND_Google")  # this is not a real class
+
+        # For some reason, the category names in "101_ObjectCategories" and
+        # "Annotations" do not always match. This is a manual map between the
+        # two. Defaults to using same name, since most names are fine.
+        name_map = {"Faces": "Faces_2",
+                    "Faces_easy": "Faces_3",
+                    "Motorbikes": "Motorbikes_16",
+                    "airplanes": "Airplanes_Side_2"}
+        self.annotation_categories = list(map(lambda x: name_map[x] if x in name_map else x, self.categories))
+
+        self.index = []
+        self.y = []
+        for (i, c) in enumerate(self.categories):
+            n = len(os.listdir(os.path.join(self.root, "101_ObjectCategories", c)))
+            self.index.extend(range(1, n + 1))
+            self.y.extend(n * [i])
+
+    def __getitem__(self, index):
+        """
+        Args:
+            index (int): Index
+
+        Returns:
+            tuple: (image, target) where the type of target specified by target_type.
+        """
+        import scipy.io
+
+        img = Image.open(os.path.join(self.root,
+                                      "101_ObjectCategories",
+                                      self.categories[self.y[index]],
+                                      "image_{:04d}.jpg".format(self.index[index])))
+
+        target = []
+        for t in self.target_type:
+            if t == "category":
+                target.append(self.y[index])
+            elif t == "annotation":
+                data = scipy.io.loadmat(os.path.join(self.root,
+                                                     "Annotations",
+                                                     self.annotation_categories[self.y[index]],
+                                                     "annotation_{:04d}.mat".format(self.index[index])))
+                target.append(data["obj_contour"])
+        target = tuple(target) if len(target) > 1 else target[0]
+
+        if self.transform is not None:
+            img = self.transform(img)
+
+        if self.target_transform is not None:
+            target = self.target_transform(target)
+
+        return img, target
+
+    def _check_integrity(self):
+        # can be more robust and check hash of files
+        return os.path.exists(os.path.join(self.root, "101_ObjectCategories"))
+
+    def __len__(self):
+        return len(self.index)
+
+    def download(self):
+        if self._check_integrity():
+            print('Files already downloaded and verified')
+            return
+
+        download_and_extract_archive(
+            "http://www.vision.caltech.edu/Image_Datasets/Caltech101/101_ObjectCategories.tar.gz",
+            self.root,
+            filename="101_ObjectCategories.tar.gz",
+            md5="b224c7392d521a49829488ab0f1120d9")
+        download_and_extract_archive(
+            "http://www.vision.caltech.edu/Image_Datasets/Caltech101/Annotations.tar",
+            self.root,
+            filename="101_Annotations.tar",
+            md5="6f83eeb1f24d99cab4eb377263132c91")
+
+    def extra_repr(self):
+        return "Target type: {target_type}".format(**self.__dict__)
+
+
+class Caltech256(VisionDataset):
+    """`Caltech 256 <http://www.vision.caltech.edu/Image_Datasets/Caltech256/>`_ Dataset.
+
+    Args:
+        root (string): Root directory of dataset where directory
+            ``caltech256`` exists or will be saved to if download is set to True.
+        transform (callable, optional): A function/transform that takes in an PIL image
+            and returns a transformed version. E.g, ``transforms.RandomCrop``
+        target_transform (callable, optional): A function/transform that takes in the
+            target and transforms it.
+        download (bool, optional): If true, downloads the dataset from the internet and
+            puts it in root directory. If dataset is already downloaded, it is not
+            downloaded again.
+    """
+
+    def __init__(self, root, transform=None, target_transform=None, download=False):
+        super(Caltech256, self).__init__(os.path.join(root, 'caltech256'),
+                                         transform=transform,
+                                         target_transform=target_transform)
+        makedir_exist_ok(self.root)
+
+        if download:
+            self.download()
+
+        if not self._check_integrity():
+            raise RuntimeError('Dataset not found or corrupted.' +
+                               ' You can use download=True to download it')
+
+        self.categories = sorted(os.listdir(os.path.join(self.root, "256_ObjectCategories")))
+        self.index = []
+        self.y = []
+        for (i, c) in enumerate(self.categories):
+            n = len(os.listdir(os.path.join(self.root, "256_ObjectCategories", c)))
+            self.index.extend(range(1, n + 1))
+            self.y.extend(n * [i])
+
+    def __getitem__(self, index):
+        """
+        Args:
+            index (int): Index
+
+        Returns:
+            tuple: (image, target) where target is index of the target class.
+        """
+        img = Image.open(os.path.join(self.root,
+                                      "256_ObjectCategories",
+                                      self.categories[self.y[index]],
+                                      "{:03d}_{:04d}.jpg".format(self.y[index] + 1, self.index[index])))
+
+        target = self.y[index]
+
+        if self.transform is not None:
+            img = self.transform(img)
+
+        if self.target_transform is not None:
+            target = self.target_transform(target)
+
+        return img, target
+
+    def _check_integrity(self):
+        # can be more robust and check hash of files
+        return os.path.exists(os.path.join(self.root, "256_ObjectCategories"))
+
+    def __len__(self):
+        return len(self.index)
+
+    def download(self):
+        if self._check_integrity():
+            print('Files already downloaded and verified')
+            return
+
+        download_and_extract_archive(
+            "http://www.vision.caltech.edu/Image_Datasets/Caltech256/256_ObjectCategories.tar",
+            self.root,
+            filename="256_ObjectCategories.tar",
+            md5="67b4f42ca05d46448c6bb8ecd2220f6d")
diff --git a/modules/pytorch_jacinto_ai/vision/datasets/celeba.py b/modules/pytorch_jacinto_ai/vision/datasets/celeba.py
new file mode 100644 (file)
index 0000000..8d59321
--- /dev/null
@@ -0,0 +1,151 @@
+from functools import partial
+import torch
+import os
+import PIL
+from .vision import VisionDataset
+from .utils import download_file_from_google_drive, check_integrity, verify_str_arg
+
+
+class CelebA(VisionDataset):
+    """`Large-scale CelebFaces Attributes (CelebA) Dataset <http://mmlab.ie.cuhk.edu.hk/projects/CelebA.html>`_ Dataset.
+
+    Args:
+        root (string): Root directory where images are downloaded to.
+        split (string): One of {'train', 'valid', 'test', 'all'}.
+            Accordingly dataset is selected.
+        target_type (string or list, optional): Type of target to use, ``attr``, ``identity``, ``bbox``,
+            or ``landmarks``. Can also be a list to output a tuple with all specified target types.
+            The targets represent:
+                ``attr`` (np.array shape=(40,) dtype=int): binary (0, 1) labels for attributes
+                ``identity`` (int): label for each person (data points with the same identity are the same person)
+                ``bbox`` (np.array shape=(4,) dtype=int): bounding box (x, y, width, height)
+                ``landmarks`` (np.array shape=(10,) dtype=int): landmark points (lefteye_x, lefteye_y, righteye_x,
+                    righteye_y, nose_x, nose_y, leftmouth_x, leftmouth_y, rightmouth_x, rightmouth_y)
+            Defaults to ``attr``.
+        transform (callable, optional): A function/transform that  takes in an PIL image
+            and returns a transformed version. E.g, ``transforms.ToTensor``
+        target_transform (callable, optional): A function/transform that takes in the
+            target and transforms it.
+        download (bool, optional): If true, downloads the dataset from the internet and
+            puts it in root directory. If dataset is already downloaded, it is not
+            downloaded again.
+    """
+
+    base_folder = "celeba"
+    # There currently does not appear to be a easy way to extract 7z in python (without introducing additional
+    # dependencies). The "in-the-wild" (not aligned+cropped) images are only in 7z, so they are not available
+    # right now.
+    file_list = [
+        # File ID                         MD5 Hash                            Filename
+        ("0B7EVK8r0v71pZjFTYXZWM3FlRnM", "00d2c5bc6d35e252742224ab0c1e8fcb", "img_align_celeba.zip"),
+        # ("0B7EVK8r0v71pbWNEUjJKdDQ3dGc", "b6cd7e93bc7a96c2dc33f819aa3ac651", "img_align_celeba_png.7z"),
+        # ("0B7EVK8r0v71peklHb0pGdDl6R28", "b6cd7e93bc7a96c2dc33f819aa3ac651", "img_celeba.7z"),
+        ("0B7EVK8r0v71pblRyaVFSWGxPY0U", "75e246fa4810816ffd6ee81facbd244c", "list_attr_celeba.txt"),
+        ("1_ee_0u7vcNLOfNLegJRHmolfH5ICW-XS", "32bd1bd63d3c78cd57e08160ec5ed1e2", "identity_CelebA.txt"),
+        ("0B7EVK8r0v71pbThiMVRxWXZ4dU0", "00566efa6fedff7a56946cd1c10f1c16", "list_bbox_celeba.txt"),
+        ("0B7EVK8r0v71pd0FJY3Blby1HUTQ", "cc24ecafdb5b50baae59b03474781f8c", "list_landmarks_align_celeba.txt"),
+        # ("0B7EVK8r0v71pTzJIdlJWdHczRlU", "063ee6ddb681f96bc9ca28c6febb9d1a", "list_landmarks_celeba.txt"),
+        ("0B7EVK8r0v71pY0NSMzRuSXJEVkk", "d32c9cbf5e040fd4025c592c306e6668", "list_eval_partition.txt"),
+    ]
+
+    def __init__(self, root, split="train", target_type="attr", transform=None,
+                 target_transform=None, download=False):
+        import pandas
+        super(CelebA, self).__init__(root, transform=transform,
+                                     target_transform=target_transform)
+        self.split = split
+        if isinstance(target_type, list):
+            self.target_type = target_type
+        else:
+            self.target_type = [target_type]
+
+        if download:
+            self.download()
+
+        if not self._check_integrity():
+            raise RuntimeError('Dataset not found or corrupted.' +
+                               ' You can use download=True to download it')
+
+        split_map = {
+            "train": 0,
+            "valid": 1,
+            "test": 2,
+            "all": None,
+        }
+        split = split_map[verify_str_arg(split.lower(), "split",
+                                         ("train", "valid", "test", "all"))]
+
+        fn = partial(os.path.join, self.root, self.base_folder)
+        splits = pandas.read_csv(fn("list_eval_partition.txt"), delim_whitespace=True, header=None, index_col=0)
+        identity = pandas.read_csv(fn("identity_CelebA.txt"), delim_whitespace=True, header=None, index_col=0)
+        bbox = pandas.read_csv(fn("list_bbox_celeba.txt"), delim_whitespace=True, header=1, index_col=0)
+        landmarks_align = pandas.read_csv(fn("list_landmarks_align_celeba.txt"), delim_whitespace=True, header=1)
+        attr = pandas.read_csv(fn("list_attr_celeba.txt"), delim_whitespace=True, header=1)
+
+        mask = slice(None) if split is None else (splits[1] == split)
+
+        self.filename = splits[mask].index.values
+        self.identity = torch.as_tensor(identity[mask].values)
+        self.bbox = torch.as_tensor(bbox[mask].values)
+        self.landmarks_align = torch.as_tensor(landmarks_align[mask].values)
+        self.attr = torch.as_tensor(attr[mask].values)
+        self.attr = (self.attr + 1) // 2  # map from {-1, 1} to {0, 1}
+        self.attr_names = list(attr.columns)
+
+    def _check_integrity(self):
+        for (_, md5, filename) in self.file_list:
+            fpath = os.path.join(self.root, self.base_folder, filename)
+            _, ext = os.path.splitext(filename)
+            # Allow original archive to be deleted (zip and 7z)
+            # Only need the extracted images
+            if ext not in [".zip", ".7z"] and not check_integrity(fpath, md5):
+                return False
+
+        # Should check a hash of the images
+        return os.path.isdir(os.path.join(self.root, self.base_folder, "img_align_celeba"))
+
+    def download(self):
+        import zipfile
+
+        if self._check_integrity():
+            print('Files already downloaded and verified')
+            return
+
+        for (file_id, md5, filename) in self.file_list:
+            download_file_from_google_drive(file_id, os.path.join(self.root, self.base_folder), filename, md5)
+
+        with zipfile.ZipFile(os.path.join(self.root, self.base_folder, "img_align_celeba.zip"), "r") as f:
+            f.extractall(os.path.join(self.root, self.base_folder))
+
+    def __getitem__(self, index):
+        X = PIL.Image.open(os.path.join(self.root, self.base_folder, "img_align_celeba", self.filename[index]))
+
+        target = []
+        for t in self.target_type:
+            if t == "attr":
+                target.append(self.attr[index, :])
+            elif t == "identity":
+                target.append(self.identity[index, 0])
+            elif t == "bbox":
+                target.append(self.bbox[index, :])
+            elif t == "landmarks":
+                target.append(self.landmarks_align[index, :])
+            else:
+                # TODO: refactor with utils.verify_str_arg
+                raise ValueError("Target type \"{}\" is not recognized.".format(t))
+        target = tuple(target) if len(target) > 1 else target[0]
+
+        if self.transform is not None:
+            X = self.transform(X)
+
+        if self.target_transform is not None:
+            target = self.target_transform(target)
+
+        return X, target
+
+    def __len__(self):
+        return len(self.attr)
+
+    def extra_repr(self):
+        lines = ["Target type: {target_type}", "Split: {split}"]
+        return '\n'.join(lines).format(**self.__dict__)
diff --git a/modules/pytorch_jacinto_ai/vision/datasets/cifar.py b/modules/pytorch_jacinto_ai/vision/datasets/cifar.py
new file mode 100644 (file)
index 0000000..6230a64
--- /dev/null
@@ -0,0 +1,174 @@
+from __future__ import print_function
+from PIL import Image
+import os
+import os.path
+import numpy as np
+import sys
+
+if sys.version_info[0] == 2:
+    import cPickle as pickle
+else:
+    import pickle
+
+from .vision import VisionDataset
+from .utils import check_integrity, download_and_extract_archive
+
+
+class CIFAR10(VisionDataset):
+    """`CIFAR10 <https://www.cs.toronto.edu/~kriz/cifar.html>`_ Dataset.
+
+    Args:
+        root (string): Root directory of dataset where directory
+            ``cifar-10-batches-py`` exists or will be saved to if download is set to True.
+        train (bool, optional): If True, creates dataset from training set, otherwise
+            creates from test set.
+        transform (callable, optional): A function/transform that takes in an PIL image
+            and returns a transformed version. E.g, ``transforms.RandomCrop``
+        target_transform (callable, optional): A function/transform that takes in the
+            target and transforms it.
+        download (bool, optional): If true, downloads the dataset from the internet and
+            puts it in root directory. If dataset is already downloaded, it is not
+            downloaded again.
+
+    """
+    base_folder = 'cifar-10-batches-py'
+    url = "https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz"
+    filename = "cifar-10-python.tar.gz"
+    tgz_md5 = 'c58f30108f718f92721af3b95e74349a'
+    train_list = [
+        ['data_batch_1', 'c99cafc152244af753f735de768cd75f'],
+        ['data_batch_2', 'd4bba439e000b95fd0a9bffe97cbabec'],
+        ['data_batch_3', '54ebc095f3ab1f0389bbae665268c751'],
+        ['data_batch_4', '634d18415352ddfa80567beed471001a'],
+        ['data_batch_5', '482c414d41f54cd18b22e5b47cb7c3cb'],
+    ]
+
+    test_list = [
+        ['test_batch', '40351d587109b95175f43aff81a1287e'],
+    ]
+    meta = {
+        'filename': 'batches.meta',
+        'key': 'label_names',
+        'md5': '5ff9c542aee3614f3951f8cda6e48888',
+    }
+
+    def __init__(self, root, train=True, transform=None, target_transform=None,
+                 download=False):
+
+        super(CIFAR10, self).__init__(root, transform=transform,
+                                      target_transform=target_transform)
+
+        self.train = train  # training set or test set
+
+        if download:
+            self.download()
+
+        if not self._check_integrity():
+            raise RuntimeError('Dataset not found or corrupted.' +
+                               ' You can use download=True to download it')
+
+        if self.train:
+            downloaded_list = self.train_list
+        else:
+            downloaded_list = self.test_list
+
+        self.data = []
+        self.targets = []
+
+        # now load the picked numpy arrays
+        for file_name, checksum in downloaded_list:
+            file_path = os.path.join(self.root, self.base_folder, file_name)
+            with open(file_path, 'rb') as f:
+                if sys.version_info[0] == 2:
+                    entry = pickle.load(f)
+                else:
+                    entry = pickle.load(f, encoding='latin1')
+                self.data.append(entry['data'])
+                if 'labels' in entry:
+                    self.targets.extend(entry['labels'])
+                else:
+                    self.targets.extend(entry['fine_labels'])
+
+        self.data = np.vstack(self.data).reshape(-1, 3, 32, 32)
+        self.data = self.data.transpose((0, 2, 3, 1))  # convert to HWC
+
+        self._load_meta()
+
+    def _load_meta(self):
+        path = os.path.join(self.root, self.base_folder, self.meta['filename'])
+        if not check_integrity(path, self.meta['md5']):
+            raise RuntimeError('Dataset metadata file not found or corrupted.' +
+                               ' You can use download=True to download it')
+        with open(path, 'rb') as infile:
+            if sys.version_info[0] == 2:
+                data = pickle.load(infile)
+            else:
+                data = pickle.load(infile, encoding='latin1')
+            self.classes = data[self.meta['key']]
+        self.class_to_idx = {_class: i for i, _class in enumerate(self.classes)}
+
+    def __getitem__(self, index):
+        """
+        Args:
+            index (int): Index
+
+        Returns:
+            tuple: (image, target) where target is index of the target class.
+        """
+        img, target = self.data[index], self.targets[index]
+
+        # doing this so that it is consistent with all other datasets
+        # to return a PIL Image
+        img = Image.fromarray(img)
+
+        if self.transform is not None:
+            img = self.transform(img)
+
+        if self.target_transform is not None:
+            target = self.target_transform(target)
+
+        return img, target
+
+    def __len__(self):
+        return len(self.data)
+
+    def _check_integrity(self):
+        root = self.root
+        for fentry in (self.train_list + self.test_list):
+            filename, md5 = fentry[0], fentry[1]
+            fpath = os.path.join(root, self.base_folder, filename)
+            if not check_integrity(fpath, md5):
+                return False
+        return True
+
+    def download(self):
+        if self._check_integrity():
+            print('Files already downloaded and verified')
+            return
+        download_and_extract_archive(self.url, self.root, filename=self.filename, md5=self.tgz_md5)
+
+    def extra_repr(self):
+        return "Split: {}".format("Train" if self.train is True else "Test")
+
+
+class CIFAR100(CIFAR10):
+    """`CIFAR100 <https://www.cs.toronto.edu/~kriz/cifar.html>`_ Dataset.
+
+    This is a subclass of the `CIFAR10` Dataset.
+    """
+    base_folder = 'cifar-100-python'
+    url = "https://www.cs.toronto.edu/~kriz/cifar-100-python.tar.gz"
+    filename = "cifar-100-python.tar.gz"
+    tgz_md5 = 'eb9058c3a382ffc7106e4002c42a8d85'
+    train_list = [
+        ['train', '16019d7e3df5f24257cddd939b257f8d'],
+    ]
+
+    test_list = [
+        ['test', 'f0ef6b0ae62326f3e7ffdfab6717acfc'],
+    ]
+    meta = {
+        'filename': 'meta',
+        'key': 'fine_label_names',
+        'md5': '7973b15100ade9c7d40fb424638fde48',
+    }
diff --git a/modules/pytorch_jacinto_ai/vision/datasets/cityscapes.py b/modules/pytorch_jacinto_ai/vision/datasets/cityscapes.py
new file mode 100644 (file)
index 0000000..56ff20b
--- /dev/null
@@ -0,0 +1,207 @@
+import json
+import os
+from collections import namedtuple
+import zipfile
+
+from .utils import extract_archive, verify_str_arg, iterable_to_str
+from .vision import VisionDataset
+from PIL import Image
+
+
+class Cityscapes(VisionDataset):
+    """`Cityscapes <http://www.cityscapes-dataset.com/>`_ Dataset.
+
+    Args:
+        root (string): Root directory of dataset where directory ``leftImg8bit``
+            and ``gtFine`` or ``gtCoarse`` are located.
+        split (string, optional): The image split to use, ``train``, ``test`` or ``val`` if mode="gtFine"
+            otherwise ``train``, ``train_extra`` or ``val``
+        mode (string, optional): The quality mode to use, ``gtFine`` or ``gtCoarse``
+        target_type (string or list, optional): Type of target to use, ``instance``, ``semantic``, ``polygon``
+            or ``color``. Can also be a list to output a tuple with all specified target types.
+        transform (callable, optional): A function/transform that takes in a PIL image
+            and returns a transformed version. E.g, ``transforms.RandomCrop``
+        target_transform (callable, optional): A function/transform that takes in the
+            target and transforms it.
+        transforms (callable, optional): A function/transform that takes input sample and its target as entry
+            and returns a transformed version.
+
+    Examples:
+
+        Get semantic segmentation target
+
+        .. code-block:: python
+
+            dataset = Cityscapes('./data/cityscapes', split='train', mode='fine',
+                                 target_type='semantic')
+
+            img, smnt = dataset[0]
+
+        Get multiple targets
+
+        .. code-block:: python
+
+            dataset = Cityscapes('./data/cityscapes', split='train', mode='fine',
+                                 target_type=['instance', 'color', 'polygon'])
+
+            img, (inst, col, poly) = dataset[0]
+
+        Validate on the "coarse" set
+
+        .. code-block:: python
+
+            dataset = Cityscapes('./data/cityscapes', split='val', mode='coarse',
+                                 target_type='semantic')
+
+            img, smnt = dataset[0]
+    """
+
+    # Based on https://github.com/mcordts/cityscapesScripts
+    CityscapesClass = namedtuple('CityscapesClass', ['name', 'id', 'train_id', 'category', 'category_id',
+                                                     'has_instances', 'ignore_in_eval', 'color'])
+
+    classes = [
+        CityscapesClass('unlabeled', 0, 255, 'void', 0, False, True, (0, 0, 0)),
+        CityscapesClass('ego vehicle', 1, 255, 'void', 0, False, True, (0, 0, 0)),
+        CityscapesClass('rectification border', 2, 255, 'void', 0, False, True, (0, 0, 0)),
+        CityscapesClass('out of roi', 3, 255, 'void', 0, False, True, (0, 0, 0)),
+        CityscapesClass('static', 4, 255, 'void', 0, False, True, (0, 0, 0)),
+        CityscapesClass('dynamic', 5, 255, 'void', 0, False, True, (111, 74, 0)),
+        CityscapesClass('ground', 6, 255, 'void', 0, False, True, (81, 0, 81)),
+        CityscapesClass('road', 7, 0, 'flat', 1, False, False, (128, 64, 128)),
+        CityscapesClass('sidewalk', 8, 1, 'flat', 1, False, False, (244, 35, 232)),
+        CityscapesClass('parking', 9, 255, 'flat', 1, False, True, (250, 170, 160)),
+        CityscapesClass('rail track', 10, 255, 'flat', 1, False, True, (230, 150, 140)),
+        CityscapesClass('building', 11, 2, 'construction', 2, False, False, (70, 70, 70)),
+        CityscapesClass('wall', 12, 3, 'construction', 2, False, False, (102, 102, 156)),
+        CityscapesClass('fence', 13, 4, 'construction', 2, False, False, (190, 153, 153)),
+        CityscapesClass('guard rail', 14, 255, 'construction', 2, False, True, (180, 165, 180)),
+        CityscapesClass('bridge', 15, 255, 'construction', 2, False, True, (150, 100, 100)),
+        CityscapesClass('tunnel', 16, 255, 'construction', 2, False, True, (150, 120, 90)),
+        CityscapesClass('pole', 17, 5, 'object', 3, False, False, (153, 153, 153)),
+        CityscapesClass('polegroup', 18, 255, 'object', 3, False, True, (153, 153, 153)),
+        CityscapesClass('traffic light', 19, 6, 'object', 3, False, False, (250, 170, 30)),
+        CityscapesClass('traffic sign', 20, 7, 'object', 3, False, False, (220, 220, 0)),
+        CityscapesClass('vegetation', 21, 8, 'nature', 4, False, False, (107, 142, 35)),
+        CityscapesClass('terrain', 22, 9, 'nature', 4, False, False, (152, 251, 152)),
+        CityscapesClass('sky', 23, 10, 'sky', 5, False, False, (70, 130, 180)),
+        CityscapesClass('person', 24, 11, 'human', 6, True, False, (220, 20, 60)),
+        CityscapesClass('rider', 25, 12, 'human', 6, True, False, (255, 0, 0)),
+        CityscapesClass('car', 26, 13, 'vehicle', 7, True, False, (0, 0, 142)),
+        CityscapesClass('truck', 27, 14, 'vehicle', 7, True, False, (0, 0, 70)),
+        CityscapesClass('bus', 28, 15, 'vehicle', 7, True, False, (0, 60, 100)),
+        CityscapesClass('caravan', 29, 255, 'vehicle', 7, True, True, (0, 0, 90)),
+        CityscapesClass('trailer', 30, 255, 'vehicle', 7, True, True, (0, 0, 110)),
+        CityscapesClass('train', 31, 16, 'vehicle', 7, True, False, (0, 80, 100)),
+        CityscapesClass('motorcycle', 32, 17, 'vehicle', 7, True, False, (0, 0, 230)),
+        CityscapesClass('bicycle', 33, 18, 'vehicle', 7, True, False, (119, 11, 32)),
+        CityscapesClass('license plate', -1, -1, 'vehicle', 7, False, True, (0, 0, 142)),
+    ]
+
+    def __init__(self, root, split='train', mode='fine', target_type='instance',
+                 transform=None, target_transform=None, transforms=None):
+        super(Cityscapes, self).__init__(root, transforms, transform, target_transform)
+        self.mode = 'gtFine' if mode == 'fine' else 'gtCoarse'
+        self.images_dir = os.path.join(self.root, 'leftImg8bit', split)
+        self.targets_dir = os.path.join(self.root, self.mode, split)
+        self.target_type = target_type
+        self.split = split
+        self.images = []
+        self.targets = []
+
+        verify_str_arg(mode, "mode", ("fine", "coarse"))
+        if mode == "fine":
+            valid_modes = ("train", "test", "val")
+        else:
+            valid_modes = ("train", "train_extra", "val")
+        msg = ("Unknown value '{}' for argument split if mode is '{}'. "
+               "Valid values are {{{}}}.")
+        msg = msg.format(split, mode, iterable_to_str(valid_modes))
+        verify_str_arg(split, "split", valid_modes, msg)
+
+        if not isinstance(target_type, list):
+            self.target_type = [target_type]
+        [verify_str_arg(value, "target_type",
+                        ("instance", "semantic", "polygon", "color"))
+         for value in self.target_type]
+
+        if not os.path.isdir(self.images_dir) or not os.path.isdir(self.targets_dir):
+
+            if split == 'train_extra':
+                image_dir_zip = os.path.join(self.root, 'leftImg8bit{}'.format('_trainextra.zip'))
+            else:
+                image_dir_zip = os.path.join(self.root, 'leftImg8bit{}'.format('_trainvaltest.zip'))
+
+            if self.mode == 'gtFine':
+                target_dir_zip = os.path.join(self.root, '{}{}'.format(self.mode, '_trainvaltest.zip'))
+            elif self.mode == 'gtCoarse':
+                target_dir_zip = os.path.join(self.root, '{}{}'.format(self.mode, '.zip'))
+
+            if os.path.isfile(image_dir_zip) and os.path.isfile(target_dir_zip):
+                extract_archive(from_path=image_dir_zip, to_path=self.root)
+                extract_archive(from_path=target_dir_zip, to_path=self.root)
+            else:
+                raise RuntimeError('Dataset not found or incomplete. Please make sure all required folders for the'
+                                   ' specified "split" and "mode" are inside the "root" directory')
+
+        for city in os.listdir(self.images_dir):
+            img_dir = os.path.join(self.images_dir, city)
+            target_dir = os.path.join(self.targets_dir, city)
+            for file_name in os.listdir(img_dir):
+                target_types = []
+                for t in self.target_type:
+                    target_name = '{}_{}'.format(file_name.split('_leftImg8bit')[0],
+                                                 self._get_target_suffix(self.mode, t))
+                    target_types.append(os.path.join(target_dir, target_name))
+
+                self.images.append(os.path.join(img_dir, file_name))
+                self.targets.append(target_types)
+
+    def __getitem__(self, index):
+        """
+        Args:
+            index (int): Index
+        Returns:
+            tuple: (image, target) where target is a tuple of all target types if target_type is a list with more
+            than one item. Otherwise target is a json object if target_type="polygon", else the image segmentation.
+        """
+
+        image = Image.open(self.images[index]).convert('RGB')
+
+        targets = []
+        for i, t in enumerate(self.target_type):
+            if t == 'polygon':
+                target = self._load_json(self.targets[index][i])
+            else:
+                target = Image.open(self.targets[index][i])
+
+            targets.append(target)
+
+        target = tuple(targets) if len(targets) > 1 else targets[0]
+
+        if self.transforms is not None:
+            image, target = self.transforms(image, target)
+
+        return image, target
+
+    def __len__(self):
+        return len(self.images)
+
+    def extra_repr(self):
+        lines = ["Split: {split}", "Mode: {mode}", "Type: {target_type}"]
+        return '\n'.join(lines).format(**self.__dict__)
+
+    def _load_json(self, path):
+        with open(path, 'r') as file:
+            data = json.load(file)
+        return data
+
+    def _get_target_suffix(self, mode, target_type):
+        if target_type == 'instance':
+            return '{}_instanceIds.png'.format(mode)
+        elif target_type == 'semantic':
+            return '{}_labelIds.png'.format(mode)
+        elif target_type == 'color':
+            return '{}_color.png'.format(mode)
+        else:
+            return '{}_polygons.json'.format(mode)
diff --git a/modules/pytorch_jacinto_ai/vision/datasets/classification/__init__.py b/modules/pytorch_jacinto_ai/vision/datasets/classification/__init__.py
new file mode 100644 (file)
index 0000000..bafa3ce
--- /dev/null
@@ -0,0 +1,64 @@
+import os
+from .. import folder
+from .. import cifar
+from .. import imagenet
+
+__all__ = ['image_folder_classification_train', 'image_folder_classification_validation', 'image_folder_classification',
+           'imagenet_classification_train', 'imagenet_classification_validation', 'imagenet_classification',
+           'cifar10_classification', 'cifar100_classification']
+
+########################################################################
+def image_folder_classification_train(dataset_config, root, split=None, transforms=None):
+    split = 'train' if split is None else split
+    traindir = os.path.join(root, split)
+    assert os.path.exists(traindir), f'dataset training folder does not exist {traindir}'
+    train_transform = transforms[0] if isinstance(transforms,(list,tuple)) else transforms
+    train_dataset = folder.ImageFolder(traindir, train_transform)
+    return train_dataset
+
+def image_folder_classification_validation(dataset_config, root, split=None, transforms=None):
+    split = 'val' if split is None else split
+    # validation folder can be either 'val' or 'validation'
+    if (split == 'val') and (not os.path.exists(os.path.join(root,split))):
+        split = 'validation'
+    #
+    valdir = os.path.join(root, split)
+    assert os.path.exists(valdir), f'dataset validation folder does not exist {valdir}'
+    val_transform = transforms[1] if isinstance(transforms,(list,tuple)) else transforms
+    val_dataset = folder.ImageFolder(valdir, val_transform)
+    return val_dataset
+
+def image_folder_classification(dataset_config, root, split=None, transforms=None):
+    split = ('train', 'val') if split is None else split
+    train_transform, val_transform = transforms
+    train_dataset = image_folder_classification_train(dataset_config, root, split[0], train_transform)
+    val_dataset = image_folder_classification_validation(dataset_config, root, split[1], val_transform)
+    return train_dataset, val_dataset
+
+########################################################################
+def imagenet_classification_train(dataset_config, root, split=None, transforms=None):
+    train_transform = transforms[0] if isinstance(transforms,(list,tuple)) else transforms
+    train_dataset = imagenet.ImageNet(root, train=True, transform=train_transform, target_transform=None, download=True)
+    return train_dataset
+
+def imagenet_classification_validation(dataset_config, root, split=None, transforms=None):
+    val_transform = transforms[1] if isinstance(transforms,(list,tuple)) else transforms
+    val_dataset = imagenet.ImageNet(root, train=False, transform=val_transform, target_transform=None, download=True)
+    return val_dataset
+
+def imagenet_classification(dataset_config, root, split=None, transforms=None):
+    train_dataset = imagenet_classification_train(dataset_config, root, split, transforms)
+    val_dataset = imagenet_classification_validation(dataset_config, root, split, transforms)
+    return train_dataset, val_dataset
+
+
+########################################################################
+def cifar10_classification(dataset_config, root, split=None, transforms=None):
+    train_dataset = cifar.CIFAR10(root, train=True, transform=transforms[0], target_transform=None, download=True)
+    val_dataset = cifar.CIFAR10(root, train=False, transform=transforms[1], target_transform=None, download=True)
+    return train_dataset, val_dataset
+
+def cifar100_classification(dataset_config, root, split=None, transforms=None):
+    train_dataset = cifar.CIFAR100(root, train=True, transform=transforms[0], target_transform=None, download=True)
+    val_dataset = cifar.CIFAR100(root, train=False, transform=transforms[1], target_transform=None, download=True)
+    return train_dataset, val_dataset
\ No newline at end of file
diff --git a/modules/pytorch_jacinto_ai/vision/datasets/coco.py b/modules/pytorch_jacinto_ai/vision/datasets/coco.py
new file mode 100644 (file)
index 0000000..9dd3c7a
--- /dev/null
@@ -0,0 +1,123 @@
+from .vision import VisionDataset
+from PIL import Image
+import os
+import os.path
+
+
+class CocoCaptions(VisionDataset):
+    """`MS Coco Captions <http://mscoco.org/dataset/#captions-challenge2015>`_ Dataset.
+
+    Args:
+        root (string): Root directory where images are downloaded to.
+        annFile (string): Path to json annotation file.
+        transform (callable, optional): A function/transform that  takes in an PIL image
+            and returns a transformed version. E.g, ``transforms.ToTensor``
+        target_transform (callable, optional): A function/transform that takes in the
+            target and transforms it.
+        transforms (callable, optional): A function/transform that takes input sample and its target as entry
+            and returns a transformed version.
+
+    Example:
+
+        .. code:: python
+
+            import torchvision.datasets as dset
+            import torchvision.transforms as transforms
+            cap = dset.CocoCaptions(root = 'dir where images are',
+                                    annFile = 'json annotation file',
+                                    transform=transforms.ToTensor())
+
+            print('Number of samples: ', len(cap))
+            img, target = cap[3] # load 4th sample
+
+            print("Image Size: ", img.size())
+            print(target)
+
+        Output: ::
+
+            Number of samples: 82783
+            Image Size: (3L, 427L, 640L)
+            [u'A plane emitting smoke stream flying over a mountain.',
+            u'A plane darts across a bright blue sky behind a mountain covered in snow',
+            u'A plane leaves a contrail above the snowy mountain top.',
+            u'A mountain that has a plane flying overheard in the distance.',
+            u'A mountain view with a plume of smoke in the background']
+
+    """
+
+    def __init__(self, root, annFile, transform=None, target_transform=None, transforms=None):
+        super(CocoCaptions, self).__init__(root, transforms, transform, target_transform)
+        from pycocotools.coco import COCO
+        self.coco = COCO(annFile)
+        self.ids = list(sorted(self.coco.imgs.keys()))
+
+    def __getitem__(self, index):
+        """
+        Args:
+            index (int): Index
+
+        Returns:
+            tuple: Tuple (image, target). target is a list of captions for the image.
+        """
+        coco = self.coco
+        img_id = self.ids[index]
+        ann_ids = coco.getAnnIds(imgIds=img_id)
+        anns = coco.loadAnns(ann_ids)
+        target = [ann['caption'] for ann in anns]
+
+        path = coco.loadImgs(img_id)[0]['file_name']
+
+        img = Image.open(os.path.join(self.root, path)).convert('RGB')
+
+        if self.transforms is not None:
+            img, target = self.transforms(img, target)
+
+        return img, target
+
+    def __len__(self):
+        return len(self.ids)
+
+
+class CocoDetection(VisionDataset):
+    """`MS Coco Detection <http://mscoco.org/dataset/#detections-challenge2016>`_ Dataset.
+
+    Args:
+        root (string): Root directory where images are downloaded to.
+        annFile (string): Path to json annotation file.
+        transform (callable, optional): A function/transform that  takes in an PIL image
+            and returns a transformed version. E.g, ``transforms.ToTensor``
+        target_transform (callable, optional): A function/transform that takes in the
+            target and transforms it.
+        transforms (callable, optional): A function/transform that takes input sample and its target as entry
+            and returns a transformed version.
+    """
+
+    def __init__(self, root, annFile, transform=None, target_transform=None, transforms=None):
+        super(CocoDetection, self).__init__(root, transforms, transform, target_transform)
+        from pycocotools.coco import COCO
+        self.coco = COCO(annFile)
+        self.ids = list(sorted(self.coco.imgs.keys()))
+
+    def __getitem__(self, index):
+        """
+        Args:
+            index (int): Index
+
+        Returns:
+            tuple: Tuple (image, target). target is the object returned by ``coco.loadAnns``.
+        """
+        coco = self.coco
+        img_id = self.ids[index]
+        ann_ids = coco.getAnnIds(imgIds=img_id)
+        target = coco.loadAnns(ann_ids)
+
+        path = coco.loadImgs(img_id)[0]['file_name']
+
+        img = Image.open(os.path.join(self.root, path)).convert('RGB')
+        if self.transforms is not None:
+            img, target = self.transforms(img, target)
+
+        return img, target
+
+    def __len__(self):
+        return len(self.ids)
diff --git a/modules/pytorch_jacinto_ai/vision/datasets/fakedata.py b/modules/pytorch_jacinto_ai/vision/datasets/fakedata.py
new file mode 100644 (file)
index 0000000..f079c1a
--- /dev/null
@@ -0,0 +1,58 @@
+import torch
+from .vision import VisionDataset
+from .. import transforms
+
+
+class FakeData(VisionDataset):
+    """A fake dataset that returns randomly generated images and returns them as PIL images
+
+    Args:
+        size (int, optional): Size of the dataset. Default: 1000 images
+        image_size(tuple, optional): Size if the returned images. Default: (3, 224, 224)
+        num_classes(int, optional): Number of classes in the datset. Default: 10
+        transform (callable, optional): A function/transform that  takes in an PIL image
+            and returns a transformed version. E.g, ``transforms.RandomCrop``
+        target_transform (callable, optional): A function/transform that takes in the
+            target and transforms it.
+        random_offset (int): Offsets the index-based random seed used to
+            generate each image. Default: 0
+
+    """
+
+    def __init__(self, size=1000, image_size=(3, 224, 224), num_classes=10,
+                 transform=None, target_transform=None, random_offset=0):
+        super(FakeData, self).__init__(None, transform=transform,
+                                       target_transform=target_transform)
+        self.size = size
+        self.num_classes = num_classes
+        self.image_size = image_size
+        self.random_offset = random_offset
+
+    def __getitem__(self, index):
+        """
+        Args:
+            index (int): Index
+
+        Returns:
+            tuple: (image, target) where target is class_index of the target class.
+        """
+        # create random image that is consistent with the index id
+        if index >= len(self):
+            raise IndexError("{} index out of range".format(self.__class__.__name__))
+        rng_state = torch.get_rng_state()
+        torch.manual_seed(index + self.random_offset)
+        img = torch.randn(*self.image_size)
+        target = torch.randint(0, self.num_classes, size=(1,), dtype=torch.long)[0]
+        torch.set_rng_state(rng_state)
+
+        # convert to PIL Image
+        img = transforms.ToPILImage()(img)
+        if self.transform is not None:
+            img = self.transform(img)
+        if self.target_transform is not None:
+            target = self.target_transform(target)
+
+        return img, target
+
+    def __len__(self):
+        return self.size
diff --git a/modules/pytorch_jacinto_ai/vision/datasets/flickr.py b/modules/pytorch_jacinto_ai/vision/datasets/flickr.py
new file mode 100644 (file)
index 0000000..af8b1fe
--- /dev/null
@@ -0,0 +1,154 @@
+from collections import defaultdict
+from PIL import Image
+from six.moves import html_parser
+
+import glob
+import os
+from .vision import VisionDataset
+
+
+class Flickr8kParser(html_parser.HTMLParser):
+    """Parser for extracting captions from the Flickr8k dataset web page."""
+
+    def __init__(self, root):
+        super(Flickr8kParser, self).__init__()
+
+        self.root = root
+
+        # Data structure to store captions
+        self.annotations = {}
+
+        # State variables
+        self.in_table = False
+        self.current_tag = None
+        self.current_img = None
+
+    def handle_starttag(self, tag, attrs):
+        self.current_tag = tag
+
+        if tag == 'table':
+            self.in_table = True
+
+    def handle_endtag(self, tag):
+        self.current_tag = None
+
+        if tag == 'table':
+            self.in_table = False
+
+    def handle_data(self, data):
+        if self.in_table:
+            if data == 'Image Not Found':
+                self.current_img = None
+            elif self.current_tag == 'a':
+                img_id = data.split('/')[-2]
+                img_id = os.path.join(self.root, img_id + '_*.jpg')
+                img_id = glob.glob(img_id)[0]
+                self.current_img = img_id
+                self.annotations[img_id] = []
+            elif self.current_tag == 'li' and self.current_img:
+                img_id = self.current_img
+                self.annotations[img_id].append(data.strip())
+
+
+class Flickr8k(VisionDataset):
+    """`Flickr8k Entities <http://nlp.cs.illinois.edu/HockenmaierGroup/8k-pictures.html>`_ Dataset.
+
+    Args:
+        root (string): Root directory where images are downloaded to.
+        ann_file (string): Path to annotation file.
+        transform (callable, optional): A function/transform that takes in a PIL image
+            and returns a transformed version. E.g, ``transforms.ToTensor``
+        target_transform (callable, optional): A function/transform that takes in the
+            target and transforms it.
+    """
+
+    def __init__(self, root, ann_file, transform=None, target_transform=None):
+        super(Flickr8k, self).__init__(root, transform=transform,
+                                       target_transform=target_transform)
+        self.ann_file = os.path.expanduser(ann_file)
+
+        # Read annotations and store in a dict
+        parser = Flickr8kParser(self.root)
+        with open(self.ann_file) as fh:
+            parser.feed(fh.read())
+        self.annotations = parser.annotations
+
+        self.ids = list(sorted(self.annotations.keys()))
+
+    def __getitem__(self, index):
+        """
+        Args:
+            index (int): Index
+
+        Returns:
+            tuple: Tuple (image, target). target is a list of captions for the image.
+        """
+        img_id = self.ids[index]
+
+        # Image
+        img = Image.open(img_id).convert('RGB')
+        if self.transform is not None:
+            img = self.transform(img)
+
+        # Captions
+        target = self.annotations[img_id]
+        if self.target_transform is not None:
+            target = self.target_transform(target)
+
+        return img, target
+
+    def __len__(self):
+        return len(self.ids)
+
+
+class Flickr30k(VisionDataset):
+    """`Flickr30k Entities <http://web.engr.illinois.edu/~bplumme2/Flickr30kEntities/>`_ Dataset.
+
+    Args:
+        root (string): Root directory where images are downloaded to.
+        ann_file (string): Path to annotation file.
+        transform (callable, optional): A function/transform that takes in a PIL image
+            and returns a transformed version. E.g, ``transforms.ToTensor``
+        target_transform (callable, optional): A function/transform that takes in the
+            target and transforms it.
+    """
+
+    def __init__(self, root, ann_file, transform=None, target_transform=None):
+        super(Flickr30k, self).__init__(root, transform=transform,
+                                        target_transform=target_transform)
+        self.ann_file = os.path.expanduser(ann_file)
+
+        # Read annotations and store in a dict
+        self.annotations = defaultdict(list)
+        with open(self.ann_file) as fh:
+            for line in fh:
+                img_id, caption = line.strip().split('\t')
+                self.annotations[img_id[:-2]].append(caption)
+
+        self.ids = list(sorted(self.annotations.keys()))
+
+    def __getitem__(self, index):
+        """
+        Args:
+            index (int): Index
+
+        Returns:
+            tuple: Tuple (image, target). target is a list of captions for the image.
+        """
+        img_id = self.ids[index]
+
+        # Image
+        filename = os.path.join(self.root, img_id)
+        img = Image.open(filename).convert('RGB')
+        if self.transform is not None:
+            img = self.transform(img)
+
+        # Captions
+        target = self.annotations[img_id]
+        if self.target_transform is not None:
+            target = self.target_transform(target)
+
+        return img, target
+
+    def __len__(self):
+        return len(self.ids)
diff --git a/modules/pytorch_jacinto_ai/vision/datasets/folder.py b/modules/pytorch_jacinto_ai/vision/datasets/folder.py
new file mode 100644 (file)
index 0000000..88ef557
--- /dev/null
@@ -0,0 +1,213 @@
+from .vision import VisionDataset
+
+from PIL import Image
+
+import os
+import os.path
+import sys
+import warnings
+
+warnings.filterwarnings('ignore', 'Corrupt EXIF data', UserWarning)
+warnings.filterwarnings('ignore', 'Possibly corrupt EXIF data', UserWarning)
+
+def has_file_allowed_extension(filename, extensions):
+    """Checks if a file is an allowed extension.
+
+    Args:
+        filename (string): path to a file
+        extensions (tuple of strings): extensions to consider (lowercase)
+
+    Returns:
+        bool: True if the filename ends with one of given extensions
+    """
+    return filename.lower().endswith(extensions)
+
+
+def is_image_file(filename):
+    """Checks if a file is an allowed image extension.
+
+    Args:
+        filename (string): path to a file
+
+    Returns:
+        bool: True if the filename ends with a known image extension
+    """
+    return has_file_allowed_extension(filename, IMG_EXTENSIONS)
+
+
+def make_dataset(dir, class_to_idx, extensions=None, is_valid_file=None):
+    images = []
+    dir = os.path.expanduser(dir)
+    if not ((extensions is None) ^ (is_valid_file is None)):
+        raise ValueError("Both extensions and is_valid_file cannot be None or not None at the same time")
+    if extensions is not None:
+        def is_valid_file(x):
+            return has_file_allowed_extension(x, extensions)
+    for target in sorted(class_to_idx.keys()):
+        d = os.path.join(dir, target)
+        if not os.path.isdir(d):
+            continue
+        for root, _, fnames in sorted(os.walk(d)):
+            for fname in sorted(fnames):
+                path = os.path.join(root, fname)
+                if is_valid_file(path):
+                    item = (path, class_to_idx[target])
+                    images.append(item)
+
+    return images
+
+
+class DatasetFolder(VisionDataset):
+    """A generic data loader where the samples are arranged in this way: ::
+
+        root/class_x/xxx.ext
+        root/class_x/xxy.ext
+        root/class_x/xxz.ext
+
+        root/class_y/123.ext
+        root/class_y/nsdf3.ext
+        root/class_y/asd932_.ext
+
+    Args:
+        root (string): Root directory path.
+        loader (callable): A function to load a sample given its path.
+        extensions (tuple[string]): A list of allowed extensions.
+            both extensions and is_valid_file should not be passed.
+        transform (callable, optional): A function/transform that takes in
+            a sample and returns a transformed version.
+            E.g, ``transforms.RandomCrop`` for images.
+        target_transform (callable, optional): A function/transform that takes
+            in the target and transforms it.
+        is_valid_file (callable, optional): A function that takes path of an Image file
+            and check if the file is a valid_file (used to check of corrupt files)
+            both extensions and is_valid_file should not be passed.
+
+     Attributes:
+        classes (list): List of the class names.
+        class_to_idx (dict): Dict with items (class_name, class_index).
+        samples (list): List of (sample path, class_index) tuples
+        targets (list): The class_index value for each image in the dataset
+    """
+
+    def __init__(self, root, loader, extensions=None, transform=None,
+                 target_transform=None, is_valid_file=None):
+        super(DatasetFolder, self).__init__(root, transform=transform,
+                                            target_transform=target_transform)
+        classes, class_to_idx = self._find_classes(self.root)
+        samples = make_dataset(self.root, class_to_idx, extensions, is_valid_file)
+        if len(samples) == 0:
+            raise (RuntimeError("Found 0 files in subfolders of: " + self.root + "\n"
+                                "Supported extensions are: " + ",".join(extensions)))
+
+        self.loader = loader
+        self.extensions = extensions
+
+        self.classes = classes
+        self.class_to_idx = class_to_idx
+        self.samples = samples
+        self.targets = [s[1] for s in samples]
+
+    def _find_classes(self, dir):
+        """
+        Finds the class folders in a dataset.
+
+        Args:
+            dir (string): Root directory path.
+
+        Returns:
+            tuple: (classes, class_to_idx) where classes are relative to (dir), and class_to_idx is a dictionary.
+
+        Ensures:
+            No class is a subdirectory of another.
+        """
+        if sys.version_info >= (3, 5):
+            # Faster and available in Python 3.5 and above
+            classes = [d.name for d in os.scandir(dir) if d.is_dir()]
+        else:
+            classes = [d for d in os.listdir(dir) if os.path.isdir(os.path.join(dir, d))]
+        classes.sort()
+        class_to_idx = {classes[i]: i for i in range(len(classes))}
+        return classes, class_to_idx
+
+    def __getitem__(self, index):
+        """
+        Args:
+            index (int): Index
+
+        Returns:
+            tuple: (sample, target) where target is class_index of the target class.
+        """
+        path, target = self.samples[index]
+        sample = self.loader(path)
+        if self.transform is not None:
+            sample = self.transform(sample)
+        if self.target_transform is not None:
+            target = self.target_transform(target)
+
+        return sample, target
+
+    def __len__(self):
+        return len(self.samples)
+
+
+IMG_EXTENSIONS = ('.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.pgm', '.tif', '.tiff', '.webp')
+
+
+def pil_loader(path):
+    # open path as file to avoid ResourceWarning (https://github.com/python-pillow/Pillow/issues/835)
+    with open(path, 'rb') as f:
+        img = Image.open(f)
+        return img.convert('RGB')
+
+
+def accimage_loader(path):
+    import accimage
+    try:
+        return accimage.Image(path)
+    except IOError:
+        # Potentially a decoding problem, fall back to PIL.Image
+        return pil_loader(path)
+
+
+def default_loader(path):
+    from ...vision import get_image_backend
+    if get_image_backend() == 'accimage':
+        return accimage_loader(path)
+    else:
+        return pil_loader(path)
+
+
+class ImageFolder(DatasetFolder):
+    """A generic data loader where the images are arranged in this way: ::
+
+        root/dog/xxx.png
+        root/dog/xxy.png
+        root/dog/xxz.png
+
+        root/cat/123.png
+        root/cat/nsdf3.png
+        root/cat/asd932_.png
+
+    Args:
+        root (string): Root directory path.
+        transform (callable, optional): A function/transform that  takes in an PIL image
+            and returns a transformed version. E.g, ``transforms.RandomCrop``
+        target_transform (callable, optional): A function/transform that takes in the
+            target and transforms it.
+        loader (callable, optional): A function to load an image given its path.
+        is_valid_file (callable, optional): A function that takes path of an Image file
+            and check if the file is a valid_file (used to check of corrupt files)
+
+     Attributes:
+        classes (list): List of the class names.
+        class_to_idx (dict): Dict with items (class_name, class_index).
+        imgs (list): List of (image path, class_index) tuples
+    """
+
+    def __init__(self, root, transform=None, target_transform=None,
+                 loader=default_loader, is_valid_file=None):
+        super(ImageFolder, self).__init__(root, loader, IMG_EXTENSIONS if is_valid_file is None else None,
+                                          transform=transform,
+                                          target_transform=target_transform,
+                                          is_valid_file=is_valid_file)
+        self.imgs = self.samples
diff --git a/modules/pytorch_jacinto_ai/vision/datasets/hmdb51.py b/modules/pytorch_jacinto_ai/vision/datasets/hmdb51.py
new file mode 100644 (file)
index 0000000..e601eaf
--- /dev/null
@@ -0,0 +1,97 @@
+import glob
+import os
+
+from .video_utils import VideoClips
+from .utils import list_dir
+from .folder import make_dataset
+from .vision import VisionDataset
+
+
+class HMDB51(VisionDataset):
+    """
+    HMDB51 <http://serre-lab.clps.brown.edu/resource/hmdb-a-large-human-motion-database/>`_
+    dataset.
+
+    HMDB51 is an action recognition video dataset.
+    This dataset consider every video as a collection of video clips of fixed size, specified
+    by ``frames_per_clip``, where the step in frames between each clip is given by
+    ``step_between_clips``.
+
+    To give an example, for 2 videos with 10 and 15 frames respectively, if ``frames_per_clip=5``
+    and ``step_between_clips=5``, the dataset size will be (2 + 3) = 5, where the first two
+    elements will come from video 1, and the next three elements from video 2.
+    Note that we drop clips which do not have exactly ``frames_per_clip`` elements, so not all
+    frames in a video might be present.
+
+    Internally, it uses a VideoClips object to handle clip creation.
+
+    Args:
+        root (string): Root directory of the HMDB51 Dataset.
+        annotation_path (str): path to the folder containing the split files
+        frames_per_clip (int): number of frames in a clip.
+        step_between_clips (int): number of frames between each clip.
+        fold (int, optional): which fold to use. Should be between 1 and 3.
+        train (bool, optional): if ``True``, creates a dataset from the train split,
+            otherwise from the ``test`` split.
+        transform (callable, optional): A function/transform that  takes in a TxHxWxC video
+            and returns a transformed version.
+
+    Returns:
+        video (Tensor[T, H, W, C]): the `T` video frames
+        audio(Tensor[K, L]): the audio frames, where `K` is the number of channels
+            and `L` is the number of points
+        label (int): class of the video clip
+    """
+
+    data_url = "http://serre-lab.clps.brown.edu/wp-content/uploads/2013/10/hmdb51_org.rar"
+    splits = {
+        "url": "http://serre-lab.clps.brown.edu/wp-content/uploads/2013/10/test_train_splits.rar",
+        "md5": "15e67781e70dcfbdce2d7dbb9b3344b5"
+    }
+
+    def __init__(self, root, annotation_path, frames_per_clip, step_between_clips=1,
+                 fold=1, train=True, transform=None):
+        super(HMDB51, self).__init__(root)
+        if not 1 <= fold <= 3:
+            raise ValueError("fold should be between 1 and 3, got {}".format(fold))
+
+        extensions = ('avi',)
+        self.fold = fold
+        self.train = train
+
+        classes = list(sorted(list_dir(root)))
+        class_to_idx = {classes[i]: i for i in range(len(classes))}
+        self.samples = make_dataset(self.root, class_to_idx, extensions, is_valid_file=None)
+        self.classes = classes
+        video_list = [x[0] for x in self.samples]
+        video_clips = VideoClips(video_list, frames_per_clip, step_between_clips)
+        self.indices = self._select_fold(video_list, annotation_path, fold, train)
+        self.video_clips = video_clips.subset(self.indices)
+        self.transform = transform
+
+    def _select_fold(self, video_list, annotation_path, fold, train):
+        target_tag = 1 if train else 2
+        name = "*test_split{}.txt".format(fold)
+        files = glob.glob(os.path.join(annotation_path, name))
+        selected_files = []
+        for f in files:
+            with open(f, "r") as fid:
+                data = fid.readlines()
+                data = [x.strip().split(" ") for x in data]
+                data = [x[0] for x in data if int(x[1]) == target_tag]
+                selected_files.extend(data)
+        selected_files = set(selected_files)
+        indices = [i for i in range(len(video_list)) if os.path.basename(video_list[i]) in selected_files]
+        return indices
+
+    def __len__(self):
+        return self.video_clips.num_clips()
+
+    def __getitem__(self, idx):
+        video, audio, info, video_idx = self.video_clips.get_clip(idx)
+        label = self.samples[self.indices[video_idx]][1]
+
+        if self.transform is not None:
+            video = self.transform(video)
+
+        return video, audio, label
diff --git a/modules/pytorch_jacinto_ai/vision/datasets/imagenet.py b/modules/pytorch_jacinto_ai/vision/datasets/imagenet.py
new file mode 100644 (file)
index 0000000..14a256c
--- /dev/null
@@ -0,0 +1,171 @@
+from __future__ import print_function
+import os
+import shutil
+import tempfile
+import torch
+from .folder import ImageFolder
+from .utils import check_integrity, download_and_extract_archive, extract_archive, \
+    verify_str_arg
+
+ARCHIVE_DICT = {
+    'train': {
+        'url': 'http://www.image-net.org/challenges/LSVRC/2012/nnoupb/ILSVRC2012_img_train.tar',
+        'md5': '1d675b47d978889d74fa0da5fadfb00e',
+    },
+    'val': {
+        'url': 'http://www.image-net.org/challenges/LSVRC/2012/nnoupb/ILSVRC2012_img_val.tar',
+        'md5': '29b22e2961454d5413ddabcf34fc5622',
+    },
+    'devkit': {
+        'url': 'http://www.image-net.org/challenges/LSVRC/2012/nnoupb/ILSVRC2012_devkit_t12.tar.gz',
+        'md5': 'fa75699e90414af021442c21a62c3abf',
+    }
+}
+
+
+class ImageNet(ImageFolder):
+    """`ImageNet <http://image-net.org/>`_ 2012 Classification Dataset.
+
+    Args:
+        root (string): Root directory of the ImageNet Dataset.
+        split (string, optional): The dataset split, supports ``train``, or ``val``.
+        download (bool, optional): If true, downloads the dataset from the internet and
+            puts it in root directory. If dataset is already downloaded, it is not
+            downloaded again.
+        transform (callable, optional): A function/transform that  takes in an PIL image
+            and returns a transformed version. E.g, ``transforms.RandomCrop``
+        target_transform (callable, optional): A function/transform that takes in the
+            target and transforms it.
+        loader (callable, optional): A function to load an image given its path.
+
+     Attributes:
+        classes (list): List of the class name tuples.
+        class_to_idx (dict): Dict with items (class_name, class_index).
+        wnids (list): List of the WordNet IDs.
+        wnid_to_idx (dict): Dict with items (wordnet_id, class_index).
+        imgs (list): List of (image path, class_index) tuples
+        targets (list): The class_index value for each image in the dataset
+    """
+
+    def __init__(self, root, split='train', download=False, **kwargs):
+        root = self.root = os.path.expanduser(root)
+        self.split = verify_str_arg(split, "split", ("train", "val"))
+
+        if download:
+            self.download()
+        wnid_to_classes = self._load_meta_file()[0]
+
+        super(ImageNet, self).__init__(self.split_folder, **kwargs)
+        self.root = root
+
+        self.wnids = self.classes
+        self.wnid_to_idx = self.class_to_idx
+        self.classes = [wnid_to_classes[wnid] for wnid in self.wnids]
+        self.class_to_idx = {cls: idx
+                             for idx, clss in enumerate(self.classes)
+                             for cls in clss}
+
+    def download(self):
+        if not check_integrity(self.meta_file):
+            tmp_dir = tempfile.mkdtemp()
+
+            archive_dict = ARCHIVE_DICT['devkit']
+            download_and_extract_archive(archive_dict['url'], self.root,
+                                         extract_root=tmp_dir,
+                                         md5=archive_dict['md5'])
+            devkit_folder = _splitexts(os.path.basename(archive_dict['url']))[0]
+            meta = parse_devkit(os.path.join(tmp_dir, devkit_folder))
+            self._save_meta_file(*meta)
+
+            shutil.rmtree(tmp_dir)
+
+        if not os.path.isdir(self.split_folder):
+            archive_dict = ARCHIVE_DICT[self.split]
+            download_and_extract_archive(archive_dict['url'], self.root,
+                                         extract_root=self.split_folder,
+                                         md5=archive_dict['md5'])
+
+            if self.split == 'train':
+                prepare_train_folder(self.split_folder)
+            elif self.split == 'val':
+                val_wnids =