]> Gitweb @ Texas Instruments - Open Source Git Repositories - git.TI.com/gitweb - jacinto-ai/pytorch-jacinto-ai-devkit.git/blob - modules/pytorch_jacinto_ai/engine/train_pixel2pixel.py
updated quantization modules to support mmdetection, using Hardtanh for fixed range...
[jacinto-ai/pytorch-jacinto-ai-devkit.git] / modules / pytorch_jacinto_ai / engine / train_pixel2pixel.py
1 import os
2 import shutil
3 import time
4 import math
5 import copy
7 import torch
8 import torch.nn.parallel
9 import torch.backends.cudnn as cudnn
10 import torch.optim
11 import torch.utils.data
12 import torch.onnx
13 import onnx
15 import datetime
16 from tensorboardX import SummaryWriter
17 import numpy as np
18 import random
19 import cv2
20 from colorama import Fore
21 import progiter
22 from packaging import version
23 import warnings
25 from .. import xnn
26 from .. import vision
27 from . infer_pixel2pixel import compute_accuracy
30 ##################################################
31 warnings.filterwarnings('ignore', category=torch.jit.TracerWarning)
33 ##################################################
34 def get_config():
35     args = xnn.utils.ConfigNode()
37     args.dataset_config = xnn.utils.ConfigNode()
38     args.dataset_config.split_name = 'val'
39     args.dataset_config.max_depth_bfr_scaling = 80
40     args.dataset_config.depth_scale = 1
41     args.dataset_config.train_depth_log = 1
42     args.use_semseg_for_depth = False
44     # model config
45     args.model_config = xnn.utils.ConfigNode()
46     args.model_config.output_type = ['segmentation']   # the network is used to predict flow or depth or sceneflow
47     args.model_config.output_channels = None            # number of output channels
48     args.model_config.input_channels = None             # number of input channels
49     args.model_config.output_range = None               # max range of output
50     args.model_config.num_decoders = None               # number of decoders to use. [options: 0, 1, None]
51     args.model_config.freeze_encoder = False            # do not update encoder weights
52     args.model_config.freeze_decoder = False            # do not update decoder weights
53     args.model_config.multi_task_type = 'learned'       # find out loss multiplier by learning, choices=[None, 'learned', 'uncertainty', 'grad_norm', 'dwa_grad_norm']
54     args.model_config.target_input_ratio = 1            # Keep target size same as input size
55     args.model_config.input_nv12 = False
56     args.model = None                                   # the model itself can be given from ouside
57     args.model_name = 'deeplabv2lite_mobilenetv2'       # model architecture, overwritten if pretrained is specified
58     args.dataset_name = 'cityscapes_segmentation'       # dataset type
59     args.transforms = None                              # the transforms itself can be given from outside
61     args.data_path = './data/cityscapes'                # 'path to dataset'
62     args.save_path = None                               # checkpoints save path
63     args.phase = 'training'                             # training/calibration/validation
64     args.date = None                                    # date to add to save path. if this is None, current date will be added.
66     args.logger = None                                  # logger stream to output into
67     args.show_gpu_usage = False                         # Shows gpu usage at the begining of each training epoch
69     args.split_file = None                              # train_val split file
70     args.split_files = None                             # split list files. eg: train.txt val.txt
71     args.split_value = None                             # test_val split proportion (between 0 (only test) and 1 (only train))
73     args.solver = 'adam'                                # solver algorithms, choices=['adam','sgd']
74     args.scheduler = 'step'                             # scheduler algorithms, choices=['step','poly', 'cosine']
75     args.workers = 8                                    # number of data loading workers
77     args.epochs = 250                                   # number of total epochs to run
78     args.start_epoch = 0                                # manual epoch number (useful on restarts)
80     args.epoch_size = 0                                 # manual epoch size (will match dataset size if not specified)
81     args.epoch_size_val = 0                             # manual epoch size (will match dataset size if not specified)
82     args.batch_size = 12                                # mini_batch size
83     args.total_batch_size = None                        # accumulated batch size. total_batch_size = batch_size*iter_size
84     args.iter_size = 1                                  # iteration size. total_batch_size = batch_size*iter_size
86     args.lr = 1e-4                                      # initial learning rate
87     args.lr_clips = None                                # use args.lr itself if it is None
88     args.lr_calib = 0.05                                # lr for bias calibration
89     args.warmup_epochs = 5                              # number of epochs to warmup
91     args.momentum = 0.9                                 # momentum for sgd, alpha parameter for adam
92     args.beta = 0.999                                   # beta parameter for adam
93     args.weight_decay = 1e-4                            # weight decay
94     args.bias_decay = None                              # bias decay
96     args.sparse = True                                  # avoid invalid/ignored target pixels from loss computation, use NEAREST for interpolation
98     args.tensorboard_num_imgs = 5                       # number of imgs to display in tensorboard
99     args.pretrained = None                              # path to pre_trained model
100     args.resume = None                                  # path to latest checkpoint (default: none)
101     args.no_date = False                                # don\'t append date timestamp to folder
102     args.print_freq = 100                               # print frequency (default: 100)
104     args.milestones = (100, 200)                        # epochs at which learning rate is divided by 2
106     args.losses = ['segmentation_loss']                 # loss functions to mchoices=['step','poly', 'cosine'],loss multiplication factor')
107     args.metrics = ['segmentation_metrics']  # metric/measurement/error functions for train/validation
108     args.multi_task_factors = None                      # loss mult factors
109     args.class_weights = None                           # class weights
111     args.loss_mult_factors = None                       # fixed loss mult factors - per loss - not: this is different from multi_task_factors (which is per task)
113     args.multistep_gamma = 0.5                          # steps for step scheduler
114     args.polystep_power = 1.0                           # power for polynomial scheduler
116     args.rand_seed = 1                                  # random seed
117     args.img_border_crop = None                         # image border crop rectangle. can be relative or absolute
118     args.target_mask = None                              # mask rectangle. can be relative or absolute. last value is the mask value
120     args.rand_resize = None                             # random image size to be resized to during training
121     args.rand_output_size = None                        # output size to be resized to during training
122     args.rand_scale = (1.0, 2.0)                        # random scale range for training
123     args.rand_crop = None                               # image size to be cropped to
125     args.img_resize = None                              # image size to be resized to during evaluation
126     args.output_size = None                             # target output size to be resized to
128     args.count_flops = True                             # count flops and report
130     args.shuffle = True                                 # shuffle or not
131     args.shuffle_val = True                             # shuffle val dataset or not
133     args.transform_rotation = 0.                        # apply rotation augumentation. value is rotation in degrees. 0 indicates no rotation
134     args.is_flow = None                                 # whether entries in images and targets lists are optical flow or not
136     args.upsample_mode = 'bilinear'                     # upsample mode to use, choices=['nearest','bilinear']
138     args.image_prenorm = True                           # whether normalization is done before all other the transforms
139     args.image_mean = (128.0,)                          # image mean for input image normalization
140     args.image_scale = (1.0 / (0.25 * 256),)            # image scaling/mult for input iamge normalization
142     args.max_depth = 80                                 # maximum depth to be used for visualization
144     args.pivot_task_idx = 0                             # task id to select best model
146     args.parallel_model = True                          # Usedata parallel for model
147     args.parallel_criterion = True                      # Usedata parallel for loss and metric
149     args.evaluate_start = True                          # evaluate right at the begining of training or not
150     args.save_onnx = True                               # apply quantized inference or not
151     args.print_model = False                            # print the model to text
152     args.run_soon = True                                # To start training after generating configs/models
154     args.quantize = False                               # apply quantized inference or not
155     #args.model_surgery = None                           # replace activations with PAct2 activation module. Helpful in quantized training.
156     args.bitwidth_weights = 8                           # bitwidth for weights
157     args.bitwidth_activations = 8                       # bitwidth for activations
158     args.histogram_range = True                         # histogram range for calibration
159     args.bias_calibration = True                        # apply bias correction during quantized inference calibration
160     args.per_channel_q = False                          # apply separate quantizion factor for each channel in depthwise or not
162     args.save_mod_files = False                         # saves modified files after last commit. Also  stores commit id.
163     args.make_score_zero_mean = False                   # make score zero mean while learning
164     args.no_q_for_dws_layer_idx = 0                     # no_q_for_dws_layer_idx
166     args.viz_colormap = 'rainbow'                       # colormap for tensorboard: 'rainbow', 'plasma', 'magma', 'bone'
168     args.freeze_bn = False                              # freeze the statistics of bn
169     args.tensorboard_enable = True                      # en/disable of TB writing
170     args.print_train_class_iou = False
171     args.print_val_class_iou = False
172     args.freeze_layers = None
174     args.opset_version = 9                              # onnx opset_version
176     return args
179 # ################################################
180 # to avoid hangs in data loader with multi threads
181 # this was observed after using cv2 image processing functions
182 # https://github.com/pytorch/pytorch/issues/1355
183 cv2.setNumThreads(0)
185 # ################################################
186 def main(args):
187     # ensure pytorch version is 1.2 or higher
188     assert version.parse(torch.__version__) >= version.parse('1.1'), \
189         'torch version must be 1.1 or higher, due to the change in scheduler.step() and optimiser.step() call order'
191     assert (not hasattr(args, 'evaluate')), 'args.evaluate is deprecated. use args.phase=training or calibration or validation'
192     assert is_valid_phase(args.phase), f'invalid phase {args.phase}'
193     assert not hasattr(args, 'model_surgery'), 'the argument model_surgery is deprecated, it is not needed now - remove it'
195     if (args.phase == 'validation' and args.bias_calibration):
196         args.bias_calibration = False
197         warnings.warn('switching off bias calibration in validation')
198     #
200     #################################################
201     args.rand_resize = args.img_resize if args.rand_resize is None else args.rand_resize
202     args.rand_crop = args.img_resize if args.rand_crop is None else args.rand_crop
203     args.output_size = args.img_resize if args.output_size is None else args.output_size
204     # resume has higher priority
205     args.pretrained = None if (args.resume is not None) else args.pretrained
207     if args.save_path is None:
208         save_path = get_save_path(args)
209     else:
210         save_path = args.save_path
211     #
212     if not os.path.exists(save_path):
213         os.makedirs(save_path)
215     if args.save_mod_files:
216         #store all the files after the last commit.
217         mod_files_path = save_path+'/mod_files'
218         os.makedirs(mod_files_path)
219         
220         cmd = "git ls-files --modified | xargs -i cp {} {}".format("{}", mod_files_path)
221         print("cmd:", cmd)    
222         os.system(cmd)
224         #stoe last commit id. 
225         cmd = "git log -n 1  >> {}".format(mod_files_path + '/commit_id.txt')
226         print("cmd:", cmd)    
227         os.system(cmd)
229     #################################################
230     if args.logger is None:
231         log_file = os.path.splitext(os.path.basename(__file__))[0] + '.log'
232         args.logger = xnn.utils.TeeLogger(filename=os.path.join(save_path,log_file))
234     #################################################
235     # global settings. rand seeds for repeatability
236     random.seed(args.rand_seed)
237     np.random.seed(args.rand_seed)
238     torch.manual_seed(args.rand_seed)
239     torch.cuda.manual_seed(args.rand_seed)
241     ################################
242     # args check and config
243     if args.iter_size != 1 and args.total_batch_size is not None:
244         warnings.warn("only one of --iter_size or --total_batch_size must be set")
245     #
246     if args.total_batch_size is not None:
247         args.iter_size = args.total_batch_size//args.batch_size
248     else:
249         args.total_batch_size = args.batch_size*args.iter_size
251     #################################################
252     # set some global flags and initializations
253     # keep it in args for now - although they don't belong here strictly
254     # using pin_memory is seen to cause issues, especially when when lot of memory is used.
255     args.use_pinned_memory = False
256     args.n_iter = 0
257     args.best_metric = -1
258     cudnn.benchmark = True
259     # torch.autograd.set_detect_anomaly(True)
261     ################################
262     # reset character color, in case it is different
263     print('{}'.format(Fore.RESET))
264     # print everything for log
265     print('=> args: {}'.format(args))
266     print('\n'.join("%s: %s" % item for item in sorted(vars(args).items())))
268     print('=> will save everything to {}'.format(save_path))
270     #################################################
271     train_writer = SummaryWriter(os.path.join(save_path,'train')) if args.tensorboard_enable else None
272     val_writer = SummaryWriter(os.path.join(save_path,'val')) if args.tensorboard_enable else None
273     transforms = get_transforms(args) if args.transforms is None else args.transforms
274     assert isinstance(transforms, (list,tuple)) and len(transforms) == 2, 'incorrect transforms were given'
276     print("=> fetching images in '{}'".format(args.data_path))
277     split_arg = args.split_file if args.split_file else (args.split_files if args.split_files else args.split_value)
278     train_dataset, val_dataset = vision.datasets.pixel2pixel.__dict__[args.dataset_name](args.dataset_config, args.data_path, split=split_arg, transforms=transforms)
280     #################################################
281     print('=> {} samples found, {} train samples and {} test samples '.format(len(train_dataset)+len(val_dataset),
282         len(train_dataset), len(val_dataset)))
283     train_sampler = get_dataset_sampler(train_dataset, args.epoch_size) if args.epoch_size != 0 else None
284     shuffle_train = args.shuffle and (train_sampler is None)
285     train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=args.batch_size,
286         num_workers=args.workers, pin_memory=args.use_pinned_memory, sampler=train_sampler, shuffle=shuffle_train)
288     val_sampler = get_dataset_sampler(val_dataset, args.epoch_size_val) if args.epoch_size_val != 0 else None
289     shuffle_val = args.shuffle_val and (val_sampler is None)
290     val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=args.batch_size,
291         num_workers=args.workers, pin_memory=args.use_pinned_memory, sampler=val_sampler, shuffle=shuffle_val)
293     #################################################
294     if (args.model_config.input_channels is None):
295         args.model_config.input_channels = (3,)
296         print("=> input channels is not given - setting to {}".format(args.model_config.input_channels))
298     if (args.model_config.output_channels is None):
299         if ('num_classes' in dir(train_dataset)):
300             args.model_config.output_channels = train_dataset.num_classes()
301         else:
302             args.model_config.output_channels = (2 if args.model_config.output_type == 'flow' else args.model_config.output_channels)
303             xnn.utils.print_yellow("=> output channels is not given - setting to {} - not sure to work".format(args.model_config.output_channels))
304         #
305         if not isinstance(args.model_config.output_channels,(list,tuple)):
306             args.model_config.output_channels = [args.model_config.output_channels]
308     if (args.class_weights is None) and ('class_weights' in dir(train_dataset)):
309         args.class_weights = train_dataset.class_weights()
310         if not isinstance(args.class_weights, (list,tuple)):
311             args.class_weights = [args.class_weights]
312         #
313         print("=> class weights available for dataset: {}".format(args.class_weights))
315     #################################################
316     pretrained_data = None
317     model_surgery_quantize = False
318     pretrained_data = None
319     if args.pretrained and args.pretrained != "None":
320         pretrained_data = []
321         pretrained_files = args.pretrained if isinstance(args.pretrained,(list,tuple)) else [args.pretrained]
322         for p in pretrained_files:
323             if isinstance(p, dict):
324                 p_data = p
325             else:
326                 if p.startswith('http://') or p.startswith('https://'):
327                     p_file = vision.datasets.utils.download_url(p, './data/downloads')
328                 else:
329                     p_file = p
330                 #
331                 print(f'=> loading pretrained weights file: {p}')
332                 p_data = torch.load(p_file)
333             #
334             pretrained_data.append(p_data)
335             model_surgery_quantize = p_data['quantize'] if 'quantize' in p_data else False
336     #
338     #################################################
339     # create model
340     is_onnx_model = False
341     if isinstance(args.model, torch.nn.Module):
342         model, change_names_dict = args.model if isinstance(args.model, (list, tuple)) else (args.model, None)
343         assert isinstance(model, torch.nn.Module), 'args.model, if provided must be a valid torch.nn.Module'
344     elif isinstance(args.model, str) and args.model.endswith('.onnx'):
345         model = xnn.onnx.import_onnx(args.model)
346         is_onnx_model = True
347     else:
348         xnn.utils.print_yellow("=> creating model '{}'".format(args.model_name))
349         model = vision.models.pixel2pixel.__dict__[args.model_name](args.model_config)
350         # check if we got the model as well as parameters to change the names in pretrained
351         model, change_names_dict = model if isinstance(model, (list,tuple)) else (model,None)
352     #
354     if args.quantize:
355         # dummy input is used by quantized models to analyze graph
356         is_cuda = next(model.parameters()).is_cuda
357         dummy_input = create_rand_inputs(args, is_cuda=is_cuda)
358         #
359         if 'training' in args.phase:
360             model = xnn.quantize.QuantTrainModule(model, per_channel_q=args.per_channel_q,
361                         histogram_range=args.histogram_range, bitwidth_weights=args.bitwidth_weights, bitwidth_activations=args.bitwidth_activations, dummy_input=dummy_input)
362         elif 'calibration' in args.phase:
363             model = xnn.quantize.QuantCalibrateModule(model, per_channel_q=args.per_channel_q,
364                         bitwidth_weights=args.bitwidth_weights, bitwidth_activations=args.bitwidth_activations,
365                         histogram_range=args.histogram_range, bias_calibration=args.bias_calibration,
366                         dummy_input=dummy_input, lr_calib=args.lr_calib)
367         elif 'validation' in args.phase:
368             # Note: bias_calibration is not emabled
369             model = xnn.quantize.QuantTestModule(model, per_channel_q=args.per_channel_q,
370                         bitwidth_weights=args.bitwidth_weights, bitwidth_activations=args.bitwidth_activations,
371                         histogram_range=args.histogram_range, dummy_input=dummy_input,
372                         model_surgery_quantize=model_surgery_quantize)
373         else:
374             assert False, f'invalid phase {args.phase}'
375     #
377     # load pretrained model
378     if pretrained_data is not None and not is_onnx_model:
379         for (p_data,p_file) in zip(pretrained_data, pretrained_files):
380             print("=> using pretrained weights from: {}".format(p_file))
381             xnn.utils.load_weights(get_model_orig(model), pretrained=p_data, change_names_dict=change_names_dict)
382     #
384     #################################################
385     if args.count_flops:
386         count_flops(args, model)
388     #################################################
389     if args.save_onnx and (any(args.phase in p for p in ('training','calibration')) or (args.run_soon == False)):
390         write_onnx_model(args, get_model_orig(model), save_path, save_traced_model=False)
391     #
393     #################################################
394     if args.print_model:
395         print(model)
396         print('\n')
397     else:
398         args.logger.debug(str(model))
399         args.logger.debug('\n')
401     #################################################
402     if (not args.run_soon):
403         print("Training not needed for now")
404         close(args)
405         exit()
407     #################################################
408     # DataParallel does not work for QuantCalibrateModule or QuantTestModule
409     if args.parallel_model and (not isinstance(model, (xnn.quantize.QuantCalibrateModule, xnn.quantize.QuantTestModule))):
410         model = torch.nn.DataParallel(model)
412     #################################################
413     model = model.cuda()
415     #################################################
416     # for help in debug/print
417     for name, module in model.named_modules():
418         module.name = name
420     #################################################
421     args.loss_modules = copy.deepcopy(args.losses)
422     for task_dx, task_losses in enumerate(args.losses):
423         for loss_idx, loss_fn in enumerate(task_losses):
424             kw_args = {}
425             loss_args = vision.losses.__dict__[loss_fn].args()
426             for arg in loss_args:
427                 if arg == 'weight' and (args.class_weights is not None):
428                     kw_args.update({arg:args.class_weights[task_dx]})
429                 elif arg == 'num_classes':
430                     kw_args.update({arg:args.model_config.output_channels[task_dx]})
431                 elif arg == 'sparse':
432                     kw_args.update({arg:args.sparse})
433                 #
434             #
435             loss_fn_raw = vision.losses.__dict__[loss_fn](**kw_args)
436             if args.parallel_criterion:
437                 loss_fn = torch.nn.DataParallel(loss_fn_raw).cuda() if args.parallel_criterion else loss_fn_raw.cuda()
438                 loss_fn.info = loss_fn_raw.info
439                 loss_fn.clear = loss_fn_raw.clear
440             else:
441                 loss_fn = loss_fn_raw.cuda()
442             #
443             args.loss_modules[task_dx][loss_idx] = loss_fn
444     #
446     args.metric_modules = copy.deepcopy(args.metrics)
447     for task_dx, task_metrics in enumerate(args.metrics):
448         for midx, metric_fn in enumerate(task_metrics):
449             kw_args = {}
450             loss_args = vision.losses.__dict__[metric_fn].args()
451             for arg in loss_args:
452                 if arg == 'weight':
453                     kw_args.update({arg:args.class_weights[task_dx]})
454                 elif arg == 'num_classes':
455                     kw_args.update({arg:args.model_config.output_channels[task_dx]})
456                 elif arg == 'sparse':
457                     kw_args.update({arg:args.sparse})
459             metric_fn_raw = vision.losses.__dict__[metric_fn](**kw_args)
460             if args.parallel_criterion:
461                 metric_fn = torch.nn.DataParallel(metric_fn_raw).cuda()
462                 metric_fn.info = metric_fn_raw.info
463                 metric_fn.clear = metric_fn_raw.clear
464             else:
465                 metric_fn = metric_fn_raw.cuda()
466             #
467             args.metric_modules[task_dx][midx] = metric_fn
468     #
470     #################################################
471     if args.phase=='validation':
472         with torch.no_grad():
473             validate(args, val_dataset, val_loader, model, 0, val_writer)
474         #
475         close(args)
476         return
478     #################################################
479     assert(args.solver in ['adam', 'sgd'])
480     print('=> setting {} solver'.format(args.solver))
481     if args.lr_clips is not None:
482         learning_rate_clips = args.lr_clips if 'training' in args.phase else 0.0
483         clips_decay = args.bias_decay if (args.bias_decay is not None and args.bias_decay != 0.0) else args.weight_decay
484         clips_params = [p for n,p in model.named_parameters() if 'clips' in n]
485         other_params = [p for n,p in model.named_parameters() if 'clips' not in n]
486         param_groups = [{'params': clips_params, 'weight_decay': clips_decay, 'lr': learning_rate_clips},
487                         {'params': other_params, 'weight_decay': args.weight_decay}]
488     else:
489         param_groups = [{'params': filter(lambda p: p.requires_grad, model.parameters()), 'weight_decay': args.weight_decay}]
490     #
492     learning_rate = args.lr if ('training'in args.phase) else 0.0
493     if args.solver == 'adam':
494         optimizer = torch.optim.Adam(param_groups, learning_rate, betas=(args.momentum, args.beta))
495     elif args.solver == 'sgd':
496         optimizer = torch.optim.SGD(param_groups, learning_rate, momentum=args.momentum)
497     else:
498         raise ValueError('Unknown optimizer type{}'.format(args.solver))
499     #
501     #################################################
502     max_iter = args.epochs * len(train_loader)
503     scheduler = xnn.optim.lr_scheduler.SchedulerWrapper(args.scheduler, optimizer, args.epochs, args.start_epoch, \
504                                                             args.warmup_epochs, max_iter=max_iter, polystep_power=args.polystep_power,
505                                                             milestones=args.milestones, multistep_gamma=args.multistep_gamma)
507     # optionally resume from a checkpoint
508     if args.resume:
509         if not os.path.isfile(args.resume):
510             print("=> no checkpoint found at '{}'".format(args.resume))        
511         else:
512             print("=> loading checkpoint '{}'".format(args.resume))
514         checkpoint = torch.load(args.resume)
515         model = xnn.utils.load_weights(model, checkpoint)
516             
517         if args.start_epoch == 0:
518             args.start_epoch = checkpoint['epoch']
519         
520         if 'best_metric' in list(checkpoint.keys()):    
521             args.best_metric = checkpoint['best_metric']
523         if 'optimizer' in list(checkpoint.keys()):  
524             optimizer.load_state_dict(checkpoint['optimizer'])
526         if 'scheduler' in list(checkpoint.keys()):
527             scheduler.load_state_dict(checkpoint['scheduler'])
529         if 'multi_task_factors' in list(checkpoint.keys()):
530             args.multi_task_factors = checkpoint['multi_task_factors']
532         print("=> loaded checkpoint '{}' (epoch {})".format(args.resume, checkpoint['epoch']))
534     #################################################
535     if args.evaluate_start:
536         with torch.no_grad():
537             validate(args, val_dataset, val_loader, model, args.start_epoch, val_writer)
539     for epoch in range(args.start_epoch, args.epochs):
540         # epoch is needed to seed shuffling in DistributedSampler, every epoch.
541         # otherwise seed of 0 is used every epoch, which seems incorrect.
542         if train_sampler and isinstance(train_sampler, torch.utils.data.DistributedSampler):
543             train_sampler.set_epoch(epoch)
544         if val_sampler and isinstance(val_sampler, torch.utils.data.DistributedSampler):
545             val_sampler.set_epoch(epoch)
547         # train for one epoch
548         train(args, train_dataset, train_loader, model, optimizer, epoch, train_writer, scheduler)
550         # evaluate on validation set
551         with torch.no_grad():
552             val_metric, metric_name = validate(args, val_dataset, val_loader, model, epoch, val_writer)
554         if args.best_metric < 0:
555             args.best_metric = val_metric
557         if "iou" in metric_name.lower() or "acc" in metric_name.lower():
558             is_best = val_metric >= args.best_metric
559             args.best_metric = max(val_metric, args.best_metric)
560         elif "error" in metric_name.lower() or "diff" in metric_name.lower() or "norm" in metric_name.lower() \
561                 or "loss" in metric_name.lower() or "outlier" in metric_name.lower():
562             is_best = val_metric <= args.best_metric
563             args.best_metric = min(val_metric, args.best_metric)
564         else:
565             raise ValueError("Metric is not known. Best model could not be saved.")
566         #
568         checkpoint_dict = { 'epoch': epoch + 1, 'model_name': args.model_name,
569                             'state_dict': get_model_orig(model).state_dict(),
570                             'optimizer': optimizer.state_dict(),
571                             'scheduler': scheduler.state_dict(),
572                             'best_metric': args.best_metric,
573                             'multi_task_factors': args.multi_task_factors,
574                             'quantize' : args.quantize}
576         save_checkpoint(args, save_path, get_model_orig(model), checkpoint_dict, is_best)
578         if args.tensorboard_enable:
579             train_writer.file_writer.flush()
580             val_writer.file_writer.flush()
582         # adjust the learning rate using lr scheduler
583         if 'training' in args.phase:
584             scheduler.step()
585         #
586     #
588     # close and cleanup
589     close(args)
592 ###################################################################
593 def is_valid_phase(phase):
594     phases = ('training', 'calibration', 'validation')
595     return any(p in phase for p in phases)
598 ###################################################################
599 def train(args, train_dataset, train_loader, model, optimizer, epoch, train_writer, scheduler):
600     batch_time = xnn.utils.AverageMeter()
601     data_time = xnn.utils.AverageMeter()
602     # if the loss/ metric is already an average, no need to further average
603     avg_loss = [xnn.utils.AverageMeter(print_avg=(not task_loss[0].info()['is_avg'])) for task_loss in args.loss_modules]
604     avg_loss_orig = [xnn.utils.AverageMeter(print_avg=(not task_loss[0].info()['is_avg'])) for task_loss in args.loss_modules]
605     avg_metric = [xnn.utils.AverageMeter(print_avg=(not task_metric[0].info()['is_avg'])) for task_metric in args.metric_modules]
607     ##########################
608     # switch to train mode
609     model.train()
610     if args.freeze_bn:
611         xnn.utils.freeze_bn(model)
612     #
613     
614     #freeze layers 
615     if args.freeze_layers is not None:
616         # 'freeze_layer_name' could be part of 'name', i.e. 'name' need not be exact same as 'freeze_layer_name'
617         # e.g. freeze_layer_name = 'encoder.0' then all layers like, 'encoder.0.0'  'encoder.0.1' will be frozen
618         for freeze_layer_name in args.freeze_layers:
619             for name, module in model.named_modules():
620                 if freeze_layer_name in name:
621                     xnn.utils.print_once("Freezing the module : {}".format(name))
622                     module.eval()
623                     for param in module.parameters():
624                         param.requires_grad = False
626     ##########################
627     for task_dx, task_losses in enumerate(args.loss_modules):
628         for loss_idx, loss_fn in enumerate(task_losses):
629             loss_fn.clear()
630     for task_dx, task_metrics in enumerate(args.metric_modules):
631         for midx, metric_fn in enumerate(task_metrics):
632             metric_fn.clear()
634     num_iter = len(train_loader)
635     progress_bar = progiter.ProgIter(np.arange(num_iter), chunksize=1)
636     metric_name = "Metric"
637     metric_ctx = [None] * len(args.metric_modules)
638     end_time = time.time()
639     writer_idx = 0
640     last_update_iter = -1
642     # change color to yellow for calibration
643     progressbar_color = (Fore.YELLOW if (('calibration' in args.phase) or ('training' in args.phase and args.quantize)) else Fore.WHITE)
644     print('{}'.format(progressbar_color), end='')
646     ##########################
647     for iter_id, (inputs, targets) in enumerate(train_loader):
648         # measure data loading time
649         data_time.update(time.time() - end_time)
651         lr = scheduler.get_lr()[0]
653         input_list = [[jj.cuda() for jj in img] if isinstance(img,(list,tuple)) else  img.cuda() for img in inputs]
654         target_list = [tgt.cuda(non_blocking=True) for tgt in targets]
655         target_sizes = [tgt.shape for tgt in target_list]
656         batch_size_cur = target_sizes[0][0]
658         ##########################
659         # compute output
660         task_outputs = model(input_list)
662         task_outputs = task_outputs if isinstance(task_outputs,(list,tuple)) else [task_outputs]
663         # upsample output to target resolution
664         if args.upsample_mode is not None:
665             task_outputs = upsample_tensors(task_outputs, target_sizes, args.upsample_mode)
667         if args.model_config.multi_task_type is not None and len(args.model_config.output_channels) > 1:
668             args.multi_task_factors, args.multi_task_offsets = xnn.layers.get_loss_scales(model)
669         else:
670             args.multi_task_factors = None
671             args.multi_task_offsets = None
673         loss_total, loss_list, loss_names, loss_types, loss_list_orig = \
674             compute_task_objectives(args, args.loss_modules, input_list, task_outputs, target_list,
675                          task_mults=args.multi_task_factors, task_offsets=args.multi_task_offsets,
676                          loss_mult_factors=args.loss_mult_factors)
678         if args.print_train_class_iou:
679             metric_total, metric_list, metric_names, metric_types, _, confusion_matrix = \
680                 compute_task_objectives(args, args.metric_modules, input_list, task_outputs, target_list, 
681                 get_confusion_matrix=args.print_train_class_iou)
682         else:        
683             metric_total, metric_list, metric_names, metric_types, _ = \
684                 compute_task_objectives(args, args.metric_modules, input_list, task_outputs, target_list, 
685                 get_confusion_matrix=args.print_train_class_iou)
687         if args.model_config.multi_task_type is not None and len(args.model_config.output_channels) > 1:
688             xnn.layers.set_losses(model, loss_list_orig)
690         if 'training' in args.phase:
691             # zero gradients so that we can accumulate gradients
692             if (iter_id % args.iter_size) == 0:
693                 optimizer.zero_grad()
695             # accumulate gradients
696             loss_total.backward()
697             # optimization step
698             if ((iter_id+1) % args.iter_size) == 0:
699                 optimizer.step()
700         #
702         # record loss.
703         for task_idx, task_losses in enumerate(args.loss_modules):
704             avg_loss[task_idx].update(float(loss_list[task_idx].cpu()), batch_size_cur)
705             avg_loss_orig[task_idx].update(float(loss_list_orig[task_idx].cpu()), batch_size_cur)
706             if args.tensorboard_enable:
707                 train_writer.add_scalar('Training/Task{}_{}_Loss_Iter'.format(task_idx,loss_names[task_idx]), float(loss_list[task_idx]), args.n_iter)
708                 if args.model_config.multi_task_type is not None and len(args.model_config.output_channels) > 1:
709                     train_writer.add_scalar('Training/multi_task_Factor_Task{}_{}'.format(task_idx,loss_names[task_idx]), float(args.multi_task_factors[task_idx]), args.n_iter)
711         # record error/accuracy.
712         for task_idx, task_metrics in enumerate(args.metric_modules):
713             avg_metric[task_idx].update(float(metric_list[task_idx].cpu()), batch_size_cur)
715         ##########################
716         if args.tensorboard_enable:
717             write_output(args, 'Training_', num_iter, iter_id, epoch, train_dataset, train_writer, input_list, task_outputs, target_list, metric_name, writer_idx)
719         if ((iter_id % args.print_freq) == 0) or (iter_id == (num_iter-1)):
720             output_string = ''
721             for task_idx, task_metrics in enumerate(args.metric_modules):
722                 output_string += '[{}={}]'.format(metric_names[task_idx], str(avg_metric[task_idx]))
724             epoch_str = '{}/{}'.format(epoch + 1, args.epochs)
725             progress_bar.set_description("{}=> {}  ".format(progressbar_color, args.phase))
726             multi_task_factors_print = ['{:.3f}'.format(float(lmf)) for lmf in args.multi_task_factors] if args.multi_task_factors is not None else None
727             progress_bar.set_postfix(Epoch=epoch_str, LR=lr, DataTime=str(data_time), LossMult=multi_task_factors_print, Loss=avg_loss, Output=output_string)
728             progress_bar.update(iter_id-last_update_iter)
729             last_update_iter = iter_id
731         args.n_iter += 1
732         end_time = time.time()
733         writer_idx = (writer_idx + 1) % args.tensorboard_num_imgs
735         # add onnx graph to tensorboard
736         # commenting out due to issues in transitioning to pytorch 0.4
737         # (bilinear mode in upsampling causes hang or crash - may be due to align_borders change, nearest is fine)
738         #if epoch == 0 and iter_id == 0:
739         #    input_zero = torch.zeros(input_var.shape)
740         #    train_writer.add_graph(model, input_zero)
741         #This cache operation slows down tranining  
742         #torch.cuda.empty_cache()
743     #
745     if args.print_train_class_iou:
746         print_class_iou(args=args, confusion_matrix=confusion_matrix, task_idx=task_idx)
747         
748     progress_bar.close()
750     # to print a new line - do not provide end=''
751     print('{}'.format(Fore.RESET), end='')
753     if args.tensorboard_enable:
754         for task_idx, task_losses in enumerate(args.loss_modules):
755             train_writer.add_scalar('Training/Task{}_{}_Loss_Epoch'.format(task_idx,loss_names[task_idx]), float(avg_loss[task_idx]), epoch)
757         for task_idx, task_metrics in enumerate(args.metric_modules):
758             train_writer.add_scalar('Training/Task{}_{}_Metric_Epoch'.format(task_idx,metric_names[task_idx]), float(avg_metric[task_idx]), epoch)
760     output_name = metric_names[args.pivot_task_idx]
761     output_metric = float(avg_metric[args.pivot_task_idx])
763     ##########################
764     if args.quantize:
765         def debug_format(v):
766             return ('{:.3f}'.format(v) if v is not None else 'None')
767         #
768         clips_act = [m.get_clips_act()[1] for n,m in model.named_modules() if isinstance(m,xnn.layers.PAct2)]
769         if len(clips_act) > 0:
770             args.logger.debug('\nclips_act : ' + ' '.join(map(debug_format, clips_act)))
771             args.logger.debug('')
772     #
773     return output_metric, output_name
776 ###################################################################
777 def validate(args, val_dataset, val_loader, model, epoch, val_writer):
778     data_time = xnn.utils.AverageMeter()
779     # if the loss/ metric is already an average, no need to further average
780     avg_metric = [xnn.utils.AverageMeter(print_avg=(not task_metric[0].info()['is_avg'])) for task_metric in args.metric_modules]
782     ##########################
783     # switch to evaluate mode
784     model.eval()
786     ##########################
787     for task_dx, task_metrics in enumerate(args.metric_modules):
788         for midx, metric_fn in enumerate(task_metrics):
789             metric_fn.clear()
791     metric_name = "Metric"
792     end_time = time.time()
793     writer_idx = 0
794     last_update_iter = -1
795     metric_ctx = [None] * len(args.metric_modules)
797     num_iter = len(val_loader)
798     progress_bar = progiter.ProgIter(np.arange(num_iter), chunksize=1)
800     # change color to green
801     print('{}'.format(Fore.GREEN), end='')
803     ##########################
804     for iter_id, (inputs, targets) in enumerate(val_loader):
805         data_time.update(time.time() - end_time)
806         input_list = [[jj.cuda() for jj in img] if isinstance(img,(list,tuple)) else img.cuda() for img in inputs]
807         target_list = [j.cuda(non_blocking=True) for j in targets]
808         target_sizes = [tgt.shape for tgt in target_list]
809         batch_size_cur = target_sizes[0][0]
811         # compute output
812         task_outputs = model(input_list)
814         task_outputs = task_outputs if isinstance(task_outputs, (list, tuple)) else [task_outputs]
815         if args.upsample_mode is not None:
816            task_outputs = upsample_tensors(task_outputs, target_sizes, args.upsample_mode)
817         
818         if args.print_val_class_iou:
819             metric_total, metric_list, metric_names, metric_types, _, confusion_matrix = \
820                 compute_task_objectives(args, args.metric_modules, input_list, task_outputs, target_list, 
821                 get_confusion_matrix = args.print_val_class_iou)
822         else:        
823             metric_total, metric_list, metric_names, metric_types, _ = \
824                 compute_task_objectives(args, args.metric_modules, input_list, task_outputs, target_list, 
825                 get_confusion_matrix = args.print_val_class_iou)
827         # record error/accuracy.
828         for task_idx, task_metrics in enumerate(args.metric_modules):
829             avg_metric[task_idx].update(float(metric_list[task_idx].cpu()), batch_size_cur)
831         if args.tensorboard_enable:
832             write_output(args, 'Validation_', num_iter, iter_id, epoch, val_dataset, val_writer, input_list, task_outputs, target_list, metric_names, writer_idx)
834         if ((iter_id % args.print_freq) == 0) or (iter_id == (num_iter-1)):
835             output_string = ''
836             for task_idx, task_metrics in enumerate(args.metric_modules):
837                 output_string += '[{}={}]'.format(metric_names[task_idx], str(avg_metric[task_idx]))
839             epoch_str = '{}/{}'.format(epoch + 1, args.epochs)
840             progress_bar.set_description("=> validation")
841             progress_bar.set_postfix(Epoch=epoch_str, DataTime=data_time, Output="{}".format(output_string))
842             progress_bar.update(iter_id-last_update_iter)
843             last_update_iter = iter_id
844         #
846         end_time = time.time()
847         writer_idx = (writer_idx + 1) % args.tensorboard_num_imgs
848     #
850     if args.print_val_class_iou:
851         print_class_iou(args = args, confusion_matrix = confusion_matrix, task_idx=task_idx)
852     #
854     #print_conf_matrix(conf_matrix=conf_matrix, en=False)
855     progress_bar.close()
857     # to print a new line - do not provide end=''
858     print('{}'.format(Fore.RESET), end='')
860     if args.tensorboard_enable:
861         for task_idx, task_metrics in enumerate(args.metric_modules):
862             val_writer.add_scalar('Validation/Task{}_{}_Metric_Epoch'.format(task_idx,metric_names[task_idx]), float(avg_metric[task_idx]), epoch)
864     output_name = metric_names[args.pivot_task_idx]
865     output_metric = float(avg_metric[args.pivot_task_idx])
866     return output_metric, output_name
869 ###################################################################
870 def close(args):
871     if args.logger is not None:
872         del args.logger
873         args.logger = None
874     #
875     args.best_metric = -1
879 def get_save_path(args, phase=None):
880     date = args.date if args.date else datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
881     save_path = os.path.join('./data/checkpoints', args.dataset_name, date + '_' + args.dataset_name + '_' + args.model_name)
882     save_path += '_resize{}x{}_traincrop{}x{}'.format(args.img_resize[1], args.img_resize[0], args.rand_crop[1], args.rand_crop[0])
883     phase = phase if (phase is not None) else args.phase
884     save_path = os.path.join(save_path, phase)
885     return save_path
888 def get_model_orig(model):
889     is_parallel_model = isinstance(model, (torch.nn.DataParallel, torch.nn.parallel.DistributedDataParallel))
890     model_orig = (model.module if is_parallel_model else model)
891     model_orig = (model_orig.module if isinstance(model_orig, (xnn.quantize.QuantBaseModule)) else model_orig)
892     return model_orig
895 def create_rand_inputs(args, is_cuda):
896     dummy_input = []
897     if not args.model_config.input_nv12:
898         for i_ch in args.model_config.input_channels:
899             x = torch.rand((1, i_ch, args.img_resize[0], args.img_resize[1]))
900             x = x.cuda() if is_cuda else x
901             dummy_input.append(x)
902     else: #nv12    
903         for i_ch in args.model_config.input_channels:
904             y = torch.rand((1, 1, args.img_resize[0], args.img_resize[1]))
905             uv = torch.rand((1, 1, args.img_resize[0]//2, args.img_resize[1]))
906             y = y.cuda() if is_cuda else y
907             uv = uv.cuda() if is_cuda else uv
908             dummy_input.append([y,uv])
910     return dummy_input
912 def count_flops(args, model):
913     is_cuda = next(model.parameters()).is_cuda
914     dummy_input = create_rand_inputs(args, is_cuda)
915     #
916     model.eval()
917     flops = xnn.utils.forward_count_flops(model, dummy_input)
918     gflops = flops/1e9
919     print('=> Size = {}, GFLOPs = {}, GMACs = {}'.format(args.img_resize, gflops, gflops/2))
922 def derive_node_name(input_name):
923     #take last entry of input names for deciding node name
924     #print("input_name[-1]: ", input_name[-1])
925     node_name = input_name[-1].rsplit('.', 1)[0]
926     #print("formed node_name: ", node_name)
927     return node_name
930 #torch onnx export does not update names. Do it using onnx.save
931 def add_node_names(onnx_model_name):
932     onnx_model = onnx.load(onnx_model_name)
933     for i in range(len(onnx_model.graph.node)):
934         for j in range(len(onnx_model.graph.node[i].input)):
935             #print('-'*60)
936             #print("name: ", onnx_model.graph.node[i].name)
937             #print("input: ", onnx_model.graph.node[i].input)
938             #print("output: ", onnx_model.graph.node[i].output)
939             onnx_model.graph.node[i].input[j] = onnx_model.graph.node[i].input[j].split(':')[0]
940             onnx_model.graph.node[i].name = derive_node_name(onnx_model.graph.node[i].input)
941         #
942     #
943     #update model inplace
944     onnx.save(onnx_model, onnx_model_name)
947 def write_onnx_model(args, model, save_path, name='checkpoint.onnx', save_traced_model=False):
948     is_cuda = next(model.parameters()).is_cuda
949     input_list = create_rand_inputs(args, is_cuda=is_cuda)
950     #
951     model.eval()
952     torch.onnx.export(model, input_list, os.path.join(save_path, name), export_params=True, verbose=False,
953                       do_constant_folding=True, opset_version=args.opset_version)
955     #torch onnx export does not update names. Do it using onnx.save
956     add_node_names(onnx_model_name = os.path.join(save_path, name))
958     if save_traced_model:
959         traced_model = torch.jit.trace(model, (input_list,))
960         traced_save_path = os.path.join(save_path, 'traced_model.pth')
961         torch.jit.save(traced_model, traced_save_path)
962     #
965 ###################################################################
966 def write_output(args, prefix, val_epoch_size, iter_id, epoch, dataset, output_writer, input_images, task_outputs, task_targets, metric_names, writer_idx):
967     write_freq = (args.tensorboard_num_imgs / float(val_epoch_size))
968     write_prob = np.random.random()
969     if (write_prob > write_freq):
970         return
971     if args.model_config.input_nv12:
972         batch_size = input_images[0][0].shape[0]
973     else:
974         batch_size = input_images[0].shape[0]
975     b_index = random.randint(0, batch_size - 1)
977     input_image = None
978     for img_idx, img in enumerate(input_images):
979         if args.model_config.input_nv12:
980             #convert NV12 to BGR for tensorboard
981             input_image = vision.transforms.image_transforms_xv12.nv12_to_bgr_image(Y = input_images[img_idx][0][b_index], UV = input_images[img_idx][1][b_index],
982                                    image_scale=args.image_scale, image_mean=args.image_mean)
983         else:
984             input_image = input_images[img_idx][b_index].cpu().numpy().transpose((1, 2, 0))
985             # convert back to original input range (0-255)
986             input_image = input_image / args.image_scale + args.image_mean
988         if args.is_flow and args.is_flow[0][img_idx]:
989             #input corresponding to flow is assumed to have been generated by adding 128
990             flow = input_image - 128
991             flow_hsv = xnn.utils.flow2hsv(flow.transpose(2, 0, 1), confidence=False).transpose(2, 0, 1)
992             #flow_hsv = (flow_hsv / 255.0).clip(0, 1) #TODO: check this
993             output_writer.add_image(prefix +'Input{}/{}'.format(img_idx, writer_idx), flow_hsv, epoch)
994         else:
995             input_image = (input_image/255.0).clip(0,1) #.astype(np.uint8)
996             output_writer.add_image(prefix + 'Input{}/{}'.format(img_idx, writer_idx), input_image.transpose((2,0,1)), epoch)
998     # for sparse data, chroma blending does not look good
999     for task_idx, output_type in enumerate(args.model_config.output_type):
1000         # metric_name = metric_names[task_idx]
1001         output = task_outputs[task_idx]
1002         target = task_targets[task_idx]
1003         if (output_type == 'segmentation') and hasattr(dataset, 'decode_segmap'):
1004             segmentation_target = dataset.decode_segmap(target[b_index,0].cpu().numpy())
1005             segmentation_output = output.max(dim=1,keepdim=True)[1].data.cpu().numpy() if(output.shape[1]>1) else output.data.cpu().numpy()
1006             segmentation_output = dataset.decode_segmap(segmentation_output[b_index,0])
1007             segmentation_output_blend = xnn.utils.chroma_blend(input_image, segmentation_output)
1008             #
1009             output_writer.add_image(prefix+'Task{}_{}_GT/{}'.format(task_idx,output_type,writer_idx), segmentation_target.transpose(2,0,1), epoch)
1010             if not args.sparse:
1011                 segmentation_target_blend = xnn.utils.chroma_blend(input_image, segmentation_target)
1012                 output_writer.add_image(prefix + 'Task{}_{}_GT_ColorBlend/{}'.format(task_idx, output_type, writer_idx), segmentation_target_blend.transpose(2, 0, 1), epoch)
1013             #
1014             output_writer.add_image(prefix+'Task{}_{}_Output/{}'.format(task_idx,output_type,writer_idx), segmentation_output.transpose(2,0,1), epoch)
1015             output_writer.add_image(prefix+'Task{}_{}_Output_ColorBlend/{}'.format(task_idx,output_type,writer_idx), segmentation_output_blend.transpose(2,0,1), epoch)
1016         elif (output_type in ('depth', 'disparity')):
1017             depth_chanidx = 0
1018             output_writer.add_image(prefix+'Task{}_{}_GT_Color_Visualization/{}'.format(task_idx,output_type,writer_idx), xnn.utils.tensor2array(target[b_index][depth_chanidx].cpu(), max_value=args.max_depth, colormap=args.viz_colormap).transpose(2,0,1), epoch)
1019             if not args.sparse:
1020                 output_writer.add_image(prefix + 'Task{}_{}_GT_ColorBlend_Visualization/{}'.format(task_idx, output_type, writer_idx), xnn.utils.tensor2array(target[b_index][depth_chanidx].cpu(), max_value=args.max_depth, colormap=args.viz_colormap, input_blend=input_image).transpose(2, 0, 1), epoch)
1021             #
1022             output_writer.add_image(prefix+'Task{}_{}_Output_Color_Visualization/{}'.format(task_idx,output_type,writer_idx), xnn.utils.tensor2array(output.data[b_index][depth_chanidx].cpu(), max_value=args.max_depth, colormap=args.viz_colormap).transpose(2,0,1), epoch)
1023             output_writer.add_image(prefix + 'Task{}_{}_Output_ColorBlend_Visualization/{}'.format(task_idx, output_type, writer_idx),xnn.utils.tensor2array(output.data[b_index][depth_chanidx].cpu(), max_value=args.max_depth, colormap=args.viz_colormap, input_blend=input_image).transpose(2, 0, 1), epoch)
1024         elif (output_type == 'flow'):
1025             max_value_flow = 10.0 # only for visualization
1026             output_writer.add_image(prefix+'Task{}_{}_GT/{}'.format(task_idx,output_type,writer_idx), xnn.utils.flow2hsv(target[b_index][:2].cpu().numpy(), max_value=max_value_flow).transpose(2,0,1), epoch)
1027             output_writer.add_image(prefix+'Task{}_{}_Output/{}'.format(task_idx,output_type,writer_idx), xnn.utils.flow2hsv(output.data[b_index][:2].cpu().numpy(), max_value=max_value_flow).transpose(2,0,1), epoch)
1028         elif (output_type == 'interest_pt'):
1029             score_chanidx = 0
1030             target_score_to_write = target[b_index][score_chanidx].cpu()
1031             output_score_to_write = output.data[b_index][score_chanidx].cpu()
1032             
1033             #if score is learnt as zero mean add offset to make it [0-255]
1034             if args.make_score_zero_mean:
1035                 # target_score_to_write!=0 : value 0 indicates GT unavailble. Leave them to be 0.
1036                 target_score_to_write[target_score_to_write!=0] += 128.0
1037                 output_score_to_write += 128.0
1039             max_value_score = float(torch.max(target_score_to_write)) #0.002
1040             output_writer.add_image(prefix+'Task{}_{}_GT_Bone_Visualization/{}'.format(task_idx,output_type,writer_idx), xnn.utils.tensor2array(target_score_to_write, max_value=max_value_score, colormap='bone').transpose(2,0,1), epoch)
1041             output_writer.add_image(prefix+'Task{}_{}_Output_Bone_Visualization/{}'.format(task_idx,output_type,writer_idx), xnn.utils.tensor2array(output_score_to_write, max_value=max_value_score, colormap='bone').transpose(2,0,1), epoch)
1042         #
1044 def print_conf_matrix(conf_matrix = [], en = False):
1045     if not en:
1046         return
1047     num_rows = conf_matrix.shape[0]
1048     num_cols = conf_matrix.shape[1]
1049     print("-"*64)
1050     num_ele = 1
1051     for r_idx in range(num_rows):
1052         print("\n")
1053         for c_idx in range(0,num_cols,num_ele):
1054             print(conf_matrix[r_idx][c_idx:c_idx+num_ele], end="")
1055     print("\n")
1056     print("-" * 64)
1058 def compute_task_objectives(args, objective_fns, input_var, task_outputs, task_targets, task_mults=None, 
1059   task_offsets=None, loss_mult_factors=None, get_confusion_matrix = False):
1060   
1061     ##########################
1062     objective_total = torch.zeros_like(task_outputs[0].view(-1)[0])
1063     objective_list = []
1064     objective_list_orig = []
1065     objective_names = []
1066     objective_types = []
1067     for task_idx, task_objectives in enumerate(objective_fns):
1068         output_type = args.model_config.output_type[task_idx]
1069         objective_sum_value = torch.zeros_like(task_outputs[task_idx].view(-1)[0])
1070         objective_sum_name = ''
1071         objective_sum_type = ''
1073         task_mult = task_mults[task_idx] if task_mults is not None else 1.0
1074         task_offset = task_offsets[task_idx] if task_offsets is not None else 0.0
1076         for oidx, objective_fn in enumerate(task_objectives):
1077             objective_batch = objective_fn(input_var, task_outputs[task_idx], task_targets[task_idx])
1078             objective_batch = objective_batch.mean() if isinstance(objective_fn, torch.nn.DataParallel) else objective_batch
1079             objective_name = objective_fn.info()['name']
1080             objective_type = objective_fn.info()['is_avg']
1081             if get_confusion_matrix:
1082                 confusion_matrix = objective_fn.info()['confusion_matrix']
1084             loss_mult = loss_mult_factors[task_idx][oidx] if (loss_mult_factors is not None) else 1.0
1085             # --
1086             objective_batch_not_nan = (objective_batch if not torch.isnan(objective_batch) else 0.0)
1087             objective_sum_value = objective_batch_not_nan*loss_mult + objective_sum_value
1088             objective_sum_name += (objective_name if (objective_sum_name == '') else ('+' + objective_name))
1089             assert (objective_sum_type == '' or objective_sum_type == objective_type), 'metric types (avg/val) for a given task should match'
1090             objective_sum_type = objective_type
1092         objective_list.append(objective_sum_value)
1093         objective_list_orig.append(objective_sum_value)
1094         objective_names.append(objective_sum_name)
1095         objective_types.append(objective_sum_type)
1097         objective_total = objective_sum_value*task_mult + task_offset + objective_total
1099     return_list = [objective_total, objective_list, objective_names, objective_types, objective_list_orig]
1100     if get_confusion_matrix:
1101         return_list.append(confusion_matrix)
1103     return return_list 
1106 def save_checkpoint(args, save_path, model, checkpoint_dict, is_best, filename='checkpoint.pth'):
1107     torch.save(checkpoint_dict, os.path.join(save_path,filename))
1108     if is_best:
1109         shutil.copyfile(os.path.join(save_path,filename), os.path.join(save_path,'model_best.pth'))
1110     #
1111     if args.save_onnx:
1112         write_onnx_model(args, model, save_path, name='checkpoint.onnx')
1113         if is_best:
1114             write_onnx_model(args, model, save_path, name='model_best.onnx')
1115     #
1118 def get_dataset_sampler(dataset_object, epoch_size):
1119     print('=> creating a random sampler as epoch_size is specified')
1120     num_samples = len(dataset_object)
1121     epoch_size = int(epoch_size * num_samples) if epoch_size < 1 else int(epoch_size)
1122     dataset_sampler = torch.utils.data.sampler.RandomSampler(data_source=dataset_object, replacement=True, num_samples=epoch_size)
1123     return dataset_sampler
1126 def get_train_transform(args):
1127     # image normalization can be at the beginning of transforms or at the end
1128     image_mean = np.array(args.image_mean, dtype=np.float32)
1129     image_scale = np.array(args.image_scale, dtype=np.float32)
1130     image_prenorm = vision.transforms.image_transforms.NormalizeMeanScale(mean=image_mean, scale=image_scale) if args.image_prenorm else None
1131     image_postnorm = vision.transforms.image_transforms.NormalizeMeanScale(mean=image_mean, scale=image_scale) if (not image_prenorm) else None
1133     # crop size used only for training
1134     image_train_output_scaling = vision.transforms.image_transforms.Scale(args.rand_resize, target_size=args.rand_output_size, is_flow=args.is_flow) \
1135         if (args.rand_output_size is not None and args.rand_output_size != args.rand_resize) else None
1136     train_transform = vision.transforms.image_transforms.Compose([
1137         image_prenorm,
1138         vision.transforms.image_transforms.AlignImages(),
1139         vision.transforms.image_transforms.MaskTarget(args.target_mask, 0),
1140         vision.transforms.image_transforms.CropRect(args.img_border_crop),
1141         vision.transforms.image_transforms.RandomRotate(args.transform_rotation, is_flow=args.is_flow) if args.transform_rotation else None,
1142         vision.transforms.image_transforms.RandomScaleCrop(args.rand_resize, scale_range=args.rand_scale, is_flow=args.is_flow),
1143         vision.transforms.image_transforms.RandomHorizontalFlip(is_flow=args.is_flow),
1144         vision.transforms.image_transforms.RandomCrop(args.rand_crop),
1145         vision.transforms.image_transforms.RandomColor2Gray(is_flow=args.is_flow, random_threshold=0.5) if 'tiad' in args.dataset_name else None,
1146         image_train_output_scaling,
1147         image_postnorm,
1148         vision.transforms.image_transforms.ConvertToTensor()
1149         ])
1150     return train_transform
1153 def get_validation_transform(args):
1154     # image normalization can be at the beginning of transforms or at the end
1155     image_mean = np.array(args.image_mean, dtype=np.float32)
1156     image_scale = np.array(args.image_scale, dtype=np.float32)
1157     image_prenorm = vision.transforms.image_transforms.NormalizeMeanScale(mean=image_mean, scale=image_scale) if args.image_prenorm else None
1158     image_postnorm = vision.transforms.image_transforms.NormalizeMeanScale(mean=image_mean, scale=image_scale) if (not image_prenorm) else None
1160     # prediction is resized to output_size before evaluation.
1161     val_transform = vision.transforms.image_transforms.Compose([
1162         image_prenorm,
1163         vision.transforms.image_transforms.AlignImages(),
1164         vision.transforms.image_transforms.MaskTarget(args.target_mask, 0),
1165         vision.transforms.image_transforms.CropRect(args.img_border_crop),
1166         vision.transforms.image_transforms.Scale(args.img_resize, target_size=args.output_size, is_flow=args.is_flow),
1167         image_postnorm,
1168         vision.transforms.image_transforms.ConvertToTensor()
1169         ])
1170     return val_transform
1173 def get_transforms(args):
1174     # Provision to train with val transform - provide rand_scale as (0, 0)
1175     # Fixing the train-test resolution discrepancy, https://arxiv.org/abs/1906.06423
1176     always_use_val_transform = (args.rand_scale[0] == 0)
1177     train_transform = get_validation_transform(args) if always_use_val_transform else get_train_transform(args)
1178     val_transform = get_validation_transform(args)
1179     return train_transform, val_transform
1182 def _upsample_impl(tensor, output_size, upsample_mode):
1183     # upsample of long tensor is not supported currently. covert to float, just to avoid error.
1184     # we can do thsi only in the case of nearest mode, otherwise output will have invalid values.
1185     convert_to_float = False
1186     if isinstance(tensor, (torch.LongTensor,torch.cuda.LongTensor)):
1187         convert_to_float = True
1188         original_dtype = tensor.dtype
1189         tensor = tensor.float()
1190         upsample_mode = 'nearest'
1192     dim_added = False
1193     if len(tensor.shape) < 4:
1194         tensor = tensor[np.newaxis,...]
1195         dim_added = True
1197     if (tensor.size()[-2:] != output_size):
1198         tensor = torch.nn.functional.interpolate(tensor, output_size, mode=upsample_mode)
1200     if dim_added:
1201         tensor = tensor[0,...]
1203     if convert_to_float:
1204         tensor = tensor.long() #tensor.astype(original_dtype)
1206     return tensor
1209 def upsample_tensors(tensors, output_sizes, upsample_mode):
1210     if isinstance(tensors, (list,tuple)):
1211         for tidx, tensor in enumerate(tensors):
1212             tensors[tidx] = _upsample_impl(tensor, output_sizes[tidx][-2:], upsample_mode)
1213         #
1214     else:
1215         tensors = _upsample_impl(tensors, output_sizes[0][-2:], upsample_mode)
1216     return tensors
1218 #print IoU for each class
1219 def print_class_iou(args = None, confusion_matrix = None, task_idx = 0):    
1220     n_classes = args.model_config.output_channels[task_idx]
1221     [accuracy, mean_iou, iou, f1_score] = compute_accuracy(args, confusion_matrix, n_classes)
1222     print("\n Class IoU: [", end = "")
1223     for class_iou in iou:
1224         print("{:0.3f}".format(class_iou), end=",")
1225     print("]")    
1227 if __name__ == '__main__':
1228     train_args = get_config()
1229     main(train_args)