605cef2d44ae3464aa7b8e98f549fcdb59211501
[jacinto-ai/pytorch-jacinto-ai-devkit.git] / modules / pytorch_jacinto_ai / engine / train_pixel2pixel.py
1 import os
2 import shutil
3 import time
4 import math
5 import copy
7 import torch
8 import torch.nn.parallel
9 import torch.backends.cudnn as cudnn
10 import torch.optim
11 import torch.utils.data
12 import torch.onnx
13 import onnx
15 import datetime
16 from tensorboardX import SummaryWriter
17 import numpy as np
18 import random
19 import cv2
20 from colorama import Fore
21 import progiter
22 from packaging import version
23 import warnings
25 from .. import xnn
26 from .. import vision
27 from . infer_pixel2pixel import compute_accuracy
30 ##################################################
31 warnings.filterwarnings('ignore', category=torch.jit.TracerWarning)
33 ##################################################
34 def get_config():
35     args = xnn.utils.ConfigNode()
37     args.dataset_config = xnn.utils.ConfigNode()
38     args.dataset_config.split_name = 'val'
39     args.dataset_config.max_depth_bfr_scaling = 80
40     args.dataset_config.depth_scale = 1
41     args.dataset_config.train_depth_log = 1
42     args.use_semseg_for_depth = False
44     # model config
45     args.model_config = xnn.utils.ConfigNode()
46     args.model_config.output_type = ['segmentation']   # the network is used to predict flow or depth or sceneflow
47     args.model_config.output_channels = None            # number of output channels
48     args.model_config.input_channels = None             # number of input channels
49     args.model_config.output_range = None               # max range of output
50     args.model_config.num_decoders = None               # number of decoders to use. [options: 0, 1, None]
51     args.model_config.freeze_encoder = False            # do not update encoder weights
52     args.model_config.freeze_decoder = False            # do not update decoder weights
53     args.model_config.multi_task_type = 'learned'       # find out loss multiplier by learning, choices=[None, 'learned', 'uncertainty', 'grad_norm', 'dwa_grad_norm']
54     args.model_config.target_input_ratio = 1            # Keep target size same as input size
56     args.model = None                                   # the model itself can be given from ouside
57     args.model_name = 'deeplabv2lite_mobilenetv2'       # model architecture, overwritten if pretrained is specified
58     args.dataset_name = 'cityscapes_segmentation'       # dataset type
59     args.transforms = None                              # the transforms itself can be given from outside
61     args.data_path = './data/cityscapes'                # 'path to dataset'
62     args.save_path = None                               # checkpoints save path
63     args.phase = 'training'                             # training/calibration/validation
64     args.date = None                                    # date to add to save path. if this is None, current date will be added.
66     args.logger = None                                  # logger stream to output into
67     args.show_gpu_usage = False                         # Shows gpu usage at the begining of each training epoch
69     args.split_file = None                              # train_val split file
70     args.split_files = None                             # split list files. eg: train.txt val.txt
71     args.split_value = None                             # test_val split proportion (between 0 (only test) and 1 (only train))
73     args.solver = 'adam'                                # solver algorithms, choices=['adam','sgd']
74     args.scheduler = 'step'                             # scheduler algorithms, choices=['step','poly', 'cosine']
75     args.workers = 8                                    # number of data loading workers
77     args.epochs = 250                                   # number of total epochs to run
78     args.start_epoch = 0                                # manual epoch number (useful on restarts)
80     args.epoch_size = 0                                 # manual epoch size (will match dataset size if not specified)
81     args.epoch_size_val = 0                             # manual epoch size (will match dataset size if not specified)
82     args.batch_size = 12                                # mini_batch size
83     args.total_batch_size = None                        # accumulated batch size. total_batch_size = batch_size*iter_size
84     args.iter_size = 1                                  # iteration size. total_batch_size = batch_size*iter_size
86     args.lr = 1e-4                                      # initial learning rate
87     args.lr_clips = None                                 # use args.lr itself if it is None
88     args.lr_calib = 0.1                                 # lr for bias calibration
89     args.warmup_epochs = 5                              # number of epochs to warmup
91     args.momentum = 0.9                                 # momentum for sgd, alpha parameter for adam
92     args.beta = 0.999                                   # beta parameter for adam
93     args.weight_decay = 1e-4                            # weight decay
94     args.bias_decay = None                              # bias decay
96     args.sparse = True                                  # avoid invalid/ignored target pixels from loss computation, use NEAREST for interpolation
98     args.tensorboard_num_imgs = 5                       # number of imgs to display in tensorboard
99     args.pretrained = None                              # path to pre_trained model
100     args.resume = None                                  # path to latest checkpoint (default: none)
101     args.no_date = False                                # don\'t append date timestamp to folder
102     args.print_freq = 100                               # print frequency (default: 100)
104     args.milestones = (100, 200)                        # epochs at which learning rate is divided by 2
106     args.losses = ['segmentation_loss']                 # loss functions to mchoices=['step','poly', 'cosine'],loss multiplication factor')
107     args.metrics = ['segmentation_metrics']  # metric/measurement/error functions for train/validation
108     args.multi_task_factors = None                      # loss mult factors
109     args.class_weights = None                           # class weights
111     args.loss_mult_factors = None                       # fixed loss mult factors - per loss - not: this is different from multi_task_factors (which is per task)
113     args.multistep_gamma = 0.5                          # steps for step scheduler
114     args.polystep_power = 1.0                           # power for polynomial scheduler
116     args.rand_seed = 1                                  # random seed
117     args.img_border_crop = None                         # image border crop rectangle. can be relative or absolute
118     args.target_mask = None                              # mask rectangle. can be relative or absolute. last value is the mask value
120     args.rand_resize = None                             # random image size to be resized to during training
121     args.rand_output_size = None                        # output size to be resized to during training
122     args.rand_scale = (1.0, 2.0)                        # random scale range for training
123     args.rand_crop = None                               # image size to be cropped to
125     args.img_resize = None                              # image size to be resized to during evaluation
126     args.output_size = None                             # target output size to be resized to
128     args.count_flops = True                             # count flops and report
130     args.shuffle = True                                 # shuffle or not
131     args.shuffle_val = False                            # shuffle val dataset or not
133     args.transform_rotation = 0.                        # apply rotation augumentation. value is rotation in degrees. 0 indicates no rotation
134     args.is_flow = None                                 # whether entries in images and targets lists are optical flow or not
136     args.upsample_mode = 'bilinear'                     # upsample mode to use, choices=['nearest','bilinear']
138     args.image_prenorm = True                           # whether normalization is done before all other the transforms
139     args.image_mean = (128.0,)                          # image mean for input image normalization
140     args.image_scale = (1.0 / (0.25 * 256),)            # image scaling/mult for input iamge normalization
142     args.max_depth = 80                                 # maximum depth to be used for visualization
144     args.pivot_task_idx = 0                             # task id to select best model
146     args.parallel_model = True                          # Usedata parallel for model
147     args.parallel_criterion = True                      # Usedata parallel for loss and metric
149     args.evaluate_start = True                          # evaluate right at the begining of training or not
150     args.generate_onnx = True                           # apply quantized inference or not
151     args.print_model = False                            # print the model to text
152     args.run_soon = True                                # To start training after generating configs/models
154     args.quantize = False                               # apply quantized inference or not
155     #args.model_surgery = None                           # replace activations with PAct2 activation module. Helpful in quantized training.
156     args.bitwidth_weights = 8                           # bitwidth for weights
157     args.bitwidth_activations = 8                       # bitwidth for activations
158     args.histogram_range = True                         # histogram range for calibration
159     args.bias_calibration = True                        # apply bias correction during quantized inference calibration
160     args.per_channel_q = False                          # apply separate quantizion factor for each channel in depthwise or not
162     args.save_mod_files = False                         # saves modified files after last commit. Also  stores commit id.
163     args.make_score_zero_mean = False                   # make score zero mean while learning
164     args.no_q_for_dws_layer_idx = 0                     # no_q_for_dws_layer_idx
166     args.viz_colormap = 'rainbow'                       # colormap for tensorboard: 'rainbow', 'plasma', 'magma', 'bone'
168     args.freeze_bn = False                              # freeze the statistics of bn
169     args.tensorboard_enable = True                      # en/disable of TB writing
170     args.print_train_class_iou = False
171     args.print_val_class_iou = False
173     return args
176 # ################################################
177 # to avoid hangs in data loader with multi threads
178 # this was observed after using cv2 image processing functions
179 # https://github.com/pytorch/pytorch/issues/1355
180 cv2.setNumThreads(0)
182 # ################################################
183 def main(args):
184     # ensure pytorch version is 1.2 or higher
185     assert version.parse(torch.__version__) >= version.parse('1.1'), \
186         'torch version must be 1.1 or higher, due to the change in scheduler.step() and optimiser.step() call order'
188     assert (not hasattr(args, 'evaluate')), 'args.evaluate is deprecated. use args.phase=training or calibration or validation'
189     assert is_valid_phase(args.phase), f'invalid phase {args.phase}'
190     assert not hasattr(args, 'model_surgery'), 'the argument model_surgery is deprecated, it is not needed now - remove it'
192     if (args.phase == 'validation' and args.bias_calibration):
193         args.bias_calibration = False
194         warnings.warn('switching off bias calibration in validation')
195     #
197     #################################################
198     args.rand_resize = args.img_resize if args.rand_resize is None else args.rand_resize
199     args.rand_crop = args.img_resize if args.rand_crop is None else args.rand_crop
200     args.output_size = args.img_resize if args.output_size is None else args.output_size
201     # resume has higher priority
202     args.pretrained = None if (args.resume is not None) else args.pretrained
204     if args.save_path is None:
205         save_path = get_save_path(args)
206     else:
207         save_path = args.save_path
208     #
209     if not os.path.exists(save_path):
210         os.makedirs(save_path)
212     if args.save_mod_files:
213         #store all the files after the last commit.
214         mod_files_path = save_path+'/mod_files'
215         os.makedirs(mod_files_path)
216         
217         cmd = "git ls-files --modified | xargs -i cp {} {}".format("{}", mod_files_path)
218         print("cmd:", cmd)    
219         os.system(cmd)
221         #stoe last commit id. 
222         cmd = "git log -n 1  >> {}".format(mod_files_path + '/commit_id.txt')
223         print("cmd:", cmd)    
224         os.system(cmd)
226     #################################################
227     if args.logger is None:
228         log_file = os.path.splitext(os.path.basename(__file__))[0] + '.log'
229         args.logger = xnn.utils.TeeLogger(filename=os.path.join(save_path,log_file))
231     #################################################
232     # global settings. rand seeds for repeatability
233     random.seed(args.rand_seed)
234     np.random.seed(args.rand_seed)
235     torch.manual_seed(args.rand_seed)
236     torch.cuda.manual_seed(args.rand_seed)
238     ################################
239     # args check and config
240     if args.iter_size != 1 and args.total_batch_size is not None:
241         warnings.warn("only one of --iter_size or --total_batch_size must be set")
242     #
243     if args.total_batch_size is not None:
244         args.iter_size = args.total_batch_size//args.batch_size
245     else:
246         args.total_batch_size = args.batch_size*args.iter_size
248     #################################################
249     # set some global flags and initializations
250     # keep it in args for now - although they don't belong here strictly
251     # using pin_memory is seen to cause issues, especially when when lot of memory is used.
252     args.use_pinned_memory = False
253     args.n_iter = 0
254     args.best_metric = -1
255     cudnn.benchmark = True
256     # torch.autograd.set_detect_anomaly(True)
258     ################################
259     # reset character color, in case it is different
260     print('{}'.format(Fore.RESET))
261     # print everything for log
262     print('=> args: {}'.format(args))
263     print('\n'.join("%s: %s" % item for item in sorted(vars(args).items())))
265     print('=> will save everything to {}'.format(save_path))
267     #################################################
268     train_writer = SummaryWriter(os.path.join(save_path,'train')) if args.tensorboard_enable else None
269     val_writer = SummaryWriter(os.path.join(save_path,'val')) if args.tensorboard_enable else None
270     transforms = get_transforms(args) if args.transforms is None else args.transforms
271     assert isinstance(transforms, (list,tuple)) and len(transforms) == 2, 'incorrect transforms were given'
273     print("=> fetching images in '{}'".format(args.data_path))
274     split_arg = args.split_file if args.split_file else (args.split_files if args.split_files else args.split_value)
275     train_dataset, val_dataset = vision.datasets.pixel2pixel.__dict__[args.dataset_name](args.dataset_config, args.data_path, split=split_arg, transforms=transforms)
277     #################################################
278     train_sampler = None
279     val_sampler = None
280     print('=> {} samples found, {} train samples and {} test samples '.format(len(train_dataset)+len(val_dataset),
281         len(train_dataset), len(val_dataset)))
282     train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=args.batch_size,
283         num_workers=args.workers, pin_memory=args.use_pinned_memory, sampler=train_sampler, shuffle=args.shuffle)
285     val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=args.batch_size,
286         num_workers=args.workers, pin_memory=args.use_pinned_memory, sampler=val_sampler, shuffle=args.shuffle_val)
288     #################################################
289     if (args.model_config.input_channels is None):
290         args.model_config.input_channels = (3,)
291         print("=> input channels is not given - setting to {}".format(args.model_config.input_channels))
293     if (args.model_config.output_channels is None):
294         if ('num_classes' in dir(train_dataset)):
295             args.model_config.output_channels = train_dataset.num_classes()
296         else:
297             args.model_config.output_channels = (2 if args.model_config.output_type == 'flow' else args.model_config.output_channels)
298             xnn.utils.print_yellow("=> output channels is not given - setting to {} - not sure to work".format(args.model_config.output_channels))
299         #
300         if not isinstance(args.model_config.output_channels,(list,tuple)):
301             args.model_config.output_channels = [args.model_config.output_channels]
303     if (args.class_weights is None) and ('class_weights' in dir(train_dataset)):
304         args.class_weights = train_dataset.class_weights()
305         if not isinstance(args.class_weights, (list,tuple)):
306             args.class_weights = [args.class_weights]
307         #
308         print("=> class weights available for dataset: {}".format(args.class_weights))
310     #################################################
311     pretrained_data = None
312     model_surgery_quantize = False
313     if args.pretrained and args.pretrained != "None":
314         if isinstance(args.pretrained, dict):
315             pretrained_data = args.pretrained
316         else:
317             if args.pretrained.startswith('http://') or args.pretrained.startswith('https://'):
318                 pretrained_file = vision.datasets.utils.download_url(args.pretrained, './data/downloads')
319             else:
320                 pretrained_file = args.pretrained
321             #
322             print(f'=> using pre-trained weights from: {args.pretrained}')
323             pretrained_data = torch.load(pretrained_file)
324         #
325         model_surgery_quantize = pretrained_data['quantize'] if 'quantize' in pretrained_data else False
326     #
328     #################################################
329     # create model
330     if args.model is not None:
331         model, change_names_dict = args.model if isinstance(args.model, (list, tuple)) else (args.model, None)
332         assert isinstance(model, torch.nn.Module), 'args.model, if provided must be a valid torch.nn.Module'
333     else:
334         xnn.utils.print_yellow("=> creating model '{}'".format(args.model_name))
335         model = vision.models.pixel2pixel.__dict__[args.model_name](args.model_config)
336         # check if we got the model as well as parameters to change the names in pretrained
337         model, change_names_dict = model if isinstance(model, (list,tuple)) else (model,None)
338     #
340     if args.quantize:
341         # dummy input is used by quantized models to analyze graph
342         is_cuda = next(model.parameters()).is_cuda
343         dummy_input = create_rand_inputs(args, is_cuda=is_cuda)
344         #
345         if 'training' in args.phase:
346             model = xnn.quantize.QuantTrainModule(model, per_channel_q=args.per_channel_q,
347                         histogram_range=args.histogram_range, bitwidth_weights=args.bitwidth_weights, bitwidth_activations=args.bitwidth_activations, dummy_input=dummy_input)
348         elif 'calibration' in args.phase:
349             model = xnn.quantize.QuantCalibrateModule(model, per_channel_q=args.per_channel_q,
350                         bitwidth_weights=args.bitwidth_weights, bitwidth_activations=args.bitwidth_activations,
351                         histogram_range=args.histogram_range, bias_calibration=args.bias_calibration, lr_calib=args.lr_calib, dummy_input=dummy_input)
352         elif 'validation' in args.phase:
353             # Note: bias_calibration is not emabled
354             model = xnn.quantize.QuantTestModule(model, per_channel_q=args.per_channel_q,
355                         bitwidth_weights=args.bitwidth_weights, bitwidth_activations=args.bitwidth_activations,
356                         histogram_range=args.histogram_range, model_surgery_quantize=model_surgery_quantize,
357                         dummy_input=dummy_input)
358         else:
359             assert False, f'invalid phase {args.phase}'
360     #
362     # load pretrained model
363     if pretrained_data is not None:
364         xnn.utils.load_weights(get_model_orig(model), pretrained=pretrained_data, change_names_dict=change_names_dict)
365     #
367     #################################################
368     if args.count_flops:
369         count_flops(args, model)
371     #################################################
372     if args.generate_onnx and (any(args.phase in p for p in ('training','calibration')) or (args.run_soon == False)):
373         write_onnx_model(args, get_model_orig(model), save_path)
374     #
376     #################################################
377     if args.print_model:
378         print(model)
379         print('\n')
380     else:
381         args.logger.debug(str(model))
382         args.logger.debug('\n')
384     #################################################
385     if (not args.run_soon):
386         print("Training not needed for now")
387         close(args)
388         exit()
390     #################################################
391     # multi gpu mode does not work for calibration/training for quantization
392     # so use it only when args.quantize is False
393     if args.parallel_model and ((not args.quantize)):
394         model = torch.nn.DataParallel(model)
396     #################################################
397     model = model.cuda()
399     #################################################
400     # for help in debug/print
401     for name, module in model.named_modules():
402         module.name = name
404     #################################################
405     args.loss_modules = copy.deepcopy(args.losses)
406     for task_dx, task_losses in enumerate(args.losses):
407         for loss_idx, loss_fn in enumerate(task_losses):
408             kw_args = {}
409             loss_args = vision.losses.__dict__[loss_fn].args()
410             for arg in loss_args:
411                 if arg == 'weight' and (args.class_weights is not None):
412                     kw_args.update({arg:args.class_weights[task_dx]})
413                 elif arg == 'num_classes':
414                     kw_args.update({arg:args.model_config.output_channels[task_dx]})
415                 elif arg == 'sparse':
416                     kw_args.update({arg:args.sparse})
417                 #
418             #
419             loss_fn_raw = vision.losses.__dict__[loss_fn](**kw_args)
420             if args.parallel_criterion:
421                 loss_fn = torch.nn.DataParallel(loss_fn_raw).cuda() if args.parallel_criterion else loss_fn_raw.cuda()
422                 loss_fn.info = loss_fn_raw.info
423                 loss_fn.clear = loss_fn_raw.clear
424             else:
425                 loss_fn = loss_fn_raw.cuda()
426             #
427             args.loss_modules[task_dx][loss_idx] = loss_fn
428     #
430     args.metric_modules = copy.deepcopy(args.metrics)
431     for task_dx, task_metrics in enumerate(args.metrics):
432         for midx, metric_fn in enumerate(task_metrics):
433             kw_args = {}
434             loss_args = vision.losses.__dict__[metric_fn].args()
435             for arg in loss_args:
436                 if arg == 'weight':
437                     kw_args.update({arg:args.class_weights[task_dx]})
438                 elif arg == 'num_classes':
439                     kw_args.update({arg:args.model_config.output_channels[task_dx]})
440                 elif arg == 'sparse':
441                     kw_args.update({arg:args.sparse})
443             metric_fn_raw = vision.losses.__dict__[metric_fn](**kw_args)
444             if args.parallel_criterion:
445                 metric_fn = torch.nn.DataParallel(metric_fn_raw).cuda()
446                 metric_fn.info = metric_fn_raw.info
447                 metric_fn.clear = metric_fn_raw.clear
448             else:
449                 metric_fn = metric_fn_raw.cuda()
450             #
451             args.metric_modules[task_dx][midx] = metric_fn
452     #
454     #################################################
455     if args.phase=='validation':
456         with torch.no_grad():
457             validate(args, val_dataset, val_loader, model, 0, val_writer)
458         #
459         close(args)
460         return
462     #################################################
463     assert(args.solver in ['adam', 'sgd'])
464     print('=> setting {} solver'.format(args.solver))
465     if args.lr_clips is not None:
466         learning_rate_clips = args.lr_clips if 'training' in args.phase else 0.0
467         clips_decay = args.bias_decay if (args.bias_decay is not None and args.bias_decay != 0.0) else args.weight_decay
468         clips_params = [p for n,p in model.named_parameters() if 'clips' in n]
469         other_params = [p for n,p in model.named_parameters() if 'clips' not in n]
470         param_groups = [{'params': clips_params, 'weight_decay': clips_decay, 'lr': learning_rate_clips},
471                         {'params': other_params, 'weight_decay': args.weight_decay}]
472     else:
473         param_groups = [{'params': filter(lambda p: p.requires_grad, model.parameters()), 'weight_decay': args.weight_decay}]
474     #
476     learning_rate = args.lr if ('training'in args.phase) else 0.0
477     if args.solver == 'adam':
478         optimizer = torch.optim.Adam(param_groups, learning_rate, betas=(args.momentum, args.beta))
479     elif args.solver == 'sgd':
480         optimizer = torch.optim.SGD(param_groups, learning_rate, momentum=args.momentum)
481     else:
482         raise ValueError('Unknown optimizer type{}'.format(args.solver))
483     #
485     #################################################
486     epoch_size = get_epoch_size(args, train_loader, args.epoch_size)
487     max_iter = args.epochs * epoch_size
488     scheduler = xnn.optim.lr_scheduler.SchedulerWrapper(args.scheduler, optimizer, args.epochs, args.start_epoch, \
489                                                             args.warmup_epochs, max_iter=max_iter, polystep_power=args.polystep_power,
490                                                             milestones=args.milestones, multistep_gamma=args.multistep_gamma)
492     # optionally resume from a checkpoint
493     if args.resume:
494         if not os.path.isfile(args.resume):
495             print("=> no checkpoint found at '{}'".format(args.resume))        
496         else:
497             print("=> loading checkpoint '{}'".format(args.resume))
499         checkpoint = torch.load(args.resume)
500         model = xnn.utils.load_weights(model, checkpoint)
501             
502         if args.start_epoch == 0:
503             args.start_epoch = checkpoint['epoch']
504         
505         if 'best_metric' in list(checkpoint.keys()):    
506             args.best_metric = checkpoint['best_metric']
508         if 'optimizer' in list(checkpoint.keys()):  
509             optimizer.load_state_dict(checkpoint['optimizer'])
511         if 'scheduler' in list(checkpoint.keys()):
512             scheduler.load_state_dict(checkpoint['scheduler'])
514         if 'multi_task_factors' in list(checkpoint.keys()):
515             args.multi_task_factors = checkpoint['multi_task_factors']
517         print("=> loaded checkpoint '{}' (epoch {})".format(args.resume, checkpoint['epoch']))
519     #################################################
520     if args.evaluate_start:
521         with torch.no_grad():
522             validate(args, val_dataset, val_loader, model, args.start_epoch, val_writer)
524     for epoch in range(args.start_epoch, args.epochs):
525         if train_sampler:
526             train_sampler.set_epoch(epoch)
527         if val_sampler:
528             val_sampler.set_epoch(epoch)
530         # train for one epoch
531         train(args, train_dataset, train_loader, model, optimizer, epoch, train_writer, scheduler)
533         # evaluate on validation set
534         with torch.no_grad():
535             val_metric, metric_name = validate(args, val_dataset, val_loader, model, epoch, val_writer)
537         if args.best_metric < 0:
538             args.best_metric = val_metric
540         if "iou" in metric_name.lower() or "acc" in metric_name.lower():
541             is_best = val_metric >= args.best_metric
542             args.best_metric = max(val_metric, args.best_metric)
543         elif "error" in metric_name.lower() or "diff" in metric_name.lower() or "norm" in metric_name.lower() \
544                 or "loss" in metric_name.lower() or "outlier" in metric_name.lower():
545             is_best = val_metric <= args.best_metric
546             args.best_metric = min(val_metric, args.best_metric)
547         else:
548             raise ValueError("Metric is not known. Best model could not be saved.")
549         #
551         checkpoint_dict = { 'epoch': epoch + 1, 'model_name': args.model_name,
552                             'state_dict': get_model_orig(model).state_dict(),
553                             'optimizer': optimizer.state_dict(),
554                             'scheduler': scheduler.state_dict(),
555                             'best_metric': args.best_metric,
556                             'multi_task_factors': args.multi_task_factors,
557                             'quantize' : args.quantize}
559         save_checkpoint(args, save_path, get_model_orig(model), checkpoint_dict, is_best)
561         if args.tensorboard_enable:
562             train_writer.file_writer.flush()
563             val_writer.file_writer.flush()
565         # adjust the learning rate using lr scheduler
566         if 'training' in args.phase:
567             scheduler.step()
568         #
569     #
571     # close and cleanup
572     close(args)
575 ###################################################################
576 def is_valid_phase(phase):
577     phases = ('training', 'calibration', 'validation')
578     return any(p in phase for p in phases)
581 ###################################################################
582 def train(args, train_dataset, train_loader, model, optimizer, epoch, train_writer, scheduler):
583     batch_time = xnn.utils.AverageMeter()
584     data_time = xnn.utils.AverageMeter()
585     # if the loss/ metric is already an average, no need to further average
586     avg_loss = [xnn.utils.AverageMeter(print_avg=(not task_loss[0].info()['is_avg'])) for task_loss in args.loss_modules]
587     avg_loss_orig = [xnn.utils.AverageMeter(print_avg=(not task_loss[0].info()['is_avg'])) for task_loss in args.loss_modules]
588     avg_metric = [xnn.utils.AverageMeter(print_avg=(not task_metric[0].info()['is_avg'])) for task_metric in args.metric_modules]
589     epoch_size = get_epoch_size(args, train_loader, args.epoch_size)
591     ##########################
592     # switch to train mode
593     model.train()
594     if args.freeze_bn:
595         xnn.utils.freeze_bn(model)
596     #
598     ##########################
599     for task_dx, task_losses in enumerate(args.loss_modules):
600         for loss_idx, loss_fn in enumerate(task_losses):
601             loss_fn.clear()
602     for task_dx, task_metrics in enumerate(args.metric_modules):
603         for midx, metric_fn in enumerate(task_metrics):
604             metric_fn.clear()
606     progress_bar = progiter.ProgIter(np.arange(epoch_size), chunksize=1)
607     metric_name = "Metric"
608     metric_ctx = [None] * len(args.metric_modules)
609     end_time = time.time()
610     writer_idx = 0
611     last_update_iter = -1
613     # change color to yellow for calibration
614     progressbar_color = (Fore.YELLOW if (('calibration' in args.phase) or ('training' in args.phase and args.quantize)) else Fore.WHITE)
615     print('{}'.format(progressbar_color), end='')
617     ##########################
618     for iter, (inputs, targets) in enumerate(train_loader):
619         # measure data loading time
620         data_time.update(time.time() - end_time)
622         lr = scheduler.get_lr()[0]
624         input_list = [img.cuda() for img in inputs]
625         target_list = [tgt.cuda(non_blocking=True) for tgt in targets]
626         target_sizes = [tgt.shape for tgt in target_list]
627         batch_size_cur = target_sizes[0][0]
629         ##########################
630         # compute output
631         task_outputs = model(input_list)
633         task_outputs = task_outputs if isinstance(task_outputs,(list,tuple)) else [task_outputs]
634         # upsample output to target resolution
635         task_outputs = upsample_tensors(task_outputs, target_sizes, args.upsample_mode)
637         if args.model_config.multi_task_type is not None and len(args.model_config.output_channels) > 1:
638             args.multi_task_factors, args.multi_task_offsets = xnn.layers.get_loss_scales(model)
639         else:
640             args.multi_task_factors = None
641             args.multi_task_offsets = None
643         loss_total, loss_list, loss_names, loss_types, loss_list_orig = \
644             compute_task_objectives(args, args.loss_modules, input_list, task_outputs, target_list,
645                          task_mults=args.multi_task_factors, task_offsets=args.multi_task_offsets,
646                          loss_mult_factors=args.loss_mult_factors)
648         if args.print_train_class_iou:
649             metric_total, metric_list, metric_names, metric_types, _, confusion_matrix = \
650                 compute_task_objectives(args, args.metric_modules, input_list, task_outputs, target_list, 
651                 get_confusion_matrix=args.print_train_class_iou)
652         else:        
653             metric_total, metric_list, metric_names, metric_types, _ = \
654                 compute_task_objectives(args, args.metric_modules, input_list, task_outputs, target_list, 
655                 get_confusion_matrix=args.print_train_class_iou)
657         if args.model_config.multi_task_type is not None and len(args.model_config.output_channels) > 1:
658             xnn.layers.set_losses(model, loss_list_orig)
660         if 'training' in args.phase:
661             # zero gradients so that we can accumulate gradients
662             if (iter % args.iter_size) == 0:
663                 optimizer.zero_grad()
665             # accumulate gradients
666             loss_total.backward()
667             # optimization step
668             if ((iter+1) % args.iter_size) == 0:
669                 optimizer.step()
670         #
672         # record loss.
673         for task_idx, task_losses in enumerate(args.loss_modules):
674             avg_loss[task_idx].update(float(loss_list[task_idx].cpu()), batch_size_cur)
675             avg_loss_orig[task_idx].update(float(loss_list_orig[task_idx].cpu()), batch_size_cur)
676             if args.tensorboard_enable:
677                 train_writer.add_scalar('Training/Task{}_{}_Loss_Iter'.format(task_idx,loss_names[task_idx]), float(loss_list[task_idx]), args.n_iter)
678                 if args.model_config.multi_task_type is not None and len(args.model_config.output_channels) > 1:
679                     train_writer.add_scalar('Training/multi_task_Factor_Task{}_{}'.format(task_idx,loss_names[task_idx]), float(args.multi_task_factors[task_idx]), args.n_iter)
681         # record error/accuracy.
682         for task_idx, task_metrics in enumerate(args.metric_modules):
683             avg_metric[task_idx].update(float(metric_list[task_idx].cpu()), batch_size_cur)
685         ##########################
686         if args.tensorboard_enable:
687             write_output(args, 'Training_', epoch_size, iter, epoch, train_dataset, train_writer, input_list, task_outputs, target_list, metric_name, writer_idx)
689         if ((iter % args.print_freq) == 0) or (iter == (epoch_size-1)):
690             output_string = ''
691             for task_idx, task_metrics in enumerate(args.metric_modules):
692                 output_string += '[{}={}]'.format(metric_names[task_idx], str(avg_metric[task_idx]))
694             epoch_str = '{}/{}'.format(epoch + 1, args.epochs)
695             progress_bar.set_description("{}=> {}  ".format(progressbar_color, args.phase))
696             multi_task_factors_print = ['{:.3f}'.format(float(lmf)) for lmf in args.multi_task_factors] if args.multi_task_factors is not None else None
697             progress_bar.set_postfix(Epoch=epoch_str, LR=lr, DataTime=str(data_time), LossMult=multi_task_factors_print, Loss=avg_loss, Output=output_string)
698             progress_bar.update(iter-last_update_iter)
699             last_update_iter = iter
701         args.n_iter += 1
702         end_time = time.time()
703         writer_idx = (writer_idx + 1) % args.tensorboard_num_imgs
705         # add onnx graph to tensorboard
706         # commenting out due to issues in transitioning to pytorch 0.4
707         # (bilinear mode in upsampling causes hang or crash - may be due to align_borders change, nearest is fine)
708         #if epoch == 0 and iter == 0:
709         #    input_zero = torch.zeros(input_var.shape)
710         #    train_writer.add_graph(model, input_zero)
711         #This cache operation slows down tranining  
712         #torch.cuda.empty_cache()
714         if iter >= epoch_size:
715             break
717     if args.print_train_class_iou:
718         print_class_iou(args=args, confusion_matrix=confusion_matrix, task_idx=task_idx)
719         
720     progress_bar.close()
722     # to print a new line - do not provide end=''
723     print('{}'.format(Fore.RESET), end='')
725     if args.tensorboard_enable:
726         for task_idx, task_losses in enumerate(args.loss_modules):
727             train_writer.add_scalar('Training/Task{}_{}_Loss_Epoch'.format(task_idx,loss_names[task_idx]), float(avg_loss[task_idx]), epoch)
729         for task_idx, task_metrics in enumerate(args.metric_modules):
730             train_writer.add_scalar('Training/Task{}_{}_Metric_Epoch'.format(task_idx,metric_names[task_idx]), float(avg_metric[task_idx]), epoch)
732     output_name = metric_names[args.pivot_task_idx]
733     output_metric = float(avg_metric[args.pivot_task_idx])
735     ##########################
736     if args.quantize:
737         def debug_format(v):
738             return ('{:.3f}'.format(v) if v is not None else 'None')
739         #
740         clips_act = [m.get_clips_act()[1] for n,m in model.named_modules() if isinstance(m,xnn.layers.PAct2)]
741         if len(clips_act) > 0:
742             args.logger.debug('\nclips_act : ' + ' '.join(map(debug_format, clips_act)))
743             args.logger.debug('')
744     #
745     return output_metric, output_name
748 ###################################################################
749 def validate(args, val_dataset, val_loader, model, epoch, val_writer):
750     data_time = xnn.utils.AverageMeter()
751     # if the loss/ metric is already an average, no need to further average
752     avg_metric = [xnn.utils.AverageMeter(print_avg=(not task_metric[0].info()['is_avg'])) for task_metric in args.metric_modules]
753     epoch_size = get_epoch_size(args, val_loader, args.epoch_size_val)
755     ##########################
756     # switch to evaluate mode
757     model.eval()
759     ##########################
760     for task_dx, task_metrics in enumerate(args.metric_modules):
761         for midx, metric_fn in enumerate(task_metrics):
762             metric_fn.clear()
764     metric_name = "Metric"
765     end_time = time.time()
766     writer_idx = 0
767     last_update_iter = -1
768     metric_ctx = [None] * len(args.metric_modules)
769     progress_bar = progiter.ProgIter(np.arange(epoch_size), chunksize=1)
771     # change color to green
772     print('{}'.format(Fore.GREEN), end='')
774     ##########################
775     for iter, (inputs, targets) in enumerate(val_loader):
776         data_time.update(time.time() - end_time)
777         input_list = [j.cuda() for j in inputs]
778         target_list = [j.cuda(non_blocking=True) for j in targets]
779         target_sizes = [tgt.shape for tgt in target_list]
780         batch_size_cur = target_sizes[0][0]
782         # compute output
783         task_outputs = model(input_list)
785         task_outputs = task_outputs if isinstance(task_outputs,(list,tuple)) else [task_outputs]
786         task_outputs = upsample_tensors(task_outputs, target_sizes, args.upsample_mode)
787         
788         if args.print_val_class_iou:
789             metric_total, metric_list, metric_names, metric_types, _, confusion_matrix = \
790                 compute_task_objectives(args, args.metric_modules, input_list, task_outputs, target_list, 
791                 get_confusion_matrix = args.print_val_class_iou)
792         else:        
793             metric_total, metric_list, metric_names, metric_types, _ = \
794                 compute_task_objectives(args, args.metric_modules, input_list, task_outputs, target_list, 
795                 get_confusion_matrix = args.print_val_class_iou)
797         # record error/accuracy.
798         for task_idx, task_metrics in enumerate(args.metric_modules):
799             avg_metric[task_idx].update(float(metric_list[task_idx].cpu()), batch_size_cur)
801         if args.tensorboard_enable:
802             write_output(args, 'Validation_', epoch_size, iter, epoch, val_dataset, val_writer, input_list, task_outputs, target_list, metric_names, writer_idx)
804         if ((iter % args.print_freq) == 0) or (iter == (epoch_size-1)):
805             output_string = ''
806             for task_idx, task_metrics in enumerate(args.metric_modules):
807                 output_string += '[{}={}]'.format(metric_names[task_idx], str(avg_metric[task_idx]))
809             epoch_str = '{}/{}'.format(epoch + 1, args.epochs)
810             progress_bar.set_description("=> validation")
811             progress_bar.set_postfix(Epoch=epoch_str, DataTime=data_time, Output="{}".format(output_string))
812             progress_bar.update(iter-last_update_iter)
813             last_update_iter = iter
815         end_time = time.time()
816         writer_idx = (writer_idx + 1) % args.tensorboard_num_imgs
818         if iter >= epoch_size:
819             break
821     if args.print_val_class_iou:
822         print_class_iou(args = args, confusion_matrix = confusion_matrix, task_idx=task_idx)
823     
824     #print_conf_matrix(conf_matrix=conf_matrix, en=False)
825     progress_bar.close()
827     # to print a new line - do not provide end=''
828     print('{}'.format(Fore.RESET), end='')
830     if args.tensorboard_enable:
831         for task_idx, task_metrics in enumerate(args.metric_modules):
832             val_writer.add_scalar('Validation/Task{}_{}_Metric_Epoch'.format(task_idx,metric_names[task_idx]), float(avg_metric[task_idx]), epoch)
834     output_name = metric_names[args.pivot_task_idx]
835     output_metric = float(avg_metric[args.pivot_task_idx])
836     return output_metric, output_name
839 ###################################################################
840 def close(args):
841     if args.logger is not None:
842         del args.logger
843         args.logger = None
844     #
845     args.best_metric = -1
849 def get_save_path(args, phase=None):
850     date = args.date if args.date else datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
851     save_path = os.path.join('./data/checkpoints', args.dataset_name, date + '_' + args.dataset_name + '_' + args.model_name)
852     save_path += '_resize{}x{}_traincrop{}x{}'.format(args.img_resize[1], args.img_resize[0], args.rand_crop[1], args.rand_crop[0])
853     phase = phase if (phase is not None) else args.phase
854     save_path = os.path.join(save_path, phase)
855     return save_path
858 def get_model_orig(model):
859     is_parallel_model = isinstance(model, (torch.nn.DataParallel, torch.nn.parallel.DistributedDataParallel))
860     model_orig = (model.module if is_parallel_model else model)
861     model_orig = (model_orig.module if isinstance(model_orig, (xnn.quantize.QuantBaseModule)) else model_orig)
862     return model_orig
865 def create_rand_inputs(args, is_cuda):
866     dummy_input = []
867     for i_ch in args.model_config.input_channels:
868         x = torch.rand((1, i_ch, args.img_resize[0], args.img_resize[1]))
869         x = x.cuda() if is_cuda else x
870         dummy_input.append(x)
871     #
872     return dummy_input
875 def count_flops(args, model):
876     is_cuda = next(model.parameters()).is_cuda
877     dummy_input = create_rand_inputs(args, is_cuda)
878     #
879     model.eval()
880     flops = xnn.utils.forward_count_flops(model, dummy_input)
881     gflops = flops/1e9
882     print('=> Size = {}, GFLOPs = {}, GMACs = {}'.format(args.img_resize, gflops, gflops/2))
885 def derive_node_name(input_name):
886     #take last entry of input names for deciding node name
887     #print("input_name[-1]: ", input_name[-1])
888     node_name = input_name[-1].rsplit('.', 1)[0]
889     #print("formed node_name: ", node_name)
890     return node_name
893 #torch onnx export does not update names. Do it using onnx.save
894 def add_node_names(onnx_model_name= []):
895     onnx_model = onnx.load(onnx_model_name)
896     for i in range(len(onnx_model.graph.node)):
897         for j in range(len(onnx_model.graph.node[i].input)):
898             #print('-'*60)
899             #print("name: ", onnx_model.graph.node[i].name)
900             #print("input: ", onnx_model.graph.node[i].input)
901             #print("output: ", onnx_model.graph.node[i].output)
902             onnx_model.graph.node[i].input[j] = onnx_model.graph.node[i].input[j].split(':')[0]
903             onnx_model.graph.node[i].name = derive_node_name(onnx_model.graph.node[i].input)
904         #
905     #
906     #update model inplace
907     onnx.save(onnx_model, onnx_model_name)
909 def write_onnx_model(args, model, save_path, name='checkpoint.onnx'):
910     is_cuda = next(model.parameters()).is_cuda
911     input_list = create_rand_inputs(args, is_cuda=is_cuda)
912     #
913     model.eval()
914     torch.onnx.export(model, input_list, os.path.join(save_path, name), export_params=True, verbose=False)
915     #torch onnx export does not update names. Do it using onnx.save
916     add_node_names(onnx_model_name = os.path.join(save_path, name))
919 ###################################################################
920 def write_output(args, prefix, val_epoch_size, iter, epoch, dataset, output_writer, input_images, task_outputs, task_targets, metric_names, writer_idx):
921     write_freq = (args.tensorboard_num_imgs / float(val_epoch_size))
922     write_prob = np.random.random()
923     if (write_prob > write_freq):
924         return
926     batch_size = input_images[0].shape[0]
927     b_index = random.randint(0, batch_size - 1)
929     input_image = None
930     for img_idx, img in enumerate(input_images):
931         input_image = input_images[img_idx][b_index].cpu().numpy().transpose((1, 2, 0))
932         # convert back to original input range (0-255)
933         input_image = input_image / args.image_scale + args.image_mean
934         if args.is_flow and args.is_flow[0][img_idx]:
935             #input corresponding to flow is assumed to have been generated by adding 128
936             flow = input_image - 128
937             flow_hsv = xnn.utils.flow2hsv(flow.transpose(2, 0, 1), confidence=False).transpose(2, 0, 1)
938             #flow_hsv = (flow_hsv / 255.0).clip(0, 1) #TODO: check this
939             output_writer.add_image(prefix +'Input{}/{}'.format(img_idx, writer_idx), flow_hsv, epoch)
940         else:
941             input_image = (input_image/255.0).clip(0,1) #.astype(np.uint8)
942             output_writer.add_image(prefix + 'Input{}/{}'.format(img_idx, writer_idx), input_image.transpose((2,0,1)), epoch)
944     # for sparse data, chroma blending does not look good
945     for task_idx, output_type in enumerate(args.model_config.output_type):
946         # metric_name = metric_names[task_idx]
947         output = task_outputs[task_idx]
948         target = task_targets[task_idx]
949         if (output_type == 'segmentation') and hasattr(dataset, 'decode_segmap'):
950             segmentation_target = dataset.decode_segmap(target[b_index,0].cpu().numpy())
951             segmentation_output = output.max(dim=1,keepdim=True)[1].data.cpu().numpy() if(output.shape[1]>1) else output.data.cpu().numpy()
952             segmentation_output = dataset.decode_segmap(segmentation_output[b_index,0])
953             segmentation_output_blend = xnn.utils.chroma_blend(input_image, segmentation_output)
954             #
955             output_writer.add_image(prefix+'Task{}_{}_GT/{}'.format(task_idx,output_type,writer_idx), segmentation_target.transpose(2,0,1), epoch)
956             if not args.sparse:
957                 segmentation_target_blend = xnn.utils.chroma_blend(input_image, segmentation_target)
958                 output_writer.add_image(prefix + 'Task{}_{}_GT_ColorBlend/{}'.format(task_idx, output_type, writer_idx), segmentation_target_blend.transpose(2, 0, 1), epoch)
959             #
960             output_writer.add_image(prefix+'Task{}_{}_Output/{}'.format(task_idx,output_type,writer_idx), segmentation_output.transpose(2,0,1), epoch)
961             output_writer.add_image(prefix+'Task{}_{}_Output_ColorBlend/{}'.format(task_idx,output_type,writer_idx), segmentation_output_blend.transpose(2,0,1), epoch)
962         elif (output_type in ('depth', 'disparity')):
963             depth_chanidx = 0
964             output_writer.add_image(prefix+'Task{}_{}_GT_Color_Visualization/{}'.format(task_idx,output_type,writer_idx), xnn.utils.tensor2array(target[b_index][depth_chanidx].cpu(), max_value=args.max_depth, colormap=args.viz_colormap).transpose(2,0,1), epoch)
965             if not args.sparse:
966                 output_writer.add_image(prefix + 'Task{}_{}_GT_ColorBlend_Visualization/{}'.format(task_idx, output_type, writer_idx), xnn.utils.tensor2array(target[b_index][depth_chanidx].cpu(), max_value=args.max_depth, colormap=args.viz_colormap, input_blend=input_image).transpose(2, 0, 1), epoch)
967             #
968             output_writer.add_image(prefix+'Task{}_{}_Output_Color_Visualization/{}'.format(task_idx,output_type,writer_idx), xnn.utils.tensor2array(output.data[b_index][depth_chanidx].cpu(), max_value=args.max_depth, colormap=args.viz_colormap).transpose(2,0,1), epoch)
969             output_writer.add_image(prefix + 'Task{}_{}_Output_ColorBlend_Visualization/{}'.format(task_idx, output_type, writer_idx),xnn.utils.tensor2array(output.data[b_index][depth_chanidx].cpu(), max_value=args.max_depth, colormap=args.viz_colormap, input_blend=input_image).transpose(2, 0, 1), epoch)
970         elif (output_type == 'flow'):
971             max_value_flow = 10.0 # only for visualization
972             output_writer.add_image(prefix+'Task{}_{}_GT/{}'.format(task_idx,output_type,writer_idx), xnn.utils.flow2hsv(target[b_index][:2].cpu().numpy(), max_value=max_value_flow).transpose(2,0,1), epoch)
973             output_writer.add_image(prefix+'Task{}_{}_Output/{}'.format(task_idx,output_type,writer_idx), xnn.utils.flow2hsv(output.data[b_index][:2].cpu().numpy(), max_value=max_value_flow).transpose(2,0,1), epoch)
974         elif (output_type == 'interest_pt'):
975             score_chanidx = 0
976             target_score_to_write = target[b_index][score_chanidx].cpu()
977             output_score_to_write = output.data[b_index][score_chanidx].cpu()
978             
979             #if score is learnt as zero mean add offset to make it [0-255]
980             if args.make_score_zero_mean:
981                 # target_score_to_write!=0 : value 0 indicates GT unavailble. Leave them to be 0.
982                 target_score_to_write[target_score_to_write!=0] += 128.0
983                 output_score_to_write += 128.0
985             max_value_score = float(torch.max(target_score_to_write)) #0.002
986             output_writer.add_image(prefix+'Task{}_{}_GT_Bone_Visualization/{}'.format(task_idx,output_type,writer_idx), xnn.utils.tensor2array(target_score_to_write, max_value=max_value_score, colormap='bone').transpose(2,0,1), epoch)
987             output_writer.add_image(prefix+'Task{}_{}_Output_Bone_Visualization/{}'.format(task_idx,output_type,writer_idx), xnn.utils.tensor2array(output_score_to_write, max_value=max_value_score, colormap='bone').transpose(2,0,1), epoch)
988         #
990 def print_conf_matrix(conf_matrix = [], en = False):
991     if not en:
992         return
993     num_rows = conf_matrix.shape[0]
994     num_cols = conf_matrix.shape[1]
995     print("-"*64)
996     num_ele = 1
997     for r_idx in range(num_rows):
998         print("\n")
999         for c_idx in range(0,num_cols,num_ele):
1000             print(conf_matrix[r_idx][c_idx:c_idx+num_ele], end="")
1001     print("\n")
1002     print("-" * 64)
1004 def compute_task_objectives(args, objective_fns, input_var, task_outputs, task_targets, task_mults=None, 
1005   task_offsets=None, loss_mult_factors=None, get_confusion_matrix = False):
1006   
1007     ##########################
1008     objective_total = torch.zeros_like(task_outputs[0].view(-1)[0])
1009     objective_list = []
1010     objective_list_orig = []
1011     objective_names = []
1012     objective_types = []
1013     for task_idx, task_objectives in enumerate(objective_fns):
1014         output_type = args.model_config.output_type[task_idx]
1015         objective_sum_value = torch.zeros_like(task_outputs[task_idx].view(-1)[0])
1016         objective_sum_name = ''
1017         objective_sum_type = ''
1019         task_mult = task_mults[task_idx] if task_mults is not None else 1.0
1020         task_offset = task_offsets[task_idx] if task_offsets is not None else 0.0
1022         for oidx, objective_fn in enumerate(task_objectives):
1023             objective_batch = objective_fn(input_var, task_outputs[task_idx], task_targets[task_idx])
1024             objective_batch = objective_batch.mean() if isinstance(objective_fn, torch.nn.DataParallel) else objective_batch
1025             objective_name = objective_fn.info()['name']
1026             objective_type = objective_fn.info()['is_avg']
1027             if get_confusion_matrix:
1028                 confusion_matrix = objective_fn.info()['confusion_matrix']
1030             loss_mult = loss_mult_factors[task_idx][oidx] if (loss_mult_factors is not None) else 1.0
1031             # --
1032             objective_batch_not_nan = (objective_batch if not torch.isnan(objective_batch) else 0.0)
1033             objective_sum_value = objective_batch_not_nan*loss_mult + objective_sum_value
1034             objective_sum_name += (objective_name if (objective_sum_name == '') else ('+' + objective_name))
1035             assert (objective_sum_type == '' or objective_sum_type == objective_type), 'metric types (avg/val) for a given task should match'
1036             objective_sum_type = objective_type
1038         objective_list.append(objective_sum_value)
1039         objective_list_orig.append(objective_sum_value)
1040         objective_names.append(objective_sum_name)
1041         objective_types.append(objective_sum_type)
1043         objective_total = objective_sum_value*task_mult + task_offset + objective_total
1045     return_list = [objective_total, objective_list, objective_names, objective_types, objective_list_orig]
1046     if get_confusion_matrix:
1047         return_list.append(confusion_matrix)
1049     return return_list 
1052 def save_checkpoint(args, save_path, model, checkpoint_dict, is_best, filename='checkpoint.pth.tar'):
1053     torch.save(checkpoint_dict, os.path.join(save_path,filename))
1054     if is_best:
1055         shutil.copyfile(os.path.join(save_path,filename), os.path.join(save_path,'model_best.pth.tar'))
1056     #
1057     if args.generate_onnx:
1058         write_onnx_model(args, model, save_path, name='checkpoint.onnx')
1059         if is_best:
1060             write_onnx_model(args, model, save_path, name='model_best.onnx')
1061     #
1064 def get_epoch_size(args, loader, args_epoch_size):
1065     if args_epoch_size == 0:
1066         epoch_size = len(loader)
1067     elif args_epoch_size < 1:
1068         epoch_size = int(len(loader) * args_epoch_size)
1069     else:
1070         epoch_size = min(len(loader), int(args_epoch_size))
1071     return epoch_size
1074 def get_train_transform(args):
1075     # image normalization can be at the beginning of transforms or at the end
1076     image_mean = np.array(args.image_mean, dtype=np.float32)
1077     image_scale = np.array(args.image_scale, dtype=np.float32)
1078     image_prenorm = vision.transforms.image_transforms.NormalizeMeanScale(mean=image_mean, scale=image_scale) if args.image_prenorm else None
1079     image_postnorm = vision.transforms.image_transforms.NormalizeMeanScale(mean=image_mean, scale=image_scale) if (not image_prenorm) else None
1081     # crop size used only for training
1082     image_train_output_scaling = vision.transforms.image_transforms.Scale(args.rand_resize, target_size=args.rand_output_size, is_flow=args.is_flow) \
1083         if (args.rand_output_size is not None and args.rand_output_size != args.rand_resize) else None
1084     train_transform = vision.transforms.image_transforms.Compose([
1085         image_prenorm,
1086         vision.transforms.image_transforms.AlignImages(),
1087         vision.transforms.image_transforms.MaskTarget(args.target_mask, 0),
1088         vision.transforms.image_transforms.CropRect(args.img_border_crop),
1089         vision.transforms.image_transforms.RandomRotate(args.transform_rotation, is_flow=args.is_flow) if args.transform_rotation else None,
1090         vision.transforms.image_transforms.RandomScaleCrop(args.rand_resize, scale_range=args.rand_scale, is_flow=args.is_flow),
1091         vision.transforms.image_transforms.RandomHorizontalFlip(is_flow=args.is_flow),
1092         vision.transforms.image_transforms.RandomCrop(args.rand_crop),
1093         vision.transforms.image_transforms.RandomColor2Gray(is_flow=args.is_flow, random_threshold=0.5) if 'tiad' in args.dataset_name else None,
1094         image_train_output_scaling,
1095         image_postnorm,
1096         vision.transforms.image_transforms.ConvertToTensor()
1097         ])
1098     return train_transform
1101 def get_validation_transform(args):
1102     # image normalization can be at the beginning of transforms or at the end
1103     image_mean = np.array(args.image_mean, dtype=np.float32)
1104     image_scale = np.array(args.image_scale, dtype=np.float32)
1105     image_prenorm = vision.transforms.image_transforms.NormalizeMeanScale(mean=image_mean, scale=image_scale) if args.image_prenorm else None
1106     image_postnorm = vision.transforms.image_transforms.NormalizeMeanScale(mean=image_mean, scale=image_scale) if (not image_prenorm) else None
1108     # prediction is resized to output_size before evaluation.
1109     val_transform = vision.transforms.image_transforms.Compose([
1110         image_prenorm,
1111         vision.transforms.image_transforms.AlignImages(),
1112         vision.transforms.image_transforms.MaskTarget(args.target_mask, 0),
1113         vision.transforms.image_transforms.CropRect(args.img_border_crop),
1114         vision.transforms.image_transforms.Scale(args.img_resize, target_size=args.output_size, is_flow=args.is_flow),
1115         image_postnorm,
1116         vision.transforms.image_transforms.ConvertToTensor()
1117         ])
1118     return val_transform
1121 def get_transforms(args):
1122     # Provision to train with val transform - provide rand_scale as (0, 0)
1123     # Fixing the train-test resolution discrepancy, https://arxiv.org/abs/1906.06423
1124     always_use_val_transform = (args.rand_scale[0] == 0)
1125     train_transform = get_validation_transform(args) if always_use_val_transform else get_train_transform(args)
1126     val_transform = get_validation_transform(args)
1127     return train_transform, val_transform
1130 def _upsample_impl(tensor, output_size, upsample_mode):
1131     # upsample of long tensor is not supported currently. covert to float, just to avoid error.
1132     # we can do thsi only in the case of nearest mode, otherwise output will have invalid values.
1133     convert_to_float = False
1134     if isinstance(tensor, (torch.LongTensor,torch.cuda.LongTensor)):
1135         convert_to_float = True
1136         original_dtype = tensor.dtype
1137         tensor = tensor.float()
1138         upsample_mode = 'nearest'
1140     dim_added = False
1141     if len(tensor.shape) < 4:
1142         tensor = tensor[np.newaxis,...]
1143         dim_added = True
1145     if (tensor.size()[-2:] != output_size):
1146         tensor = torch.nn.functional.interpolate(tensor, output_size, mode=upsample_mode)
1148     if dim_added:
1149         tensor = tensor[0,...]
1151     if convert_to_float:
1152         tensor = tensor.long() #tensor.astype(original_dtype)
1154     return tensor
1157 def upsample_tensors(tensors, output_sizes, upsample_mode):
1158     if isinstance(tensors, (list,tuple)):
1159         for tidx, tensor in enumerate(tensors):
1160             tensors[tidx] = _upsample_impl(tensor, output_sizes[tidx][-2:], upsample_mode)
1161         #
1162     else:
1163         tensors = _upsample_impl(tensors, output_sizes[0][-2:], upsample_mode)
1164     return tensors
1166 #print IoU for each class
1167 def print_class_iou(args = None, confusion_matrix = None, task_idx = 0):    
1168     n_classes = args.model_config.output_channels[task_idx]
1169     [accuracy, mean_iou, iou, f1_score] = compute_accuracy(args, confusion_matrix, n_classes)
1170     print("\n Class IoU: [", end = "")
1171     for class_iou in iou:
1172         print("{:0.3f}".format(class_iou), end=",")
1173     print("]")    
1175 if __name__ == '__main__':
1176     train_args = get_config()
1177     main(train_args)