From 91481a2708e533165329df37a30a226fd4774f65 Mon Sep 17 00:00:00 2001 From: Manu Mathew Date: Mon, 6 Jan 2020 11:14:00 +0530 Subject: [PATCH] release commit --- README.md | 1 - docs/Image_Classification.md | 2 +- docs/Quantization.md | 73 ++++++++++++++----- docs/Semantic_Segmentation.md | 34 +++++---- .../engine/infer_pixel2pixel.py | 18 ++++- .../engine/test_classification.py | 9 +-- .../engine/train_classification.py | 29 +++++--- .../engine/train_pixel2pixel.py | 29 +++++--- .../vision/losses/segmentation_loss.py | 2 +- .../vision/models/mobilenetv1.py | 13 ++-- .../vision/models/mobilenetv2.py | 15 ++-- .../models/pixel2pixel/fpn_pixel2pixel.py | 41 +++++++++-- .../vision/models/resnet.py | 4 +- .../xnn/layers/multi_task.py | 2 + .../xnn/quantize/quant_graph_module.py | 6 +- run_classification.sh | 6 +- run_depth.sh | 3 +- run_quantization.sh | 58 +++++++-------- run_segmentation.sh | 35 ++++++--- scripts/train_classification_main.py | 28 ++++--- scripts/train_depth_main.py | 10 +-- scripts/train_segmentation_main.py | 10 +-- 22 files changed, 272 insertions(+), 156 deletions(-) diff --git a/README.md b/README.md index 8eafb6b..3d0a207 100644 --- a/README.md +++ b/README.md @@ -37,7 +37,6 @@ The following examples are currently available. Click on each of the links below * Object Detection - coming soon..
* Object Keypoint Estimation - coming soon..
* Quantization
- * [**Post Training Calibration For Quantization**](docs/Quantization.md)
* [**Quantization Aware Training**](docs/Quantization.md)
Some of the common training and validation commands are provided in shell scripts (.sh files) in the root folder. diff --git a/docs/Image_Classification.md b/docs/Image_Classification.md index a05cc5d..e832ee8 100644 --- a/docs/Image_Classification.md +++ b/docs/Image_Classification.md @@ -112,7 +112,7 @@ |Dataset |Mode Name |Resize Resolution|Crop Resolution|Complexity (GigaMACS)|MeanIoU% | |---------|---------- |----------- |---------- |-------- |-------- | |ImageNet |MobileNetV1 |256x256 |224x224 |0.568 |**71.83**| -|ImageNet |MobileNetV2 |256x256 |224x224 |0.296 |**71.89**| +|ImageNet |MobileNetV2 |256x256 |224x224 |0.296 |**72.13**| |ImageNet |ResNet50 |256x256 |224x224 | | | |. |ImageNet |MobileNetV1[1]|256x256 |224x224 |0.569 |70.60 | diff --git a/docs/Quantization.md b/docs/Quantization.md index ff07b76..376a29a 100644 --- a/docs/Quantization.md +++ b/docs/Quantization.md @@ -6,16 +6,26 @@ In order to make the activations quantization friendly, it is important to clip

This code also contains two wrapper modules called QuantCalibrateModule and QuantTrainModule that will handle such tasks as replacing ReLUs with PACT2, quantizing weights, quantizing activations etc - so the user does not need to do any of these things - the only thing that is required is to wrap the user's module in these wapper modules as explained below. -There are two primary methods of quantization - Post Training Calibration For Quantization and Trained Quantization (a.k.a Quantization Aware Training). +There are two primary methods of quantization - Post Training Calibration For Quantization and Trained Quantization (a.k.a Quantization Aware Training). These required the following (only for information - automatically done by QuantCalibrateModule and QuantTrainModule): +- Replace all the ReLU, ReLU6 layers in the model by PACT2. Insert PACT2 after Convolution+BatchNorm if a ReLU is missing after that. Insert PACT2 anywhere else required - where activation range clipping and range collection is required. For example it can ne after the Fully Connected Layer. We use forward post hooks of PyTorch nn.Modules to call these extra activation functions. Thus we are able to add these extra activations without disturbing the loading of existing pre-trained weights. +- Clip the weights to an appropriate range if the weight range is very high. +- Quantize the weights and activations during the forward pass. +- Other modifications required for Calibration / Trained Quantization -## Post Training Calibration For Quantization +Trained Quantization typically provides better accuracy compared to Post Training Calibration. Trained Quantization is also easy to incorporate into your existing PyTorch training code - hence we recommend to use it. Here, we explain both methods for the sake of completeness. +## Post Training Calibration For Quantization +Note: Although Calibration methods are described in detail here, this is not our recommended method as this does not yield the best results in some cases. There are specific scenarios where Trained Quantization is not possible (eg. to import a floating point model for embedded inference) and Calibration is mainly meant for such scenarios. However, Calibration methods are improving at a rapid pace and it is possible that Calibration methods will be good enough for all cases in the future - but for now, you can skip to the section on Trained Quantization as it is the recommended method. + Post Training Calibration & Quantization can take a model trained in floating point and with a few steps convert it to a model that is friendly for quantized inference. Compared to the alternative (Trained Quantization), the advantages of this method are: - Calibration is fast - a typical calibration finishes in a few minutes. - Ground truth is not required - just input images are sufficient. - Loss function or backward (back-propagation) are not required. -Thus, this is the preferred method of quantization from an ease of use point of view. As explained earlier, in this method, the training happens entirely in floating point. The inference (possibly in an embedded device) happens in fixed point. In between training and fixed point inference, the model goes through the step called Calibration with some sample images. The Calibration happens in PC and Quantized Inference happens in the embedded device. Calibration basically tries to make the quantized output similar to the floating point output - by choosing appropriate activation ranges, weights and biases. +The disadvantage is: +- This is a fast and approximate method of quantizing the model and may not always yield the best accuracy. + +In Post Training Calibration, the training happens entirely in floating point. The inference (possibly in an embedded device) happens in fixed point. In between training and fixed point inference, the model goes through the step called Calibration with some sample images. The Calibration happens in PC and Quantized Inference happens in the embedded device. Calibration basically tries to make the quantized output similar to the floating point output - by choosing appropriate activation ranges, weights and biases. A block diagram of Post Training Calibration is shown below:

@@ -28,25 +38,19 @@ Depending on how the activation range is collected and Quantization is done, we Out of these methods, **Advanced Calib** is our recommended Calibration method as of now, as it has the best trade-off between the Accuracy and the features required during fixed point inference. All the Calibration scripts that we have in this page uses "Advanced Calib" by default. Other Calibration methods described here are for information only. -#### How to use QuantCalibrateModule -- The section briefly explains how to make use of our helper/wrapper module to do the calibration of your model. For further details, please see pytorch_jacinto_ai.engine.train_classification.py and pytorch_jacinto_ai.engine.train_pixel2pixel.py. The step by step process is as follows: - -###### Model preparation: -- Replace all the ReLU, ReLU6 layers in the model by PACT2. Insert PACT2 after Convolution+BatchNorm if a ReLU is missing after that. Insert PACT2 anywhere else required - where activation range clipping and range collection is required. For example it can ne after the Fully Connected Layer. We use forward post hooks of PyTorch nn.Modules to call these extra activation functions. Thus we are able to add these extra activations without disturbing the loading of existing pre-trained weights. -- Clip the weights to an appropriate range if the weight range is very high. -- Note that this Model preparation is automatically done by QuantCalibrateModule +In order to do Calibration easily we have a developed a wrapper module called QuantCalibrateModule, which is located in pytorch_jacinto_ai.xnn.quantize.QuantCalibrateModule. We make use of a kind of Parametric Activation called **PACT2** in order to store the calibrated ranges of activations. PACT2 is a improved form of PACT [[1]]. **PACT2 uses power of 2 activation ranges** for activation clipping. PACT2 can learn ranges very quickly (using a statistic method) without back propagation - this feature makes it quite attractive for Calibration. Our wrapper module replaces all the ReLUs in the model with PACT2. It also inserts PACT2 in other places where activation ranges need to be collected. Statistical range clipping in PACT2 improves the Quantized Accuracy over simple min-max range clipping. -###### Forward iterations: +#### What happens during Calibration? - For each iteration perform a forward in floating point using the original weights and biases. During this pass PACT2 layers will collect output ranges using histogram and running average. - In addition, perform Convolution+BatchNorm merging and quantization of the resulting weights. These quantized and de-quantized weights are used in a forward pass. Ranges collected by PACT2 is used for activation quantization (and de-quantization) to generate quantized output. - The floating point output and quantized output are compared using statistic measures. Using such statistic measures, we can adjust the weights and biases of Convolutions and Batch Normalization layers - so that the quantized output becomes closer to the floating point output. -- Within a few iterations, we could get reasonable quantization accuracy for several models that we tried this method on. - -In order to do Calibration easily we have a developed a wrapper module called QuantCalibrateModule, which is located in pytorch_jacinto_ai.xnn.quantize.QuantCalibrateModule. We make use of a kind of Parametric Activation called **PACT2** in order to store the calibrated ranges of activations. PACT2 is a improved form of PACT [[1]]. **PACT2 uses power of 2 activation ranges** for activation clipping. PACT2 can learn ranges very quickly (using a statistic method) without back propagation - this feature makes it quite attractive for Calibration. Our wrapper module replaces all the ReLUs in the model with PACT2. It also inserts PACT2 in other places where activation ranges need to be collected. Statistical range clipping in PACT2 improves the Quantized Accuracy over simple min-max range clipping. +- Within a few iterations, we should get reasonable quantization accuracy. +#### How to use QuantCalibrateModule As explained, the method of **Calibration does not need ground truth, loss function or back propagation.** However in the calibration script, we make use of ground truth to measure the loss/accuracy even in the Calibration stage - although that is not necessary. -###### Sample Code Snippet: +The section briefly explains how to make use of our helper/wrapper module to do the calibration of your model. For further details, please see pytorch_jacinto_ai.engine.train_classification.py and pytorch_jacinto_ai.engine.train_pixel2pixel.py. The step by step process is as follows: + ``` # create your model here: model = ... @@ -73,6 +77,7 @@ torch.save(model.module.state_dict(), os.path.join(save_path,'model.pth')) torch.onnx.export(model.module, dummy_input, os.path.join(save_path,'model.onnx'), export_params=True, verbose=False) ``` +Careful attention needs to be given to how the pretrained model is loaded and trained model is saved as shown in the above code snippet. Few examples of calibration are provided below. These commands are also listed in the file **run_quantization.sh** for convenience.
@@ -93,20 +98,52 @@ python ./scripts/train_segmentation_main.py --phase calibration --dataset_name c --batch_size 12 --quantize True --epochs 1 --epoch_size 100 ``` -## Trained Quantization (a.k.a Quantization Aware Training) -As explained in the previous section, Calibration is our preferred method of making a model quantization friendly. However, in exceptional cases, it is possible that the drop in accuracy during calibration is more than acceptable. In this case, Trained Quantization can be used. +## Trained Quantization a.k.a Quantization Aware Training (recommended method) +Trained Quantization typically provides better accuracy compared to Post Training Calibration. -Unlike Calibration, Trained Quantization involves ground truth, loss function and back propagation. The most popular method of trained quantization is [[4]]. It takes care of merging Convolution layers with the adjascent Batch Normalization layers (on-the-fly) during the quantized training (if this merging is not correctly done, quantized training may not improve the accuracy). In addition, we use Straight-Through Estimation (STE) [[2,3]] to improve the gradient flow in back-propagation. Also, the statistical range clipping in PACT2 improves the Quantized Accuracy over simple min-max range clipping. +Unlike Calibration, Trained Quantization involves ground truth, loss function and back propagation - hence it is very similar to the original floating point training. + +The most popular method of trained quantization is [[4]]. It takes care of merging Convolution layers with the adjascent Batch Normalization layers (on-the-fly) during the quantized training (if this merging is not correctly done, quantized training may not improve the accuracy). In addition, we use Straight-Through Estimation (STE) [[2,3]] to improve the gradient flow in back-propagation. Also, the statistical range clipping in PACT2 improves the Quantized Accuracy over simple min-max range clipping. Note: Instead of STE and statistical ranges for PACT2, we also tried out approximate gradients for scale and trained quantization thresholds proposed in [[5]] (We did not use the gradient nomralization and log-domain training mentioned in the paper). We found that method to be able to learn the clipping thresholds for initial few epochs, but became unstable after a few epochs and loss became high. Compared to that learned thresholds method, our statistical PACT2 ranges/thresholds combined with STE is simple and stable. A block diagram of Trained Quantization is shown below:

+#### What happens during Trained Quantization? +- For each iteration perform a forward in floating point using the original weights and biases. During this pass PACT2 layers will collect output ranges using histogram and running average. +- In addition, perform Convolution+BatchNorm merging and quantization of the resulting weights. These quantized and de-quantized weights are used in a forward pass. Ranges collected by PACT2 is used for activation quantization (and de-quantization) to generate quantized output. +- Backpropagation with STE will update the parameters of the model to reduce the loss with quantization. +- Within a few epochs, we should get reasonable quantization accuracy. + +#### How to use QuantTrainModule In order to enable quantized training, we have developed the wrapper class pytorch_jacinto_ai.xnn.quantize.QuantTrainModule. The usage of this module can be seen in pytorch_jacinto_ai.engine.train_classification.py and pytorch_jacinto_ai.engine.train_pixel2pixel.py. ``` +# create your model here: +model = ... + +# create a dummy input - this is required to analyze the model - fill in the input image size expected by your model. +dummy_input = torch.rand((1,3,384,768)) + +#wrap your model in QuantTrainModule. Once it is wrapped, the actual model is in model.module model = pytorch_jacinto_ai.xnn.quantize.QuantTrainModule(model, dummy_input=dummy_input) + +# load your pretrained weights here into model.module +pretrained_data = torch.load(pretrained_path) +model.module.load_state_dict(pretrained_data) + +# your training loop here with with loss, backward, optimizer and scheduler. +# this is the usual training loop - but use a lower learning rate such as 5e-5 +.... +.... + +# save the model - the calibrated module is in model.module +torch.save(model.module.state_dict(), os.path.join(save_path,'model.pth')) +torch.onnx.export(model.module, dummy_input, os.path.join(save_path,'model.onnx'), export_params=True, verbose=False) ``` + +As can be seen, it is very easy to incorporate QuantTrainModule in your existing training code as the only thing required is to wrap your original model in QuantTrainModule. Careful attention needs to be given to how the pretrained model is loaded and trained model is saved as shown in the above code snippet. + The Model Preparation steps used for Calibration applies for Trained Quantization as well. - Note that this Model preparation is automatically done by QuantTrainModule. The resultant model can then be used for training as usual and it will take care of quantization constraints during the training forward and backward passes. One word of caution is that our current implementation of Trained Quantization is a bit slow. The reason for this slowdown is that our implementation is using the top-level python layer of PyTorch and not the underlying C++ layer. But with PyTorch natively supporting the functionality required for quantization under the hood - we hope that this speed issue can be resolved in a future update. diff --git a/docs/Semantic_Segmentation.md b/docs/Semantic_Segmentation.md index a341c86..b2c86b7 100644 --- a/docs/Semantic_Segmentation.md +++ b/docs/Semantic_Segmentation.md @@ -83,22 +83,24 @@ Inference can be done as follows (fill in the path to the pretrained model):
### Cityscapes Segmentation -|Dataset |Mode Architecture |Backbone Model|Backbone Stride|Resolution |Complexity (GigaMACS)|MeanIoU% |Model Configuration Name | -|--------- |---------- |----------- |-------------- |-----------|-------- |----------|--------------------------------- | -|Cityscapes |DeepLabV3Lite |MobileNetV2 |16 |768x384 |3.54 |**69.13** |deeplabv3lite_mobilenetv2_tv | -|Cityscapes |FPNPixel2Pixel |MobileNetV2 |32 |768x384 |3.84 |**70.39** |fpn_pixel2pixel_mobilenetv2_tv | -|Cityscapes |FPNPixel2Pixel |MobileNetV2 |64 |1536x768 |3.96 |**71.28** |fpn_pixel2pixel_mobilenetv2_tv_es64| -|Cityscapes |FPNPixel2Pixel |MobileNetV2 |64 |2048x1024 |7.03 | |fpn_pixel2pixel_mobilenetv2_tv_es64| -|Cityscapes |DeepLabV3Lite |MobileNetV2 |16 |1536x768 |14.48 |**73.59** |deeplabv3lite_mobilenetv2_tv | -|Cityscapes |FPNPixel2Pixel |MobileNetV2 |32 |1536x768 |15.37 |**74.98** |fpn_pixel2pixel_mobilenetv2_tv | - -|Dataset |Mode Architecture |Backbone Model|Backbone Stride|Resolution |Complexity (GigaMACS)|MeanIoU% |Model Configuration Name | -|--------- |---------- |----------- |-------------- |-----------|-------- |----------|--------------------------------- | -|Cityscapes |ERFNet[[4]] | | |1024x512 |27.705 |69.7 |N/A | -|Cityscapes |SwiftNetMNV2[[5]] |MobileNetV2 | |2048x1024 |41.0 |75.3 |N/A | -|Cityscapes |DeepLabV3Plus[[6,7]]|MobileNetV2 |16 | |21.27 |70.71 |N/A | -|Cityscapes |DeepLabV3Plus[[6,7]]|Xception65 |16 | |418.64 |78.79 |N/A | - +|Dataset |Mode Architecture |Backbone Model|Backbone Stride|Resolution |Complexity (GigaMACS)|MeanIoU% |Model Configuration Name | +|--------- |---------- |----------- |-------------- |-----------|-------- |----------|---------------------------------------- | +|Cityscapes |FPNPixel2Pixel with DWASPP|MobileNetV2 |64 |768x384 |0.99 |62.43 |fpn_pixel2pixel_aspp_mobilenetv2_tv_es64 | +|Cityscapes |DeepLabV3Lite with DWASPP |MobileNetV2 |16 |768x384 |**3.54** |**69.13** |**deeplabv3lite_mobilenetv2_tv** | +|Cityscapes |FPNPixel2Pixel |MobileNetV2 |32 |768x384 |3.66 |- |fpn_pixel2pixel_mobilenetv2_tv | +|Cityscapes |FPNPixel2Pixel with DWASPP|MobileNetV2 |32 |768x384 |3.84 |70.39 |fpn_pixel2pixel_aspp_mobilenetv2_tv | +|Cityscapes |FPNPixel2Pixel |MobileNetV2 |64 |1536x768 |3.85 |69.82 |fpn_pixel2pixel_mobilenetv2_tv_es64 | +|Cityscapes |FPNPixel2Pixel with DWASPP|MobileNetV2 |64 |1536x768 |**3.96** |**71.28** |**fpn_pixel2pixel_aspp_mobilenetv2_tv_es64**| +|Cityscapes |FPNPixel2Pixel with DWASPP|MobileNetV2 |64 |2048x1024 |7.03 |72.67 |fpn_pixel2pixel_aspp_mobilenetv2_tv_es64 | +|Cityscapes |DeepLabV3Lite with DWASPP |MobileNetV2 |16 |1536x768 |14.48 |73.59 |deeplabv3lite_mobilenetv2_tv | +|Cityscapes |FPNPixel2Pixel with DWASPP|MobileNetV2 |32 |1536x768 |**15.37** |**74.98** |**fpn_pixel2pixel_aspp_mobilenetv2_tv** | + +|Dataset |Mode Architecture |Backbone Model|Backbone Stride|Resolution |Complexity (GigaMACS)|MeanIoU% |Model Configuration Name | +|--------- |---------- |----------- |-------------- |-----------|-------- |----------|---------------------------------------- | +|Cityscapes |ERFNet[[4]] | | |1024x512 |27.705 |69.7 |N/A | +|Cityscapes |SwiftNetMNV2[[5]] |MobileNetV2 | |2048x1024 |41.0 |75.3 |N/A | +|Cityscapes |DeepLabV3Plus[[6,7]] |MobileNetV2 |16 | |21.27 |70.71 |N/A | +|Cityscapes |DeepLabV3Plus[[6,7]] |Xception65 |16 | |418.64 |78.79 |N/A | ## References [1]The Cityscapes Dataset for Semantic Urban Scene Understanding, Marius Cordts, Mohamed Omran, Sebastian Ramos, Timo Rehfeld, Markus Enzweiler, Rodrigo Benenson, Uwe Franke, Stefan Roth, Bernt Schiele, CVPR 2016, https://www.cityscapes-dataset.com/ diff --git a/modules/pytorch_jacinto_ai/engine/infer_pixel2pixel.py b/modules/pytorch_jacinto_ai/engine/infer_pixel2pixel.py index 585ff3e..d61ac15 100644 --- a/modules/pytorch_jacinto_ai/engine/infer_pixel2pixel.py +++ b/modules/pytorch_jacinto_ai/engine/infer_pixel2pixel.py @@ -44,6 +44,7 @@ def get_config(): args.model_config.output_channels = None # number of output channels args.model_config.input_channels = None # number of input channels args.model_config.num_classes = None # number of classes (for segmentation) + args.model_config.output_range = None # max range of output args.model_config.num_decoders = None # number of decoders to use. [options: 0, 1, None] args.sky_dir = False @@ -275,8 +276,7 @@ def main(args): assert args.pretrained is not None, 'pretrained path must be provided' # onnx generation is filing for post quantized module - args.generate_onnx = False if (args.quantize) else args.generate_onnx - + # args.generate_onnx = False if (args.quantize) else args.generate_onnx ################################################# # set some global flags and initializations # keep it in args for now - although they don't belong here strictly @@ -380,9 +380,11 @@ def main(args): # load pretrained weights xnn.utils.load_weights_check(get_model_orig(model), pretrained=pretrained_data, change_names_dict=change_names_dict) + if args.generate_onnx: + write_onnx_model(args, model, save_path, name='model_best.onnx') ################################################# # multi gpu mode is not yet supported with quantization in evaluate - if args.gpu_mode and (args.phase=='training'): + if args.gpu_mode and ('training' in args.phase): model = torch.nn.DataParallel(model) ################################################# @@ -941,6 +943,16 @@ def create_video(args, infer_path): op_file_name = args.data_path.split('/')[-1] os.system(' ffmpeg -framerate 30 -pattern_type glob -i "{}/*.png" -c:v libx264 -vb 50000k -qmax 20 -r 10 -vf scale=1024:512 -pix_fmt yuv420p {}.MP4'.format(infer_path, op_file_name)) +def write_onnx_model(args, model, save_path, name='checkpoint.onnx'): + is_cuda = next(model.parameters()).is_cuda + input_list = create_rand_inputs(args, is_cuda=is_cuda) + # + model.eval() + torch.onnx.export(model, input_list, os.path.join(save_path, name), export_params=True, verbose=False) + # torch onnx export does not update names. Do it using onnx.save + + + if __name__ == '__main__': train_args = get_config() main(train_args) diff --git a/modules/pytorch_jacinto_ai/engine/test_classification.py b/modules/pytorch_jacinto_ai/engine/test_classification.py index 554cc5e..909f0f5 100644 --- a/modules/pytorch_jacinto_ai/engine/test_classification.py +++ b/modules/pytorch_jacinto_ai/engine/test_classification.py @@ -146,16 +146,16 @@ def main(args): is_cuda = next(model.parameters()).is_cuda dummy_input = create_rand_inputs(args, is_cuda=is_cuda) # - if args.phase == 'training': + if 'training' in args.phase: model = xnn.quantize.QuantTrainModule(model, per_channel_q=args.per_channel_q, bitwidth_weights=args.bitwidth_weights, bitwidth_activations=args.bitwidth_activations, dummy_input=dummy_input) - elif args.phase == 'calibration': + elif 'calibration' in args.phase: model = xnn.quantize.QuantCalibrateModule(model, per_channel_q=args.per_channel_q, bitwidth_weights=args.bitwidth_weights, bitwidth_activations=args.bitwidth_activations, bias_calibration=args.bias_calibration, lr_calib=args.lr_calib, dummy_input=dummy_input) - elif args.phase == 'validation': + elif 'validation' in args.phase: # Note: bias_calibration is not enabled in test model = xnn.quantize.QuantTestModule(model, per_channel_q=args.per_channel_q, bitwidth_weights=args.bitwidth_weights, bitwidth_activations=args.bitwidth_activations, @@ -165,7 +165,6 @@ def main(args): assert False, f'invalid phase {args.phase}' # - # load pretrained xnn.utils.load_weights_check(get_model_orig(model), pretrained=pretrained_data, change_names_dict=change_names_dict) @@ -191,7 +190,7 @@ def main(args): ################################################# # multi gpu mode is not yet supported with quantization in evaluate - if args.parallel_model and (args.phase=='training'): + if args.parallel_model and ('training' in args.phase): model = torch.nn.DataParallel(model) ################################################# diff --git a/modules/pytorch_jacinto_ai/engine/train_classification.py b/modules/pytorch_jacinto_ai/engine/train_classification.py index 83a0e9d..bf9ffec 100644 --- a/modules/pytorch_jacinto_ai/engine/train_classification.py +++ b/modules/pytorch_jacinto_ai/engine/train_classification.py @@ -113,7 +113,7 @@ cudnn.benchmark = True def main(args): assert (not hasattr(args, 'evaluate')), 'args.evaluate is deprecated. use args.phase=training or calibration or validation' - assert args.phase in ('training', 'calibration', 'validation'), f'invalid phase {args.phase}' + assert is_valid_phase(args.phase), f'invalid phase {args.phase}' assert not hasattr(args, 'model_surgery'), 'the argument model_surgery is deprecated, it is not needed now - remove it' if (args.phase == 'validation' and args.bias_calibration): @@ -205,17 +205,17 @@ def main(args): is_cuda = next(model.parameters()).is_cuda dummy_input = create_rand_inputs(args, is_cuda=is_cuda) # - if args.phase == 'training': + if 'training' in args.phase: model = xnn.quantize.QuantTrainModule(model, per_channel_q=args.per_channel_q, histogram_range=args.histogram_range, bitwidth_weights=args.bitwidth_weights, bitwidth_activations=args.bitwidth_activations, dummy_input=dummy_input) - elif args.phase == 'calibration': + elif 'calibration' in args.phase: model = xnn.quantize.QuantCalibrateModule(model, per_channel_q=args.per_channel_q, bitwidth_weights=args.bitwidth_weights, bitwidth_activations=args.bitwidth_activations, histogram_range=args.histogram_range, bias_calibration=args.bias_calibration, lr_calib=args.lr_calib, dummy_input=dummy_input) - elif args.phase == 'validation': + elif 'validation' in args.phase: # Note: bias_calibration is not used in test model = xnn.quantize.QuantTestModule(model, per_channel_q=args.per_channel_q, bitwidth_weights=args.bitwidth_weights, bitwidth_activations=args.bitwidth_activations, @@ -233,7 +233,7 @@ def main(args): count_flops(args, model) ################################################# - if args.generate_onnx and ((args.phase in ('training','calibration')) or (args.run_soon == False)): + if args.generate_onnx and (any(p in args.phase for p in ('training','calibration')) or (args.run_soon == False)): write_onnx_model(args, get_model_orig(model), save_path) # @@ -267,7 +267,7 @@ def main(args): model_module = model.module if hasattr(model, 'module') else model if args.lr_clips is not None: - learning_rate_clips = args.lr_clips if args.phase == 'training' else 0.0 + learning_rate_clips = args.lr_clips if 'training' in args.phase else 0.0 clips_decay = args.bias_decay if (args.bias_decay is not None and args.bias_decay != 0.0) else args.weight_decay clips_params = [p for n,p in model_module.named_parameters() if 'clips' in n] other_params = [p for n,p in model_module.named_parameters() if 'clips' not in n] @@ -290,7 +290,7 @@ def main(args): if args.scheduler == 'step': print("=> milestones : {}".format(args.milestones)) - learning_rate = args.lr if (args.phase == 'training') else 0.0 + learning_rate = args.lr if ('training'in args.phase) else 0.0 if args.optimizer == 'adam': optimizer = torch.optim.Adam(param_groups, learning_rate, betas=(args.momentum, args.beta)) elif args.optimizer == 'sgd': @@ -364,13 +364,17 @@ def main(args): # ################################################################### +def is_valid_phase(phase): + phases = ('training', 'calibration', 'validation') + return any(p in phase for p in phases) + + def close(args): if args.logger is not None: del args.logger args.logger = None # args.best_prec1 = -1 -# def get_save_path(args, phase=None): @@ -434,7 +438,7 @@ def train(args, train_loader, model, criterion, optimizer, epoch): train_iter = iter(train_loader) last_update_iter = -1 - progressbar_color = (Fore.YELLOW if args.phase=='calibration' else Fore.WHITE) + progressbar_color = (Fore.YELLOW if (('calibration' in args.phase) or ('training' in args.phase and args.quantize)) else Fore.WHITE) print('{}'.format(progressbar_color), end='') for iteration in range(args.iters): @@ -457,7 +461,7 @@ def train(args, train_loader, model, criterion, optimizer, epoch): top1.update(prec1[0], input_size[0]) top5.update(prec5[0], input_size[0]) - if args.phase == 'training': + if 'training' in args.phase: # zero gradients so that we can accumulate gradients if (iteration % args.iter_size) == 0: optimizer.zero_grad() @@ -685,7 +689,10 @@ def get_validation_transform(args): return val_transform def get_transforms(args): - train_transform = get_train_transform(args) + # Provision to train with val transform - provide rand_scale as (0, 0) + # Fixing the train-test resolution discrepancy, https://arxiv.org/abs/1906.06423 + always_use_val_transform = (args.rand_scale[0] == 0) + train_transform = get_validation_transform(args) if always_use_val_transform else get_train_transform(args) val_transform = get_validation_transform(args) return train_transform, val_transform diff --git a/modules/pytorch_jacinto_ai/engine/train_pixel2pixel.py b/modules/pytorch_jacinto_ai/engine/train_pixel2pixel.py index 4589a1a..9273786 100644 --- a/modules/pytorch_jacinto_ai/engine/train_pixel2pixel.py +++ b/modules/pytorch_jacinto_ai/engine/train_pixel2pixel.py @@ -178,7 +178,7 @@ def main(args): 'torch version must be 1.1 or higher, due to the change in scheduler.step() and optimiser.step() call order' assert (not hasattr(args, 'evaluate')), 'args.evaluate is deprecated. use args.phase=training or calibration or validation' - assert args.phase in ('training', 'calibration', 'validation'), f'invalid phase {args.phase}' + assert is_valid_phase(args.phase), f'invalid phase {args.phase}' assert not hasattr(args, 'model_surgery'), 'the argument model_surgery is deprecated, it is not needed now - remove it' if (args.phase == 'validation' and args.bias_calibration): @@ -321,14 +321,14 @@ def main(args): is_cuda = next(model.parameters()).is_cuda dummy_input = create_rand_inputs(args, is_cuda=is_cuda) # - if args.phase == 'training': + if 'training' in args.phase: model = xnn.quantize.QuantTrainModule(model, per_channel_q=args.per_channel_q, histogram_range=args.histogram_range, bitwidth_weights=args.bitwidth_weights, bitwidth_activations=args.bitwidth_activations, dummy_input=dummy_input) - elif args.phase == 'calibration': + elif 'calibration' in args.phase: model = xnn.quantize.QuantCalibrateModule(model, per_channel_q=args.per_channel_q, bitwidth_weights=args.bitwidth_weights, bitwidth_activations=args.bitwidth_activations, histogram_range=args.histogram_range, bias_calibration=args.bias_calibration, lr_calib=args.lr_calib, dummy_input=dummy_input) - elif args.phase == 'validation': + elif 'validation' in args.phase: # Note: bias_calibration is not emabled model = xnn.quantize.QuantTestModule(model, per_channel_q=args.per_channel_q, bitwidth_weights=args.bitwidth_weights, bitwidth_activations=args.bitwidth_activations, @@ -346,7 +346,7 @@ def main(args): count_flops(args, model) ################################################# - if args.generate_onnx and ((args.phase in ('training','calibration')) or (args.run_soon == False)): + if args.generate_onnx and (any(args.phase in p for p in ('training','calibration')) or (args.run_soon == False)): write_onnx_model(args, get_model_orig(model), save_path) # @@ -440,7 +440,7 @@ def main(args): assert(args.solver in ['adam', 'sgd']) print('=> setting {} solver'.format(args.solver)) if args.lr_clips is not None: - learning_rate_clips = args.lr_clips if args.phase == 'training' else 0.0 + learning_rate_clips = args.lr_clips if 'training' in args.phase else 0.0 clips_decay = args.bias_decay if (args.bias_decay is not None and args.bias_decay != 0.0) else args.weight_decay clips_params = [p for n,p in model.named_parameters() if 'clips' in n] other_params = [p for n,p in model.named_parameters() if 'clips' not in n] @@ -450,7 +450,7 @@ def main(args): param_groups = [{'params': filter(lambda p: p.requires_grad, model.parameters()), 'weight_decay': args.weight_decay}] # - learning_rate = args.lr if (args.phase == 'training') else 0.0 + learning_rate = args.lr if ('training'in args.phase) else 0.0 if args.solver == 'adam': optimizer = torch.optim.Adam(param_groups, learning_rate, betas=(args.momentum, args.beta)) elif args.solver == 'sgd': @@ -539,7 +539,7 @@ def main(args): val_writer.file_writer.flush() # adjust the learning rate using lr scheduler - if args.phase == 'training': + if 'training' in args.phase: scheduler.step() # # @@ -548,6 +548,10 @@ def main(args): close(args) # +################################################################### +def is_valid_phase(phase): + phases = ('training', 'calibration', 'validation') + return any(p in phase for p in phases) ################################################################### @@ -583,7 +587,7 @@ def train(args, train_dataset, train_loader, model, optimizer, epoch, train_writ last_update_iter = -1 # change color to yellow for calibration - progressbar_color = (Fore.YELLOW if args.phase=='calibration' else Fore.WHITE) + progressbar_color = (Fore.YELLOW if (('calibration' in args.phase) or ('training' in args.phase and args.quantize)) else Fore.WHITE) print('{}'.format(progressbar_color), end='') ########################## @@ -621,7 +625,7 @@ def train(args, train_dataset, train_loader, model, optimizer, epoch, train_writ if args.model_config.multi_task_type is not None and len(args.model_config.output_channels) > 1: xnn.layers.set_losses(model, loss_list_orig) - if args.phase == 'training': + if 'training' in args.phase: # zero gradients so that we can accumulate gradients if (iter % args.iter_size) == 0: optimizer.zero_grad() @@ -1042,7 +1046,10 @@ def get_validation_transform(args): def get_transforms(args): - train_transform = get_train_transform(args) + # Provision to train with val transform - provide rand_scale as (0, 0) + # Fixing the train-test resolution discrepancy, https://arxiv.org/abs/1906.06423 + always_use_val_transform = (args.rand_scale[0] == 0) + train_transform = get_validation_transform(args) if always_use_val_transform else get_train_transform(args) val_transform = get_validation_transform(args) return train_transform, val_transform diff --git a/modules/pytorch_jacinto_ai/vision/losses/segmentation_loss.py b/modules/pytorch_jacinto_ai/vision/losses/segmentation_loss.py index 60ea113..22e7daf 100755 --- a/modules/pytorch_jacinto_ai/vision/losses/segmentation_loss.py +++ b/modules/pytorch_jacinto_ai/vision/losses/segmentation_loss.py @@ -106,7 +106,7 @@ class SegmentationMetricsCalc(object): class SegmentationLoss(torch.nn.Module): def __init__(self, *args, ignore_index = 255, weight=None, **kwargs): super().__init__() - self.register_buffer('weight', torch.FloatTensor(weight)) if weight is not None else None + self.weight = None if weight is None else self.register_buffer('weight', torch.FloatTensor(weight)) self.ignore_index = ignore_index self.is_avg = False # diff --git a/modules/pytorch_jacinto_ai/vision/models/mobilenetv1.py b/modules/pytorch_jacinto_ai/vision/models/mobilenetv1.py index eaa0549..c0701a1 100644 --- a/modules/pytorch_jacinto_ai/vision/models/mobilenetv1.py +++ b/modules/pytorch_jacinto_ai/vision/models/mobilenetv1.py @@ -13,7 +13,7 @@ def get_config(): model_config.num_classes = 1000 model_config.width_mult = 1. model_config.expand_ratio = 6 - model_config.strides = (2,2,2,2,2) + model_config.strides = None #(2,2,2,2,2) model_config.activation = xnn.layers.DefaultAct2d model_config.use_blocks = False model_config.kernel_size = 3 @@ -37,11 +37,12 @@ class MobileNetV1Base(torch.nn.Module): self.num_classes = self.model_config.num_classes # strides of various layers - s0 = model_config.strides[0] - s1 = model_config.strides[1] - s2 = model_config.strides[2] - s3 = model_config.strides[3] - s4 = model_config.strides[4] + strides = model_config.strides if (model_config.strides is not None) else (2,2,2,2,2) + s0 = strides[0] + s1 = strides[1] + s2 = strides[2] + s3 = strides[3] + s4 = strides[4] if self.model_config.layer_setting is None: self.model_config.layer_setting = [ diff --git a/modules/pytorch_jacinto_ai/vision/models/mobilenetv2.py b/modules/pytorch_jacinto_ai/vision/models/mobilenetv2.py index e351aff..a47d4b1 100644 --- a/modules/pytorch_jacinto_ai/vision/models/mobilenetv2.py +++ b/modules/pytorch_jacinto_ai/vision/models/mobilenetv2.py @@ -13,7 +13,7 @@ def get_config(): model_config.num_classes = 1000 model_config.width_mult = 1. model_config.expand_ratio = 6 - model_config.strides = (2,2,2,2,2) + model_config.strides = None #(2,2,2,2,2) model_config.activation = xnn.layers.DefaultAct2d model_config.use_blocks = False model_config.kernel_size = 3 @@ -28,8 +28,6 @@ model_urls = { } - - class InvertedResidual(torch.nn.Module): def __init__(self, inp, oup, stride, expand_ratio, activation, kernel_size, linear_dw): super(InvertedResidual, self).__init__() @@ -88,12 +86,13 @@ class MobileNetV2TVBase(torch.nn.Module): self.num_classes = self.model_config.num_classes # strides of various layers - s0 = model_config.strides[0] + strides = model_config.strides if (model_config.strides is not None) else (2,2,2,2,2) + s0 = strides[0] sf = 2 if model_config.fastdown else 1 # extra stride if fastdown - s1 = model_config.strides[1] - s2 = model_config.strides[2] - s3 = model_config.strides[3] - s4 = model_config.strides[4] + s1 = strides[1] + s2 = strides[2] + s3 = strides[3] + s4 = strides[4] if self.model_config.layer_setting is None: ex = self.model_config.expand_ratio diff --git a/modules/pytorch_jacinto_ai/vision/models/pixel2pixel/fpn_pixel2pixel.py b/modules/pytorch_jacinto_ai/vision/models/pixel2pixel/fpn_pixel2pixel.py index 31afca9..b84bfb9 100644 --- a/modules/pytorch_jacinto_ai/vision/models/pixel2pixel/fpn_pixel2pixel.py +++ b/modules/pytorch_jacinto_ai/vision/models/pixel2pixel/fpn_pixel2pixel.py @@ -7,8 +7,12 @@ from ..multi_input_net import MobileNetV2TVMI4, ResNet50MI4 __all__ = ['FPNPixel2PixelASPP', 'FPNPixel2PixelDecoder', - 'fpn_pixel2pixel_aspp_mobilenetv2_tv', 'fpn_pixel2pixel_aspp_resnet50', - 'fpn_pixel2pixel_aspp_mobilenetv2_tv_es64', 'fpn_pixel2pixel_aspp_resnet50_es64'] + 'fpn_pixel2pixel_aspp_mobilenetv2_tv', 'fpn_pixel2pixel_aspp_mobilenetv2_tv_es64', + # no aspp models + 'fpn_pixel2pixel_mobilenetv2_tv', 'fpn_pixel2pixel_mobilenetv2_tv_es64', + # resnet models + 'fpn_pixel2pixel_aspp_resnet50', 'fpn_pixel2pixel_aspp_resnet50_es64', + ] ########################################### @@ -19,20 +23,21 @@ class FPNPixel2PixelDecoder(torch.nn.Module): activation = self.model_config.activation self.output_type = model_config.output_type - current_channels = self.model_config.shortcut_channels[-1] self.decoder_channels = decoder_channels = round(self.model_config.decoder_chan*self.model_config.decoder_factor) - aspp_channels = round(self.model_config.aspp_chan*self.model_config.decoder_factor) self.rfblock = None if self.model_config.use_aspp: + current_channels = self.model_config.shortcut_channels[-1] + aspp_channels = round(self.model_config.aspp_chan * self.model_config.decoder_factor) self.rfblock = xnn.layers.DWASPPLiteBlock(current_channels, aspp_channels, decoder_channels, dilation=self.model_config.aspp_dil, avg_pool=False, activation=activation) elif self.model_config.use_extra_strides: # a low complexity pyramid + current_channels = self.model_config.shortcut_channels[-3] self.rfblock = torch.nn.Sequential(xnn.layers.ConvDWSepNormAct2d(current_channels, current_channels, kernel_size=3, stride=2, activation=(activation, activation)), xnn.layers.ConvDWSepNormAct2d(current_channels, decoder_channels, kernel_size=3, stride=2, activation=(activation, activation))) - self.model_config.shortcut_strides += [64, 128] else: + current_channels = self.model_config.shortcut_channels[-1] self.rfblock = xnn.layers.ConvNormAct2d(current_channels, decoder_channels, kernel_size=1, stride=1) current_channels = decoder_channels @@ -176,6 +181,7 @@ def fpn_pixel2pixel_aspp_mobilenetv2_tv(model_config, pretrained=None): return model, change_names_dict +# fast down sampling model (encoder stride 64 model) def fpn_pixel2pixel_aspp_mobilenetv2_tv_es64(model_config, pretrained=None): model_config = get_config_fpnp2p_mnv2().merge_from(model_config) model_config.fastdown = True @@ -187,6 +193,31 @@ def fpn_pixel2pixel_aspp_mobilenetv2_tv_es64(model_config, pretrained=None): return fpn_pixel2pixel_aspp_mobilenetv2_tv(model_config, pretrained=pretrained) +################## +# similar to the original fpn model with extra convolutions with strides (no aspp) +def fpn_pixel2pixel_mobilenetv2_tv(model_config, pretrained=None): + model_config = get_config_fpnp2p_mnv2().merge_from(model_config) + model_config.use_aspp = False + model_config.use_extra_strides = True + model_config.shortcut_strides = (4, 8, 16, 32, 64, 128) + model_config.shortcut_channels = (24, 32, 96, 320, 320, 256) + return fpn_pixel2pixel_aspp_mobilenetv2_tv(model_config, pretrained=pretrained) + + +# similar to the original fpn model with extra convolutions with strides (no aspp) - fast down sampling model (encoder stride 64 model) +def fpn_pixel2pixel_mobilenetv2_tv_es64(model_config, pretrained=None): + model_config = get_config_fpnp2p_mnv2().merge_from(model_config) + model_config.use_aspp = False + model_config.use_extra_strides = True + model_config.fastdown = True + model_config.strides = (2,2,2,2,2) + model_config.shortcut_strides = (8, 16, 32, 64, 128, 256) + model_config.shortcut_channels = (24, 32, 96, 320, 320, 256) + model_config.decoder_chan = 256 + model_config.aspp_chan = 256 + return fpn_pixel2pixel_aspp_mobilenetv2_tv(model_config, pretrained=pretrained) + + ########################################### def get_config_fpnp2p_resnet50(): # only the delta compared to the one defined for mobilenetv2 diff --git a/modules/pytorch_jacinto_ai/vision/models/resnet.py b/modules/pytorch_jacinto_ai/vision/models/resnet.py index e63b5b5..0f8e3ee 100644 --- a/modules/pytorch_jacinto_ai/vision/models/resnet.py +++ b/modules/pytorch_jacinto_ai/vision/models/resnet.py @@ -129,7 +129,6 @@ class ResNet(nn.Module): groups=1, width_per_group=64, replace_stride_with_dilation=None, norm_layer=None, input_channels=3, strides=None, width_mult=1.0, fastdown=False): super(ResNet, self).__init__() - strides = strides if strides is not None else (2,2,2,2,2) if norm_layer is None: norm_layer = nn.BatchNorm2d self._norm_layer = norm_layer @@ -145,6 +144,7 @@ class ResNet(nn.Module): raise ValueError(f"replace_stride_with_dilation should be None or a 3-element tuple, got {replace_stride_with_dilation}") # strides of various layers + strides = strides if (strides is not None) else (2,2,2,2,2) s0 = strides[0] s1 = strides[1] sf = 2 if fastdown else 1 # additional stride if fast down is true @@ -362,7 +362,7 @@ def get_config(): model_config = xnn.utils.ConfigNode() model_config.input_channels = 3 model_config.num_classes = 1000 - model_config.strides = (2,2,2,2,2) + model_config.strides = None #(2,2,2,2,2) model_config.fastdown = False return model_config diff --git a/modules/pytorch_jacinto_ai/xnn/layers/multi_task.py b/modules/pytorch_jacinto_ai/xnn/layers/multi_task.py index d3c1975..e45908d 100644 --- a/modules/pytorch_jacinto_ai/xnn/layers/multi_task.py +++ b/modules/pytorch_jacinto_ai/xnn/layers/multi_task.py @@ -148,6 +148,8 @@ class MultiTask(torch.nn.Module): self.loss_scales[task_idx] = torch.nn.functional.tanh(loss_scale) if clip_scale else loss_scale # self.loss_offset = self.uncertainty_factors + elif self.multi_task_type == "dtp": #dynamic task priority + pass # del dy_norms_smooth_mean, dy_norms_mean diff --git a/modules/pytorch_jacinto_ai/xnn/quantize/quant_graph_module.py b/modules/pytorch_jacinto_ai/xnn/quantize/quant_graph_module.py index 070fb8b..0ce62c8 100644 --- a/modules/pytorch_jacinto_ai/xnn/quantize/quant_graph_module.py +++ b/modules/pytorch_jacinto_ai/xnn/quantize/quant_graph_module.py @@ -253,7 +253,7 @@ class QuantGraphModule(HookedModule): else: quantize_out = True # - elif utils.is_conv(module): + elif utils.is_conv(module) or utils.is_deconv(module): if len(next_module)==1 and (utils.is_normalization(next_module[0]) or utils.is_activation(next_module[0])): quantize_out = False else: @@ -269,8 +269,8 @@ class QuantGraphModule(HookedModule): # quantize_out = True # # - qparams.quantize_w = isinstance(module, (torch.nn.Conv2d,torch.nn.Linear)) # all conv layers will be quantized - qparams.quantize_b = isinstance(module, (torch.nn.Conv2d,torch.nn.Linear)) # all conv layers will be quantized + qparams.quantize_w = isinstance(module, (torch.nn.Conv2d,torch.nn.ConvTranspose2d,torch.nn.Linear)) # all conv/deconv layers will be quantized + qparams.quantize_b = isinstance(module, (torch.nn.Conv2d,torch.nn.ConvTranspose2d,torch.nn.Linear)) # all conv/deconv layers will be quantized qparams.quantize_out = quantize_out # selectively quantize output qparams.quantize_in = qparams.is_input # only top modules's input need to be quantized qparams.align_in = isinstance(module, (layers.AddBlock, layers.CatBlock,torch.nn.AdaptiveAvgPool2d))# all tensors to be made same q at the input diff --git a/run_classification.sh b/run_classification.sh index abc67eb..430f686 100755 --- a/run_classification.sh +++ b/run_classification.sh @@ -6,11 +6,11 @@ ## ===================================================================================== ## Cifar100 Classification (Automatic Download) #### Training with MobileNetV2 -#python ./scripts/train_classification_main.py --dataset_name cifar100_classification --model_name mobilenetv2_tv_x1 --data_path ./data/datasets/cifar100_classification --img_resize 32 --img_crop 32 --rand_scale 0.5 1.0 +#python ./scripts/train_classification_main.py --dataset_name cifar100_classification --model_name mobilenetv2_tv_x1 --data_path ./data/datasets/cifar100_classification --img_resize 32 --img_crop 32 --rand_scale 0.5 1.0 --strides 1 1 1 2 2 ## Cifar10 Classification (Automatic Download) #### Training with MobileNetV2 -#python ./scripts/train_classification_main.py --dataset_name cifar10_classification --model_name mobilenetv2_tv_x1 --data_path ./data/datasets/cifar10_classification --img_resize 32 --img_crop 32 --rand_scale 0.5 1.0 +#python ./scripts/train_classification_main.py --dataset_name cifar10_classification --model_name mobilenetv2_tv_x1 --data_path ./data/datasets/cifar10_classification --img_resize 32 --img_crop 32 --rand_scale 0.5 1.0 --strides 1 1 1 2 2 ## ImageNet Classification (Automatic Download) #### Training with MobileNetV2 @@ -19,6 +19,8 @@ ## ImageNet Classification (Manual Download) #### Training with MobileNetV2 #python ./scripts/train_classification_main.py --dataset_name image_folder_classification --model_name mobilenetv2_tv_x1 --data_path ./data/datasets/image_folder_classification +#### Training with MobileNetV2 - Small Resolution +#python ./scripts/train_classification_main.py --dataset_name image_folder_classification --model_name mobilenetv2_tv_x1 --data_path ./data/datasets/image_folder_classification --img_resize 146 --img_crop 128 --batch_size 1024 --lr 0.2 --workers 16 #### Training with MobileNetV2 with 2x channels and expansion factor of 2 #python ./scripts/train_classification_main.py --dataset_name image_folder_classification --model_name mobilenetv2_tv_x2_t2 --data_path ./data/datasets/image_folder_classification --batch_size 256 diff --git a/run_depth.sh b/run_depth.sh index b0ceaab..f96d1af 100755 --- a/run_depth.sh +++ b/run_depth.sh @@ -5,7 +5,8 @@ ## Training ## ===================================================================================== #### KITTI Depth (Manual Download) - Training with MobileNetV2+DeeplabV3Lite -#python ./scripts/train_depth_main.py --dataset_name kitti_depth --model_name deeplabv3lite_mobilenetv2_tv --data_path ./data/datasets/kitti/kitti_depth/data --img_resize 384 768 --output_size 374 1242 +#python ./scripts/train_depth_main.py --dataset_name kitti_depth --model_name deeplabv3lite_mobilenetv2_tv --data_path ./data/datasets/kitti/kitti_depth/data --img_resize 384 768 --output_size 374 1242 \ +#--pretrained https://download.pytorch.org/models/mobilenet_v2-b0353104.pth #### KITTI Depth (Manual Download) - Training with ResNet50+FPN #python ./scripts/train_depth_main.py --dataset_name kitti_depth --model_name fpn_pixel2pixel_aspp_resnet50 --data_path ./data/datasets/kitti/kitti_depth/data --img_resize 384 768 --output_size 374 1242 \ diff --git a/run_quantization.sh b/run_quantization.sh index b7ec51f..a867cc9 100755 --- a/run_quantization.sh +++ b/run_quantization.sh @@ -1,35 +1,7 @@ # Quantization ## ===================================================================================== -## Post Training Calibration & Quantization - this will write out a quantization friendly model very quickly -## ===================================================================================== -# -#### Image Classification - Post Training Calibration & Quantization - ResNet50 -#python ./scripts/train_classification_main.py --phase calibration --dataset_name image_folder_classification --model_name resnet50_x1 --data_path ./data/datasets/image_folder_classification \ -#--pretrained https://download.pytorch.org/models/resnet50-19c8e357.pth \ -#--batch_size 64 --quantize True --epochs 1 --epoch_size 100 -# -# -#### Image Classification - Post Training Calibration & Quantization - MobileNetV2 -#python ./scripts/train_classification_main.py --phase calibration --dataset_name image_folder_classification --model_name mobilenetv2_tv_x1 --data_path ./data/datasets/image_folder_classification \ -#--pretrained https://download.pytorch.org/models/mobilenet_v2-b0353104.pth \ -#--batch_size 64 --quantize True --epochs 1 --epoch_size 100 -# -# -#### Image Classification - Post Training Calibration & Quantization for a TOUGH MobileNetV2 pretrained model -#python ./scripts/train_classification_main.py --phase calibration --dataset_name image_folder_classification --model_name mobilenetv2_shicai_x1 --data_path ./data/datasets/image_folder_classification \ -#--pretrained ./data/modelzoo/pretrained/pytorch/others/shicai/MobileNet-Caffe/mobilenetv2_shicai_rgb.tar \ -#--batch_size 64 --quantize True --epochs 1 --epoch_size 100 -# -# -#### Semantic Segmentation - Post Training Calibration & Quantization for MobileNetV2+DeeplabV3Lite -#python ./scripts/train_segmentation_main.py --phase calibration --dataset_name cityscapes_segmentation --model_name deeplabv3lite_mobilenetv2_tv --data_path ./data/datasets/cityscapes/data --img_resize 384 768 --output_size 1024 2048 --gpus 0 1 \ -#--pretrained ./data/modelzoo/semantic_segmentation/cityscapes/deeplabv3lite-mobilenetv2/cityscapes_segmentation_deeplabv3lite-mobilenetv2_2019-06-26-08-59-32.pth \ -#--batch_size 12 --quantize True --epochs 1 --epoch_size 100 -# -# -## ===================================================================================== -## Trained Quantization - If And Only if Post Training Calibration and Quantization (above) doesn't work +## Trained Quantization ## ===================================================================================== # #### Image Classification - Trained Quantization - MobileNetV2 @@ -49,6 +21,7 @@ #--pretrained ./data/modelzoo/semantic_segmentation/cityscapes/deeplabv3lite-mobilenetv2/cityscapes_segmentation_deeplabv3lite-mobilenetv2_2019-06-26-08-59-32.pth \ #--batch_size 12 --quantize True --epochs 150 --lr 5e-5 --evaluate_start False + ## ===================================================================================== ## Acuracy Evaluation with Post Training Quantization - cannot save quantized model - only accuracy evaluation ## ===================================================================================== @@ -74,6 +47,31 @@ #--batch_size 1 --quantize True - +## ===================================================================================== +## Post Training Calibration & Quantization - this is fast, but may not always yield best quantized accuracy (not recommended) +## ===================================================================================== +# +#### Image Classification - Post Training Calibration & Quantization - ResNet50 +#python ./scripts/train_classification_main.py --phase calibration --dataset_name image_folder_classification --model_name resnet50_x1 --data_path ./data/datasets/image_folder_classification \ +#--pretrained https://download.pytorch.org/models/resnet50-19c8e357.pth \ +#--batch_size 64 --quantize True --epochs 1 --epoch_size 100 +# +# +#### Image Classification - Post Training Calibration & Quantization - MobileNetV2 +#python ./scripts/train_classification_main.py --phase calibration --dataset_name image_folder_classification --model_name mobilenetv2_tv_x1 --data_path ./data/datasets/image_folder_classification \ +#--pretrained https://download.pytorch.org/models/mobilenet_v2-b0353104.pth \ +#--batch_size 64 --quantize True --epochs 1 --epoch_size 100 +# +# +#### Image Classification - Post Training Calibration & Quantization for a TOUGH MobileNetV2 pretrained model +#python ./scripts/train_classification_main.py --phase calibration --dataset_name image_folder_classification --model_name mobilenetv2_shicai_x1 --data_path ./data/datasets/image_folder_classification \ +#--pretrained ./data/modelzoo/pretrained/pytorch/others/shicai/MobileNet-Caffe/mobilenetv2_shicai_rgb.tar \ +#--batch_size 64 --quantize True --epochs 1 --epoch_size 100 +# +# +#### Semantic Segmentation - Post Training Calibration & Quantization for MobileNetV2+DeeplabV3Lite +#python ./scripts/train_segmentation_main.py --phase calibration --dataset_name cityscapes_segmentation --model_name deeplabv3lite_mobilenetv2_tv --data_path ./data/datasets/cityscapes/data --img_resize 384 768 --output_size 1024 2048 --gpus 0 1 \ +#--pretrained ./data/modelzoo/semantic_segmentation/cityscapes/deeplabv3lite-mobilenetv2/cityscapes_segmentation_deeplabv3lite-mobilenetv2_2019-06-26-08-59-32.pth \ +#--batch_size 12 --quantize True --epochs 1 --epoch_size 100 diff --git a/run_segmentation.sh b/run_segmentation.sh index 9e05986..fef2a5e 100755 --- a/run_segmentation.sh +++ b/run_segmentation.sh @@ -1,35 +1,46 @@ # Summary of commands - uncomment one and run this script -#### For the datasets in sections marked as "Automatic Download", dataset will be downloaded automatically downloaded before training begins. For "Manual Download", it is expected that it is manually downloaded and kept in the folder specified agaianst the --data_path option. +#### Manual Download: It is expected that the dataset is manually downloaded and kept in the folder specified agaianst the --data_path option. ## ===================================================================================== ## Training ## ===================================================================================== -#### Cityscapes Semantic Segmentation (Manual Download) - Training with MobileNetV2+DeeplabV3Lite -#python ./scripts/train_segmentation_main.py --dataset_name cityscapes_segmentation --model_name deeplabv3lite_mobilenetv2_tv --data_path ./data/datasets/cityscapes/data --img_resize 384 768 --output_size 1024 2048 --gpus 0 1 +#### Cityscapes Semantic Segmentation - Training with MobileNetV2+DeeplabV3Lite +#python ./scripts/train_segmentation_main.py --dataset_name cityscapes_segmentation --model_name deeplabv3lite_mobilenetv2_tv --data_path ./data/datasets/cityscapes/data --img_resize 384 768 --output_size 1024 2048 --gpus 0 1 \ +#--pretrained https://download.pytorch.org/models/mobilenet_v2-b0353104.pth -#### Cityscapes Semantic Segmentation (Manual Download) - Training with MobileNetV2+DeeplabV3Lite, Higher Resolution -#python ./scripts/train_segmentation_main.py --dataset_name cityscapes_segmentation --model_name deeplabv3lite_mobilenetv2_tv --data_path ./data/datasets/cityscapes/data --img_resize 768 1536 --rand_crop 512 1024 --output_size 1024 2048 --gpus 0 1 +#### Cityscapes Semantic Segmentation - original fpn - no aspp model, stride 64 model - Low Complexity Model +#python ./scripts/train_segmentation_main.py --dataset_name cityscapes_segmentation --model_name fpn_pixel2pixel_aspp_mobilenetv2_tv_es64 --data_path ./data/datasets/cityscapes/data --img_resize 384 768 --output_size 1024 2048 --gpus 0 1 \ +#--pretrained https://download.pytorch.org/models/mobilenet_v2-b0353104.pth -#### Cityscapes Semantic Segmentation (Manual Download) - Training with ResNet50+FPN +#### Cityscapes Semantic Segmentation - Training with MobileNetV2+DeeplabV3Lite, Higher Resolution +#python ./scripts/train_segmentation_main.py --dataset_name cityscapes_segmentation --model_name deeplabv3lite_mobilenetv2_tv --data_path ./data/datasets/cityscapes/data --img_resize 768 1536 --rand_crop 512 1024 --output_size 1024 2048 --gpus 0 1 \ +#--pretrained https://download.pytorch.org/models/mobilenet_v2-b0353104.pth + +#### Cityscapes Semantic Segmentation - original fpn - no aspp model, stride 64 model, Higher Resolution - Low Complexity Model +#python ./scripts/train_segmentation_main.py --dataset_name cityscapes_segmentation --model_name fpn_pixel2pixel_aspp_mobilenetv2_tv_es64 --data_path ./data/datasets/cityscapes/data --img_resize 768 1536 --rand_crop 512 1024 --output_size 1024 2048 --gpus 0 1 \ +#--pretrained https://download.pytorch.org/models/mobilenet_v2-b0353104.pth + +#### Cityscapes Semantic Segmentation - Training with ResNet50+FPN #python ./scripts/train_segmentation_main.py --dataset_name cityscapes_segmentation --model_name fpn_pixel2pixel_aspp_resnet50 --data_path ./data/datasets/cityscapes/data --img_resize 768 1536 --rand_crop 512 1024 --output_size 1024 2048 --gpus 0 1 \ #--pretrained https://download.pytorch.org/models/resnet50-19c8e357.pth -#### VOC Segmentation (Manual Download) - Training with MobileNetV2+DeeplabV3Lite -#python ./scripts/train_segmentation_main.py --dataset_name voc_segmentation --model_name deeplabv3lite_mobilenetv2_tv --data_path ./data/datasets/voc --img_resize 512 512 --output_size 512 512 --gpus 0 1 +#### VOC Segmentation - Training with MobileNetV2+DeeplabV3Lite +#python ./scripts/train_segmentation_main.py --dataset_name voc_segmentation --model_name deeplabv3lite_mobilenetv2_tv --data_path ./data/datasets/voc --img_resize 512 512 --output_size 512 512 --gpus 0 1 \ +#--pretrained https://download.pytorch.org/models/mobilenet_v2-b0353104.pth ## ===================================================================================== ## Validation ## ===================================================================================== -#### Validation - Cityscapes Semantic Segmentation (Manual Download) - Validation with MobileNetV2+DeeplabV3Lite - populate the pretrained filename in ?? +#### Validation - Cityscapes Semantic Segmentation - Validation with MobileNetV2+DeeplabV3Lite - populate the pretrained filename in ?? #python ./scripts/train_segmentation_main.py --evaluate True --dataset_name cityscapes_segmentation --model_name deeplabv3lite_mobilenetv2_tv --data_path ./data/datasets/cityscapes/data --img_resize 384 768 --output_size 1024 2048 --gpus 0 1 \ #--pretrained ?? -#### Inference - Cityscapes Semantic Segmentation (Manual Download) - Inference with MobileNetV2+DeeplabV3Lite - populate the pretrained filename in ?? +#### Inference - Cityscapes Semantic Segmentation - Inference with MobileNetV2+DeeplabV3Lite - populate the pretrained filename in ?? #python ./scripts/infer_segmentation_main.py --dataset_name cityscapes_segmentation_measure --model_name deeplabv3lite_mobilenetv2_tv --data_path ./data/datasets/cityscapes/data --img_resize 384 768 --output_size 1024 2048 --gpus 0 1 \ -#--pretrained ./data/modelzoo/semantic_segmentation/cityscapes/deeplabv3lite-mobilenetv2/cityscapes_segmentation_deeplabv3lite-mobilenetv2_2019-06-26-08-59-32.pth +#--pretrained ??? -#### Validation - VOC Segmentation (Manual Download) - Validation with MobileNetV2+DeeplabV3Lite - populate the pretrained filename in ?? +#### Validation - VOC Segmentation - Validation with MobileNetV2+DeeplabV3Lite - populate the pretrained filename in ?? #python ./scripts/train_segmentation_main.py --evaluate True --dataset_name voc_segmentation_measure --model_name deeplabv3lite_mobilenetv2_tv --data_path ./data/datasets/voc --img_resize 512 512 --output_size 512 512 --gpus 0 1 #--evaluate True --pretrained ??? diff --git a/scripts/train_classification_main.py b/scripts/train_classification_main.py index 8149824..2887be1 100755 --- a/scripts/train_classification_main.py +++ b/scripts/train_classification_main.py @@ -17,6 +17,7 @@ parser = argparse.ArgumentParser() parser.add_argument('--save_path', type=str, default=None, help='checkpoint save folder') parser.add_argument('--gpus', type=int, nargs='*', default=None, help='Base learning rate') parser.add_argument('--batch_size', type=int, default=None, help='Batch size') +parser.add_argument('--strides', type=int, nargs='*', default=None, help='strides in the model') parser.add_argument('--lr', type=float, default=None, help='Base learning rate') parser.add_argument('--lr_calib', type=float, default=None, help='Learning rate for calibration') parser.add_argument('--model_name', type=str, default=None, help='model name') @@ -33,6 +34,7 @@ parser.add_argument('--pretrained', type=str, default=None, help='pretrained mod parser.add_argument('--resume', type=str, default=None, help='resume an unfinished training from this model') parser.add_argument('--phase', type=str, default=None, help='training/calibration/validation') parser.add_argument('--evaluate_start', type=str2bool, default=None, help='Whether to run validation before the training') +parser.add_argument('--workers', type=int, default=None, help='number of workers for dataloading') # parser.add_argument('--quantize', type=str2bool, default=None, help='Quantize the model') parser.add_argument('--histogram_range', type=str2bool, default=None, help='run only evaluation and no training') @@ -89,7 +91,10 @@ args.data_path = f'./data/datasets/{args.dataset_name}' args.model_config.input_channels = 3 args.model_config.output_type = 'classification' args.model_config.output_channels = None +args.model_config.strides = None #(2,2,2,2,2) +args.img_resize = 256 +args.img_crop = 224 args.solver = 'sgd' #'sgd' #'adam' args.epochs = 150 #150 #120 args.start_epoch = 0 #0 @@ -121,8 +126,13 @@ args.date = datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S") ################################ for key in vars(cmds): - if key == 'gpus': - pass # already taken care above, since this has to be done before importing pytorch + if key == 'gpus': # already taken care above, since this has to be done before importing pytorch + pass + elif key == 'strides': # strides is in model_config + value = getattr(cmds, key) + if value != 'None' and value is not None: + setattr(args.model_config, key, value) + # elif hasattr(args, key): value = getattr(cmds, key) if value != 'None' and value is not None: @@ -133,10 +143,8 @@ for key in vars(cmds): ################################ # these dependent on the dataset chosen -args.img_resize = (args.img_resize if args.img_resize else 256) -args.img_crop = (args.img_crop if args.img_crop else 224) args.model_config.num_classes = (100 if 'cifar100' in args.dataset_name else (10 if 'cifar10' in args.dataset_name else 1000)) -args.model_config.strides = (1,1,1,2,2) if args.img_crop<56 else ((1,1,2,2,2) if args.img_crop<112 else ((1,2,2,2,2) if args.img_crop<224 else (2,2,2,2,2))) + ################################ @@ -145,19 +153,19 @@ train_classification.main(args) ################################ # In addition run a quantized calibration, starting from the trained model -if args.phase == 'training' and (not args.quantize): +if 'training' in args.phase and (not args.quantize): save_path = train_classification.get_save_path(args) args.pretrained = os.path.join(save_path, 'model_best.pth.tar') - args.phase = 'calibration' + args.phase = 'training_quantize' args.quantize = True - args.epochs = 1 - args.epoch_size = 100 + args.epochs = 25 + args.lr = 5e-5 train_classification.main(args) # ################################ # In addition run a separate validation, starting from the calibrated model - to estimate the quantized accuracy accurately -if args.phase == 'training' or args.phase == 'calibration': +if 'training' in args.phase or 'calibration' in args.phase: save_path = train_classification.get_save_path(args) args.pretrained = os.path.join(save_path, 'model_best.pth.tar') args.phase = 'validation' diff --git a/scripts/train_depth_main.py b/scripts/train_depth_main.py index 63ce617..286c084 100755 --- a/scripts/train_depth_main.py +++ b/scripts/train_depth_main.py @@ -144,19 +144,19 @@ train_pixel2pixel.main(args) ################################ # In addition run a quantized calibration, starting from the trained model -if args.phase == 'training' and (not args.quantize): +if 'training'in args.phase and (not args.quantize): save_path = train_pixel2pixel.get_save_path(args) args.pretrained = os.path.join(save_path, 'model_best.pth.tar') - args.phase = 'calibration' + args.phase = 'training_quantize' args.quantize = True - args.epochs = 1 - args.epoch_size = 100 + args.epochs = 25 + args.lr = 5e-5 train_pixel2pixel.main(args) # ################################ # In addition run a separate validation, starting from the calibrated model - to estimate the quantized accuracy accurately -if args.phase == 'training' or args.phase == 'calibration': +if 'training' in args.phase or 'calibration' in args.phase: save_path = train_pixel2pixel.get_save_path(args) args.pretrained = os.path.join(save_path, 'model_best.pth.tar') args.phase = 'validation' diff --git a/scripts/train_segmentation_main.py b/scripts/train_segmentation_main.py index c1e54aa..e2297a3 100755 --- a/scripts/train_segmentation_main.py +++ b/scripts/train_segmentation_main.py @@ -144,19 +144,19 @@ train_pixel2pixel.main(args) ################################ # In addition run a quantized calibration, starting from the trained model -if args.phase == 'training' and (not args.quantize): +if 'training' in args.phase and (not args.quantize): save_path = train_pixel2pixel.get_save_path(args) args.pretrained = os.path.join(save_path, 'model_best.pth.tar') - args.phase = 'calibration' + args.phase = 'training_quantize' args.quantize = True - args.epochs = 1 - args.epoch_size = 100 + args.epochs = 25 + args.lr = 5e-5 train_pixel2pixel.main(args) # ################################ # In addition run a separate validation, starting from the calibrated model - to estimate the quantized accuracy accurately -if args.phase == 'training' or args.phase == 'calibration': +if 'training' in args.phase or 'calibration' in args.phase: save_path = train_pixel2pixel.get_save_path(args) args.pretrained = os.path.join(save_path, 'model_best.pth.tar') args.phase = 'validation' -- 2.39.2