torch.nn.ReLU is the recommended activation module. removed the custom defined module...
[jacinto-ai/pytorch-jacinto-ai-devkit.git] / modules / pytorch_jacinto_ai / engine / train_pixel2pixel.py
1 import os
2 import shutil
3 import time
4 import math
5 import copy
7 import torch
8 import torch.nn.parallel
9 import torch.backends.cudnn as cudnn
10 import torch.optim
11 import torch.utils.data
12 import torch.onnx
13 import onnx
15 import datetime
16 from tensorboardX import SummaryWriter
17 import numpy as np
18 import random
19 import cv2
20 from colorama import Fore
21 import progiter
22 from packaging import version
23 import warnings
25 from .. import xnn
26 from .. import vision
27 from . infer_pixel2pixel import compute_accuracy
30 ##################################################
31 warnings.filterwarnings('ignore', category=torch.jit.TracerWarning)
33 ##################################################
34 def get_config():
35     args = xnn.utils.ConfigNode()
37     args.dataset_config = xnn.utils.ConfigNode()
38     args.dataset_config.split_name = 'val'
39     args.dataset_config.max_depth_bfr_scaling = 80
40     args.dataset_config.depth_scale = 1
41     args.dataset_config.train_depth_log = 1
42     args.use_semseg_for_depth = False
44     # model config
45     args.model_config = xnn.utils.ConfigNode()
46     args.model_config.output_type = ['segmentation']   # the network is used to predict flow or depth or sceneflow
47     args.model_config.output_channels = None            # number of output channels
48     args.model_config.input_channels = None             # number of input channels
49     args.model_config.output_range = None               # max range of output
50     args.model_config.num_decoders = None               # number of decoders to use. [options: 0, 1, None]
51     args.model_config.freeze_encoder = False            # do not update encoder weights
52     args.model_config.freeze_decoder = False            # do not update decoder weights
53     args.model_config.multi_task_type = 'learned'       # find out loss multiplier by learning, choices=[None, 'learned', 'uncertainty', 'grad_norm', 'dwa_grad_norm']
54     args.model_config.target_input_ratio = 1            # Keep target size same as input size
55     args.model_config.input_nv12 = False
56     args.model = None                                   # the model itself can be given from ouside
57     args.model_name = 'deeplabv2lite_mobilenetv2'       # model architecture, overwritten if pretrained is specified
58     args.dataset_name = 'cityscapes_segmentation'       # dataset type
59     args.transforms = None                              # the transforms itself can be given from outside
61     args.data_path = './data/cityscapes'                # 'path to dataset'
62     args.save_path = None                               # checkpoints save path
63     args.phase = 'training'                             # training/calibration/validation
64     args.date = None                                    # date to add to save path. if this is None, current date will be added.
66     args.logger = None                                  # logger stream to output into
67     args.show_gpu_usage = False                         # Shows gpu usage at the begining of each training epoch
69     args.split_file = None                              # train_val split file
70     args.split_files = None                             # split list files. eg: train.txt val.txt
71     args.split_value = None                             # test_val split proportion (between 0 (only test) and 1 (only train))
73     args.solver = 'adam'                                # solver algorithms, choices=['adam','sgd']
74     args.scheduler = 'step'                             # scheduler algorithms, choices=['step','poly', 'cosine']
75     args.workers = 8                                    # number of data loading workers
77     args.epochs = 250                                   # number of total epochs to run
78     args.start_epoch = 0                                # manual epoch number (useful on restarts)
80     args.epoch_size = 0                                 # manual epoch size (will match dataset size if not specified)
81     args.epoch_size_val = 0                             # manual epoch size (will match dataset size if not specified)
82     args.batch_size = 12                                # mini_batch size
83     args.total_batch_size = None                        # accumulated batch size. total_batch_size = batch_size*iter_size
84     args.iter_size = 1                                  # iteration size. total_batch_size = batch_size*iter_size
86     args.lr = 1e-4                                      # initial learning rate
87     args.lr_clips = None                                 # use args.lr itself if it is None
88     args.lr_calib = 0.1                                 # lr for bias calibration
89     args.warmup_epochs = 5                              # number of epochs to warmup
91     args.momentum = 0.9                                 # momentum for sgd, alpha parameter for adam
92     args.beta = 0.999                                   # beta parameter for adam
93     args.weight_decay = 1e-4                            # weight decay
94     args.bias_decay = None                              # bias decay
96     args.sparse = True                                  # avoid invalid/ignored target pixels from loss computation, use NEAREST for interpolation
98     args.tensorboard_num_imgs = 5                       # number of imgs to display in tensorboard
99     args.pretrained = None                              # path to pre_trained model
100     args.resume = None                                  # path to latest checkpoint (default: none)
101     args.no_date = False                                # don\'t append date timestamp to folder
102     args.print_freq = 100                               # print frequency (default: 100)
104     args.milestones = (100, 200)                        # epochs at which learning rate is divided by 2
106     args.losses = ['segmentation_loss']                 # loss functions to mchoices=['step','poly', 'cosine'],loss multiplication factor')
107     args.metrics = ['segmentation_metrics']  # metric/measurement/error functions for train/validation
108     args.multi_task_factors = None                      # loss mult factors
109     args.class_weights = None                           # class weights
111     args.loss_mult_factors = None                       # fixed loss mult factors - per loss - not: this is different from multi_task_factors (which is per task)
113     args.multistep_gamma = 0.5                          # steps for step scheduler
114     args.polystep_power = 1.0                           # power for polynomial scheduler
116     args.rand_seed = 1                                  # random seed
117     args.img_border_crop = None                         # image border crop rectangle. can be relative or absolute
118     args.target_mask = None                              # mask rectangle. can be relative or absolute. last value is the mask value
120     args.rand_resize = None                             # random image size to be resized to during training
121     args.rand_output_size = None                        # output size to be resized to during training
122     args.rand_scale = (1.0, 2.0)                        # random scale range for training
123     args.rand_crop = None                               # image size to be cropped to
125     args.img_resize = None                              # image size to be resized to during evaluation
126     args.output_size = None                             # target output size to be resized to
128     args.count_flops = True                             # count flops and report
130     args.shuffle = True                                 # shuffle or not
131     args.shuffle_val = False                            # shuffle val dataset or not
133     args.transform_rotation = 0.                        # apply rotation augumentation. value is rotation in degrees. 0 indicates no rotation
134     args.is_flow = None                                 # whether entries in images and targets lists are optical flow or not
136     args.upsample_mode = 'bilinear'                     # upsample mode to use, choices=['nearest','bilinear']
138     args.image_prenorm = True                           # whether normalization is done before all other the transforms
139     args.image_mean = (128.0,)                          # image mean for input image normalization
140     args.image_scale = (1.0 / (0.25 * 256),)            # image scaling/mult for input iamge normalization
142     args.max_depth = 80                                 # maximum depth to be used for visualization
144     args.pivot_task_idx = 0                             # task id to select best model
146     args.parallel_model = True                          # Usedata parallel for model
147     args.parallel_criterion = True                      # Usedata parallel for loss and metric
149     args.evaluate_start = True                          # evaluate right at the begining of training or not
150     args.generate_onnx = True                           # apply quantized inference or not
151     args.print_model = False                            # print the model to text
152     args.run_soon = True                                # To start training after generating configs/models
154     args.quantize = False                               # apply quantized inference or not
155     #args.model_surgery = None                           # replace activations with PAct2 activation module. Helpful in quantized training.
156     args.bitwidth_weights = 8                           # bitwidth for weights
157     args.bitwidth_activations = 8                       # bitwidth for activations
158     args.histogram_range = True                         # histogram range for calibration
159     args.bias_calibration = True                        # apply bias correction during quantized inference calibration
160     args.per_channel_q = False                          # apply separate quantizion factor for each channel in depthwise or not
162     args.save_mod_files = False                         # saves modified files after last commit. Also  stores commit id.
163     args.make_score_zero_mean = False                   # make score zero mean while learning
164     args.no_q_for_dws_layer_idx = 0                     # no_q_for_dws_layer_idx
166     args.viz_colormap = 'rainbow'                       # colormap for tensorboard: 'rainbow', 'plasma', 'magma', 'bone'
168     args.freeze_bn = False                              # freeze the statistics of bn
169     args.tensorboard_enable = True                      # en/disable of TB writing
170     args.print_train_class_iou = False
171     args.print_val_class_iou = False
172     args.freeze_layers = None
174     return args
177 # ################################################
178 # to avoid hangs in data loader with multi threads
179 # this was observed after using cv2 image processing functions
180 # https://github.com/pytorch/pytorch/issues/1355
181 cv2.setNumThreads(0)
183 # ################################################
184 def main(args):
185     # ensure pytorch version is 1.2 or higher
186     assert version.parse(torch.__version__) >= version.parse('1.1'), \
187         'torch version must be 1.1 or higher, due to the change in scheduler.step() and optimiser.step() call order'
189     assert (not hasattr(args, 'evaluate')), 'args.evaluate is deprecated. use args.phase=training or calibration or validation'
190     assert is_valid_phase(args.phase), f'invalid phase {args.phase}'
191     assert not hasattr(args, 'model_surgery'), 'the argument model_surgery is deprecated, it is not needed now - remove it'
193     if (args.phase == 'validation' and args.bias_calibration):
194         args.bias_calibration = False
195         warnings.warn('switching off bias calibration in validation')
196     #
198     #################################################
199     args.rand_resize = args.img_resize if args.rand_resize is None else args.rand_resize
200     args.rand_crop = args.img_resize if args.rand_crop is None else args.rand_crop
201     args.output_size = args.img_resize if args.output_size is None else args.output_size
202     # resume has higher priority
203     args.pretrained = None if (args.resume is not None) else args.pretrained
205     if args.save_path is None:
206         save_path = get_save_path(args)
207     else:
208         save_path = args.save_path
209     #
210     if not os.path.exists(save_path):
211         os.makedirs(save_path)
213     if args.save_mod_files:
214         #store all the files after the last commit.
215         mod_files_path = save_path+'/mod_files'
216         os.makedirs(mod_files_path)
217         
218         cmd = "git ls-files --modified | xargs -i cp {} {}".format("{}", mod_files_path)
219         print("cmd:", cmd)    
220         os.system(cmd)
222         #stoe last commit id. 
223         cmd = "git log -n 1  >> {}".format(mod_files_path + '/commit_id.txt')
224         print("cmd:", cmd)    
225         os.system(cmd)
227     #################################################
228     if args.logger is None:
229         log_file = os.path.splitext(os.path.basename(__file__))[0] + '.log'
230         args.logger = xnn.utils.TeeLogger(filename=os.path.join(save_path,log_file))
232     #################################################
233     # global settings. rand seeds for repeatability
234     random.seed(args.rand_seed)
235     np.random.seed(args.rand_seed)
236     torch.manual_seed(args.rand_seed)
237     torch.cuda.manual_seed(args.rand_seed)
239     ################################
240     # args check and config
241     if args.iter_size != 1 and args.total_batch_size is not None:
242         warnings.warn("only one of --iter_size or --total_batch_size must be set")
243     #
244     if args.total_batch_size is not None:
245         args.iter_size = args.total_batch_size//args.batch_size
246     else:
247         args.total_batch_size = args.batch_size*args.iter_size
249     #################################################
250     # set some global flags and initializations
251     # keep it in args for now - although they don't belong here strictly
252     # using pin_memory is seen to cause issues, especially when when lot of memory is used.
253     args.use_pinned_memory = False
254     args.n_iter = 0
255     args.best_metric = -1
256     cudnn.benchmark = True
257     # torch.autograd.set_detect_anomaly(True)
259     ################################
260     # reset character color, in case it is different
261     print('{}'.format(Fore.RESET))
262     # print everything for log
263     print('=> args: {}'.format(args))
264     print('\n'.join("%s: %s" % item for item in sorted(vars(args).items())))
266     print('=> will save everything to {}'.format(save_path))
268     #################################################
269     train_writer = SummaryWriter(os.path.join(save_path,'train')) if args.tensorboard_enable else None
270     val_writer = SummaryWriter(os.path.join(save_path,'val')) if args.tensorboard_enable else None
271     transforms = get_transforms(args) if args.transforms is None else args.transforms
272     assert isinstance(transforms, (list,tuple)) and len(transforms) == 2, 'incorrect transforms were given'
274     print("=> fetching images in '{}'".format(args.data_path))
275     split_arg = args.split_file if args.split_file else (args.split_files if args.split_files else args.split_value)
276     train_dataset, val_dataset = vision.datasets.pixel2pixel.__dict__[args.dataset_name](args.dataset_config, args.data_path, split=split_arg, transforms=transforms)
278     #################################################
279     train_sampler = None
280     val_sampler = None
281     print('=> {} samples found, {} train samples and {} test samples '.format(len(train_dataset)+len(val_dataset),
282         len(train_dataset), len(val_dataset)))
283     train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=args.batch_size,
284         num_workers=args.workers, pin_memory=args.use_pinned_memory, sampler=train_sampler, shuffle=args.shuffle)
286     val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=args.batch_size,
287         num_workers=args.workers, pin_memory=args.use_pinned_memory, sampler=val_sampler, shuffle=args.shuffle_val)
289     #################################################
290     if (args.model_config.input_channels is None):
291         args.model_config.input_channels = (3,)
292         print("=> input channels is not given - setting to {}".format(args.model_config.input_channels))
294     if (args.model_config.output_channels is None):
295         if ('num_classes' in dir(train_dataset)):
296             args.model_config.output_channels = train_dataset.num_classes()
297         else:
298             args.model_config.output_channels = (2 if args.model_config.output_type == 'flow' else args.model_config.output_channels)
299             xnn.utils.print_yellow("=> output channels is not given - setting to {} - not sure to work".format(args.model_config.output_channels))
300         #
301         if not isinstance(args.model_config.output_channels,(list,tuple)):
302             args.model_config.output_channels = [args.model_config.output_channels]
304     if (args.class_weights is None) and ('class_weights' in dir(train_dataset)):
305         args.class_weights = train_dataset.class_weights()
306         if not isinstance(args.class_weights, (list,tuple)):
307             args.class_weights = [args.class_weights]
308         #
309         print("=> class weights available for dataset: {}".format(args.class_weights))
311     #################################################
312     pretrained_data = None
313     model_surgery_quantize = False
314     pretrained_data = None
315     if args.pretrained and args.pretrained != "None":
316         pretrained_data = []
317         pretrained_files = args.pretrained if isinstance(args.pretrained,(list,tuple)) else [args.pretrained]
318         for p in pretrained_files:
319             if isinstance(p, dict):
320                 p_data = p
321             else:
322                 if p.startswith('http://') or p.startswith('https://'):
323                     p_file = vision.datasets.utils.download_url(p, './data/downloads')
324                 else:
325                     p_file = p
326                 #
327                 print(f'=> loading pretrained weights file: {p}')
328                 p_data = torch.load(p_file)
329             #
330             pretrained_data.append(p_data)
331             model_surgery_quantize = p_data['quantize'] if 'quantize' in p_data else False
332     #
334     #################################################
335     # create model
336     if args.model is not None:
337         model, change_names_dict = args.model if isinstance(args.model, (list, tuple)) else (args.model, None)
338         assert isinstance(model, torch.nn.Module), 'args.model, if provided must be a valid torch.nn.Module'
339     else:
340         xnn.utils.print_yellow("=> creating model '{}'".format(args.model_name))
341         model = vision.models.pixel2pixel.__dict__[args.model_name](args.model_config)
342         # check if we got the model as well as parameters to change the names in pretrained
343         model, change_names_dict = model if isinstance(model, (list,tuple)) else (model,None)
344     #
346     if args.quantize:
347         # dummy input is used by quantized models to analyze graph
348         is_cuda = next(model.parameters()).is_cuda
349         dummy_input = create_rand_inputs(args, is_cuda=is_cuda)
350         #
351         if 'training' in args.phase:
352             model = xnn.quantize.QuantTrainModule(model, per_channel_q=args.per_channel_q,
353                         histogram_range=args.histogram_range, bitwidth_weights=args.bitwidth_weights, bitwidth_activations=args.bitwidth_activations, dummy_input=dummy_input)
354         elif 'calibration' in args.phase:
355             model = xnn.quantize.QuantCalibrateModule(model, per_channel_q=args.per_channel_q,
356                         bitwidth_weights=args.bitwidth_weights, bitwidth_activations=args.bitwidth_activations,
357                         histogram_range=args.histogram_range, bias_calibration=args.bias_calibration, lr_calib=args.lr_calib, dummy_input=dummy_input)
358         elif 'validation' in args.phase:
359             # Note: bias_calibration is not emabled
360             model = xnn.quantize.QuantTestModule(model, per_channel_q=args.per_channel_q,
361                         bitwidth_weights=args.bitwidth_weights, bitwidth_activations=args.bitwidth_activations,
362                         histogram_range=args.histogram_range, model_surgery_quantize=model_surgery_quantize,
363                         dummy_input=dummy_input)
364         else:
365             assert False, f'invalid phase {args.phase}'
366     #
368     # load pretrained model
369     if pretrained_data is not None:
370         for (p_data,p_file) in zip(pretrained_data, pretrained_files):
371             print("=> using pretrained weights from: {}".format(p_file))
372             xnn.utils.load_weights(get_model_orig(model), pretrained=p_data, change_names_dict=change_names_dict)
373     #
375     #################################################
376     if args.count_flops:
377         count_flops(args, model)
379     #################################################
380     if args.generate_onnx and (any(args.phase in p for p in ('training','calibration')) or (args.run_soon == False)):
381         write_onnx_model(args, get_model_orig(model), save_path)
382     #
384     #################################################
385     if args.print_model:
386         print(model)
387         print('\n')
388     else:
389         args.logger.debug(str(model))
390         args.logger.debug('\n')
392     #################################################
393     if (not args.run_soon):
394         print("Training not needed for now")
395         close(args)
396         exit()
398     #################################################
399     # multi gpu mode does not work for calibration/training for quantization
400     # so use it only when args.quantize is False
401     if args.parallel_model and ((not args.quantize)):
402         model = torch.nn.DataParallel(model)
404     #################################################
405     model = model.cuda()
407     #################################################
408     # for help in debug/print
409     for name, module in model.named_modules():
410         module.name = name
412     #################################################
413     args.loss_modules = copy.deepcopy(args.losses)
414     for task_dx, task_losses in enumerate(args.losses):
415         for loss_idx, loss_fn in enumerate(task_losses):
416             kw_args = {}
417             loss_args = vision.losses.__dict__[loss_fn].args()
418             for arg in loss_args:
419                 if arg == 'weight' and (args.class_weights is not None):
420                     kw_args.update({arg:args.class_weights[task_dx]})
421                 elif arg == 'num_classes':
422                     kw_args.update({arg:args.model_config.output_channels[task_dx]})
423                 elif arg == 'sparse':
424                     kw_args.update({arg:args.sparse})
425                 #
426             #
427             loss_fn_raw = vision.losses.__dict__[loss_fn](**kw_args)
428             if args.parallel_criterion:
429                 loss_fn = torch.nn.DataParallel(loss_fn_raw).cuda() if args.parallel_criterion else loss_fn_raw.cuda()
430                 loss_fn.info = loss_fn_raw.info
431                 loss_fn.clear = loss_fn_raw.clear
432             else:
433                 loss_fn = loss_fn_raw.cuda()
434             #
435             args.loss_modules[task_dx][loss_idx] = loss_fn
436     #
438     args.metric_modules = copy.deepcopy(args.metrics)
439     for task_dx, task_metrics in enumerate(args.metrics):
440         for midx, metric_fn in enumerate(task_metrics):
441             kw_args = {}
442             loss_args = vision.losses.__dict__[metric_fn].args()
443             for arg in loss_args:
444                 if arg == 'weight':
445                     kw_args.update({arg:args.class_weights[task_dx]})
446                 elif arg == 'num_classes':
447                     kw_args.update({arg:args.model_config.output_channels[task_dx]})
448                 elif arg == 'sparse':
449                     kw_args.update({arg:args.sparse})
451             metric_fn_raw = vision.losses.__dict__[metric_fn](**kw_args)
452             if args.parallel_criterion:
453                 metric_fn = torch.nn.DataParallel(metric_fn_raw).cuda()
454                 metric_fn.info = metric_fn_raw.info
455                 metric_fn.clear = metric_fn_raw.clear
456             else:
457                 metric_fn = metric_fn_raw.cuda()
458             #
459             args.metric_modules[task_dx][midx] = metric_fn
460     #
462     #################################################
463     if args.phase=='validation':
464         with torch.no_grad():
465             validate(args, val_dataset, val_loader, model, 0, val_writer)
466         #
467         close(args)
468         return
470     #################################################
471     assert(args.solver in ['adam', 'sgd'])
472     print('=> setting {} solver'.format(args.solver))
473     if args.lr_clips is not None:
474         learning_rate_clips = args.lr_clips if 'training' in args.phase else 0.0
475         clips_decay = args.bias_decay if (args.bias_decay is not None and args.bias_decay != 0.0) else args.weight_decay
476         clips_params = [p for n,p in model.named_parameters() if 'clips' in n]
477         other_params = [p for n,p in model.named_parameters() if 'clips' not in n]
478         param_groups = [{'params': clips_params, 'weight_decay': clips_decay, 'lr': learning_rate_clips},
479                         {'params': other_params, 'weight_decay': args.weight_decay}]
480     else:
481         param_groups = [{'params': filter(lambda p: p.requires_grad, model.parameters()), 'weight_decay': args.weight_decay}]
482     #
484     learning_rate = args.lr if ('training'in args.phase) else 0.0
485     if args.solver == 'adam':
486         optimizer = torch.optim.Adam(param_groups, learning_rate, betas=(args.momentum, args.beta))
487     elif args.solver == 'sgd':
488         optimizer = torch.optim.SGD(param_groups, learning_rate, momentum=args.momentum)
489     else:
490         raise ValueError('Unknown optimizer type{}'.format(args.solver))
491     #
493     #################################################
494     epoch_size = get_epoch_size(args, train_loader, args.epoch_size)
495     max_iter = args.epochs * epoch_size
496     scheduler = xnn.optim.lr_scheduler.SchedulerWrapper(args.scheduler, optimizer, args.epochs, args.start_epoch, \
497                                                             args.warmup_epochs, max_iter=max_iter, polystep_power=args.polystep_power,
498                                                             milestones=args.milestones, multistep_gamma=args.multistep_gamma)
500     # optionally resume from a checkpoint
501     if args.resume:
502         if not os.path.isfile(args.resume):
503             print("=> no checkpoint found at '{}'".format(args.resume))        
504         else:
505             print("=> loading checkpoint '{}'".format(args.resume))
507         checkpoint = torch.load(args.resume)
508         model = xnn.utils.load_weights(model, checkpoint)
509             
510         if args.start_epoch == 0:
511             args.start_epoch = checkpoint['epoch']
512         
513         if 'best_metric' in list(checkpoint.keys()):    
514             args.best_metric = checkpoint['best_metric']
516         if 'optimizer' in list(checkpoint.keys()):  
517             optimizer.load_state_dict(checkpoint['optimizer'])
519         if 'scheduler' in list(checkpoint.keys()):
520             scheduler.load_state_dict(checkpoint['scheduler'])
522         if 'multi_task_factors' in list(checkpoint.keys()):
523             args.multi_task_factors = checkpoint['multi_task_factors']
525         print("=> loaded checkpoint '{}' (epoch {})".format(args.resume, checkpoint['epoch']))
527     #################################################
528     if args.evaluate_start:
529         with torch.no_grad():
530             validate(args, val_dataset, val_loader, model, args.start_epoch, val_writer)
532     for epoch in range(args.start_epoch, args.epochs):
533         if train_sampler:
534             train_sampler.set_epoch(epoch)
535         if val_sampler:
536             val_sampler.set_epoch(epoch)
538         # train for one epoch
539         train(args, train_dataset, train_loader, model, optimizer, epoch, train_writer, scheduler)
541         # evaluate on validation set
542         with torch.no_grad():
543             val_metric, metric_name = validate(args, val_dataset, val_loader, model, epoch, val_writer)
545         if args.best_metric < 0:
546             args.best_metric = val_metric
548         if "iou" in metric_name.lower() or "acc" in metric_name.lower():
549             is_best = val_metric >= args.best_metric
550             args.best_metric = max(val_metric, args.best_metric)
551         elif "error" in metric_name.lower() or "diff" in metric_name.lower() or "norm" in metric_name.lower() \
552                 or "loss" in metric_name.lower() or "outlier" in metric_name.lower():
553             is_best = val_metric <= args.best_metric
554             args.best_metric = min(val_metric, args.best_metric)
555         else:
556             raise ValueError("Metric is not known. Best model could not be saved.")
557         #
559         checkpoint_dict = { 'epoch': epoch + 1, 'model_name': args.model_name,
560                             'state_dict': get_model_orig(model).state_dict(),
561                             'optimizer': optimizer.state_dict(),
562                             'scheduler': scheduler.state_dict(),
563                             'best_metric': args.best_metric,
564                             'multi_task_factors': args.multi_task_factors,
565                             'quantize' : args.quantize}
567         save_checkpoint(args, save_path, get_model_orig(model), checkpoint_dict, is_best)
569         if args.tensorboard_enable:
570             train_writer.file_writer.flush()
571             val_writer.file_writer.flush()
573         # adjust the learning rate using lr scheduler
574         if 'training' in args.phase:
575             scheduler.step()
576         #
577     #
579     # close and cleanup
580     close(args)
583 ###################################################################
584 def is_valid_phase(phase):
585     phases = ('training', 'calibration', 'validation')
586     return any(p in phase for p in phases)
589 ###################################################################
590 def train(args, train_dataset, train_loader, model, optimizer, epoch, train_writer, scheduler):
591     batch_time = xnn.utils.AverageMeter()
592     data_time = xnn.utils.AverageMeter()
593     # if the loss/ metric is already an average, no need to further average
594     avg_loss = [xnn.utils.AverageMeter(print_avg=(not task_loss[0].info()['is_avg'])) for task_loss in args.loss_modules]
595     avg_loss_orig = [xnn.utils.AverageMeter(print_avg=(not task_loss[0].info()['is_avg'])) for task_loss in args.loss_modules]
596     avg_metric = [xnn.utils.AverageMeter(print_avg=(not task_metric[0].info()['is_avg'])) for task_metric in args.metric_modules]
597     epoch_size = get_epoch_size(args, train_loader, args.epoch_size)
599     ##########################
600     # switch to train mode
601     model.train()
602     if args.freeze_bn:
603         xnn.utils.freeze_bn(model)
604     #
605     
606     #freeze layers 
607     if args.freeze_layers is not None:
608         # 'freeze_layer_name' could be part of 'name', i.e. 'name' need not be exact same as 'freeze_layer_name'
609         # e.g. freeze_layer_name = 'encoder.0' then all layers like, 'encoder.0.0'  'encoder.0.1' will be frozen
610         for freeze_layer_name in args.freeze_layers:
611             for name, module in model.named_modules():
612                 if freeze_layer_name in name:
613                     xnn.utils.print_once("Freezing the module : {}".format(name))
614                     module.eval()
615                     for param in module.parameters():
616                         param.requires_grad = False
618     ##########################
619     for task_dx, task_losses in enumerate(args.loss_modules):
620         for loss_idx, loss_fn in enumerate(task_losses):
621             loss_fn.clear()
622     for task_dx, task_metrics in enumerate(args.metric_modules):
623         for midx, metric_fn in enumerate(task_metrics):
624             metric_fn.clear()
626     progress_bar = progiter.ProgIter(np.arange(epoch_size), chunksize=1)
627     metric_name = "Metric"
628     metric_ctx = [None] * len(args.metric_modules)
629     end_time = time.time()
630     writer_idx = 0
631     last_update_iter = -1
633     # change color to yellow for calibration
634     progressbar_color = (Fore.YELLOW if (('calibration' in args.phase) or ('training' in args.phase and args.quantize)) else Fore.WHITE)
635     print('{}'.format(progressbar_color), end='')
637     ##########################
638     for iter, (inputs, targets) in enumerate(train_loader):
639         # measure data loading time
640         data_time.update(time.time() - end_time)
642         lr = scheduler.get_lr()[0]
644         input_list = [[jj.cuda() for jj in img] if isinstance(img,(list,tuple)) else  img.cuda() for img in inputs]
645         target_list = [tgt.cuda(non_blocking=True) for tgt in targets]
646         target_sizes = [tgt.shape for tgt in target_list]
647         batch_size_cur = target_sizes[0][0]
649         ##########################
650         # compute output
651         task_outputs = model(input_list)
653         task_outputs = task_outputs if isinstance(task_outputs,(list,tuple)) else [task_outputs]
654         # upsample output to target resolution
655         if args.upsample_mode is not None:
656             task_outputs = upsample_tensors(task_outputs, target_sizes, args.upsample_mode)
658         if args.model_config.multi_task_type is not None and len(args.model_config.output_channels) > 1:
659             args.multi_task_factors, args.multi_task_offsets = xnn.layers.get_loss_scales(model)
660         else:
661             args.multi_task_factors = None
662             args.multi_task_offsets = None
664         loss_total, loss_list, loss_names, loss_types, loss_list_orig = \
665             compute_task_objectives(args, args.loss_modules, input_list, task_outputs, target_list,
666                          task_mults=args.multi_task_factors, task_offsets=args.multi_task_offsets,
667                          loss_mult_factors=args.loss_mult_factors)
669         if args.print_train_class_iou:
670             metric_total, metric_list, metric_names, metric_types, _, confusion_matrix = \
671                 compute_task_objectives(args, args.metric_modules, input_list, task_outputs, target_list, 
672                 get_confusion_matrix=args.print_train_class_iou)
673         else:        
674             metric_total, metric_list, metric_names, metric_types, _ = \
675                 compute_task_objectives(args, args.metric_modules, input_list, task_outputs, target_list, 
676                 get_confusion_matrix=args.print_train_class_iou)
678         if args.model_config.multi_task_type is not None and len(args.model_config.output_channels) > 1:
679             xnn.layers.set_losses(model, loss_list_orig)
681         if 'training' in args.phase:
682             # zero gradients so that we can accumulate gradients
683             if (iter % args.iter_size) == 0:
684                 optimizer.zero_grad()
686             # accumulate gradients
687             loss_total.backward()
688             # optimization step
689             if ((iter+1) % args.iter_size) == 0:
690                 optimizer.step()
691         #
693         # record loss.
694         for task_idx, task_losses in enumerate(args.loss_modules):
695             avg_loss[task_idx].update(float(loss_list[task_idx].cpu()), batch_size_cur)
696             avg_loss_orig[task_idx].update(float(loss_list_orig[task_idx].cpu()), batch_size_cur)
697             if args.tensorboard_enable:
698                 train_writer.add_scalar('Training/Task{}_{}_Loss_Iter'.format(task_idx,loss_names[task_idx]), float(loss_list[task_idx]), args.n_iter)
699                 if args.model_config.multi_task_type is not None and len(args.model_config.output_channels) > 1:
700                     train_writer.add_scalar('Training/multi_task_Factor_Task{}_{}'.format(task_idx,loss_names[task_idx]), float(args.multi_task_factors[task_idx]), args.n_iter)
702         # record error/accuracy.
703         for task_idx, task_metrics in enumerate(args.metric_modules):
704             avg_metric[task_idx].update(float(metric_list[task_idx].cpu()), batch_size_cur)
706         ##########################
707         if args.tensorboard_enable:
708             write_output(args, 'Training_', epoch_size, iter, epoch, train_dataset, train_writer, input_list, task_outputs, target_list, metric_name, writer_idx)
710         if ((iter % args.print_freq) == 0) or (iter == (epoch_size-1)):
711             output_string = ''
712             for task_idx, task_metrics in enumerate(args.metric_modules):
713                 output_string += '[{}={}]'.format(metric_names[task_idx], str(avg_metric[task_idx]))
715             epoch_str = '{}/{}'.format(epoch + 1, args.epochs)
716             progress_bar.set_description("{}=> {}  ".format(progressbar_color, args.phase))
717             multi_task_factors_print = ['{:.3f}'.format(float(lmf)) for lmf in args.multi_task_factors] if args.multi_task_factors is not None else None
718             progress_bar.set_postfix(Epoch=epoch_str, LR=lr, DataTime=str(data_time), LossMult=multi_task_factors_print, Loss=avg_loss, Output=output_string)
719             progress_bar.update(iter-last_update_iter)
720             last_update_iter = iter
722         args.n_iter += 1
723         end_time = time.time()
724         writer_idx = (writer_idx + 1) % args.tensorboard_num_imgs
726         # add onnx graph to tensorboard
727         # commenting out due to issues in transitioning to pytorch 0.4
728         # (bilinear mode in upsampling causes hang or crash - may be due to align_borders change, nearest is fine)
729         #if epoch == 0 and iter == 0:
730         #    input_zero = torch.zeros(input_var.shape)
731         #    train_writer.add_graph(model, input_zero)
732         #This cache operation slows down tranining  
733         #torch.cuda.empty_cache()
735         if iter >= epoch_size:
736             break
738     if args.print_train_class_iou:
739         print_class_iou(args=args, confusion_matrix=confusion_matrix, task_idx=task_idx)
740         
741     progress_bar.close()
743     # to print a new line - do not provide end=''
744     print('{}'.format(Fore.RESET), end='')
746     if args.tensorboard_enable:
747         for task_idx, task_losses in enumerate(args.loss_modules):
748             train_writer.add_scalar('Training/Task{}_{}_Loss_Epoch'.format(task_idx,loss_names[task_idx]), float(avg_loss[task_idx]), epoch)
750         for task_idx, task_metrics in enumerate(args.metric_modules):
751             train_writer.add_scalar('Training/Task{}_{}_Metric_Epoch'.format(task_idx,metric_names[task_idx]), float(avg_metric[task_idx]), epoch)
753     output_name = metric_names[args.pivot_task_idx]
754     output_metric = float(avg_metric[args.pivot_task_idx])
756     ##########################
757     if args.quantize:
758         def debug_format(v):
759             return ('{:.3f}'.format(v) if v is not None else 'None')
760         #
761         clips_act = [m.get_clips_act()[1] for n,m in model.named_modules() if isinstance(m,xnn.layers.PAct2)]
762         if len(clips_act) > 0:
763             args.logger.debug('\nclips_act : ' + ' '.join(map(debug_format, clips_act)))
764             args.logger.debug('')
765     #
766     return output_metric, output_name
769 ###################################################################
770 def validate(args, val_dataset, val_loader, model, epoch, val_writer):
771     data_time = xnn.utils.AverageMeter()
772     # if the loss/ metric is already an average, no need to further average
773     avg_metric = [xnn.utils.AverageMeter(print_avg=(not task_metric[0].info()['is_avg'])) for task_metric in args.metric_modules]
774     epoch_size = get_epoch_size(args, val_loader, args.epoch_size_val)
776     ##########################
777     # switch to evaluate mode
778     model.eval()
780     ##########################
781     for task_dx, task_metrics in enumerate(args.metric_modules):
782         for midx, metric_fn in enumerate(task_metrics):
783             metric_fn.clear()
785     metric_name = "Metric"
786     end_time = time.time()
787     writer_idx = 0
788     last_update_iter = -1
789     metric_ctx = [None] * len(args.metric_modules)
790     progress_bar = progiter.ProgIter(np.arange(epoch_size), chunksize=1)
792     # change color to green
793     print('{}'.format(Fore.GREEN), end='')
795     ##########################
796     for iter, (inputs, targets) in enumerate(val_loader):
797         data_time.update(time.time() - end_time)
798         input_list = [[jj.cuda() for jj in img] if isinstance(img,(list,tuple)) else img.cuda() for img in inputs]
799         target_list = [j.cuda(non_blocking=True) for j in targets]
800         target_sizes = [tgt.shape for tgt in target_list]
801         batch_size_cur = target_sizes[0][0]
803         # compute output
804         task_outputs = model(input_list)
807         task_outputs = task_outputs if isinstance(task_outputs, (list, tuple)) else [task_outputs]
808         if args.upsample_mode is not None:
809            task_outputs = upsample_tensors(task_outputs, target_sizes, args.upsample_mode)
810         
811         if args.print_val_class_iou:
812             metric_total, metric_list, metric_names, metric_types, _, confusion_matrix = \
813                 compute_task_objectives(args, args.metric_modules, input_list, task_outputs, target_list, 
814                 get_confusion_matrix = args.print_val_class_iou)
815         else:        
816             metric_total, metric_list, metric_names, metric_types, _ = \
817                 compute_task_objectives(args, args.metric_modules, input_list, task_outputs, target_list, 
818                 get_confusion_matrix = args.print_val_class_iou)
820         # record error/accuracy.
821         for task_idx, task_metrics in enumerate(args.metric_modules):
822             avg_metric[task_idx].update(float(metric_list[task_idx].cpu()), batch_size_cur)
824         if args.tensorboard_enable:
825             write_output(args, 'Validation_', epoch_size, iter, epoch, val_dataset, val_writer, input_list, task_outputs, target_list, metric_names, writer_idx)
827         if ((iter % args.print_freq) == 0) or (iter == (epoch_size-1)):
828             output_string = ''
829             for task_idx, task_metrics in enumerate(args.metric_modules):
830                 output_string += '[{}={}]'.format(metric_names[task_idx], str(avg_metric[task_idx]))
832             epoch_str = '{}/{}'.format(epoch + 1, args.epochs)
833             progress_bar.set_description("=> validation")
834             progress_bar.set_postfix(Epoch=epoch_str, DataTime=data_time, Output="{}".format(output_string))
835             progress_bar.update(iter-last_update_iter)
836             last_update_iter = iter
838         end_time = time.time()
839         writer_idx = (writer_idx + 1) % args.tensorboard_num_imgs
841         if iter >= epoch_size:
842             break
844     if args.print_val_class_iou:
845         print_class_iou(args = args, confusion_matrix = confusion_matrix, task_idx=task_idx)
846     
847     #print_conf_matrix(conf_matrix=conf_matrix, en=False)
848     progress_bar.close()
850     # to print a new line - do not provide end=''
851     print('{}'.format(Fore.RESET), end='')
853     if args.tensorboard_enable:
854         for task_idx, task_metrics in enumerate(args.metric_modules):
855             val_writer.add_scalar('Validation/Task{}_{}_Metric_Epoch'.format(task_idx,metric_names[task_idx]), float(avg_metric[task_idx]), epoch)
857     output_name = metric_names[args.pivot_task_idx]
858     output_metric = float(avg_metric[args.pivot_task_idx])
859     return output_metric, output_name
862 ###################################################################
863 def close(args):
864     if args.logger is not None:
865         del args.logger
866         args.logger = None
867     #
868     args.best_metric = -1
872 def get_save_path(args, phase=None):
873     date = args.date if args.date else datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
874     save_path = os.path.join('./data/checkpoints', args.dataset_name, date + '_' + args.dataset_name + '_' + args.model_name)
875     save_path += '_resize{}x{}_traincrop{}x{}'.format(args.img_resize[1], args.img_resize[0], args.rand_crop[1], args.rand_crop[0])
876     phase = phase if (phase is not None) else args.phase
877     save_path = os.path.join(save_path, phase)
878     return save_path
881 def get_model_orig(model):
882     is_parallel_model = isinstance(model, (torch.nn.DataParallel, torch.nn.parallel.DistributedDataParallel))
883     model_orig = (model.module if is_parallel_model else model)
884     model_orig = (model_orig.module if isinstance(model_orig, (xnn.quantize.QuantBaseModule)) else model_orig)
885     return model_orig
888 def create_rand_inputs(args, is_cuda):
889     dummy_input = []
890     if not args.model_config.input_nv12:
891         for i_ch in args.model_config.input_channels:
892             x = torch.rand((1, i_ch, args.img_resize[0], args.img_resize[1]))
893             x = x.cuda() if is_cuda else x
894             dummy_input.append(x)
895     else: #nv12    
896         for i_ch in args.model_config.input_channels:
897             y = torch.rand((1, 1, args.img_resize[0], args.img_resize[1]))
898             uv = torch.rand((1, 1, args.img_resize[0]//2, args.img_resize[1]))
899             y = y.cuda() if is_cuda else y
900             uv = uv.cuda() if is_cuda else uv
901             dummy_input.append([y,uv])
903     return dummy_input
905 def count_flops(args, model):
906     is_cuda = next(model.parameters()).is_cuda
907     dummy_input = create_rand_inputs(args, is_cuda)
908     #
909     model.eval()
910     flops = xnn.utils.forward_count_flops(model, dummy_input)
911     gflops = flops/1e9
912     print('=> Size = {}, GFLOPs = {}, GMACs = {}'.format(args.img_resize, gflops, gflops/2))
915 def derive_node_name(input_name):
916     #take last entry of input names for deciding node name
917     #print("input_name[-1]: ", input_name[-1])
918     node_name = input_name[-1].rsplit('.', 1)[0]
919     #print("formed node_name: ", node_name)
920     return node_name
923 #torch onnx export does not update names. Do it using onnx.save
924 def add_node_names(onnx_model_name= []):
925     onnx_model = onnx.load(onnx_model_name)
926     for i in range(len(onnx_model.graph.node)):
927         for j in range(len(onnx_model.graph.node[i].input)):
928             #print('-'*60)
929             #print("name: ", onnx_model.graph.node[i].name)
930             #print("input: ", onnx_model.graph.node[i].input)
931             #print("output: ", onnx_model.graph.node[i].output)
932             onnx_model.graph.node[i].input[j] = onnx_model.graph.node[i].input[j].split(':')[0]
933             onnx_model.graph.node[i].name = derive_node_name(onnx_model.graph.node[i].input)
934         #
935     #
936     #update model inplace
937     onnx.save(onnx_model, onnx_model_name)
939 def write_onnx_model(args, model, save_path, name='checkpoint.onnx'):
940     is_cuda = next(model.parameters()).is_cuda
941     input_list = create_rand_inputs(args, is_cuda=is_cuda)
942     #
943     model.eval()
944     torch.onnx.export(model, input_list, os.path.join(save_path, name), export_params=True, verbose=False)
945     #torch onnx export does not update names. Do it using onnx.save
946     add_node_names(onnx_model_name = os.path.join(save_path, name))
949 ###################################################################
950 def write_output(args, prefix, val_epoch_size, iter, epoch, dataset, output_writer, input_images, task_outputs, task_targets, metric_names, writer_idx):
951     write_freq = (args.tensorboard_num_imgs / float(val_epoch_size))
952     write_prob = np.random.random()
953     if (write_prob > write_freq):
954         return
955     if args.model_config.input_nv12:
956         batch_size = input_images[0][0].shape[0]
957     else:
958         batch_size = input_images[0].shape[0]
959     b_index = random.randint(0, batch_size - 1)
961     input_image = None
962     for img_idx, img in enumerate(input_images):
963         if args.model_config.input_nv12:
964             #convert NV12 to BGR for tensorboard
965             input_image = vision.transforms.image_transforms_xv12.nv12_to_bgr_image(Y = input_images[img_idx][0][b_index], UV = input_images[img_idx][1][b_index],
966                                    image_scale=args.image_scale, image_mean=args.image_mean)
967         else:
968             input_image = input_images[img_idx][b_index].cpu().numpy().transpose((1, 2, 0))
969             # convert back to original input range (0-255)
970             input_image = input_image / args.image_scale + args.image_mean
972         if args.is_flow and args.is_flow[0][img_idx]:
973             #input corresponding to flow is assumed to have been generated by adding 128
974             flow = input_image - 128
975             flow_hsv = xnn.utils.flow2hsv(flow.transpose(2, 0, 1), confidence=False).transpose(2, 0, 1)
976             #flow_hsv = (flow_hsv / 255.0).clip(0, 1) #TODO: check this
977             output_writer.add_image(prefix +'Input{}/{}'.format(img_idx, writer_idx), flow_hsv, epoch)
978         else:
979             input_image = (input_image/255.0).clip(0,1) #.astype(np.uint8)
980             output_writer.add_image(prefix + 'Input{}/{}'.format(img_idx, writer_idx), input_image.transpose((2,0,1)), epoch)
982     # for sparse data, chroma blending does not look good
983     for task_idx, output_type in enumerate(args.model_config.output_type):
984         # metric_name = metric_names[task_idx]
985         output = task_outputs[task_idx]
986         target = task_targets[task_idx]
987         if (output_type == 'segmentation') and hasattr(dataset, 'decode_segmap'):
988             segmentation_target = dataset.decode_segmap(target[b_index,0].cpu().numpy())
989             segmentation_output = output.max(dim=1,keepdim=True)[1].data.cpu().numpy() if(output.shape[1]>1) else output.data.cpu().numpy()
990             segmentation_output = dataset.decode_segmap(segmentation_output[b_index,0])
991             segmentation_output_blend = xnn.utils.chroma_blend(input_image, segmentation_output)
992             #
993             output_writer.add_image(prefix+'Task{}_{}_GT/{}'.format(task_idx,output_type,writer_idx), segmentation_target.transpose(2,0,1), epoch)
994             if not args.sparse:
995                 segmentation_target_blend = xnn.utils.chroma_blend(input_image, segmentation_target)
996                 output_writer.add_image(prefix + 'Task{}_{}_GT_ColorBlend/{}'.format(task_idx, output_type, writer_idx), segmentation_target_blend.transpose(2, 0, 1), epoch)
997             #
998             output_writer.add_image(prefix+'Task{}_{}_Output/{}'.format(task_idx,output_type,writer_idx), segmentation_output.transpose(2,0,1), epoch)
999             output_writer.add_image(prefix+'Task{}_{}_Output_ColorBlend/{}'.format(task_idx,output_type,writer_idx), segmentation_output_blend.transpose(2,0,1), epoch)
1000         elif (output_type in ('depth', 'disparity')):
1001             depth_chanidx = 0
1002             output_writer.add_image(prefix+'Task{}_{}_GT_Color_Visualization/{}'.format(task_idx,output_type,writer_idx), xnn.utils.tensor2array(target[b_index][depth_chanidx].cpu(), max_value=args.max_depth, colormap=args.viz_colormap).transpose(2,0,1), epoch)
1003             if not args.sparse:
1004                 output_writer.add_image(prefix + 'Task{}_{}_GT_ColorBlend_Visualization/{}'.format(task_idx, output_type, writer_idx), xnn.utils.tensor2array(target[b_index][depth_chanidx].cpu(), max_value=args.max_depth, colormap=args.viz_colormap, input_blend=input_image).transpose(2, 0, 1), epoch)
1005             #
1006             output_writer.add_image(prefix+'Task{}_{}_Output_Color_Visualization/{}'.format(task_idx,output_type,writer_idx), xnn.utils.tensor2array(output.data[b_index][depth_chanidx].cpu(), max_value=args.max_depth, colormap=args.viz_colormap).transpose(2,0,1), epoch)
1007             output_writer.add_image(prefix + 'Task{}_{}_Output_ColorBlend_Visualization/{}'.format(task_idx, output_type, writer_idx),xnn.utils.tensor2array(output.data[b_index][depth_chanidx].cpu(), max_value=args.max_depth, colormap=args.viz_colormap, input_blend=input_image).transpose(2, 0, 1), epoch)
1008         elif (output_type == 'flow'):
1009             max_value_flow = 10.0 # only for visualization
1010             output_writer.add_image(prefix+'Task{}_{}_GT/{}'.format(task_idx,output_type,writer_idx), xnn.utils.flow2hsv(target[b_index][:2].cpu().numpy(), max_value=max_value_flow).transpose(2,0,1), epoch)
1011             output_writer.add_image(prefix+'Task{}_{}_Output/{}'.format(task_idx,output_type,writer_idx), xnn.utils.flow2hsv(output.data[b_index][:2].cpu().numpy(), max_value=max_value_flow).transpose(2,0,1), epoch)
1012         elif (output_type == 'interest_pt'):
1013             score_chanidx = 0
1014             target_score_to_write = target[b_index][score_chanidx].cpu()
1015             output_score_to_write = output.data[b_index][score_chanidx].cpu()
1016             
1017             #if score is learnt as zero mean add offset to make it [0-255]
1018             if args.make_score_zero_mean:
1019                 # target_score_to_write!=0 : value 0 indicates GT unavailble. Leave them to be 0.
1020                 target_score_to_write[target_score_to_write!=0] += 128.0
1021                 output_score_to_write += 128.0
1023             max_value_score = float(torch.max(target_score_to_write)) #0.002
1024             output_writer.add_image(prefix+'Task{}_{}_GT_Bone_Visualization/{}'.format(task_idx,output_type,writer_idx), xnn.utils.tensor2array(target_score_to_write, max_value=max_value_score, colormap='bone').transpose(2,0,1), epoch)
1025             output_writer.add_image(prefix+'Task{}_{}_Output_Bone_Visualization/{}'.format(task_idx,output_type,writer_idx), xnn.utils.tensor2array(output_score_to_write, max_value=max_value_score, colormap='bone').transpose(2,0,1), epoch)
1026         #
1028 def print_conf_matrix(conf_matrix = [], en = False):
1029     if not en:
1030         return
1031     num_rows = conf_matrix.shape[0]
1032     num_cols = conf_matrix.shape[1]
1033     print("-"*64)
1034     num_ele = 1
1035     for r_idx in range(num_rows):
1036         print("\n")
1037         for c_idx in range(0,num_cols,num_ele):
1038             print(conf_matrix[r_idx][c_idx:c_idx+num_ele], end="")
1039     print("\n")
1040     print("-" * 64)
1042 def compute_task_objectives(args, objective_fns, input_var, task_outputs, task_targets, task_mults=None, 
1043   task_offsets=None, loss_mult_factors=None, get_confusion_matrix = False):
1044   
1045     ##########################
1046     objective_total = torch.zeros_like(task_outputs[0].view(-1)[0])
1047     objective_list = []
1048     objective_list_orig = []
1049     objective_names = []
1050     objective_types = []
1051     for task_idx, task_objectives in enumerate(objective_fns):
1052         output_type = args.model_config.output_type[task_idx]
1053         objective_sum_value = torch.zeros_like(task_outputs[task_idx].view(-1)[0])
1054         objective_sum_name = ''
1055         objective_sum_type = ''
1057         task_mult = task_mults[task_idx] if task_mults is not None else 1.0
1058         task_offset = task_offsets[task_idx] if task_offsets is not None else 0.0
1060         for oidx, objective_fn in enumerate(task_objectives):
1061             objective_batch = objective_fn(input_var, task_outputs[task_idx], task_targets[task_idx])
1062             objective_batch = objective_batch.mean() if isinstance(objective_fn, torch.nn.DataParallel) else objective_batch
1063             objective_name = objective_fn.info()['name']
1064             objective_type = objective_fn.info()['is_avg']
1065             if get_confusion_matrix:
1066                 confusion_matrix = objective_fn.info()['confusion_matrix']
1068             loss_mult = loss_mult_factors[task_idx][oidx] if (loss_mult_factors is not None) else 1.0
1069             # --
1070             objective_batch_not_nan = (objective_batch if not torch.isnan(objective_batch) else 0.0)
1071             objective_sum_value = objective_batch_not_nan*loss_mult + objective_sum_value
1072             objective_sum_name += (objective_name if (objective_sum_name == '') else ('+' + objective_name))
1073             assert (objective_sum_type == '' or objective_sum_type == objective_type), 'metric types (avg/val) for a given task should match'
1074             objective_sum_type = objective_type
1076         objective_list.append(objective_sum_value)
1077         objective_list_orig.append(objective_sum_value)
1078         objective_names.append(objective_sum_name)
1079         objective_types.append(objective_sum_type)
1081         objective_total = objective_sum_value*task_mult + task_offset + objective_total
1083     return_list = [objective_total, objective_list, objective_names, objective_types, objective_list_orig]
1084     if get_confusion_matrix:
1085         return_list.append(confusion_matrix)
1087     return return_list 
1090 def save_checkpoint(args, save_path, model, checkpoint_dict, is_best, filename='checkpoint.pth.tar'):
1091     torch.save(checkpoint_dict, os.path.join(save_path,filename))
1092     if is_best:
1093         shutil.copyfile(os.path.join(save_path,filename), os.path.join(save_path,'model_best.pth.tar'))
1094     #
1095     if args.generate_onnx:
1096         write_onnx_model(args, model, save_path, name='checkpoint.onnx')
1097         if is_best:
1098             write_onnx_model(args, model, save_path, name='model_best.onnx')
1099     #
1102 def get_epoch_size(args, loader, args_epoch_size):
1103     if args_epoch_size == 0:
1104         epoch_size = len(loader)
1105     elif args_epoch_size < 1:
1106         epoch_size = int(len(loader) * args_epoch_size)
1107     else:
1108         epoch_size = min(len(loader), int(args_epoch_size))
1109     return epoch_size
1112 def get_train_transform(args):
1113     # image normalization can be at the beginning of transforms or at the end
1114     image_mean = np.array(args.image_mean, dtype=np.float32)
1115     image_scale = np.array(args.image_scale, dtype=np.float32)
1116     image_prenorm = vision.transforms.image_transforms.NormalizeMeanScale(mean=image_mean, scale=image_scale) if args.image_prenorm else None
1117     image_postnorm = vision.transforms.image_transforms.NormalizeMeanScale(mean=image_mean, scale=image_scale) if (not image_prenorm) else None
1119     # crop size used only for training
1120     image_train_output_scaling = vision.transforms.image_transforms.Scale(args.rand_resize, target_size=args.rand_output_size, is_flow=args.is_flow) \
1121         if (args.rand_output_size is not None and args.rand_output_size != args.rand_resize) else None
1122     train_transform = vision.transforms.image_transforms.Compose([
1123         image_prenorm,
1124         vision.transforms.image_transforms.AlignImages(),
1125         vision.transforms.image_transforms.MaskTarget(args.target_mask, 0),
1126         vision.transforms.image_transforms.CropRect(args.img_border_crop),
1127         vision.transforms.image_transforms.RandomRotate(args.transform_rotation, is_flow=args.is_flow) if args.transform_rotation else None,
1128         vision.transforms.image_transforms.RandomScaleCrop(args.rand_resize, scale_range=args.rand_scale, is_flow=args.is_flow),
1129         vision.transforms.image_transforms.RandomHorizontalFlip(is_flow=args.is_flow),
1130         vision.transforms.image_transforms.RandomCrop(args.rand_crop),
1131         vision.transforms.image_transforms.RandomColor2Gray(is_flow=args.is_flow, random_threshold=0.5) if 'tiad' in args.dataset_name else None,
1132         image_train_output_scaling,
1133         image_postnorm,
1134         vision.transforms.image_transforms.ConvertToTensor()
1135         ])
1136     return train_transform
1139 def get_validation_transform(args):
1140     # image normalization can be at the beginning of transforms or at the end
1141     image_mean = np.array(args.image_mean, dtype=np.float32)
1142     image_scale = np.array(args.image_scale, dtype=np.float32)
1143     image_prenorm = vision.transforms.image_transforms.NormalizeMeanScale(mean=image_mean, scale=image_scale) if args.image_prenorm else None
1144     image_postnorm = vision.transforms.image_transforms.NormalizeMeanScale(mean=image_mean, scale=image_scale) if (not image_prenorm) else None
1146     # prediction is resized to output_size before evaluation.
1147     val_transform = vision.transforms.image_transforms.Compose([
1148         image_prenorm,
1149         vision.transforms.image_transforms.AlignImages(),
1150         vision.transforms.image_transforms.MaskTarget(args.target_mask, 0),
1151         vision.transforms.image_transforms.CropRect(args.img_border_crop),
1152         vision.transforms.image_transforms.Scale(args.img_resize, target_size=args.output_size, is_flow=args.is_flow),
1153         image_postnorm,
1154         vision.transforms.image_transforms.ConvertToTensor()
1155         ])
1156     return val_transform
1159 def get_transforms(args):
1160     # Provision to train with val transform - provide rand_scale as (0, 0)
1161     # Fixing the train-test resolution discrepancy, https://arxiv.org/abs/1906.06423
1162     always_use_val_transform = (args.rand_scale[0] == 0)
1163     train_transform = get_validation_transform(args) if always_use_val_transform else get_train_transform(args)
1164     val_transform = get_validation_transform(args)
1165     return train_transform, val_transform
1168 def _upsample_impl(tensor, output_size, upsample_mode):
1169     # upsample of long tensor is not supported currently. covert to float, just to avoid error.
1170     # we can do thsi only in the case of nearest mode, otherwise output will have invalid values.
1171     convert_to_float = False
1172     if isinstance(tensor, (torch.LongTensor,torch.cuda.LongTensor)):
1173         convert_to_float = True
1174         original_dtype = tensor.dtype
1175         tensor = tensor.float()
1176         upsample_mode = 'nearest'
1178     dim_added = False
1179     if len(tensor.shape) < 4:
1180         tensor = tensor[np.newaxis,...]
1181         dim_added = True
1183     if (tensor.size()[-2:] != output_size):
1184         tensor = torch.nn.functional.interpolate(tensor, output_size, mode=upsample_mode)
1186     if dim_added:
1187         tensor = tensor[0,...]
1189     if convert_to_float:
1190         tensor = tensor.long() #tensor.astype(original_dtype)
1192     return tensor
1195 def upsample_tensors(tensors, output_sizes, upsample_mode):
1196     if isinstance(tensors, (list,tuple)):
1197         for tidx, tensor in enumerate(tensors):
1198             tensors[tidx] = _upsample_impl(tensor, output_sizes[tidx][-2:], upsample_mode)
1199         #
1200     else:
1201         tensors = _upsample_impl(tensors, output_sizes[0][-2:], upsample_mode)
1202     return tensors
1204 #print IoU for each class
1205 def print_class_iou(args = None, confusion_matrix = None, task_idx = 0):    
1206     n_classes = args.model_config.output_channels[task_idx]
1207     [accuracy, mean_iou, iou, f1_score] = compute_accuracy(args, confusion_matrix, n_classes)
1208     print("\n Class IoU: [", end = "")
1209     for class_iou in iou:
1210         print("{:0.3f}".format(class_iou), end=",")
1211     print("]")    
1213 if __name__ == '__main__':
1214     train_args = get_config()
1215     main(train_args)