[jacinto-ai/pytorch-jacinto-ai-devkit.git] / modules / pytorch_jacinto_ai / engine / train_pixel2pixel.py
1 # Copyright (c) 2018-2021, Texas Instruments
2 # All Rights Reserved.
3 #
4 # Redistribution and use in source and binary forms, with or without
5 # modification, are permitted provided that the following conditions are met:
6 #
7 # * Redistributions of source code must retain the above copyright notice, this
8 # list of conditions and the following disclaimer.
9 #
10 # * Redistributions in binary form must reproduce the above copyright notice,
11 # this list of conditions and the following disclaimer in the documentation
12 # and/or other materials provided with the distribution.
13 #
14 # * Neither the name of the copyright holder nor the names of its
15 # contributors may be used to endorse or promote products derived from
16 # this software without specific prior written permission.
17 #
18 # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
19 # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
20 # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
21 # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
22 # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
23 # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
24 # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
25 # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
26 # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
27 # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29 import os
30 import shutil
31 import time
32 import math
33 import copy
35 import torch
36 import torch.nn.parallel
37 import torch.backends.cudnn as cudnn
38 import torch.optim
39 import torch.utils.data
40 import torch.onnx
41 import onnx
43 import datetime
44 from torch.utils.tensorboard import SummaryWriter
45 import numpy as np
46 import random
47 import cv2
48 from colorama import Fore
49 import progiter
50 from packaging import version
51 import warnings
53 from .. import xnn
54 from .. import xvision
55 from . infer_pixel2pixel import compute_accuracy
58 ##################################################
59 warnings.filterwarnings('ignore', category=torch.jit.TracerWarning)
61 ##################################################
62 def get_config():
63 args = xnn.utils.ConfigNode()
65 args.dataset_config = xnn.utils.ConfigNode()
66 args.dataset_config.split_name = 'val'
67 args.dataset_config.max_depth_bfr_scaling = 80
68 args.dataset_config.depth_scale = 1
69 args.dataset_config.train_depth_log = 1
70 args.use_semseg_for_depth = False
72 # model config
73 args.model_config = xnn.utils.ConfigNode()
74 args.model_config.output_type = ['segmentation'] # the network is used to predict flow or depth or sceneflow
75 args.model_config.output_channels = None # number of output channels
76 args.model_config.prediction_channels = None # intermediate number of channels before final output_channels
77 args.model_config.input_channels = None # number of input channels
78 args.model_config.final_upsample = True # use final upsample to input resolution or not
79 args.model_config.output_range = None # max range of output
80 args.model_config.num_decoders = None # number of decoders to use. [options: 0, 1, None]
81 args.model_config.freeze_encoder = False # do not update encoder weights
82 args.model_config.freeze_decoder = False # do not update decoder weights
83 args.model_config.multi_task_type = 'learned' # find out loss multiplier by learning, choices=[None, 'learned', 'uncertainty', 'grad_norm', 'dwa_grad_norm']
84 args.model_config.target_input_ratio = 1 # Keep target size same as input size
85 args.model_config.input_nv12 = False # convert input to nv12 format
86 args.model_config.enable_fp16 = False # faster training if the GPU supports fp16
88 args.model = None # the model itself can be given from ouside
89 args.model_name = 'deeplabv2lite_mobilenetv2' # model architecture, overwritten if pretrained is specified
90 args.dataset_name = 'cityscapes_segmentation' # dataset type
91 args.transforms = None # the transforms itself can be given from outside
92 args.input_channel_reverse = False # reverse input channels, for example RGB to BGR
94 args.data_path = './data/cityscapes' # 'path to dataset'
95 args.save_path = None # checkpoints save path
96 args.phase = 'training' # training/calibration/validation
97 args.date = None # date to add to save path. if this is None, current date will be added.
99 args.logger = None # logger stream to output into
100 args.show_gpu_usage = False # Shows gpu usage at the begining of each training epoch
102 args.split_file = None # train_val split file
103 args.split_files = None # split list files. eg: train.txt val.txt
104 args.split_value = None # test_val split proportion (between 0 (only test) and 1 (only train))
106 args.optimizer = 'adam' # optimizer algorithms, choices=['adam','sgd']
107 args.scheduler = 'step' # scheduler algorithms, choices=['step','poly', 'cosine']
108 args.workers = 8 # number of data loading workers
110 args.epochs = 250 # number of total epochs to run
111 args.start_epoch = 0 # manual epoch number (useful on restarts)
113 args.epoch_size = 0 # manual epoch size (will match dataset size if not specified)
114 args.epoch_size_val = 0 # manual epoch size (will match dataset size if not specified)
115 args.batch_size = 12 # mini_batch size
116 args.total_batch_size = None # accumulated batch size. total_batch_size = batch_size*iter_size
117 args.iter_size = 1 # iteration size. total_batch_size = batch_size*iter_size
119 args.lr = 1e-4 # initial learning rate
120 args.lr_clips = None # use args.lr itself if it is None
121 args.lr_calib = 0.05 # lr for bias calibration
122 args.warmup_epochs = 5 # number of epochs to warmup
123 args.warmup_factor = 1e-3 # max lr allowed for the first epoch during warmup (as a factor of initial lr)
125 args.momentum = 0.9 # momentum for sgd, alpha parameter for adam
126 args.beta = 0.999 # beta parameter for adam
127 args.weight_decay = 1e-4 # weight decay
128 args.bias_decay = None # bias decay
130 args.sparse = True # avoid invalid/ignored target pixels from loss computation, use NEAREST for interpolation
132 args.tensorboard_num_imgs = 5 # number of imgs to display in tensorboard
133 args.pretrained = None # path to pre_trained model
134 args.resume = None # path to latest checkpoint (default: none)
135 args.no_date = False # don\'t append date timestamp to folder
136 args.print_freq = 100 # print frequency (default: 100)
138 args.milestones = (100, 200) # epochs at which learning rate is divided by 2
140 args.losses = ['segmentation_loss'] # loss functions to mchoices=['step','poly', 'cosine'],loss multiplication factor')
141 args.metrics = ['segmentation_metrics'] # metric/measurement/error functions for train/validation
142 args.multi_task_factors = None # loss mult factors
143 args.class_weights = None # class weights
145 args.loss_mult_factors = None # fixed loss mult factors - per loss - not: this is different from multi_task_factors (which is per task)
147 args.multistep_gamma = 0.5 # steps for step scheduler
148 args.polystep_power = 1.0 # power for polynomial scheduler
150 args.rand_seed = 1 # random seed
151 args.img_border_crop = None # image border crop rectangle. can be relative or absolute
152 args.target_mask = None # mask rectangle. can be relative or absolute. last value is the mask value
154 args.rand_resize = None # random image size to be resized to during training
155 args.rand_output_size = None # output size to be resized to during training
156 args.rand_scale = (1.0, 2.0) # random scale range for training
157 args.rand_crop = None # image size to be cropped to
159 args.img_resize = None # image size to be resized to during evaluation
160 args.output_size = None # target output size to be resized to
162 args.count_flops = True # count flops and report
164 args.shuffle = True # shuffle or not
165 args.shuffle_val = True # shuffle val dataset or not
167 args.transform_rotation = 0. # apply rotation augumentation. value is rotation in degrees. 0 indicates no rotation
168 args.is_flow = None # whether entries in images and targets lists are optical flow or not
170 args.upsample_mode = 'bilinear' # upsample mode to use, choices=['nearest','bilinear']
172 args.image_prenorm = True # whether normalization is done before all other the transforms
173 args.image_mean = (128.0,) # image mean for input image normalization
174 args.image_scale = (1.0 / (0.25 * 256),) # image scaling/mult for input iamge normalization
176 args.max_depth = 80 # maximum depth to be used for visualization
178 args.pivot_task_idx = 0 # task id to select best model
180 args.parallel_model = True # Usedata parallel for model
181 args.parallel_criterion = True # Usedata parallel for loss and metric
183 args.evaluate_start = True # evaluate right at the begining of training or not
184 args.save_onnx = True # apply quantized inference or not
185 args.print_model = False # print the model to text
186 args.run_soon = True # To start training after generating configs/models
188 args.quantize = False # apply quantized inference or not
189 #args.model_surgery = None # replace activations with PAct2 activation module. Helpful in quantized training.
190 args.bitwidth_weights = 8 # bitwidth for weights
191 args.bitwidth_activations = 8 # bitwidth for activations
192 args.histogram_range = True # histogram range for calibration
193 args.bias_calibration = True # apply bias correction during quantized inference calibration
194 args.per_channel_q = False # apply separate quantizion factor for each channel in depthwise or not
195 args.constrain_bias = None # constrain bias according to the constraints of convolution engine
197 args.save_mod_files = False # saves modified files after last commit. Also stores commit id.
198 args.make_score_zero_mean = False # make score zero mean while learning
199 args.no_q_for_dws_layer_idx = 0 # no_q_for_dws_layer_idx
201 args.viz_colormap = 'rainbow' # colormap for tensorboard: 'rainbow', 'plasma', 'magma', 'bone'
203 args.freeze_bn = False # freeze the statistics of bn
204 args.tensorboard_enable = True # en/disable of TB writing
205 args.print_train_class_iou = False
206 args.print_val_class_iou = False
207 args.freeze_layers = None
208 args.opset_version = 11 # onnx opset_version
209 args.prob_color_to_gray = (0.0,0.0) # this will be used for controlling color 2 gray augmentation
211 args.interpolation = None # interpolation method to be used for resize. one of cv2.INTER_
212 return args
215 # ################################################
216 # to avoid hangs in data loader with multi threads
217 # this was observed after using cv2 image processing functions
218 # https://github.com/pytorch/pytorch/issues/1355
219 cv2.setNumThreads(0)
221 # ################################################
222 def main(args):
223 # ensure pytorch version is 1.2 or higher
224 assert version.parse(torch.__version__) >= version.parse('1.1'), \
225 'torch version must be 1.1 or higher, due to the change in scheduler.step() and optimiser.step() call order'
227 assert (not hasattr(args, 'evaluate')), 'args.evaluate is deprecated. use args.phase=training or calibration or validation'
228 assert is_valid_phase(args.phase), f'invalid phase {args.phase}'
229 assert not hasattr(args, 'model_surgery'), 'the argument model_surgery is deprecated, it is not needed now - remove it'
231 if (args.phase == 'validation' and args.bias_calibration):
232 args.bias_calibration = False
233 warnings.warn('switching off bias calibration in validation')
234 #
236 #################################################
237 args.rand_resize = args.img_resize if args.rand_resize is None else args.rand_resize
238 args.rand_crop = args.img_resize if args.rand_crop is None else args.rand_crop
239 args.output_size = args.img_resize if args.output_size is None else args.output_size
240 # resume has higher priority
241 args.pretrained = None if (args.resume is not None) else args.pretrained
243 # prob_color_to_gray will be used for controlling color 2 gray augmentation
244 if 'tiad' in args.dataset_name and args.prob_color_to_gray == (0.0, 0.0):
245 #override in case of 'tiad' if default values are used
246 args.prob_color_to_gray = (0.5, 0.0)
248 if args.save_path is None:
249 save_path = get_save_path(args)
250 else:
251 save_path = args.save_path
252 #
253 if not os.path.exists(save_path):
254 os.makedirs(save_path)
256 if args.save_mod_files:
257 #store all the files after the last commit.
258 mod_files_path = save_path+'/mod_files'
259 os.makedirs(mod_files_path)
261 cmd = "git ls-files --modified | xargs -i cp {} {}".format("{}", mod_files_path)
262 print("cmd:", cmd)
263 os.system(cmd)
265 #stoe last commit id.
266 cmd = "git log -n 1 >> {}".format(mod_files_path + '/commit_id.txt')
267 print("cmd:", cmd)
268 os.system(cmd)
270 #################################################
271 if args.logger is None:
272 log_file = os.path.splitext(os.path.basename(__file__))[0] + '.log'
273 args.logger = xnn.utils.TeeLogger(filename=os.path.join(save_path,log_file))
275 #################################################
276 # global settings. rand seeds for repeatability
277 random.seed(args.rand_seed)
278 np.random.seed(args.rand_seed)
279 torch.manual_seed(args.rand_seed)
280 torch.cuda.manual_seed(args.rand_seed)
282 ################################
283 # args check and config
284 if args.iter_size != 1 and args.total_batch_size is not None:
285 warnings.warn("only one of --iter_size or --total_batch_size must be set")
286 #
287 if args.total_batch_size is not None:
288 args.iter_size = args.total_batch_size//args.batch_size
289 else:
290 args.total_batch_size = args.batch_size*args.iter_size
292 #################################################
293 # set some global flags and initializations
294 # keep it in args for now - although they don't belong here strictly
295 # using pin_memory is seen to cause issues, especially when when lot of memory is used.
296 args.use_pinned_memory = False
297 args.n_iter = 0
298 args.best_metric = -1
299 cudnn.benchmark = True
300 # torch.autograd.set_detect_anomaly(True)
302 ################################
303 # reset character color, in case it is different
304 print('{}'.format(Fore.RESET))
305 # print everything for log
306 print('=> args: {}'.format(args))
307 print('\n'.join("%s: %s" % item for item in sorted(vars(args).items())))
309 print('=> will save everything to {}'.format(save_path))
311 #################################################
312 train_writer = SummaryWriter(os.path.join(save_path,'train')) if args.tensorboard_enable else None
313 val_writer = SummaryWriter(os.path.join(save_path,'val')) if args.tensorboard_enable else None
314 transforms = get_transforms(args) if args.transforms is None else args.transforms
315 assert isinstance(transforms, (list,tuple)) and len(transforms) == 2, 'incorrect transforms were given'
317 print("=> fetching images in '{}'".format(args.data_path))
318 split_arg = args.split_file if args.split_file else (args.split_files if args.split_files else args.split_value)
319 train_dataset, val_dataset = xvision.datasets.pixel2pixel.__dict__[args.dataset_name](args.dataset_config, args.data_path, split=split_arg, transforms=transforms)
321 #################################################
322 print('=> {} samples found, {} train samples and {} test samples '.format(len(train_dataset)+len(val_dataset),
323 len(train_dataset), len(val_dataset)))
324 train_sampler = get_dataset_sampler(train_dataset, args.epoch_size) if args.epoch_size != 0 else None
325 shuffle_train = args.shuffle and (train_sampler is None)
326 train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=args.batch_size,
327 num_workers=args.workers, pin_memory=args.use_pinned_memory, sampler=train_sampler, shuffle=shuffle_train)
329 val_sampler = get_dataset_sampler(val_dataset, args.epoch_size_val) if args.epoch_size_val != 0 else None
330 shuffle_val = args.shuffle_val and (val_sampler is None)
331 val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=args.batch_size,
332 num_workers=args.workers, pin_memory=args.use_pinned_memory, sampler=val_sampler, shuffle=shuffle_val)
334 #################################################
335 if (args.model_config.input_channels is None):
336 args.model_config.input_channels = (3,)
337 print("=> input channels is not given - setting to {}".format(args.model_config.input_channels))
339 if (args.model_config.output_channels is None):
340 if ('num_classes' in dir(train_dataset)):
341 args.model_config.output_channels = train_dataset.num_classes()
342 else:
343 args.model_config.output_channels = (2 if args.model_config.output_type == 'flow' else args.model_config.output_channels)
344 xnn.utils.print_yellow("=> output channels is not given - setting to {} - not sure to work".format(args.model_config.output_channels))
345 #
346 if not isinstance(args.model_config.output_channels,(list,tuple)):
347 args.model_config.output_channels = [args.model_config.output_channels]
349 if (args.class_weights is None) and ('class_weights' in dir(train_dataset)):
350 args.class_weights = train_dataset.class_weights()
351 if not isinstance(args.class_weights, (list,tuple)):
352 args.class_weights = [args.class_weights]
353 #
354 print("=> class weights available for dataset: {}".format(args.class_weights))
356 #################################################
357 pretrained_data = None
358 model_surgery_quantize = False
359 pretrained_data = None
360 if args.pretrained and args.pretrained != "None":
361 pretrained_data = []
362 pretrained_files = args.pretrained if isinstance(args.pretrained,(list,tuple)) else [args.pretrained]
363 for p in pretrained_files:
364 if isinstance(p, dict):
365 p_data = p
366 else:
367 if p.startswith('http://') or p.startswith('https://'):
368 p_file = xvision.datasets.utils.download_url(p, './data/downloads')
369 else:
370 p_file = p
371 #
372 print(f'=> loading pretrained weights file: {p}')
373 p_data = torch.load(p_file)
374 #
375 pretrained_data.append(p_data)
376 model_surgery_quantize = p_data['quantize'] if 'quantize' in p_data else False
377 #
379 #################################################
380 # create model
381 is_onnx_model = False
382 if isinstance(args.model, torch.nn.Module):
383 model, change_names_dict = args.model if isinstance(args.model, (list, tuple)) else (args.model, None)
384 assert isinstance(model, torch.nn.Module), 'args.model, if provided must be a valid torch.nn.Module'
385 elif isinstance(args.model, str) and args.model.endswith('.onnx'):
386 model = xnn.onnx.import_onnx(args.model)
387 is_onnx_model = True
388 else:
389 xnn.utils.print_yellow("=> creating model '{}'".format(args.model_name))
390 model = xvision.models.pixel2pixel.__dict__[args.model_name](args.model_config)
391 # check if we got the model as well as parameters to change the names in pretrained
392 model, change_names_dict = model if isinstance(model, (list,tuple)) else (model,None)
393 #
395 if args.quantize:
396 # dummy input is used by quantized models to analyze graph
397 is_cuda = next(model.parameters()).is_cuda
398 dummy_input = create_rand_inputs(args, is_cuda=is_cuda)
399 #
400 if 'training' in args.phase:
401 model = xnn.quantize.QuantTrainModule(model, per_channel_q=args.per_channel_q,
402 histogram_range=args.histogram_range, bitwidth_weights=args.bitwidth_weights,
403 bitwidth_activations=args.bitwidth_activations, constrain_bias=args.constrain_bias,
404 dummy_input=dummy_input)
405 elif 'calibration' in args.phase:
406 model = xnn.quantize.QuantCalibrateModule(model, per_channel_q=args.per_channel_q,
407 bitwidth_weights=args.bitwidth_weights, bitwidth_activations=args.bitwidth_activations,
408 histogram_range=args.histogram_range, constrain_bias=args.constrain_bias,
409 bias_calibration=args.bias_calibration, dummy_input=dummy_input, lr_calib=args.lr_calib)
410 elif 'validation' in args.phase:
411 # Note: bias_calibration is not emabled
412 model = xnn.quantize.QuantTestModule(model, per_channel_q=args.per_channel_q,
413 bitwidth_weights=args.bitwidth_weights, bitwidth_activations=args.bitwidth_activations,
414 histogram_range=args.histogram_range, constrain_bias=args.constrain_bias,
415 dummy_input=dummy_input, model_surgery_quantize=model_surgery_quantize)
416 else:
417 assert False, f'invalid phase {args.phase}'
418 #
420 # load pretrained model
421 if pretrained_data is not None and not is_onnx_model:
422 model_orig = get_model_orig(model)
423 for (p_data,p_file) in zip(pretrained_data, pretrained_files):
424 print("=> using pretrained weights from: {}".format(p_file))
425 if hasattr(model_orig, 'load_weights'):
426 model_orig.load_weights(pretrained=p_data, change_names_dict=change_names_dict)
427 else:
428 xnn.utils.load_weights(get_model_orig(model), pretrained=p_data, change_names_dict=change_names_dict)
429 #
430 #
431 #
433 #################################################
434 if args.count_flops:
435 count_flops(args, model)
437 #################################################
438 if args.save_onnx:
439 write_onnx_model(args, get_model_orig(model), save_path, save_traced_model=False)
440 #
442 #################################################
443 if args.print_model:
444 print(model)
445 print('\n')
446 else:
447 args.logger.debug(str(model))
448 args.logger.debug('\n')
450 #################################################
451 if (not args.run_soon):
452 print("Training not needed for now")
453 close(args)
454 exit()
456 #################################################
457 # DataParallel does not work for QuantCalibrateModule or QuantTestModule
458 if args.parallel_model and (not isinstance(model, (xnn.quantize.QuantCalibrateModule, xnn.quantize.QuantTestModule))):
459 model = torch.nn.DataParallel(model)
461 #################################################
462 model = model.cuda()
464 #################################################
465 # for help in debug/print
466 for name, module in model.named_modules():
467 module.name = name
469 #################################################
470 args.loss_modules = copy.deepcopy(args.losses)
471 for task_dx, task_losses in enumerate(args.losses):
472 for loss_idx, loss_fn in enumerate(task_losses):
473 kw_args = {}
474 loss_args = xvision.losses.__dict__[loss_fn].args()
475 for arg in loss_args:
476 if arg == 'weight' and (args.class_weights is not None):
477 kw_args.update({arg:args.class_weights[task_dx]})
478 elif arg == 'num_classes':
479 kw_args.update({arg:args.model_config.output_channels[task_dx]})
480 elif arg == 'sparse':
481 kw_args.update({arg:args.sparse})
482 elif arg == 'enable_fp16':
483 kw_args.update({arg:args.model_config.enable_fp16})
484 #
485 #
486 loss_fn_raw = xvision.losses.__dict__[loss_fn](**kw_args)
487 if args.parallel_criterion:
488 loss_fn = torch.nn.DataParallel(loss_fn_raw).cuda() if args.parallel_criterion else loss_fn_raw.cuda()
489 loss_fn.info = loss_fn_raw.info
490 loss_fn.clear = loss_fn_raw.clear
491 else:
492 loss_fn = loss_fn_raw.cuda()
493 #
494 args.loss_modules[task_dx][loss_idx] = loss_fn
495 #
497 args.metric_modules = copy.deepcopy(args.metrics)
498 for task_dx, task_metrics in enumerate(args.metrics):
499 for midx, metric_fn in enumerate(task_metrics):
500 kw_args = {}
501 loss_args = xvision.losses.__dict__[metric_fn].args()
502 for arg in loss_args:
503 if arg == 'weight':
504 kw_args.update({arg:args.class_weights[task_dx]})
505 elif arg == 'num_classes':
506 kw_args.update({arg:args.model_config.output_channels[task_dx]})
507 elif arg == 'sparse':
508 kw_args.update({arg:args.sparse})
509 elif arg == 'enable_fp16':
510 kw_args.update({arg:args.model_config.enable_fp16})
511 #
512 #
513 metric_fn_raw = xvision.losses.__dict__[metric_fn](**kw_args)
514 if args.parallel_criterion:
515 metric_fn = torch.nn.DataParallel(metric_fn_raw).cuda()
516 metric_fn.info = metric_fn_raw.info
517 metric_fn.clear = metric_fn_raw.clear
518 else:
519 metric_fn = metric_fn_raw.cuda()
520 #
521 args.metric_modules[task_dx][midx] = metric_fn
522 #
524 #################################################
525 if args.phase=='validation':
526 with torch.no_grad():
527 validate(args, val_dataset, val_loader, model, 0, val_writer)
528 #
529 close(args)
530 return
532 #################################################
533 assert(args.optimizer in ['adam', 'sgd'])
534 print('=> setting {} optimizer'.format(args.optimizer))
535 if args.lr_clips is not None:
536 learning_rate_clips = args.lr_clips if 'training' in args.phase else 0.0
537 clips_decay = args.bias_decay if (args.bias_decay is not None and args.bias_decay != 0.0) else args.weight_decay
538 clips_params = [p for n,p in model.named_parameters() if 'clips' in n]
539 other_params = [p for n,p in model.named_parameters() if 'clips' not in n]
540 param_groups = [{'params': clips_params, 'weight_decay': clips_decay, 'lr': learning_rate_clips},
541 {'params': other_params, 'weight_decay': args.weight_decay}]
542 else:
543 param_groups = [{'params': filter(lambda p: p.requires_grad, model.parameters()), 'weight_decay': args.weight_decay}]
544 #
546 learning_rate = args.lr if ('training'in args.phase) else 0.0
547 if args.optimizer == 'adam':
548 optimizer = torch.optim.Adam(param_groups, learning_rate, betas=(args.momentum, args.beta))
549 elif args.optimizer == 'sgd':
550 optimizer = torch.optim.SGD(param_groups, learning_rate, momentum=args.momentum)
551 else:
552 raise ValueError('Unknown optimizer type{}'.format(args.optimizer))
553 #
555 #################################################
556 max_iter = args.epochs * len(train_loader)
557 scheduler = xnn.optim.lr_scheduler.SchedulerWrapper(scheduler_type=args.scheduler, optimizer=optimizer,
558 epochs=args.epochs, start_epoch=args.start_epoch,
559 warmup_epochs=args.warmup_epochs, warmup_factor=args.warmup_factor,
560 max_iter=max_iter, polystep_power=args.polystep_power,
561 milestones=args.milestones, multistep_gamma=args.multistep_gamma)
563 # optionally resume from a checkpoint
564 if args.resume:
565 if not os.path.isfile(args.resume):
566 print("=> no checkpoint found at '{}'".format(args.resume))
567 else:
568 print("=> loading checkpoint '{}'".format(args.resume))
570 checkpoint = torch.load(args.resume)
571 model = xnn.utils.load_weights(model, checkpoint)
573 if args.start_epoch == 0:
574 args.start_epoch = checkpoint['epoch']
576 if 'best_metric' in list(checkpoint.keys()):
577 args.best_metric = checkpoint['best_metric']
579 if 'optimizer' in list(checkpoint.keys()):
580 optimizer.load_state_dict(checkpoint['optimizer'])
582 if 'scheduler' in list(checkpoint.keys()):
583 scheduler.load_state_dict(checkpoint['scheduler'])
585 if 'multi_task_factors' in list(checkpoint.keys()):
586 args.multi_task_factors = checkpoint['multi_task_factors']
588 print("=> loaded checkpoint '{}' (epoch {})".format(args.resume, checkpoint['epoch']))
590 #################################################
591 if args.evaluate_start:
592 with torch.no_grad():
593 validate(args, val_dataset, val_loader, model, args.start_epoch, val_writer)
595 grad_scaler = torch.cuda.amp.GradScaler() if args.model_config.enable_fp16 else None
597 for epoch in range(args.start_epoch, args.epochs):
598 # epoch is needed to seed shuffling in DistributedSampler, every epoch.
599 # otherwise seed of 0 is used every epoch, which seems incorrect.
600 if train_sampler and isinstance(train_sampler, torch.utils.data.DistributedSampler):
601 train_sampler.set_epoch(epoch)
602 if val_sampler and isinstance(val_sampler, torch.utils.data.DistributedSampler):
603 val_sampler.set_epoch(epoch)
605 # train for one epoch
606 train(args, train_dataset, train_loader, model, optimizer, epoch, train_writer, scheduler, grad_scaler)
608 # evaluate on validation set
609 with torch.no_grad():
610 val_metric, metric_name = validate(args, val_dataset, val_loader, model, epoch, val_writer)
612 if args.best_metric < 0:
613 args.best_metric = val_metric
615 if "iou" in metric_name.lower() or "acc" in metric_name.lower():
616 is_best = val_metric >= args.best_metric
617 args.best_metric = max(val_metric, args.best_metric)
618 elif "error" in metric_name.lower() or "diff" in metric_name.lower() or "norm" in metric_name.lower() \
619 or "loss" in metric_name.lower() or "outlier" in metric_name.lower():
620 is_best = val_metric <= args.best_metric
621 args.best_metric = min(val_metric, args.best_metric)
622 else:
623 raise ValueError("Metric is not known. Best model could not be saved.")
624 #
626 checkpoint_dict = { 'epoch': epoch + 1, 'model_name': args.model_name,
627 'state_dict': get_model_orig(model).state_dict(),
628 'optimizer': optimizer.state_dict(),
629 'scheduler': scheduler.state_dict(),
630 'best_metric': args.best_metric,
631 'multi_task_factors': args.multi_task_factors,
632 'quantize' : args.quantize}
634 save_checkpoint(args, save_path, get_model_orig(model), checkpoint_dict, is_best)
636 if args.tensorboard_enable:
637 train_writer.file_writer.flush()
638 val_writer.file_writer.flush()
640 # adjust the learning rate using lr scheduler
641 if 'training' in args.phase:
642 scheduler.step()
643 #
644 #
646 # close and cleanup
647 close(args)
648 #
650 ###################################################################
651 def is_valid_phase(phase):
652 phases = ('training', 'calibration', 'validation')
653 return any(p in phase for p in phases)
656 ###################################################################
657 def train(args, train_dataset, train_loader, model, optimizer, epoch, train_writer, scheduler, grad_scaler):
658 batch_time = xnn.utils.AverageMeter()
659 data_time = xnn.utils.AverageMeter()
660 # if the loss/ metric is already an average, no need to further average
661 avg_loss = [xnn.utils.AverageMeter(print_avg=(not task_loss[0].info()['is_avg'])) for task_loss in args.loss_modules]
662 avg_loss_orig = [xnn.utils.AverageMeter(print_avg=(not task_loss[0].info()['is_avg'])) for task_loss in args.loss_modules]
663 avg_metric = [xnn.utils.AverageMeter(print_avg=(not task_metric[0].info()['is_avg'])) for task_metric in args.metric_modules]
665 ##########################
666 # switch to train mode
667 model.train()
669 # freeze bn and range after some epochs during quantization
670 if args.freeze_bn or (args.quantize and epoch > 2 and epoch >= ((args.epochs//2)-1)):
671 xnn.utils.print_once('Freezing BN for subsequent epochs')
672 xnn.utils.freeze_bn(model)
673 #
674 if (args.quantize and epoch > 4 and epoch >= ((args.epochs//2)+1)):
675 xnn.utils.print_once('Freezing ranges for subsequent epochs')
676 xnn.layers.freeze_quant_range(model)
677 #
679 #freeze layers
680 if args.freeze_layers is not None:
681 # 'freeze_layer_name' could be part of 'name', i.e. 'name' need not be exact same as 'freeze_layer_name'
682 # e.g. freeze_layer_name = 'encoder.0' then all layers like, 'encoder.0.0' 'encoder.0.1' will be frozen
683 for freeze_layer_name in args.freeze_layers:
684 for name, module in model.named_modules():
685 if freeze_layer_name in name:
686 xnn.utils.print_once("Freezing the module : {}".format(name))
687 module.eval()
688 for param in module.parameters():
689 param.requires_grad = False
691 ##########################
692 for task_dx, task_losses in enumerate(args.loss_modules):
693 for loss_idx, loss_fn in enumerate(task_losses):
694 loss_fn.clear()
695 for task_dx, task_metrics in enumerate(args.metric_modules):
696 for midx, metric_fn in enumerate(task_metrics):
697 metric_fn.clear()
699 num_iter = len(train_loader)
700 progress_bar = progiter.ProgIter(np.arange(num_iter), chunksize=1)
701 metric_name = "Metric"
702 metric_ctx = [None] * len(args.metric_modules)
703 end_time = time.time()
704 writer_idx = 0
705 last_update_iter = -1
707 # change color to yellow for calibration
708 progressbar_color = (Fore.YELLOW if (('calibration' in args.phase) or ('training' in args.phase and args.quantize)) else Fore.WHITE)
709 print('{}'.format(progressbar_color), end='')
711 ##########################
712 for iter_id, (inputs, targets) in enumerate(train_loader):
713 # measure data loading time
714 data_time.update(time.time() - end_time)
716 lr = scheduler.get_lr()[0]
718 input_list = [[jj.cuda() for jj in img] if isinstance(img,(list,tuple)) else img.cuda() for img in inputs]
719 target_list = [tgt.cuda(non_blocking=True) for tgt in targets]
720 target_sizes = [tgt.shape for tgt in target_list]
721 batch_size_cur = target_sizes[0][0]
723 ##########################
724 # compute output
725 task_outputs = model(input_list)
727 task_outputs = task_outputs if isinstance(task_outputs,(list,tuple)) else [task_outputs]
728 # upsample output to target resolution
729 if args.upsample_mode is not None:
730 task_outputs = upsample_tensors(task_outputs, target_sizes, args.upsample_mode)
732 if args.model_config.multi_task_type is not None and len(args.model_config.output_channels) > 1:
733 args.multi_task_factors, args.multi_task_offsets = xnn.layers.get_loss_scales(model)
734 else:
735 args.multi_task_factors = None
736 args.multi_task_offsets = None
738 loss_total, loss_list, loss_names, loss_types, loss_list_orig = \
739 compute_task_objectives(args, args.loss_modules, input_list, task_outputs, target_list,
740 task_mults=args.multi_task_factors, task_offsets=args.multi_task_offsets,
741 loss_mult_factors=args.loss_mult_factors)
743 if args.print_train_class_iou:
744 metric_total, metric_list, metric_names, metric_types, _, confusion_matrix = \
745 compute_task_objectives(args, args.metric_modules, input_list, task_outputs, target_list,
746 get_confusion_matrix=args.print_train_class_iou)
747 else:
748 metric_total, metric_list, metric_names, metric_types, _ = \
749 compute_task_objectives(args, args.metric_modules, input_list, task_outputs, target_list,
750 get_confusion_matrix=args.print_train_class_iou)
752 if args.model_config.multi_task_type is not None and len(args.model_config.output_channels) > 1:
753 xnn.layers.set_losses(model, loss_list_orig)
755 if 'training' in args.phase:
756 # accumulate gradients
757 if args.model_config.enable_fp16:
758 grad_scaler.scale(loss_total).backward()
759 else:
760 loss_total.backward()
761 #
763 # optimization step
764 if ((iter_id+1) % args.iter_size) == 0:
765 if args.model_config.enable_fp16:
766 grad_scaler.step(optimizer)
767 grad_scaler.update()
768 else:
769 optimizer.step()
770 #
771 # zero gradients so that we can accumulate gradients
772 # setting grad=None is a faster alternative instead of optimizer.zero_grad()
773 xnn.utils.clear_grad(model)
774 #
775 #
777 # record loss.
778 for task_idx, task_losses in enumerate(args.loss_modules):
779 avg_loss[task_idx].update(float(loss_list[task_idx].cpu()), batch_size_cur)
780 avg_loss_orig[task_idx].update(float(loss_list_orig[task_idx].cpu()), batch_size_cur)
781 if args.tensorboard_enable:
782 train_writer.add_scalar('Training/Task{}_{}_Loss_Iter'.format(task_idx,loss_names[task_idx]), float(loss_list[task_idx]), args.n_iter)
783 if args.model_config.multi_task_type is not None and len(args.model_config.output_channels) > 1:
784 train_writer.add_scalar('Training/multi_task_Factor_Task{}_{}'.format(task_idx,loss_names[task_idx]), float(args.multi_task_factors[task_idx]), args.n_iter)
786 # record error/accuracy.
787 for task_idx, task_metrics in enumerate(args.metric_modules):
788 avg_metric[task_idx].update(float(metric_list[task_idx].cpu()), batch_size_cur)
790 ##########################
791 if args.tensorboard_enable:
792 write_output(args, 'Training_', num_iter, iter_id, epoch, train_dataset, train_writer, input_list, task_outputs, target_list, metric_name, writer_idx)
794 if ((iter_id % args.print_freq) == 0) or (iter_id == (num_iter-1)):
795 output_string = ''
796 for task_idx, task_metrics in enumerate(args.metric_modules):
797 output_string += '[{}={}]'.format(metric_names[task_idx], str(avg_metric[task_idx]))
799 epoch_str = '{}/{}'.format(epoch + 1, args.epochs)
800 progress_bar.set_description("{}=> {} ".format(progressbar_color, args.phase))
801 multi_task_factors_print = ['{:.3f}'.format(float(lmf)) for lmf in args.multi_task_factors] if args.multi_task_factors is not None else None
802 progress_bar.set_postfix(Epoch=epoch_str, LR=lr, DataTime=str(data_time), LossMult=multi_task_factors_print, Loss=avg_loss, Output=output_string)
803 progress_bar.update(iter_id-last_update_iter)
804 last_update_iter = iter_id
806 args.n_iter += 1
807 end_time = time.time()
808 writer_idx = (writer_idx + 1) % args.tensorboard_num_imgs
810 # add onnx graph to tensorboard
811 # commenting out due to issues in transitioning to pytorch 0.4
812 # (bilinear mode in upsampling causes hang or crash - may be due to align_borders change, nearest is fine)
813 #if epoch == 0 and iter_id == 0:
814 # input_zero = torch.zeros(input_var.shape)
815 # train_writer.add_graph(model, input_zero)
816 #This cache operation slows down tranining
817 #torch.cuda.empty_cache()
818 #
820 if args.print_train_class_iou:
821 print_class_iou(args=args, confusion_matrix=confusion_matrix, task_idx=task_idx)
823 progress_bar.close()
825 # to print a new line - do not provide end=''
826 print('{}'.format(Fore.RESET), end='')
828 if args.tensorboard_enable:
829 for task_idx, task_losses in enumerate(args.loss_modules):
830 train_writer.add_scalar('Training/Task{}_{}_Loss_Epoch'.format(task_idx,loss_names[task_idx]), float(avg_loss[task_idx]), epoch)
832 for task_idx, task_metrics in enumerate(args.metric_modules):
833 train_writer.add_scalar('Training/Task{}_{}_Metric_Epoch'.format(task_idx,metric_names[task_idx]), float(avg_metric[task_idx]), epoch)
835 output_name = metric_names[args.pivot_task_idx]
836 output_metric = float(avg_metric[args.pivot_task_idx])
838 ##########################
839 if args.quantize:
840 def debug_format(v):
841 return ('{:.3f}'.format(v) if v is not None else 'None')
842 #
843 clips_act = [m.get_clips_act()[1] for n,m in model.named_modules() if isinstance(m,xnn.layers.PAct2)]
844 if len(clips_act) > 0:
845 args.logger.debug('\nclips_act : ' + ' '.join(map(debug_format, clips_act)))
846 args.logger.debug('')
847 #
848 return output_metric, output_name
851 ###################################################################
852 def validate(args, val_dataset, val_loader, model, epoch, val_writer):
853 data_time = xnn.utils.AverageMeter()
854 # if the loss/ metric is already an average, no need to further average
855 avg_metric = [xnn.utils.AverageMeter(print_avg=(not task_metric[0].info()['is_avg'])) for task_metric in args.metric_modules]
857 ##########################
858 # switch to evaluate mode
859 model.eval()
861 ##########################
862 for task_dx, task_metrics in enumerate(args.metric_modules):
863 for midx, metric_fn in enumerate(task_metrics):
864 metric_fn.clear()
866 metric_name = "Metric"
867 end_time = time.time()
868 writer_idx = 0
869 last_update_iter = -1
870 metric_ctx = [None] * len(args.metric_modules)
872 num_iter = len(val_loader)
873 progress_bar = progiter.ProgIter(np.arange(num_iter), chunksize=1)
875 # change color to green
876 print('{}'.format(Fore.GREEN), end='')
878 ##########################
879 for iter_id, (inputs, targets) in enumerate(val_loader):
880 data_time.update(time.time() - end_time)
881 input_list = [[jj.cuda() for jj in img] if isinstance(img,(list,tuple)) else img.cuda() for img in inputs]
882 target_list = [j.cuda(non_blocking=True) for j in targets]
883 target_sizes = [tgt.shape for tgt in target_list]
884 batch_size_cur = target_sizes[0][0]
886 # compute output
887 task_outputs = model(input_list)
889 task_outputs = task_outputs if isinstance(task_outputs, (list, tuple)) else [task_outputs]
890 if args.upsample_mode is not None:
891 task_outputs = upsample_tensors(task_outputs, target_sizes, args.upsample_mode)
893 if args.print_val_class_iou:
894 metric_total, metric_list, metric_names, metric_types, _, confusion_matrix = \
895 compute_task_objectives(args, args.metric_modules, input_list, task_outputs, target_list,
896 get_confusion_matrix = args.print_val_class_iou)
897 else:
898 metric_total, metric_list, metric_names, metric_types, _ = \
899 compute_task_objectives(args, args.metric_modules, input_list, task_outputs, target_list,
900 get_confusion_matrix = args.print_val_class_iou)
902 # record error/accuracy.
903 for task_idx, task_metrics in enumerate(args.metric_modules):
904 avg_metric[task_idx].update(float(metric_list[task_idx].cpu()), batch_size_cur)
906 if args.tensorboard_enable:
907 write_output(args, 'Validation_', num_iter, iter_id, epoch, val_dataset, val_writer, input_list, task_outputs, target_list, metric_names, writer_idx)
909 if ((iter_id % args.print_freq) == 0) or (iter_id == (num_iter-1)):
910 output_string = ''
911 for task_idx, task_metrics in enumerate(args.metric_modules):
912 output_string += '[{}={}]'.format(metric_names[task_idx], str(avg_metric[task_idx]))
914 epoch_str = '{}/{}'.format(epoch + 1, args.epochs)
915 progress_bar.set_description("=> validation")
916 progress_bar.set_postfix(Epoch=epoch_str, DataTime=data_time, Output="{}".format(output_string))
917 progress_bar.update(iter_id-last_update_iter)
918 last_update_iter = iter_id
919 #
921 end_time = time.time()
922 writer_idx = (writer_idx + 1) % args.tensorboard_num_imgs
923 #
925 if args.print_val_class_iou:
926 print_class_iou(args = args, confusion_matrix = confusion_matrix, task_idx=task_idx)
927 #
929 #print_conf_matrix(conf_matrix=conf_matrix, en=False)
930 progress_bar.close()
932 # to print a new line - do not provide end=''
933 print('{}'.format(Fore.RESET), end='')
935 if args.tensorboard_enable:
936 for task_idx, task_metrics in enumerate(args.metric_modules):
937 val_writer.add_scalar('Validation/Task{}_{}_Metric_Epoch'.format(task_idx,metric_names[task_idx]), float(avg_metric[task_idx]), epoch)
939 output_name = metric_names[args.pivot_task_idx]
940 output_metric = float(avg_metric[args.pivot_task_idx])
941 return output_metric, output_name
944 ###################################################################
945 def close(args):
946 if args.logger is not None:
947 del args.logger
948 args.logger = None
949 #
950 args.best_metric = -1
951 #
954 def get_save_path(args, phase=None):
955 date = args.date if args.date else datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
956 save_path = os.path.join('./data/checkpoints', args.dataset_name, date + '_' + args.dataset_name + '_' + args.model_name)
957 save_path += '_resize{}x{}_traincrop{}x{}'.format(args.img_resize[1], args.img_resize[0], args.rand_crop[1], args.rand_crop[0])
958 phase = phase if (phase is not None) else args.phase
959 save_path = os.path.join(save_path, phase)
960 return save_path
963 def get_model_orig(model):
964 is_parallel_model = isinstance(model, (torch.nn.DataParallel, torch.nn.parallel.DistributedDataParallel))
965 model_orig = (model.module if is_parallel_model else model)
966 model_orig = (model_orig.module if isinstance(model_orig, (xnn.quantize.QuantBaseModule)) else model_orig)
967 return model_orig
970 def create_rand_inputs(args, is_cuda):
971 dummy_input = []
972 if not args.model_config.input_nv12:
973 for i_ch in args.model_config.input_channels:
974 x = torch.rand((1, i_ch, args.img_resize[0], args.img_resize[1]))
975 x = x.cuda() if is_cuda else x
976 dummy_input.append(x)
977 else: #nv12
978 for i_ch in args.model_config.input_channels:
979 y = torch.rand((1, 1, args.img_resize[0], args.img_resize[1]))
980 uv = torch.rand((1, 1, args.img_resize[0]//2, args.img_resize[1]))
981 y = y.cuda() if is_cuda else y
982 uv = uv.cuda() if is_cuda else uv
983 dummy_input.append([y,uv])
985 return dummy_input
987 def count_flops(args, model):
988 is_cuda = next(model.parameters()).is_cuda
989 dummy_input = create_rand_inputs(args, is_cuda)
990 #
991 model.eval()
992 flops = xnn.utils.forward_count_flops(model, dummy_input)
993 gflops = flops/1e9
994 print('=> Size = {}, GFLOPs = {}, GMACs = {}'.format(args.img_resize, gflops, gflops/2))
997 def derive_node_name(input_name):
998 #take last entry of input names for deciding node name
999 #print("input_name[-1]: ", input_name[-1])
1000 node_name = input_name[-1].rsplit('.', 1)[0]
1001 #print("formed node_name: ", node_name)
1002 return node_name
1005 #torch onnx export does not update names. Do it using onnx.save
1006 def add_node_names(onnx_model_name):
1007 onnx_model = onnx.load(onnx_model_name)
1008 for i in range(len(onnx_model.graph.node)):
1009 for j in range(len(onnx_model.graph.node[i].input)):
1010 #print('-'*60)
1011 #print("name: ", onnx_model.graph.node[i].name)
1012 #print("input: ", onnx_model.graph.node[i].input)
1013 #print("output: ", onnx_model.graph.node[i].output)
1014 onnx_model.graph.node[i].input[j] = onnx_model.graph.node[i].input[j].split(':')[0]
1015 onnx_model.graph.node[i].name = derive_node_name(onnx_model.graph.node[i].input)
1016 #
1017 #
1018 #update model inplace
1019 onnx.save(onnx_model, onnx_model_name)
1022 def write_onnx_model(args, model, save_path, name='checkpoint.onnx', save_traced_model=False):
1023 is_cuda = next(model.parameters()).is_cuda
1024 input_list = create_rand_inputs(args, is_cuda=is_cuda)
1025 #
1026 model.eval()
1027 torch.onnx.export(model, input_list, os.path.join(save_path, name), export_params=True, verbose=False,
1028 do_constant_folding=True, opset_version=args.opset_version)
1030 #torch onnx export does not update names. Do it using onnx.save
1031 add_node_names(onnx_model_name = os.path.join(save_path, name))
1033 # if save_traced_model:
1034 # traced_model = torch.jit.trace(model, (input_list,))
1035 # traced_save_path = os.path.join(save_path, 'traced_model.pth')
1036 # torch.jit.save(traced_model, traced_save_path)
1037 # #
1040 ###################################################################
1041 def write_output(args, prefix, val_epoch_size, iter_id, epoch, dataset, output_writer, input_images, task_outputs, task_targets, metric_names, writer_idx):
1042 write_freq = (args.tensorboard_num_imgs / float(val_epoch_size))
1043 write_prob = np.random.random()
1044 if (write_prob > write_freq):
1045 return
1046 if args.model_config.input_nv12:
1047 batch_size = input_images[0][0].shape[0]
1048 else:
1049 batch_size = input_images[0].shape[0]
1050 b_index = random.randint(0, batch_size - 1)
1052 input_image = None
1053 for img_idx, img in enumerate(input_images):
1054 if args.model_config.input_nv12:
1055 #convert NV12 to BGR for tensorboard
1056 input_image = xvision.transforms.image_transforms_xv12.nv12_to_bgr_image(Y = input_images[img_idx][0][b_index], UV = input_images[img_idx][1][b_index],
1057 image_scale=args.image_scale, image_mean=args.image_mean)
1058 else:
1059 input_image = input_images[img_idx][b_index].cpu().numpy().transpose((1, 2, 0))
1060 # convert back to original input range (0-255)
1061 input_image = input_image / args.image_scale + args.image_mean
1063 if args.is_flow and args.is_flow[0][img_idx]:
1064 #input corresponding to flow is assumed to have been generated by adding 128
1065 flow = input_image - 128
1066 flow_hsv = xnn.utils.flow2hsv(flow.transpose(2, 0, 1), confidence=False).transpose(2, 0, 1)
1067 #flow_hsv = (flow_hsv / 255.0).clip(0, 1) #TODO: check this
1068 output_writer.add_image(prefix +'Input{}/{}'.format(img_idx, writer_idx), flow_hsv, epoch)
1069 else:
1070 input_image = (input_image/255.0).clip(0,1) #.astype(np.uint8)
1071 output_writer.add_image(prefix + 'Input{}/{}'.format(img_idx, writer_idx), input_image.transpose((2,0,1)), epoch)
1073 # for sparse data, chroma blending does not look good
1074 for task_idx, output_type in enumerate(args.model_config.output_type):
1075 # metric_name = metric_names[task_idx]
1076 output = task_outputs[task_idx]
1077 target = task_targets[task_idx]
1078 if (output_type == 'segmentation') and hasattr(dataset, 'decode_segmap'):
1079 segmentation_target = dataset.decode_segmap(target[b_index,0].cpu().numpy())
1080 segmentation_output = output.max(dim=1,keepdim=True)[1].data.cpu().numpy() if(output.shape[1]>1) else output.data.cpu().numpy()
1081 segmentation_output = dataset.decode_segmap(segmentation_output[b_index,0])
1082 segmentation_output_blend = xnn.utils.chroma_blend(input_image, segmentation_output)
1083 #
1084 output_writer.add_image(prefix+'Task{}_{}_GT/{}'.format(task_idx,output_type,writer_idx), segmentation_target.transpose(2,0,1), epoch)
1085 if not args.sparse:
1086 segmentation_target_blend = xnn.utils.chroma_blend(input_image, segmentation_target)
1087 output_writer.add_image(prefix + 'Task{}_{}_GT_ColorBlend/{}'.format(task_idx, output_type, writer_idx), segmentation_target_blend.transpose(2, 0, 1), epoch)
1088 #
1089 output_writer.add_image(prefix+'Task{}_{}_Output/{}'.format(task_idx,output_type,writer_idx), segmentation_output.transpose(2,0,1), epoch)
1090 output_writer.add_image(prefix+'Task{}_{}_Output_ColorBlend/{}'.format(task_idx,output_type,writer_idx), segmentation_output_blend.transpose(2,0,1), epoch)
1091 elif (output_type in ('depth', 'disparity')):
1092 depth_chanidx = 0
1093 output_writer.add_image(prefix+'Task{}_{}_GT_Color_Visualization/{}'.format(task_idx,output_type,writer_idx), xnn.utils.tensor2array(target[b_index][depth_chanidx].cpu(), max_value=args.max_depth, colormap=args.viz_colormap).transpose(2,0,1), epoch)
1094 if not args.sparse:
1095 output_writer.add_image(prefix + 'Task{}_{}_GT_ColorBlend_Visualization/{}'.format(task_idx, output_type, writer_idx), xnn.utils.tensor2array(target[b_index][depth_chanidx].cpu(), max_value=args.max_depth, colormap=args.viz_colormap, input_blend=input_image).transpose(2, 0, 1), epoch)
1096 #
1097 output_writer.add_image(prefix+'Task{}_{}_Output_Color_Visualization/{}'.format(task_idx,output_type,writer_idx), xnn.utils.tensor2array(output.data[b_index][depth_chanidx].cpu(), max_value=args.max_depth, colormap=args.viz_colormap).transpose(2,0,1), epoch)
1098 output_writer.add_image(prefix + 'Task{}_{}_Output_ColorBlend_Visualization/{}'.format(task_idx, output_type, writer_idx),xnn.utils.tensor2array(output.data[b_index][depth_chanidx].cpu(), max_value=args.max_depth, colormap=args.viz_colormap, input_blend=input_image).transpose(2, 0, 1), epoch)
1099 elif (output_type == 'flow'):
1100 max_value_flow = 10.0 # only for visualization
1101 output_writer.add_image(prefix+'Task{}_{}_GT/{}'.format(task_idx,output_type,writer_idx), xnn.utils.flow2hsv(target[b_index][:2].cpu().numpy(), max_value=max_value_flow).transpose(2,0,1), epoch)
1102 output_writer.add_image(prefix+'Task{}_{}_Output/{}'.format(task_idx,output_type,writer_idx), xnn.utils.flow2hsv(output.data[b_index][:2].cpu().numpy(), max_value=max_value_flow).transpose(2,0,1), epoch)
1103 elif (output_type == 'interest_pt'):
1104 score_chanidx = 0
1105 target_score_to_write = target[b_index][score_chanidx].cpu()
1106 output_score_to_write = output.data[b_index][score_chanidx].cpu()
1108 #if score is learnt as zero mean add offset to make it [0-255]
1109 if args.make_score_zero_mean:
1110 # target_score_to_write!=0 : value 0 indicates GT unavailble. Leave them to be 0.
1111 target_score_to_write[target_score_to_write!=0] += 128.0
1112 output_score_to_write += 128.0
1114 max_value_score = float(torch.max(target_score_to_write)) #0.002
1115 output_writer.add_image(prefix+'Task{}_{}_GT_Bone_Visualization/{}'.format(task_idx,output_type,writer_idx), xnn.utils.tensor2array(target_score_to_write, max_value=max_value_score, colormap='bone').transpose(2,0,1), epoch)
1116 output_writer.add_image(prefix+'Task{}_{}_Output_Bone_Visualization/{}'.format(task_idx,output_type,writer_idx), xnn.utils.tensor2array(output_score_to_write, max_value=max_value_score, colormap='bone').transpose(2,0,1), epoch)
1117 #
1119 def print_conf_matrix(conf_matrix = [], en = False):
1120 if not en:
1121 return
1122 num_rows = conf_matrix.shape[0]
1123 num_cols = conf_matrix.shape[1]
1124 print("-"*64)
1125 num_ele = 1
1126 for r_idx in range(num_rows):
1127 print("\n")
1128 for c_idx in range(0,num_cols,num_ele):
1129 print(conf_matrix[r_idx][c_idx:c_idx+num_ele], end="")
1130 print("\n")
1131 print("-" * 64)
1133 def compute_task_objectives(args, objective_fns, input_var, task_outputs, task_targets, task_mults=None,
1134 task_offsets=None, loss_mult_factors=None, get_confusion_matrix = False):
1136 ##########################
1137 objective_total = torch.zeros_like(task_outputs[0].view(-1)[0])
1138 objective_list = []
1139 objective_list_orig = []
1140 objective_names = []
1141 objective_types = []
1142 for task_idx, task_objectives in enumerate(objective_fns):
1143 output_type = args.model_config.output_type[task_idx]
1144 objective_sum_value = torch.zeros_like(task_outputs[task_idx].view(-1)[0])
1145 objective_sum_name = ''
1146 objective_sum_type = ''
1148 task_mult = task_mults[task_idx] if task_mults is not None else 1.0
1149 task_offset = task_offsets[task_idx] if task_offsets is not None else 0.0
1151 for oidx, objective_fn in enumerate(task_objectives):
1152 objective_batch = objective_fn(input_var, task_outputs[task_idx], task_targets[task_idx])
1153 objective_batch = objective_batch.mean() if isinstance(objective_fn, torch.nn.DataParallel) else objective_batch
1154 objective_name = objective_fn.info()['name']
1155 objective_type = objective_fn.info()['is_avg']
1156 if get_confusion_matrix:
1157 confusion_matrix = objective_fn.info()['confusion_matrix']
1159 loss_mult = loss_mult_factors[task_idx][oidx] if (loss_mult_factors is not None) else 1.0
1160 # --
1161 objective_batch_not_nan = (objective_batch if not torch.isnan(objective_batch) else 0.0)
1162 objective_sum_value = objective_batch_not_nan*loss_mult + objective_sum_value
1163 objective_sum_name += (objective_name if (objective_sum_name == '') else ('+' + objective_name))
1164 assert (objective_sum_type == '' or objective_sum_type == objective_type), 'metric types (avg/val) for a given task should match'
1165 objective_sum_type = objective_type
1167 objective_list.append(objective_sum_value)
1168 objective_list_orig.append(objective_sum_value)
1169 objective_names.append(objective_sum_name)
1170 objective_types.append(objective_sum_type)
1172 objective_total = objective_sum_value*task_mult + task_offset + objective_total
1174 return_list = [objective_total, objective_list, objective_names, objective_types, objective_list_orig]
1175 if get_confusion_matrix:
1176 return_list.append(confusion_matrix)
1178 return return_list
1181 def save_checkpoint(args, save_path, model, checkpoint_dict, is_best, filename='checkpoint.pth'):
1182 torch.save(checkpoint_dict, os.path.join(save_path,filename))
1183 if is_best:
1184 shutil.copyfile(os.path.join(save_path,filename), os.path.join(save_path,'model_best.pth'))
1185 #
1186 if args.save_onnx:
1187 write_onnx_model(args, model, save_path, name='checkpoint.onnx')
1188 if is_best:
1189 write_onnx_model(args, model, save_path, name='model_best.onnx')
1190 #
1193 def get_dataset_sampler(dataset_object, epoch_size):
1194 print('=> creating a random sampler as epoch_size is specified')
1195 num_samples = len(dataset_object)
1196 epoch_size = int(epoch_size * num_samples) if epoch_size < 1 else int(epoch_size)
1197 dataset_sampler = torch.utils.data.sampler.RandomSampler(data_source=dataset_object, replacement=True, num_samples=epoch_size)
1198 return dataset_sampler
1201 def get_train_transform(args):
1202 # image normalization can be at the beginning of transforms or at the end
1203 image_mean = np.array(args.image_mean, dtype=np.float32)
1204 image_scale = np.array(args.image_scale, dtype=np.float32)
1205 image_prenorm = xvision.transforms.image_transforms.NormalizeMeanScale(mean=image_mean, scale=image_scale) if args.image_prenorm else None
1206 image_postnorm = xvision.transforms.image_transforms.NormalizeMeanScale(mean=image_mean, scale=image_scale) if (not image_prenorm) else None
1207 reverse_channels = xvision.transforms.image_transforms.ReverseImageChannels() if args.input_channel_reverse else None
1208 color_2_gray = xvision.transforms.image_transforms.RandomColor2Gray(is_flow=args.is_flow, random_threshold=args.prob_color_to_gray[0]) if args.prob_color_to_gray[0] != 0.0 else None
1210 # crop size used only for training
1211 image_train_output_scaling = xvision.transforms.image_transforms.Scale(args.rand_resize, target_size=args.rand_output_size, is_flow=args.is_flow) \
1212 if (args.rand_output_size is not None and args.rand_output_size != args.rand_resize) else None
1213 train_transform = xvision.transforms.image_transforms.Compose([
1214 reverse_channels,
1215 image_prenorm,
1216 xvision.transforms.image_transforms.AlignImages(interpolation=args.interpolation),
1217 xvision.transforms.image_transforms.MaskTarget(args.target_mask, 0),
1218 xvision.transforms.image_transforms.CropRect(args.img_border_crop),
1219 xvision.transforms.image_transforms.RandomRotate(args.transform_rotation, is_flow=args.is_flow) if args.transform_rotation else None,
1220 xvision.transforms.image_transforms.RandomScaleCrop(args.rand_resize, scale_range=args.rand_scale, is_flow=args.is_flow, interpolation=args.interpolation),
1221 xvision.transforms.image_transforms.RandomHorizontalFlip(is_flow=args.is_flow),
1222 xvision.transforms.image_transforms.RandomCrop(args.rand_crop),
1223 color_2_gray,
1224 image_train_output_scaling,
1225 image_postnorm,
1226 xvision.transforms.image_transforms.ConvertToTensor()
1227 ])
1228 return train_transform
1231 def get_validation_transform(args):
1232 # image normalization can be at the beginning of transforms or at the end
1233 image_mean = np.array(args.image_mean, dtype=np.float32)
1234 image_scale = np.array(args.image_scale, dtype=np.float32)
1235 image_prenorm = xvision.transforms.image_transforms.NormalizeMeanScale(mean=image_mean, scale=image_scale) if args.image_prenorm else None
1236 image_postnorm = xvision.transforms.image_transforms.NormalizeMeanScale(mean=image_mean, scale=image_scale) if (not image_prenorm) else None
1237 reverse_channels = xvision.transforms.image_transforms.ReverseImageChannels() if args.input_channel_reverse else None
1238 color_2_gray = xvision.transforms.image_transforms.RandomColor2Gray(is_flow=args.is_flow, random_threshold=args.prob_color_to_gray[1]) if args.prob_color_to_gray[1] != 0.0 else None
1240 # prediction is resized to output_size before evaluation.
1241 val_transform = xvision.transforms.image_transforms.Compose([
1242 reverse_channels,
1243 image_prenorm,
1244 xvision.transforms.image_transforms.AlignImages(interpolation=args.interpolation),
1245 xvision.transforms.image_transforms.MaskTarget(args.target_mask, 0),
1246 xvision.transforms.image_transforms.CropRect(args.img_border_crop),
1247 xvision.transforms.image_transforms.Scale(args.img_resize, target_size=args.output_size, is_flow=args.is_flow, interpolation=args.interpolation),
1248 color_2_gray,
1249 image_postnorm,
1250 xvision.transforms.image_transforms.ConvertToTensor()
1251 ])
1252 return val_transform
1255 def get_transforms(args):
1256 # Provision to train with val transform - provide rand_scale as (0, 0)
1257 # Fixing the train-test resolution discrepancy, https://arxiv.org/abs/1906.06423
1258 always_use_val_transform = (args.rand_scale[0] == 0)
1259 train_transform = get_validation_transform(args) if always_use_val_transform else get_train_transform(args)
1260 val_transform = get_validation_transform(args)
1261 return train_transform, val_transform
1264 def _upsample_impl(tensor, output_size, upsample_mode):
1265 # upsample of long tensor is not supported currently. covert to float, just to avoid error.
1266 # we can do thsi only in the case of nearest mode, otherwise output will have invalid values.
1267 convert_to_float = False
1268 if isinstance(tensor, (torch.LongTensor,torch.cuda.LongTensor)):
1269 convert_to_float = True
1270 original_dtype = tensor.dtype
1271 tensor = tensor.float()
1272 upsample_mode = 'nearest'
1274 dim_added = False
1275 if len(tensor.shape) < 4:
1276 tensor = tensor[np.newaxis,...]
1277 dim_added = True
1279 if (tensor.size()[-2:] != output_size):
1280 tensor = torch.nn.functional.interpolate(tensor, output_size, mode=upsample_mode)
1282 if dim_added:
1283 tensor = tensor[0,...]
1285 if convert_to_float:
1286 tensor = tensor.long() #tensor.astype(original_dtype)
1288 return tensor
1291 def upsample_tensors(tensors, output_sizes, upsample_mode):
1292 if isinstance(tensors, (list,tuple)):
1293 for tidx, tensor in enumerate(tensors):
1294 tensors[tidx] = _upsample_impl(tensor, output_sizes[tidx][-2:], upsample_mode)
1295 #
1296 else:
1297 tensors = _upsample_impl(tensors, output_sizes[0][-2:], upsample_mode)
1298 return tensors
1300 #print IoU for each class
1301 def print_class_iou(args = None, confusion_matrix = None, task_idx = 0):
1302 n_classes = args.model_config.output_channels[task_idx]
1303 [accuracy, mean_iou, iou, f1_score] = compute_accuracy(args, confusion_matrix, n_classes)
1304 print("\n Class IoU: [", end = "")
1305 for class_iou in iou:
1306 print("{:0.3f}".format(class_iou), end=",")
1307 print("]")
1309 if __name__ == '__main__':
1310 train_args = get_config()
1311 main(train_args)