]> Gitweb @ Texas Instruments - Open Source Git Repositories - git.TI.com/gitweb - jacinto-ai/pytorch-jacinto-ai-devkit.git/blob - scripts/train_pixel2pixel_multitask_main.py
renamed pytorch_jacinto_ai.vision to pytorch_jacinto_ai.xvision
[jacinto-ai/pytorch-jacinto-ai-devkit.git] / scripts / train_pixel2pixel_multitask_main.py
1 import sys
2 import os
3 import cv2
4 import argparse
5 import datetime
6 import numpy as np
8 ################################
9 from pytorch_jacinto_ai.xnn.utils import str2bool
10 parser = argparse.ArgumentParser()
11 parser.add_argument('--save_path', type=str, default=None, help='checkpoint save folder')
12 parser.add_argument('--gpus', type=int, nargs='*', default=None, help='Base learning rate')
13 parser.add_argument('--batch_size', type=int, default=None, help='Batch size')
14 parser.add_argument('--lr', type=float, default=None, help='Base learning rate')
15 parser.add_argument('--lr_clips', type=float, default=None, help='Learning rate for clips in PAct2')
16 parser.add_argument('--lr_calib', type=float, default=None, help='Learning rate for calibration')
17 parser.add_argument('--model_name', type=str, default=None, help='model name')
18 parser.add_argument('--dataset_name', type=str, default=None, help='dataset name')
19 parser.add_argument('--image_folders', type=str, default=('leftImg8bit_flow_confidence_768x384', 'leftImg8bit'), nargs='*', help='image_folders')
20 parser.add_argument('--data_path', type=str, default=None, help='data path')
21 parser.add_argument('--epoch_size', type=float, default=None, help='epoch size. using a fraction will reduce the data used for one epoch')
22 parser.add_argument('--epochs', type=int, default=None, help='number of epochs')
23 parser.add_argument('--warmup_epochs', type=int, default=None, help='number of epochs for the learning rate to increase and reach base value')
24 parser.add_argument('--milestones', type=int, nargs='*', default=None, help='change lr at these milestones')
25 parser.add_argument('--img_resize', type=int, nargs=2, default=None, help='img_resize size. for training this will be modified according to rand_scale')
26 parser.add_argument('--rand_scale', type=float, nargs=2, default=None, help='random scale factors for training')
27 parser.add_argument('--rand_crop', type=int, nargs=2, default=None, help='random crop for training')
28 parser.add_argument('--output_size', type=int, nargs=2, default=None, help='output size of the evaluation - prediction/groundtruth. this is not used while training as it blows up memory requirement')
29 parser.add_argument('--pretrained', type=str, default=None, help='pretrained model')
30 parser.add_argument('--resume', type=str, default=None, help='resume an unfinished training from this model')
31 parser.add_argument('--phase', type=str, default=None, help='training/calibration/validation')
32 parser.add_argument('--evaluate_start', type=str2bool, default=None, help='Whether to run validation before the training')
33 #
34 parser.add_argument('--quantize', type=str2bool, default=None, help='Quantize the model')
35 parser.add_argument('--histogram_range', type=str2bool, default=None, help='run only evaluation and no training')
36 parser.add_argument('--per_channel_q', type=str2bool, default=None, help='run only evaluation and no training')
37 parser.add_argument('--bias_calibration', type=str2bool, default=None, help='run only evaluation and no training')
38 parser.add_argument('--bitwidth_weights', type=int, default=None, help='bitwidth for weight quantization')
39 parser.add_argument('--bitwidth_activations', type=int, default=None, help='bitwidth for activation quantization')
40 parser.add_argument('--multi_task_type', type=str, default=None, help='multi_task_type')
41 #
42 parser.add_argument('--freeze_bn', type=str2bool, default=None, help='freeze the bn stats or not')
43 cmds = parser.parse_args()
45 ################################
46 # taken care first, since this has to be done before importing pytorch
47 if 'gpus' in vars(cmds):
48     value = getattr(cmds, 'gpus')
49     if (value is not None) and ("CUDA_VISIBLE_DEVICES" not in os.environ):
50         os.environ["CUDA_VISIBLE_DEVICES"] = ','.join([str(v) for v in value])
51     #
52 #
54 ################################
55 # to avoid hangs in data loader with multi threads
56 # this was observed after using cv2 image processing functions
57 # https://github.com/pytorch/pytorch/issues/1355
58 cv2.setNumThreads(0)
61 ################################
62 #import of torch should be after CUDA_VISIBLE_DEVICES for it to take effect
63 import torch
64 from pytorch_jacinto_ai.engine import train_pixel2pixel
66 #Create the parse and set default arguments
67 args = train_pixel2pixel.get_config()
70 #Modify arguments
72 args.model_name = 'deeplabv3lite_mobilenetv2_ericsun' #'deeplabv3lite_mobilenetv2_mi4'
74 args.dataset_name =  'cityscapes_depth_semantic_motion_multi_input' #'cityscapes_flow_depth_segmentation_image_pair' #cityscapes_segmentation #'cityscapes_segmentation_dual'
76 #args.save_path = './data/checkpoints'
78 args.data_path = './data/datasets/cityscapes/data'  #./data/pascal-voc/VOCdevkit/VOC2012
80 args.pretrained = './data/modelzoo/pytorch/semantic_segmentation/cityscapes/jacinto_ai/deeplabv3lite_mobilenetv2_ericsun_resize768x384_best.pth.tar'
81                             #'./data/checkpoints/cityscapes_depth_semantic_five_class_motion_image_dof_conf/0p9_release/2019-06-27-13-50-10_cityscapes_depth_semantic_five_class_motion_image_dof_conf_deeplabv3lite_mobilenetv2_ericsun_mi4_resize768x384_traincrop768x384/model_best.pth.tar'
82                             #'./data/modelzoo/pretrained/pytorch/cityscapes_segmentation/v0.9-2018-12-07-19:38:26_cityscapes_segmentation_deeplabv3lite_mobilenetv2_relu_resize768x384_traincrop768x384_(68.9%)/model_best.pth.tar'
83                             #'./data/checkpoints/store/saved/cityscapes_segmentation/v0.7-2018-10-25-13:07:38_cityscapes_segmentation_deeplabv3lite_mobilenetv2_relu_resize1024x512_traincrop512x512_(71.5%)/model_best.pth.tar'
84                             #'./data/modelzoo/pretrained/pytorch/imagenet_classification/pytorch_jacinto_ai.xvision/resnet50-19c8e357.pth'
85                             #'./data/modelzoo/pretrained/pytorch/imagenet_classification/ericsun99/MobileNet-V2-Pytorch/mobilenetv2_Top1_71.806_Top2_90.410.pth.tar'
87 # args.resume = '/user/a0132471/Files/pytorch/pytorch-jacinto-models/checkpoints/cityscapes_depth_semantic_five_class_motion_image_dof_conf/2019-08-13-13-49-29_cityscapes_depth_semantic_five_class_motion_image_dof_conf_deeplabv3lite_mobilenetv2_ericsun_mi4_resize768x384_traincrop768x384/checkpoint.pth.tar'
89 args.model_config.input_channels = (3,3)
90 args.model_config.output_type = ['depth', 'segmentation', 'segmentation']
91 args.model_config.output_channels = None             #this can be found out from the dataset
92 args.losses = [['supervised_loss', 'scale_loss'], ['segmentation_loss'], ['segmentation_loss']]
93 args.metrics = [['supervised_relative_error_rng3to80'], ['segmentation_metrics'], ['segmentation_metrics']]
94 args.is_flow = [[True,False],[False,False,False]]
97 # TODO: this is not clean - fix it later
98 args.model_config.multi_task_type = cmds.multi_task_type #None, 'adaptive', 'learned', 'uncertainty'
99 args.model_config.multi_task = True if cmds.multi_task_type is not None else False
100 args.model_config.multi_task_factors = (1.0, 1.0, 1.0) #None #[1.291, 6.769, 6.852] (0.169, 1.279, 1.553)
101 args.multi_decoder = True
105 args.solver = 'adam'                    #'sgd' #'adam'
106 args.epochs = 250                       #200
107 args.epoch_size = 0                     #0 #0.5
108 args.scheduler = 'step'                 #'poly' #'step'
109 args.multistep_gamma = 0.5              #only for step scheduler
110 args.milestones = (100, 200)            #only for step scheduler
111 args.polystep_power = 0.9               #only for poly scheduler
112 args.iter_size = 1                      #2
113 args.evaluate_start = False
116 args.lr = 1e-4                          #1e-4 #0.01 #7e-3 #1e-4 #2e-4
117 args.batch_size = 16                    #12 #16 #32 #64
118 args.weight_decay = 1e-4                #4e-5 #1e-5
120 args.img_resize = (384, 768)            #(512, 1024) #(1024, 2048)
121 args.rand_scale = (1.0, 2.0)            #(1.0,2.0)
122 args.rand_crop = (384, 768)             #(512,512) #(512,1024)
123 args.output_size = (1024, 2048)          #for unflow loss only, output_size must match img_size
125 args.transform_rotation = 5             #0  #rotation degrees
127 args.model_config.aspp_dil = (2, 4, 6)
128 #args.model_config.use_aspp = True
130 # TODO: this is not clean - fix it later
131 args.dataset_config.image_folders = cmds.image_folders
133 args.save_onnx = False
134 #args.phase = 'validation'
135 #args.quantize = True
137 args.model_config.normalize_gradients = True
139 args.pivot_task_idx = 2
141 ################################
142 for key in vars(cmds):
143     if (key == 'gpus')| (key == 'image_folders') | (key == 'multi_task_type'):
144         pass # already taken care above, since this has to be done before importing pytorch
145     elif hasattr(args, key):
146         value = getattr(cmds, key)
147         if value != 'None' and value is not None:
148             setattr(args, key, value)
149     else:
150         assert False, f'invalid argument {key}'
153 ################################
154 # Run the given phase
155 train_pixel2pixel.main(args)
157 ################################
158 # In addition run a quantization aware training, starting from the trained model
159 if 'training' in args.phase and (not args.quantize):
160     save_path = train_pixel2pixel.get_save_path(args)
161     args.pretrained = os.path.join(save_path, 'model_best.pth.tar') if (args.epochs>0) else args.pretrained
162     args.phase = 'training_quantize'
163     args.quantize = True
164     args.lr = 1e-5
165     args.epochs = 50
166     train_pixel2pixel.main(args)
169 ################################
170 # In addition run a separate validation
171 if 'training' in args.phase or 'calibration' in args.phase:
172     save_path = train_pixel2pixel.get_save_path(args)
173     args.pretrained = os.path.join(save_path, 'model_best.pth.tar')
174     if 'training' in args.phase:
175         # DataParallel isn't enabled for QuantCalibrateModule and QuantTestModule.
176         # If the previous phase was training, then it is likely that the batch_size was high and won't fit in a single gpu - reduce it.
177         num_gpus = len(str(os.environ["CUDA_VISIBLE_DEVICES"]).split(',')) if ("CUDA_VISIBLE_DEVICES" in os.environ) else None
178         args.batch_size = max(args.batch_size//num_gpus, 1) if (num_gpus is not None) else args.batch_size
179     #
180     args.phase = 'validation'
181     args.quantize = True
182     train_pixel2pixel.main(args)