[jacinto-ai/pytorch-jacinto-ai-devkit.git] / modules / pytorch_jacinto_ai / engine / train_pixel2pixel.py
1 import os
2 import shutil
3 import time
4 import math
5 import copy
7 import torch
8 import torch.nn.parallel
9 import torch.backends.cudnn as cudnn
10 import torch.optim
11 import torch.utils.data
12 import torch.onnx
13 import onnx
15 import datetime
16 from tensorboardX import SummaryWriter
17 import numpy as np
18 import random
19 import cv2
20 from colorama import Fore
21 import progiter
22 from packaging import version
23 import warnings
25 from .. import xnn
26 from .. import vision
27 from . infer_pixel2pixel import compute_accuracy
30 ##################################################
31 warnings.filterwarnings('ignore', category=torch.jit.TracerWarning)
33 ##################################################
34 def get_config():
35 args = xnn.utils.ConfigNode()
37 args.dataset_config = xnn.utils.ConfigNode()
38 args.dataset_config.split_name = 'val'
39 args.dataset_config.max_depth_bfr_scaling = 80
40 args.dataset_config.depth_scale = 1
41 args.dataset_config.train_depth_log = 1
42 args.use_semseg_for_depth = False
44 # model config
45 args.model_config = xnn.utils.ConfigNode()
46 args.model_config.output_type = ['segmentation'] # the network is used to predict flow or depth or sceneflow
47 args.model_config.output_channels = None # number of output channels
48 args.model_config.input_channels = None # number of input channels
49 args.model_config.output_range = None # max range of output
50 args.model_config.num_decoders = None # number of decoders to use. [options: 0, 1, None]
51 args.model_config.freeze_encoder = False # do not update encoder weights
52 args.model_config.freeze_decoder = False # do not update decoder weights
53 args.model_config.multi_task_type = 'learned' # find out loss multiplier by learning, choices=[None, 'learned', 'uncertainty', 'grad_norm', 'dwa_grad_norm']
54 args.model_config.target_input_ratio = 1 # Keep target size same as input size
55 args.model_config.input_nv12 = False
56 args.model = None # the model itself can be given from ouside
57 args.model_name = 'deeplabv2lite_mobilenetv2' # model architecture, overwritten if pretrained is specified
58 args.dataset_name = 'cityscapes_segmentation' # dataset type
59 args.transforms = None # the transforms itself can be given from outside
61 args.data_path = './data/cityscapes' # 'path to dataset'
62 args.save_path = None # checkpoints save path
63 args.phase = 'training' # training/calibration/validation
64 args.date = None # date to add to save path. if this is None, current date will be added.
66 args.logger = None # logger stream to output into
67 args.show_gpu_usage = False # Shows gpu usage at the begining of each training epoch
69 args.split_file = None # train_val split file
70 args.split_files = None # split list files. eg: train.txt val.txt
71 args.split_value = None # test_val split proportion (between 0 (only test) and 1 (only train))
73 args.solver = 'adam' # solver algorithms, choices=['adam','sgd']
74 args.scheduler = 'step' # scheduler algorithms, choices=['step','poly', 'cosine']
75 args.workers = 8 # number of data loading workers
77 args.epochs = 250 # number of total epochs to run
78 args.start_epoch = 0 # manual epoch number (useful on restarts)
80 args.epoch_size = 0 # manual epoch size (will match dataset size if not specified)
81 args.epoch_size_val = 0 # manual epoch size (will match dataset size if not specified)
82 args.batch_size = 12 # mini_batch size
83 args.total_batch_size = None # accumulated batch size. total_batch_size = batch_size*iter_size
84 args.iter_size = 1 # iteration size. total_batch_size = batch_size*iter_size
86 args.lr = 1e-4 # initial learning rate
87 args.lr_clips = None # use args.lr itself if it is None
88 args.lr_calib = 0.05 # lr for bias calibration
89 args.warmup_epochs = 5 # number of epochs to warmup
91 args.momentum = 0.9 # momentum for sgd, alpha parameter for adam
92 args.beta = 0.999 # beta parameter for adam
93 args.weight_decay = 1e-4 # weight decay
94 args.bias_decay = None # bias decay
96 args.sparse = True # avoid invalid/ignored target pixels from loss computation, use NEAREST for interpolation
98 args.tensorboard_num_imgs = 5 # number of imgs to display in tensorboard
99 args.pretrained = None # path to pre_trained model
100 args.resume = None # path to latest checkpoint (default: none)
101 args.no_date = False # don\'t append date timestamp to folder
102 args.print_freq = 100 # print frequency (default: 100)
104 args.milestones = (100, 200) # epochs at which learning rate is divided by 2
106 args.losses = ['segmentation_loss'] # loss functions to mchoices=['step','poly', 'cosine'],loss multiplication factor')
107 args.metrics = ['segmentation_metrics'] # metric/measurement/error functions for train/validation
108 args.multi_task_factors = None # loss mult factors
109 args.class_weights = None # class weights
111 args.loss_mult_factors = None # fixed loss mult factors - per loss - not: this is different from multi_task_factors (which is per task)
113 args.multistep_gamma = 0.5 # steps for step scheduler
114 args.polystep_power = 1.0 # power for polynomial scheduler
116 args.rand_seed = 1 # random seed
117 args.img_border_crop = None # image border crop rectangle. can be relative or absolute
118 args.target_mask = None # mask rectangle. can be relative or absolute. last value is the mask value
120 args.rand_resize = None # random image size to be resized to during training
121 args.rand_output_size = None # output size to be resized to during training
122 args.rand_scale = (1.0, 2.0) # random scale range for training
123 args.rand_crop = None # image size to be cropped to
125 args.img_resize = None # image size to be resized to during evaluation
126 args.output_size = None # target output size to be resized to
128 args.count_flops = True # count flops and report
130 args.shuffle = True # shuffle or not
131 args.shuffle_val = False # shuffle val dataset or not
133 args.transform_rotation = 0. # apply rotation augumentation. value is rotation in degrees. 0 indicates no rotation
134 args.is_flow = None # whether entries in images and targets lists are optical flow or not
136 args.upsample_mode = 'bilinear' # upsample mode to use, choices=['nearest','bilinear']
138 args.image_prenorm = True # whether normalization is done before all other the transforms
139 args.image_mean = (128.0,) # image mean for input image normalization
140 args.image_scale = (1.0 / (0.25 * 256),) # image scaling/mult for input iamge normalization
142 args.max_depth = 80 # maximum depth to be used for visualization
144 args.pivot_task_idx = 0 # task id to select best model
146 args.parallel_model = True # Usedata parallel for model
147 args.parallel_criterion = True # Usedata parallel for loss and metric
149 args.evaluate_start = True # evaluate right at the begining of training or not
150 args.save_onnx = True # apply quantized inference or not
151 args.print_model = False # print the model to text
152 args.run_soon = True # To start training after generating configs/models
154 args.quantize = False # apply quantized inference or not
155 #args.model_surgery = None # replace activations with PAct2 activation module. Helpful in quantized training.
156 args.bitwidth_weights = 8 # bitwidth for weights
157 args.bitwidth_activations = 8 # bitwidth for activations
158 args.histogram_range = True # histogram range for calibration
159 args.bias_calibration = True # apply bias correction during quantized inference calibration
160 args.per_channel_q = False # apply separate quantizion factor for each channel in depthwise or not
162 args.save_mod_files = False # saves modified files after last commit. Also stores commit id.
163 args.make_score_zero_mean = False # make score zero mean while learning
164 args.no_q_for_dws_layer_idx = 0 # no_q_for_dws_layer_idx
166 args.viz_colormap = 'rainbow' # colormap for tensorboard: 'rainbow', 'plasma', 'magma', 'bone'
168 args.freeze_bn = False # freeze the statistics of bn
169 args.tensorboard_enable = True # en/disable of TB writing
170 args.print_train_class_iou = False
171 args.print_val_class_iou = False
172 args.freeze_layers = None
174 args.opset_version = 9 # onnx opset_version
176 return args
179 # ################################################
180 # to avoid hangs in data loader with multi threads
181 # this was observed after using cv2 image processing functions
182 # https://github.com/pytorch/pytorch/issues/1355
183 cv2.setNumThreads(0)
185 # ################################################
186 def main(args):
187 # ensure pytorch version is 1.2 or higher
188 assert version.parse(torch.__version__) >= version.parse('1.1'), \
189 'torch version must be 1.1 or higher, due to the change in scheduler.step() and optimiser.step() call order'
191 assert (not hasattr(args, 'evaluate')), 'args.evaluate is deprecated. use args.phase=training or calibration or validation'
192 assert is_valid_phase(args.phase), f'invalid phase {args.phase}'
193 assert not hasattr(args, 'model_surgery'), 'the argument model_surgery is deprecated, it is not needed now - remove it'
195 if (args.phase == 'validation' and args.bias_calibration):
196 args.bias_calibration = False
197 warnings.warn('switching off bias calibration in validation')
198 #
200 #################################################
201 args.rand_resize = args.img_resize if args.rand_resize is None else args.rand_resize
202 args.rand_crop = args.img_resize if args.rand_crop is None else args.rand_crop
203 args.output_size = args.img_resize if args.output_size is None else args.output_size
204 # resume has higher priority
205 args.pretrained = None if (args.resume is not None) else args.pretrained
207 if args.save_path is None:
208 save_path = get_save_path(args)
209 else:
210 save_path = args.save_path
211 #
212 if not os.path.exists(save_path):
213 os.makedirs(save_path)
215 if args.save_mod_files:
216 #store all the files after the last commit.
217 mod_files_path = save_path+'/mod_files'
218 os.makedirs(mod_files_path)
220 cmd = "git ls-files --modified | xargs -i cp {} {}".format("{}", mod_files_path)
221 print("cmd:", cmd)
222 os.system(cmd)
224 #stoe last commit id.
225 cmd = "git log -n 1 >> {}".format(mod_files_path + '/commit_id.txt')
226 print("cmd:", cmd)
227 os.system(cmd)
229 #################################################
230 if args.logger is None:
231 log_file = os.path.splitext(os.path.basename(__file__))[0] + '.log'
232 args.logger = xnn.utils.TeeLogger(filename=os.path.join(save_path,log_file))
234 #################################################
235 # global settings. rand seeds for repeatability
236 random.seed(args.rand_seed)
237 np.random.seed(args.rand_seed)
238 torch.manual_seed(args.rand_seed)
239 torch.cuda.manual_seed(args.rand_seed)
241 ################################
242 # args check and config
243 if args.iter_size != 1 and args.total_batch_size is not None:
244 warnings.warn("only one of --iter_size or --total_batch_size must be set")
245 #
246 if args.total_batch_size is not None:
247 args.iter_size = args.total_batch_size//args.batch_size
248 else:
249 args.total_batch_size = args.batch_size*args.iter_size
251 #################################################
252 # set some global flags and initializations
253 # keep it in args for now - although they don't belong here strictly
254 # using pin_memory is seen to cause issues, especially when when lot of memory is used.
255 args.use_pinned_memory = False
256 args.n_iter = 0
257 args.best_metric = -1
258 cudnn.benchmark = True
259 # torch.autograd.set_detect_anomaly(True)
261 ################################
262 # reset character color, in case it is different
263 print('{}'.format(Fore.RESET))
264 # print everything for log
265 print('=> args: {}'.format(args))
266 print('\n'.join("%s: %s" % item for item in sorted(vars(args).items())))
268 print('=> will save everything to {}'.format(save_path))
270 #################################################
271 train_writer = SummaryWriter(os.path.join(save_path,'train')) if args.tensorboard_enable else None
272 val_writer = SummaryWriter(os.path.join(save_path,'val')) if args.tensorboard_enable else None
273 transforms = get_transforms(args) if args.transforms is None else args.transforms
274 assert isinstance(transforms, (list,tuple)) and len(transforms) == 2, 'incorrect transforms were given'
276 print("=> fetching images in '{}'".format(args.data_path))
277 split_arg = args.split_file if args.split_file else (args.split_files if args.split_files else args.split_value)
278 train_dataset, val_dataset = vision.datasets.pixel2pixel.__dict__[args.dataset_name](args.dataset_config, args.data_path, split=split_arg, transforms=transforms)
280 #################################################
281 train_sampler = None
282 val_sampler = None
283 print('=> {} samples found, {} train samples and {} test samples '.format(len(train_dataset)+len(val_dataset),
284 len(train_dataset), len(val_dataset)))
285 train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=args.batch_size,
286 num_workers=args.workers, pin_memory=args.use_pinned_memory, sampler=train_sampler, shuffle=args.shuffle)
288 val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=args.batch_size,
289 num_workers=args.workers, pin_memory=args.use_pinned_memory, sampler=val_sampler, shuffle=args.shuffle_val)
291 #################################################
292 if (args.model_config.input_channels is None):
293 args.model_config.input_channels = (3,)
294 print("=> input channels is not given - setting to {}".format(args.model_config.input_channels))
296 if (args.model_config.output_channels is None):
297 if ('num_classes' in dir(train_dataset)):
298 args.model_config.output_channels = train_dataset.num_classes()
299 else:
300 args.model_config.output_channels = (2 if args.model_config.output_type == 'flow' else args.model_config.output_channels)
301 xnn.utils.print_yellow("=> output channels is not given - setting to {} - not sure to work".format(args.model_config.output_channels))
302 #
303 if not isinstance(args.model_config.output_channels,(list,tuple)):
304 args.model_config.output_channels = [args.model_config.output_channels]
306 if (args.class_weights is None) and ('class_weights' in dir(train_dataset)):
307 args.class_weights = train_dataset.class_weights()
308 if not isinstance(args.class_weights, (list,tuple)):
309 args.class_weights = [args.class_weights]
310 #
311 print("=> class weights available for dataset: {}".format(args.class_weights))
313 #################################################
314 pretrained_data = None
315 model_surgery_quantize = False
316 pretrained_data = None
317 if args.pretrained and args.pretrained != "None":
318 pretrained_data = []
319 pretrained_files = args.pretrained if isinstance(args.pretrained,(list,tuple)) else [args.pretrained]
320 for p in pretrained_files:
321 if isinstance(p, dict):
322 p_data = p
323 else:
324 if p.startswith('http://') or p.startswith('https://'):
325 p_file = vision.datasets.utils.download_url(p, './data/downloads')
326 else:
327 p_file = p
328 #
329 print(f'=> loading pretrained weights file: {p}')
330 p_data = torch.load(p_file)
331 #
332 pretrained_data.append(p_data)
333 model_surgery_quantize = p_data['quantize'] if 'quantize' in p_data else False
334 #
336 #################################################
337 # create model
338 if args.model is not None:
339 model, change_names_dict = args.model if isinstance(args.model, (list, tuple)) else (args.model, None)
340 assert isinstance(model, torch.nn.Module), 'args.model, if provided must be a valid torch.nn.Module'
341 else:
342 xnn.utils.print_yellow("=> creating model '{}'".format(args.model_name))
343 model = vision.models.pixel2pixel.__dict__[args.model_name](args.model_config)
344 # check if we got the model as well as parameters to change the names in pretrained
345 model, change_names_dict = model if isinstance(model, (list,tuple)) else (model,None)
346 #
348 if args.quantize:
349 # dummy input is used by quantized models to analyze graph
350 is_cuda = next(model.parameters()).is_cuda
351 dummy_input = create_rand_inputs(args, is_cuda=is_cuda)
352 #
353 if 'training' in args.phase:
354 model = xnn.quantize.QuantTrainModule(model, per_channel_q=args.per_channel_q,
355 histogram_range=args.histogram_range, bitwidth_weights=args.bitwidth_weights, bitwidth_activations=args.bitwidth_activations, dummy_input=dummy_input)
356 elif 'calibration' in args.phase:
357 model = xnn.quantize.QuantCalibrateModule(model, per_channel_q=args.per_channel_q,
358 bitwidth_weights=args.bitwidth_weights, bitwidth_activations=args.bitwidth_activations,
359 histogram_range=args.histogram_range, bias_calibration=args.bias_calibration,
360 dummy_input=dummy_input, lr_calib=args.lr_calib)
361 elif 'validation' in args.phase:
362 # Note: bias_calibration is not emabled
363 model = xnn.quantize.QuantTestModule(model, per_channel_q=args.per_channel_q,
364 bitwidth_weights=args.bitwidth_weights, bitwidth_activations=args.bitwidth_activations,
365 histogram_range=args.histogram_range, dummy_input=dummy_input,
366 model_surgery_quantize=model_surgery_quantize)
367 else:
368 assert False, f'invalid phase {args.phase}'
369 #
371 # load pretrained model
372 if pretrained_data is not None:
373 for (p_data,p_file) in zip(pretrained_data, pretrained_files):
374 print("=> using pretrained weights from: {}".format(p_file))
375 xnn.utils.load_weights(get_model_orig(model), pretrained=p_data, change_names_dict=change_names_dict)
376 #
378 #################################################
379 if args.count_flops:
380 count_flops(args, model)
382 #################################################
383 if args.save_onnx and (any(args.phase in p for p in ('training','calibration')) or (args.run_soon == False)):
384 write_onnx_model(args, get_model_orig(model), save_path, save_traced_model=False)
385 #
387 #################################################
388 if args.print_model:
389 print(model)
390 print('\n')
391 else:
392 args.logger.debug(str(model))
393 args.logger.debug('\n')
395 #################################################
396 if (not args.run_soon):
397 print("Training not needed for now")
398 close(args)
399 exit()
401 #################################################
402 # multi gpu mode does not work for calibration/training for quantization
403 # so use it only when args.quantize is False
404 if args.parallel_model and ((not args.quantize)):
405 model = torch.nn.DataParallel(model)
407 #################################################
408 model = model.cuda()
410 #################################################
411 # for help in debug/print
412 for name, module in model.named_modules():
413 module.name = name
415 #################################################
416 args.loss_modules = copy.deepcopy(args.losses)
417 for task_dx, task_losses in enumerate(args.losses):
418 for loss_idx, loss_fn in enumerate(task_losses):
419 kw_args = {}
420 loss_args = vision.losses.__dict__[loss_fn].args()
421 for arg in loss_args:
422 if arg == 'weight' and (args.class_weights is not None):
423 kw_args.update({arg:args.class_weights[task_dx]})
424 elif arg == 'num_classes':
425 kw_args.update({arg:args.model_config.output_channels[task_dx]})
426 elif arg == 'sparse':
427 kw_args.update({arg:args.sparse})
428 #
429 #
430 loss_fn_raw = vision.losses.__dict__[loss_fn](**kw_args)
431 if args.parallel_criterion:
432 loss_fn = torch.nn.DataParallel(loss_fn_raw).cuda() if args.parallel_criterion else loss_fn_raw.cuda()
433 loss_fn.info = loss_fn_raw.info
434 loss_fn.clear = loss_fn_raw.clear
435 else:
436 loss_fn = loss_fn_raw.cuda()
437 #
438 args.loss_modules[task_dx][loss_idx] = loss_fn
439 #
441 args.metric_modules = copy.deepcopy(args.metrics)
442 for task_dx, task_metrics in enumerate(args.metrics):
443 for midx, metric_fn in enumerate(task_metrics):
444 kw_args = {}
445 loss_args = vision.losses.__dict__[metric_fn].args()
446 for arg in loss_args:
447 if arg == 'weight':
448 kw_args.update({arg:args.class_weights[task_dx]})
449 elif arg == 'num_classes':
450 kw_args.update({arg:args.model_config.output_channels[task_dx]})
451 elif arg == 'sparse':
452 kw_args.update({arg:args.sparse})
454 metric_fn_raw = vision.losses.__dict__[metric_fn](**kw_args)
455 if args.parallel_criterion:
456 metric_fn = torch.nn.DataParallel(metric_fn_raw).cuda()
457 metric_fn.info = metric_fn_raw.info
458 metric_fn.clear = metric_fn_raw.clear
459 else:
460 metric_fn = metric_fn_raw.cuda()
461 #
462 args.metric_modules[task_dx][midx] = metric_fn
463 #
465 #################################################
466 if args.phase=='validation':
467 with torch.no_grad():
468 validate(args, val_dataset, val_loader, model, 0, val_writer)
469 #
470 close(args)
471 return
473 #################################################
474 assert(args.solver in ['adam', 'sgd'])
475 print('=> setting {} solver'.format(args.solver))
476 if args.lr_clips is not None:
477 learning_rate_clips = args.lr_clips if 'training' in args.phase else 0.0
478 clips_decay = args.bias_decay if (args.bias_decay is not None and args.bias_decay != 0.0) else args.weight_decay
479 clips_params = [p for n,p in model.named_parameters() if 'clips' in n]
480 other_params = [p for n,p in model.named_parameters() if 'clips' not in n]
481 param_groups = [{'params': clips_params, 'weight_decay': clips_decay, 'lr': learning_rate_clips},
482 {'params': other_params, 'weight_decay': args.weight_decay}]
483 else:
484 param_groups = [{'params': filter(lambda p: p.requires_grad, model.parameters()), 'weight_decay': args.weight_decay}]
485 #
487 learning_rate = args.lr if ('training'in args.phase) else 0.0
488 if args.solver == 'adam':
489 optimizer = torch.optim.Adam(param_groups, learning_rate, betas=(args.momentum, args.beta))
490 elif args.solver == 'sgd':
491 optimizer = torch.optim.SGD(param_groups, learning_rate, momentum=args.momentum)
492 else:
493 raise ValueError('Unknown optimizer type{}'.format(args.solver))
494 #
496 #################################################
497 epoch_size = get_epoch_size(args, train_loader, args.epoch_size)
498 max_iter = args.epochs * epoch_size
499 scheduler = xnn.optim.lr_scheduler.SchedulerWrapper(args.scheduler, optimizer, args.epochs, args.start_epoch, \
500 args.warmup_epochs, max_iter=max_iter, polystep_power=args.polystep_power,
501 milestones=args.milestones, multistep_gamma=args.multistep_gamma)
503 # optionally resume from a checkpoint
504 if args.resume:
505 if not os.path.isfile(args.resume):
506 print("=> no checkpoint found at '{}'".format(args.resume))
507 else:
508 print("=> loading checkpoint '{}'".format(args.resume))
510 checkpoint = torch.load(args.resume)
511 model = xnn.utils.load_weights(model, checkpoint)
513 if args.start_epoch == 0:
514 args.start_epoch = checkpoint['epoch']
516 if 'best_metric' in list(checkpoint.keys()):
517 args.best_metric = checkpoint['best_metric']
519 if 'optimizer' in list(checkpoint.keys()):
520 optimizer.load_state_dict(checkpoint['optimizer'])
522 if 'scheduler' in list(checkpoint.keys()):
523 scheduler.load_state_dict(checkpoint['scheduler'])
525 if 'multi_task_factors' in list(checkpoint.keys()):
526 args.multi_task_factors = checkpoint['multi_task_factors']
528 print("=> loaded checkpoint '{}' (epoch {})".format(args.resume, checkpoint['epoch']))
530 #################################################
531 if args.evaluate_start:
532 with torch.no_grad():
533 validate(args, val_dataset, val_loader, model, args.start_epoch, val_writer)
535 for epoch in range(args.start_epoch, args.epochs):
536 if train_sampler:
537 train_sampler.set_epoch(epoch)
538 if val_sampler:
539 val_sampler.set_epoch(epoch)
541 # train for one epoch
542 train(args, train_dataset, train_loader, model, optimizer, epoch, train_writer, scheduler)
544 # evaluate on validation set
545 with torch.no_grad():
546 val_metric, metric_name = validate(args, val_dataset, val_loader, model, epoch, val_writer)
548 if args.best_metric < 0:
549 args.best_metric = val_metric
551 if "iou" in metric_name.lower() or "acc" in metric_name.lower():
552 is_best = val_metric >= args.best_metric
553 args.best_metric = max(val_metric, args.best_metric)
554 elif "error" in metric_name.lower() or "diff" in metric_name.lower() or "norm" in metric_name.lower() \
555 or "loss" in metric_name.lower() or "outlier" in metric_name.lower():
556 is_best = val_metric <= args.best_metric
557 args.best_metric = min(val_metric, args.best_metric)
558 else:
559 raise ValueError("Metric is not known. Best model could not be saved.")
560 #
562 checkpoint_dict = { 'epoch': epoch + 1, 'model_name': args.model_name,
563 'state_dict': get_model_orig(model).state_dict(),
564 'optimizer': optimizer.state_dict(),
565 'scheduler': scheduler.state_dict(),
566 'best_metric': args.best_metric,
567 'multi_task_factors': args.multi_task_factors,
568 'quantize' : args.quantize}
570 save_checkpoint(args, save_path, get_model_orig(model), checkpoint_dict, is_best)
572 if args.tensorboard_enable:
573 train_writer.file_writer.flush()
574 val_writer.file_writer.flush()
576 # adjust the learning rate using lr scheduler
577 if 'training' in args.phase:
578 scheduler.step()
579 #
580 #
582 # close and cleanup
583 close(args)
584 #
586 ###################################################################
587 def is_valid_phase(phase):
588 phases = ('training', 'calibration', 'validation')
589 return any(p in phase for p in phases)
592 ###################################################################
593 def train(args, train_dataset, train_loader, model, optimizer, epoch, train_writer, scheduler):
594 batch_time = xnn.utils.AverageMeter()
595 data_time = xnn.utils.AverageMeter()
596 # if the loss/ metric is already an average, no need to further average
597 avg_loss = [xnn.utils.AverageMeter(print_avg=(not task_loss[0].info()['is_avg'])) for task_loss in args.loss_modules]
598 avg_loss_orig = [xnn.utils.AverageMeter(print_avg=(not task_loss[0].info()['is_avg'])) for task_loss in args.loss_modules]
599 avg_metric = [xnn.utils.AverageMeter(print_avg=(not task_metric[0].info()['is_avg'])) for task_metric in args.metric_modules]
600 epoch_size = get_epoch_size(args, train_loader, args.epoch_size)
602 ##########################
603 # switch to train mode
604 model.train()
605 if args.freeze_bn:
606 xnn.utils.freeze_bn(model)
607 #
609 #freeze layers
610 if args.freeze_layers is not None:
611 # 'freeze_layer_name' could be part of 'name', i.e. 'name' need not be exact same as 'freeze_layer_name'
612 # e.g. freeze_layer_name = 'encoder.0' then all layers like, 'encoder.0.0' 'encoder.0.1' will be frozen
613 for freeze_layer_name in args.freeze_layers:
614 for name, module in model.named_modules():
615 if freeze_layer_name in name:
616 xnn.utils.print_once("Freezing the module : {}".format(name))
617 module.eval()
618 for param in module.parameters():
619 param.requires_grad = False
621 ##########################
622 for task_dx, task_losses in enumerate(args.loss_modules):
623 for loss_idx, loss_fn in enumerate(task_losses):
624 loss_fn.clear()
625 for task_dx, task_metrics in enumerate(args.metric_modules):
626 for midx, metric_fn in enumerate(task_metrics):
627 metric_fn.clear()
629 progress_bar = progiter.ProgIter(np.arange(epoch_size), chunksize=1)
630 metric_name = "Metric"
631 metric_ctx = [None] * len(args.metric_modules)
632 end_time = time.time()
633 writer_idx = 0
634 last_update_iter = -1
636 # change color to yellow for calibration
637 progressbar_color = (Fore.YELLOW if (('calibration' in args.phase) or ('training' in args.phase and args.quantize)) else Fore.WHITE)
638 print('{}'.format(progressbar_color), end='')
640 ##########################
641 for iter, (inputs, targets) in enumerate(train_loader):
642 # measure data loading time
643 data_time.update(time.time() - end_time)
645 lr = scheduler.get_lr()[0]
647 input_list = [[jj.cuda() for jj in img] if isinstance(img,(list,tuple)) else img.cuda() for img in inputs]
648 target_list = [tgt.cuda(non_blocking=True) for tgt in targets]
649 target_sizes = [tgt.shape for tgt in target_list]
650 batch_size_cur = target_sizes[0][0]
652 ##########################
653 # compute output
654 task_outputs = model(input_list)
656 task_outputs = task_outputs if isinstance(task_outputs,(list,tuple)) else [task_outputs]
657 # upsample output to target resolution
658 if args.upsample_mode is not None:
659 task_outputs = upsample_tensors(task_outputs, target_sizes, args.upsample_mode)
661 if args.model_config.multi_task_type is not None and len(args.model_config.output_channels) > 1:
662 args.multi_task_factors, args.multi_task_offsets = xnn.layers.get_loss_scales(model)
663 else:
664 args.multi_task_factors = None
665 args.multi_task_offsets = None
667 loss_total, loss_list, loss_names, loss_types, loss_list_orig = \
668 compute_task_objectives(args, args.loss_modules, input_list, task_outputs, target_list,
669 task_mults=args.multi_task_factors, task_offsets=args.multi_task_offsets,
670 loss_mult_factors=args.loss_mult_factors)
672 if args.print_train_class_iou:
673 metric_total, metric_list, metric_names, metric_types, _, confusion_matrix = \
674 compute_task_objectives(args, args.metric_modules, input_list, task_outputs, target_list,
675 get_confusion_matrix=args.print_train_class_iou)
676 else:
677 metric_total, metric_list, metric_names, metric_types, _ = \
678 compute_task_objectives(args, args.metric_modules, input_list, task_outputs, target_list,
679 get_confusion_matrix=args.print_train_class_iou)
681 if args.model_config.multi_task_type is not None and len(args.model_config.output_channels) > 1:
682 xnn.layers.set_losses(model, loss_list_orig)
684 if 'training' in args.phase:
685 # zero gradients so that we can accumulate gradients
686 if (iter % args.iter_size) == 0:
687 optimizer.zero_grad()
689 # accumulate gradients
690 loss_total.backward()
691 # optimization step
692 if ((iter+1) % args.iter_size) == 0:
693 optimizer.step()
694 #
696 # record loss.
697 for task_idx, task_losses in enumerate(args.loss_modules):
698 avg_loss[task_idx].update(float(loss_list[task_idx].cpu()), batch_size_cur)
699 avg_loss_orig[task_idx].update(float(loss_list_orig[task_idx].cpu()), batch_size_cur)
700 if args.tensorboard_enable:
701 train_writer.add_scalar('Training/Task{}_{}_Loss_Iter'.format(task_idx,loss_names[task_idx]), float(loss_list[task_idx]), args.n_iter)
702 if args.model_config.multi_task_type is not None and len(args.model_config.output_channels) > 1:
703 train_writer.add_scalar('Training/multi_task_Factor_Task{}_{}'.format(task_idx,loss_names[task_idx]), float(args.multi_task_factors[task_idx]), args.n_iter)
705 # record error/accuracy.
706 for task_idx, task_metrics in enumerate(args.metric_modules):
707 avg_metric[task_idx].update(float(metric_list[task_idx].cpu()), batch_size_cur)
709 ##########################
710 if args.tensorboard_enable:
711 write_output(args, 'Training_', epoch_size, iter, epoch, train_dataset, train_writer, input_list, task_outputs, target_list, metric_name, writer_idx)
713 if ((iter % args.print_freq) == 0) or (iter == (epoch_size-1)):
714 output_string = ''
715 for task_idx, task_metrics in enumerate(args.metric_modules):
716 output_string += '[{}={}]'.format(metric_names[task_idx], str(avg_metric[task_idx]))
718 epoch_str = '{}/{}'.format(epoch + 1, args.epochs)
719 progress_bar.set_description("{}=> {} ".format(progressbar_color, args.phase))
720 multi_task_factors_print = ['{:.3f}'.format(float(lmf)) for lmf in args.multi_task_factors] if args.multi_task_factors is not None else None
721 progress_bar.set_postfix(Epoch=epoch_str, LR=lr, DataTime=str(data_time), LossMult=multi_task_factors_print, Loss=avg_loss, Output=output_string)
722 progress_bar.update(iter-last_update_iter)
723 last_update_iter = iter
725 args.n_iter += 1
726 end_time = time.time()
727 writer_idx = (writer_idx + 1) % args.tensorboard_num_imgs
729 # add onnx graph to tensorboard
730 # commenting out due to issues in transitioning to pytorch 0.4
731 # (bilinear mode in upsampling causes hang or crash - may be due to align_borders change, nearest is fine)
732 #if epoch == 0 and iter == 0:
733 # input_zero = torch.zeros(input_var.shape)
734 # train_writer.add_graph(model, input_zero)
735 #This cache operation slows down tranining
736 #torch.cuda.empty_cache()
738 if iter >= epoch_size:
739 break
741 if args.print_train_class_iou:
742 print_class_iou(args=args, confusion_matrix=confusion_matrix, task_idx=task_idx)
744 progress_bar.close()
746 # to print a new line - do not provide end=''
747 print('{}'.format(Fore.RESET), end='')
749 if args.tensorboard_enable:
750 for task_idx, task_losses in enumerate(args.loss_modules):
751 train_writer.add_scalar('Training/Task{}_{}_Loss_Epoch'.format(task_idx,loss_names[task_idx]), float(avg_loss[task_idx]), epoch)
753 for task_idx, task_metrics in enumerate(args.metric_modules):
754 train_writer.add_scalar('Training/Task{}_{}_Metric_Epoch'.format(task_idx,metric_names[task_idx]), float(avg_metric[task_idx]), epoch)
756 output_name = metric_names[args.pivot_task_idx]
757 output_metric = float(avg_metric[args.pivot_task_idx])
759 ##########################
760 if args.quantize:
761 def debug_format(v):
762 return ('{:.3f}'.format(v) if v is not None else 'None')
763 #
764 clips_act = [m.get_clips_act()[1] for n,m in model.named_modules() if isinstance(m,xnn.layers.PAct2)]
765 if len(clips_act) > 0:
766 args.logger.debug('\nclips_act : ' + ' '.join(map(debug_format, clips_act)))
767 args.logger.debug('')
768 #
769 return output_metric, output_name
772 ###################################################################
773 def validate(args, val_dataset, val_loader, model, epoch, val_writer):
774 data_time = xnn.utils.AverageMeter()
775 # if the loss/ metric is already an average, no need to further average
776 avg_metric = [xnn.utils.AverageMeter(print_avg=(not task_metric[0].info()['is_avg'])) for task_metric in args.metric_modules]
777 epoch_size = get_epoch_size(args, val_loader, args.epoch_size_val)
779 ##########################
780 # switch to evaluate mode
781 model.eval()
783 ##########################
784 for task_dx, task_metrics in enumerate(args.metric_modules):
785 for midx, metric_fn in enumerate(task_metrics):
786 metric_fn.clear()
788 metric_name = "Metric"
789 end_time = time.time()
790 writer_idx = 0
791 last_update_iter = -1
792 metric_ctx = [None] * len(args.metric_modules)
793 progress_bar = progiter.ProgIter(np.arange(epoch_size), chunksize=1)
795 # change color to green
796 print('{}'.format(Fore.GREEN), end='')
798 ##########################
799 for iter, (inputs, targets) in enumerate(val_loader):
800 data_time.update(time.time() - end_time)
801 input_list = [[jj.cuda() for jj in img] if isinstance(img,(list,tuple)) else img.cuda() for img in inputs]
802 target_list = [j.cuda(non_blocking=True) for j in targets]
803 target_sizes = [tgt.shape for tgt in target_list]
804 batch_size_cur = target_sizes[0][0]
806 # compute output
807 task_outputs = model(input_list)
810 task_outputs = task_outputs if isinstance(task_outputs, (list, tuple)) else [task_outputs]
811 if args.upsample_mode is not None:
812 task_outputs = upsample_tensors(task_outputs, target_sizes, args.upsample_mode)
814 if args.print_val_class_iou:
815 metric_total, metric_list, metric_names, metric_types, _, confusion_matrix = \
816 compute_task_objectives(args, args.metric_modules, input_list, task_outputs, target_list,
817 get_confusion_matrix = args.print_val_class_iou)
818 else:
819 metric_total, metric_list, metric_names, metric_types, _ = \
820 compute_task_objectives(args, args.metric_modules, input_list, task_outputs, target_list,
821 get_confusion_matrix = args.print_val_class_iou)
823 # record error/accuracy.
824 for task_idx, task_metrics in enumerate(args.metric_modules):
825 avg_metric[task_idx].update(float(metric_list[task_idx].cpu()), batch_size_cur)
827 if args.tensorboard_enable:
828 write_output(args, 'Validation_', epoch_size, iter, epoch, val_dataset, val_writer, input_list, task_outputs, target_list, metric_names, writer_idx)
830 if ((iter % args.print_freq) == 0) or (iter == (epoch_size-1)):
831 output_string = ''
832 for task_idx, task_metrics in enumerate(args.metric_modules):
833 output_string += '[{}={}]'.format(metric_names[task_idx], str(avg_metric[task_idx]))
835 epoch_str = '{}/{}'.format(epoch + 1, args.epochs)
836 progress_bar.set_description("=> validation")
837 progress_bar.set_postfix(Epoch=epoch_str, DataTime=data_time, Output="{}".format(output_string))
838 progress_bar.update(iter-last_update_iter)
839 last_update_iter = iter
841 end_time = time.time()
842 writer_idx = (writer_idx + 1) % args.tensorboard_num_imgs
844 if iter >= epoch_size:
845 break
847 if args.print_val_class_iou:
848 print_class_iou(args = args, confusion_matrix = confusion_matrix, task_idx=task_idx)
850 #print_conf_matrix(conf_matrix=conf_matrix, en=False)
851 progress_bar.close()
853 # to print a new line - do not provide end=''
854 print('{}'.format(Fore.RESET), end='')
856 if args.tensorboard_enable:
857 for task_idx, task_metrics in enumerate(args.metric_modules):
858 val_writer.add_scalar('Validation/Task{}_{}_Metric_Epoch'.format(task_idx,metric_names[task_idx]), float(avg_metric[task_idx]), epoch)
860 output_name = metric_names[args.pivot_task_idx]
861 output_metric = float(avg_metric[args.pivot_task_idx])
862 return output_metric, output_name
865 ###################################################################
866 def close(args):
867 if args.logger is not None:
868 del args.logger
869 args.logger = None
870 #
871 args.best_metric = -1
872 #
875 def get_save_path(args, phase=None):
876 date = args.date if args.date else datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
877 save_path = os.path.join('./data/checkpoints', args.dataset_name, date + '_' + args.dataset_name + '_' + args.model_name)
878 save_path += '_resize{}x{}_traincrop{}x{}'.format(args.img_resize[1], args.img_resize[0], args.rand_crop[1], args.rand_crop[0])
879 phase = phase if (phase is not None) else args.phase
880 save_path = os.path.join(save_path, phase)
881 return save_path
884 def get_model_orig(model):
885 is_parallel_model = isinstance(model, (torch.nn.DataParallel, torch.nn.parallel.DistributedDataParallel))
886 model_orig = (model.module if is_parallel_model else model)
887 model_orig = (model_orig.module if isinstance(model_orig, (xnn.quantize.QuantBaseModule)) else model_orig)
888 return model_orig
891 def create_rand_inputs(args, is_cuda):
892 dummy_input = []
893 if not args.model_config.input_nv12:
894 for i_ch in args.model_config.input_channels:
895 x = torch.rand((1, i_ch, args.img_resize[0], args.img_resize[1]))
896 x = x.cuda() if is_cuda else x
897 dummy_input.append(x)
898 else: #nv12
899 for i_ch in args.model_config.input_channels:
900 y = torch.rand((1, 1, args.img_resize[0], args.img_resize[1]))
901 uv = torch.rand((1, 1, args.img_resize[0]//2, args.img_resize[1]))
902 y = y.cuda() if is_cuda else y
903 uv = uv.cuda() if is_cuda else uv
904 dummy_input.append([y,uv])
906 return dummy_input
908 def count_flops(args, model):
909 is_cuda = next(model.parameters()).is_cuda
910 dummy_input = create_rand_inputs(args, is_cuda)
911 #
912 model.eval()
913 flops = xnn.utils.forward_count_flops(model, dummy_input)
914 gflops = flops/1e9
915 print('=> Size = {}, GFLOPs = {}, GMACs = {}'.format(args.img_resize, gflops, gflops/2))
918 def derive_node_name(input_name):
919 #take last entry of input names for deciding node name
920 #print("input_name[-1]: ", input_name[-1])
921 node_name = input_name[-1].rsplit('.', 1)[0]
922 #print("formed node_name: ", node_name)
923 return node_name
926 #torch onnx export does not update names. Do it using onnx.save
927 def add_node_names(onnx_model_name):
928 onnx_model = onnx.load(onnx_model_name)
929 for i in range(len(onnx_model.graph.node)):
930 for j in range(len(onnx_model.graph.node[i].input)):
931 #print('-'*60)
932 #print("name: ", onnx_model.graph.node[i].name)
933 #print("input: ", onnx_model.graph.node[i].input)
934 #print("output: ", onnx_model.graph.node[i].output)
935 onnx_model.graph.node[i].input[j] = onnx_model.graph.node[i].input[j].split(':')[0]
936 onnx_model.graph.node[i].name = derive_node_name(onnx_model.graph.node[i].input)
937 #
938 #
939 #update model inplace
940 onnx.save(onnx_model, onnx_model_name)
943 def write_onnx_model(args, model, save_path, name='checkpoint.onnx', save_traced_model=False):
944 is_cuda = next(model.parameters()).is_cuda
945 input_list = create_rand_inputs(args, is_cuda=is_cuda)
946 #
947 model.eval()
948 torch.onnx.export(model, input_list, os.path.join(save_path, name), export_params=True, verbose=False,
949 do_constant_folding=True, opset_version=args.opset_version)
951 #torch onnx export does not update names. Do it using onnx.save
952 add_node_names(onnx_model_name = os.path.join(save_path, name))
954 if save_traced_model:
955 traced_model = torch.jit.trace(model, (input_list,))
956 traced_save_path = os.path.join(save_path, 'traced_model.pth')
957 torch.jit.save(traced_model, traced_save_path)
958 #
961 ###################################################################
962 def write_output(args, prefix, val_epoch_size, iter, epoch, dataset, output_writer, input_images, task_outputs, task_targets, metric_names, writer_idx):
963 write_freq = (args.tensorboard_num_imgs / float(val_epoch_size))
964 write_prob = np.random.random()
965 if (write_prob > write_freq):
966 return
967 if args.model_config.input_nv12:
968 batch_size = input_images[0][0].shape[0]
969 else:
970 batch_size = input_images[0].shape[0]
971 b_index = random.randint(0, batch_size - 1)
973 input_image = None
974 for img_idx, img in enumerate(input_images):
975 if args.model_config.input_nv12:
976 #convert NV12 to BGR for tensorboard
977 input_image = vision.transforms.image_transforms_xv12.nv12_to_bgr_image(Y = input_images[img_idx][0][b_index], UV = input_images[img_idx][1][b_index],
978 image_scale=args.image_scale, image_mean=args.image_mean)
979 else:
980 input_image = input_images[img_idx][b_index].cpu().numpy().transpose((1, 2, 0))
981 # convert back to original input range (0-255)
982 input_image = input_image / args.image_scale + args.image_mean
984 if args.is_flow and args.is_flow[0][img_idx]:
985 #input corresponding to flow is assumed to have been generated by adding 128
986 flow = input_image - 128
987 flow_hsv = xnn.utils.flow2hsv(flow.transpose(2, 0, 1), confidence=False).transpose(2, 0, 1)
988 #flow_hsv = (flow_hsv / 255.0).clip(0, 1) #TODO: check this
989 output_writer.add_image(prefix +'Input{}/{}'.format(img_idx, writer_idx), flow_hsv, epoch)
990 else:
991 input_image = (input_image/255.0).clip(0,1) #.astype(np.uint8)
992 output_writer.add_image(prefix + 'Input{}/{}'.format(img_idx, writer_idx), input_image.transpose((2,0,1)), epoch)
994 # for sparse data, chroma blending does not look good
995 for task_idx, output_type in enumerate(args.model_config.output_type):
996 # metric_name = metric_names[task_idx]
997 output = task_outputs[task_idx]
998 target = task_targets[task_idx]
999 if (output_type == 'segmentation') and hasattr(dataset, 'decode_segmap'):
1000 segmentation_target = dataset.decode_segmap(target[b_index,0].cpu().numpy())
1001 segmentation_output = output.max(dim=1,keepdim=True)[1].data.cpu().numpy() if(output.shape[1]>1) else output.data.cpu().numpy()
1002 segmentation_output = dataset.decode_segmap(segmentation_output[b_index,0])
1003 segmentation_output_blend = xnn.utils.chroma_blend(input_image, segmentation_output)
1004 #
1005 output_writer.add_image(prefix+'Task{}_{}_GT/{}'.format(task_idx,output_type,writer_idx), segmentation_target.transpose(2,0,1), epoch)
1006 if not args.sparse:
1007 segmentation_target_blend = xnn.utils.chroma_blend(input_image, segmentation_target)
1008 output_writer.add_image(prefix + 'Task{}_{}_GT_ColorBlend/{}'.format(task_idx, output_type, writer_idx), segmentation_target_blend.transpose(2, 0, 1), epoch)
1009 #
1010 output_writer.add_image(prefix+'Task{}_{}_Output/{}'.format(task_idx,output_type,writer_idx), segmentation_output.transpose(2,0,1), epoch)
1011 output_writer.add_image(prefix+'Task{}_{}_Output_ColorBlend/{}'.format(task_idx,output_type,writer_idx), segmentation_output_blend.transpose(2,0,1), epoch)
1012 elif (output_type in ('depth', 'disparity')):
1013 depth_chanidx = 0
1014 output_writer.add_image(prefix+'Task{}_{}_GT_Color_Visualization/{}'.format(task_idx,output_type,writer_idx), xnn.utils.tensor2array(target[b_index][depth_chanidx].cpu(), max_value=args.max_depth, colormap=args.viz_colormap).transpose(2,0,1), epoch)
1015 if not args.sparse:
1016 output_writer.add_image(prefix + 'Task{}_{}_GT_ColorBlend_Visualization/{}'.format(task_idx, output_type, writer_idx), xnn.utils.tensor2array(target[b_index][depth_chanidx].cpu(), max_value=args.max_depth, colormap=args.viz_colormap, input_blend=input_image).transpose(2, 0, 1), epoch)
1017 #
1018 output_writer.add_image(prefix+'Task{}_{}_Output_Color_Visualization/{}'.format(task_idx,output_type,writer_idx), xnn.utils.tensor2array(output.data[b_index][depth_chanidx].cpu(), max_value=args.max_depth, colormap=args.viz_colormap).transpose(2,0,1), epoch)
1019 output_writer.add_image(prefix + 'Task{}_{}_Output_ColorBlend_Visualization/{}'.format(task_idx, output_type, writer_idx),xnn.utils.tensor2array(output.data[b_index][depth_chanidx].cpu(), max_value=args.max_depth, colormap=args.viz_colormap, input_blend=input_image).transpose(2, 0, 1), epoch)
1020 elif (output_type == 'flow'):
1021 max_value_flow = 10.0 # only for visualization
1022 output_writer.add_image(prefix+'Task{}_{}_GT/{}'.format(task_idx,output_type,writer_idx), xnn.utils.flow2hsv(target[b_index][:2].cpu().numpy(), max_value=max_value_flow).transpose(2,0,1), epoch)
1023 output_writer.add_image(prefix+'Task{}_{}_Output/{}'.format(task_idx,output_type,writer_idx), xnn.utils.flow2hsv(output.data[b_index][:2].cpu().numpy(), max_value=max_value_flow).transpose(2,0,1), epoch)
1024 elif (output_type == 'interest_pt'):
1025 score_chanidx = 0
1026 target_score_to_write = target[b_index][score_chanidx].cpu()
1027 output_score_to_write = output.data[b_index][score_chanidx].cpu()
1029 #if score is learnt as zero mean add offset to make it [0-255]
1030 if args.make_score_zero_mean:
1031 # target_score_to_write!=0 : value 0 indicates GT unavailble. Leave them to be 0.
1032 target_score_to_write[target_score_to_write!=0] += 128.0
1033 output_score_to_write += 128.0
1035 max_value_score = float(torch.max(target_score_to_write)) #0.002
1036 output_writer.add_image(prefix+'Task{}_{}_GT_Bone_Visualization/{}'.format(task_idx,output_type,writer_idx), xnn.utils.tensor2array(target_score_to_write, max_value=max_value_score, colormap='bone').transpose(2,0,1), epoch)
1037 output_writer.add_image(prefix+'Task{}_{}_Output_Bone_Visualization/{}'.format(task_idx,output_type,writer_idx), xnn.utils.tensor2array(output_score_to_write, max_value=max_value_score, colormap='bone').transpose(2,0,1), epoch)
1038 #
1040 def print_conf_matrix(conf_matrix = [], en = False):
1041 if not en:
1042 return
1043 num_rows = conf_matrix.shape[0]
1044 num_cols = conf_matrix.shape[1]
1045 print("-"*64)
1046 num_ele = 1
1047 for r_idx in range(num_rows):
1048 print("\n")
1049 for c_idx in range(0,num_cols,num_ele):
1050 print(conf_matrix[r_idx][c_idx:c_idx+num_ele], end="")
1051 print("\n")
1052 print("-" * 64)
1054 def compute_task_objectives(args, objective_fns, input_var, task_outputs, task_targets, task_mults=None,
1055 task_offsets=None, loss_mult_factors=None, get_confusion_matrix = False):
1057 ##########################
1058 objective_total = torch.zeros_like(task_outputs[0].view(-1)[0])
1059 objective_list = []
1060 objective_list_orig = []
1061 objective_names = []
1062 objective_types = []
1063 for task_idx, task_objectives in enumerate(objective_fns):
1064 output_type = args.model_config.output_type[task_idx]
1065 objective_sum_value = torch.zeros_like(task_outputs[task_idx].view(-1)[0])
1066 objective_sum_name = ''
1067 objective_sum_type = ''
1069 task_mult = task_mults[task_idx] if task_mults is not None else 1.0
1070 task_offset = task_offsets[task_idx] if task_offsets is not None else 0.0
1072 for oidx, objective_fn in enumerate(task_objectives):
1073 objective_batch = objective_fn(input_var, task_outputs[task_idx], task_targets[task_idx])
1074 objective_batch = objective_batch.mean() if isinstance(objective_fn, torch.nn.DataParallel) else objective_batch
1075 objective_name = objective_fn.info()['name']
1076 objective_type = objective_fn.info()['is_avg']
1077 if get_confusion_matrix:
1078 confusion_matrix = objective_fn.info()['confusion_matrix']
1080 loss_mult = loss_mult_factors[task_idx][oidx] if (loss_mult_factors is not None) else 1.0
1081 # --
1082 objective_batch_not_nan = (objective_batch if not torch.isnan(objective_batch) else 0.0)
1083 objective_sum_value = objective_batch_not_nan*loss_mult + objective_sum_value
1084 objective_sum_name += (objective_name if (objective_sum_name == '') else ('+' + objective_name))
1085 assert (objective_sum_type == '' or objective_sum_type == objective_type), 'metric types (avg/val) for a given task should match'
1086 objective_sum_type = objective_type
1088 objective_list.append(objective_sum_value)
1089 objective_list_orig.append(objective_sum_value)
1090 objective_names.append(objective_sum_name)
1091 objective_types.append(objective_sum_type)
1093 objective_total = objective_sum_value*task_mult + task_offset + objective_total
1095 return_list = [objective_total, objective_list, objective_names, objective_types, objective_list_orig]
1096 if get_confusion_matrix:
1097 return_list.append(confusion_matrix)
1099 return return_list
1102 def save_checkpoint(args, save_path, model, checkpoint_dict, is_best, filename='checkpoint.pth.tar'):
1103 torch.save(checkpoint_dict, os.path.join(save_path,filename))
1104 if is_best:
1105 shutil.copyfile(os.path.join(save_path,filename), os.path.join(save_path,'model_best.pth.tar'))
1106 #
1107 if args.save_onnx:
1108 write_onnx_model(args, model, save_path, name='checkpoint.onnx')
1109 if is_best:
1110 write_onnx_model(args, model, save_path, name='model_best.onnx')
1111 #
1114 def get_epoch_size(args, loader, args_epoch_size):
1115 if args_epoch_size == 0:
1116 epoch_size = len(loader)
1117 elif args_epoch_size < 1:
1118 epoch_size = int(len(loader) * args_epoch_size)
1119 else:
1120 epoch_size = min(len(loader), int(args_epoch_size))
1121 return epoch_size
1124 def get_train_transform(args):
1125 # image normalization can be at the beginning of transforms or at the end
1126 image_mean = np.array(args.image_mean, dtype=np.float32)
1127 image_scale = np.array(args.image_scale, dtype=np.float32)
1128 image_prenorm = vision.transforms.image_transforms.NormalizeMeanScale(mean=image_mean, scale=image_scale) if args.image_prenorm else None
1129 image_postnorm = vision.transforms.image_transforms.NormalizeMeanScale(mean=image_mean, scale=image_scale) if (not image_prenorm) else None
1131 # crop size used only for training
1132 image_train_output_scaling = vision.transforms.image_transforms.Scale(args.rand_resize, target_size=args.rand_output_size, is_flow=args.is_flow) \
1133 if (args.rand_output_size is not None and args.rand_output_size != args.rand_resize) else None
1134 train_transform = vision.transforms.image_transforms.Compose([
1135 image_prenorm,
1136 vision.transforms.image_transforms.AlignImages(),
1137 vision.transforms.image_transforms.MaskTarget(args.target_mask, 0),
1138 vision.transforms.image_transforms.CropRect(args.img_border_crop),
1139 vision.transforms.image_transforms.RandomRotate(args.transform_rotation, is_flow=args.is_flow) if args.transform_rotation else None,
1140 vision.transforms.image_transforms.RandomScaleCrop(args.rand_resize, scale_range=args.rand_scale, is_flow=args.is_flow),
1141 vision.transforms.image_transforms.RandomHorizontalFlip(is_flow=args.is_flow),
1142 vision.transforms.image_transforms.RandomCrop(args.rand_crop),
1143 vision.transforms.image_transforms.RandomColor2Gray(is_flow=args.is_flow, random_threshold=0.5) if 'tiad' in args.dataset_name else None,
1144 image_train_output_scaling,
1145 image_postnorm,
1146 vision.transforms.image_transforms.ConvertToTensor()
1147 ])
1148 return train_transform
1151 def get_validation_transform(args):
1152 # image normalization can be at the beginning of transforms or at the end
1153 image_mean = np.array(args.image_mean, dtype=np.float32)
1154 image_scale = np.array(args.image_scale, dtype=np.float32)
1155 image_prenorm = vision.transforms.image_transforms.NormalizeMeanScale(mean=image_mean, scale=image_scale) if args.image_prenorm else None
1156 image_postnorm = vision.transforms.image_transforms.NormalizeMeanScale(mean=image_mean, scale=image_scale) if (not image_prenorm) else None
1158 # prediction is resized to output_size before evaluation.
1159 val_transform = vision.transforms.image_transforms.Compose([
1160 image_prenorm,
1161 vision.transforms.image_transforms.AlignImages(),
1162 vision.transforms.image_transforms.MaskTarget(args.target_mask, 0),
1163 vision.transforms.image_transforms.CropRect(args.img_border_crop),
1164 vision.transforms.image_transforms.Scale(args.img_resize, target_size=args.output_size, is_flow=args.is_flow),
1165 image_postnorm,
1166 vision.transforms.image_transforms.ConvertToTensor()
1167 ])
1168 return val_transform
1171 def get_transforms(args):
1172 # Provision to train with val transform - provide rand_scale as (0, 0)
1173 # Fixing the train-test resolution discrepancy, https://arxiv.org/abs/1906.06423
1174 always_use_val_transform = (args.rand_scale[0] == 0)
1175 train_transform = get_validation_transform(args) if always_use_val_transform else get_train_transform(args)
1176 val_transform = get_validation_transform(args)
1177 return train_transform, val_transform
1180 def _upsample_impl(tensor, output_size, upsample_mode):
1181 # upsample of long tensor is not supported currently. covert to float, just to avoid error.
1182 # we can do thsi only in the case of nearest mode, otherwise output will have invalid values.
1183 convert_to_float = False
1184 if isinstance(tensor, (torch.LongTensor,torch.cuda.LongTensor)):
1185 convert_to_float = True
1186 original_dtype = tensor.dtype
1187 tensor = tensor.float()
1188 upsample_mode = 'nearest'
1190 dim_added = False
1191 if len(tensor.shape) < 4:
1192 tensor = tensor[np.newaxis,...]
1193 dim_added = True
1195 if (tensor.size()[-2:] != output_size):
1196 tensor = torch.nn.functional.interpolate(tensor, output_size, mode=upsample_mode)
1198 if dim_added:
1199 tensor = tensor[0,...]
1201 if convert_to_float:
1202 tensor = tensor.long() #tensor.astype(original_dtype)
1204 return tensor
1207 def upsample_tensors(tensors, output_sizes, upsample_mode):
1208 if isinstance(tensors, (list,tuple)):
1209 for tidx, tensor in enumerate(tensors):
1210 tensors[tidx] = _upsample_impl(tensor, output_sizes[tidx][-2:], upsample_mode)
1211 #
1212 else:
1213 tensors = _upsample_impl(tensors, output_sizes[0][-2:], upsample_mode)
1214 return tensors
1216 #print IoU for each class
1217 def print_class_iou(args = None, confusion_matrix = None, task_idx = 0):
1218 n_classes = args.model_config.output_channels[task_idx]
1219 [accuracy, mean_iou, iou, f1_score] = compute_accuracy(args, confusion_matrix, n_classes)
1220 print("\n Class IoU: [", end = "")
1221 for class_iou in iou:
1222 print("{:0.3f}".format(class_iou), end=",")
1223 print("]")
1225 if __name__ == '__main__':
1226 train_args = get_config()
1227 main(train_args)