c5e7ee05bcc37db8b9ab9bbf782f4640fd18ebcd
1 #!/usr/bin/env python
3 import sys
4 import os
5 import cv2
6 import argparse
7 import datetime
9 ################################
10 #sys.path.insert(0, os.path.abspath('./modules'))
13 ################################
14 from pytorch_jacinto_ai.xnn.utils import str2bool
15 parser = argparse.ArgumentParser()
16 parser.add_argument('--save_path', type=str, default=None, help='checkpoint save folder')
17 parser.add_argument('--gpus', type=int, nargs='*', default=None, help='Base learning rate')
18 parser.add_argument('--batch_size', type=int, default=None, help='Batch size')
19 parser.add_argument('--lr', type=float, default=None, help='Base learning rate')
20 parser.add_argument('--lr_clips', type=float, default=None, help='Learning rate for clips in PAct2')
21 parser.add_argument('--lr_calib', type=float, default=None, help='Learning rate for calibration')
22 parser.add_argument('--model_name', type=str, default=None, help='model name')
23 parser.add_argument('--dataset_name', type=str, default=None, help='dataset name')
24 parser.add_argument('--data_path', type=str, default=None, help='data path')
25 parser.add_argument('--epoch_size', type=float, default=None, help='epoch size. using a fraction will reduce the data used for one epoch')
26 parser.add_argument('--epochs', type=int, default=None, help='number of epochs')
27 parser.add_argument('--warmup_epochs', type=int, default=None, help='number of epochs for the learning rate to increase and reach base value')
28 parser.add_argument('--milestones', type=int, nargs='*', default=None, help='change lr at these milestones')
29 parser.add_argument('--img_resize', type=int, nargs=2, default=None, help='img_resize size. for training this will be modified according to rand_scale')
30 parser.add_argument('--rand_scale', type=float, nargs=2, default=None, help='random scale factors for training')
31 parser.add_argument('--rand_crop', type=int, nargs=2, default=None, help='random crop for training')
32 parser.add_argument('--output_size', type=int, nargs=2, default=None, help='output size of the evaluation - prediction/groundtruth. this is not used while training as it blows up memory requirement')
33 parser.add_argument('--pretrained', type=str, default=None, help='pretrained model')
34 parser.add_argument('--resume', type=str, default=None, help='resume an unfinished training from this model')
35 parser.add_argument('--phase', type=str, default=None, help='training/calibration/validation')
36 parser.add_argument('--evaluate_start', type=str2bool, default=None, help='Whether to run validation before the training')
37 #
38 parser.add_argument('--quantize', type=str2bool, default=None, help='Quantize the model')
39 parser.add_argument('--histogram_range', type=str2bool, default=None, help='run only evaluation and no training')
40 parser.add_argument('--per_channel_q', type=str2bool, default=None, help='run only evaluation and no training')
41 parser.add_argument('--bias_calibration', type=str2bool, default=None, help='run only evaluation and no training')
42 parser.add_argument('--bitwidth_weights', type=int, default=None, help='bitwidth for weight quantization')
43 parser.add_argument('--bitwidth_activations', type=int, default=None, help='bitwidth for activation quantization')
44 #
45 parser.add_argument('--freeze_bn', type=str2bool, default=None, help='freeze the bn stats or not')
46 cmds = parser.parse_args()
48 ################################
49 # taken care first, since this has to be done before importing pytorch
50 if 'gpus' in vars(cmds):
51 value = getattr(cmds, 'gpus')
52 if value is not None:
53 os.environ["CUDA_VISIBLE_DEVICES"] = ','.join([str(v) for v in value])
54 #
56 # to avoid hangs in data loader with multi threads
57 # this was observed after using cv2 image processing functions
58 # https://github.com/pytorch/pytorch/issues/1355
59 cv2.setNumThreads(0)
61 ################################
62 from pytorch_jacinto_ai.engine import train_pixel2pixel
64 # Create the parser and set default arguments
65 args = train_pixel2pixel.get_config()
67 ################################
68 #Modify arguments
69 args.model_name = 'deeplabv3lite_mobilenetv2_tv' #'deeplabv3lite_mobilenetv2_tv' #'fpn_pixel2pixel_aspp_mobilenetv2_tv' #'fpn_pixel2pixel_aspp_resnet50'
71 args.dataset_name = 'kitti_depth' #'kitti_depth' #'kitti_depth' #'kitti_depth2'
73 #args.save_path = './data/checkpoints'
75 args.data_path = './data/datasets/kitti/kitti_depth/data'
76 args.split_files = (args.data_path+'/train.txt', args.data_path+'/val.txt')
78 args.pretrained = './data/modelzoo/semantic_segmentation/cityscapes/deeplabv3lite-mobilenetv2/cityscapes_segmentation_deeplabv3lite-mobilenetv2_2019-06-26-08-59-32.pth'
79 # 'https://download.pytorch.org/models/mobilenet_v2-b0353104.pth'
80 # './data/modelzoo/pretrained/pytorch/imagenet_classification/ericsun99/MobileNet-V2-Pytorch/mobilenetv2_Top1_71.806_Top2_90.410.pth.tar'
81 # 'https://download.pytorch.org/models/resnet50-19c8e357.pth'
83 args.model_config.input_channels = (3,) # [3,3]
84 args.model_config.output_type = ['depth']
85 args.model_config.output_channels = [1]
86 args.losses = [['supervised_loss', 'scale_loss', 'supervised_error_var']] #[['supervised_loss', 'scale_loss']]
87 args.loss_mult_factors = [[0.125, 0.125, 4.0]]
89 args.metrics = [['supervised_relative_error_x100']] #[['supervised_root_mean_squared_error']]
91 args.solver = 'adam' #'sgd' #'adam'
92 args.epochs = 250 #300
93 args.epoch_size = 0.125 #0 #0.25
94 args.scheduler = 'step' #'poly' #'step' #'cosine'
95 args.multistep_gamma = 0.25 #only for step scheduler
96 args.milestones = (100, 200) #only for step scheduler
97 args.polystep_power = 0.9 #only for poly scheduler
98 args.iter_size = 1 #2
100 args.lr = 4e-4 #4e-4 #1e-4
101 args.batch_size = 32 #8 #12 #16 #32 #64
102 args.weight_decay = 1e-4 #4e-5 #1e-5
104 args.img_resize = (384, 768) #(256,512) #(512,512) #(512,1024) #(1024, 2048)
105 args.output_size = (374, 1242) #(512, 1024) #(720, 1280) #target output size for evaluation
107 args.transform_rotation = 5 #0 #rotation degrees
109 args.workers = 12 # more workers may speedup
111 #args.phase = 'validation'
112 #args.quantize = True
113 #args.print_model = True
114 #args.generate_onnx = False
115 #args.run_soon = False
116 #args.evaluate_start = False
118 #args.quantize = True
119 #args.per_channel_q = True
120 #args.phase = 'validation'
121 #args.parallel_model=False
123 #args.viz_colormap = 'plasma' # colormap for tensorboard: 'rainbow', 'plasma', 'magma', 'bone'
125 # defining date from outside can help to write multiple pahses into the same folder
126 args.date = datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
129 ################################
130 for key in vars(cmds):
131 if key == 'gpus':
132 pass # already taken care above, since this has to be done before importing pytorch
133 elif hasattr(args, key):
134 value = getattr(cmds, key)
135 if value != 'None' and value is not None:
136 setattr(args, key, value)
137 else:
138 assert False, f'invalid argument {key}'
139 #
141 ################################
142 # Run the given phase
143 train_pixel2pixel.main(args)
145 ################################
146 # In addition run a quantized calibration, starting from the trained model
147 if 'training'in args.phase and (not args.quantize):
148 save_path = train_pixel2pixel.get_save_path(args)
149 args.pretrained = os.path.join(save_path, 'model_best.pth.tar')
150 args.phase = 'training_quantize'
151 args.quantize = True
152 args.lr = 5e-5
153 args.epochs = min(args.epochs,25)
154 # quantized training will use only one GPU in the engine - so reduce the batch_size
155 num_gpus = len(str(os.environ["CUDA_VISIBLE_DEVICES"]).split(','))
156 args.batch_size = args.batch_size//num_gpus
157 train_pixel2pixel.main(args)
158 #
160 ################################
161 # In addition run a separate validation, starting from the calibrated model - to estimate the quantized accuracy accurately
162 if 'training' in args.phase or 'calibration' in args.phase:
163 save_path = train_pixel2pixel.get_save_path(args)
164 args.pretrained = os.path.join(save_path, 'model_best.pth.tar')
165 args.phase = 'validation'
166 args.quantize = True
167 train_pixel2pixel.main(args)
168 #