[jacinto-ai/pytorch-jacinto-ai-devkit.git] / modules / pytorch_jacinto_ai / engine / infer_pixel2pixel.py
1 # Copyright (c) 2018-2021, Texas Instruments
2 # All Rights Reserved.
3 #
4 # Redistribution and use in source and binary forms, with or without
5 # modification, are permitted provided that the following conditions are met:
6 #
7 # * Redistributions of source code must retain the above copyright notice, this
8 # list of conditions and the following disclaimer.
9 #
10 # * Redistributions in binary form must reproduce the above copyright notice,
11 # this list of conditions and the following disclaimer in the documentation
12 # and/or other materials provided with the distribution.
13 #
14 # * Neither the name of the copyright holder nor the names of its
15 # contributors may be used to endorse or promote products derived from
16 # this software without specific prior written permission.
17 #
18 # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
19 # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
20 # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
21 # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
22 # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
23 # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
24 # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
25 # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
26 # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
27 # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29 import os
30 import time
31 import sys
32 import math
33 import copy
34 import warnings
36 import torch
37 import torch.nn.parallel
38 import torch.backends.cudnn as cudnn
39 import torch.optim
40 import torch.utils.data
41 import datetime
42 import numpy as np
43 import random
44 import cv2
45 import matplotlib.pyplot as plt
47 from .. import xnn
48 from .. import xvision
49 from .engine_utils import *
51 # ################################################
52 def get_config():
53 args = xnn.utils.ConfigNode()
55 args.dataset = None
56 args.dataset_config = xnn.utils.ConfigNode()
57 args.dataset_config.split_name = 'val'
58 args.dataset_config.max_depth_bfr_scaling = 80
59 args.dataset_config.depth_scale = 1
60 args.dataset_config.train_depth_log = 1
61 args.use_semseg_for_depth = False
63 args.model = None
64 args.model_config = xnn.utils.ConfigNode()
65 args.model_config.enable_fp16 = False # faster training/inference if the GPU supports fp16
67 args.model_name = 'deeplabv2lite_mobilenetv2' # model architecture, overwritten if pretrained is specified
68 args.dataset_name = 'flying_chairs' # dataset type
69 args.transforms = None
71 args.data_path = './data/datasets' # path to dataset
72 args.save_path = None # checkpoints save path
73 args.pretrained = None
75 args.model_config.output_type = ['flow'] # the network is used to predict flow or depth or sceneflow')
76 args.model_config.output_channels = None # number of output channels
77 args.model_config.prediction_channels = None # intermediate number of channels before final output_channels
78 args.model_config.input_channels = None # number of input channels
79 args.model_config.num_classes = None # number of classes (for segmentation)
80 args.model_config.output_range = None # max range of output
82 args.model_config.num_decoders = None # number of decoders to use. [options: 0, 1, None]
83 args.sky_dir = False
85 args.logger = None # logger stream to output into
87 args.split_file = None # train_val split file
88 args.split_files = None # split list files. eg: train.txt val.txt
89 args.split_value = 0.8 # test_val split proportion (between 0 (only test) and 1 (only train))
91 args.workers = 8 # number of data loading workers
93 args.epoch_size = 0 # manual epoch size (will match dataset size if not specified)
94 args.epoch_size_val = 0 # manual epoch size (will match dataset size if not specified)
95 args.batch_size = 8 # mini_batch_size
96 args.total_batch_size = None # accumulated batch size. total_batch_size = batch_size*iter_size
97 args.iter_size = 1 # iteration size. total_batch_size = batch_size*iter_size
99 args.tensorboard_num_imgs = 5 # number of imgs to display in tensorboard
100 args.phase = 'validation' # evaluate model on validation set
101 args.pretrained = None # path to pre_trained model
102 args.date = None # don\'t append date timestamp to folder
103 args.print_freq = 10 # print frequency (default: 100)
105 args.div_flow = 1.0 # value by which flow will be divided. Original value is 20 but 1 with batchNorm gives good results
106 args.losses = ['supervised_loss'] # loss functions to minimize
107 args.metrics = ['supervised_error'] # metric/measurement/error functions for train/validation
108 args.class_weights = None # class weights
110 args.multistep_gamma = 0.5 # steps for step scheduler
111 args.polystep_power = 1.0 # power for polynomial scheduler
113 args.rand_seed = 1 # random seed
114 args.img_border_crop = None # image border crop rectangle. can be relative or absolute
115 args.target_mask = None # mask rectangle. can be relative or absolute. last value is the mask value
116 args.img_resize = None # image size to be resized to
117 args.rand_scale = (1,1.25) # random scale range for training
118 args.rand_crop = None # image size to be cropped to')
119 args.output_size = None # target output size to be resized to')
121 args.count_flops = True # count flops and report
123 args.shuffle = False # shuffle or not
124 args.is_flow = None # whether entries in images and targets lists are optical flow or not
126 args.multi_decoder = True # whether to use multiple decoders or unified decoder
128 args.create_video = False # whether to create video out of the inferred images
130 args.input_tensor_name = ['0'] # list of input tensore names
132 args.upsample_mode = 'nearest' # upsample mode to use., choices=['nearest','bilinear']
134 args.image_prenorm = True # whether normalization is done before all other the transforms
135 args.image_mean = [128.0] # image mean for input image normalization
136 args.image_scale = [1.0/(0.25*256)] # image scaling/mult for input iamge normalization
137 args.quantize = False # apply quantized inference or not
138 #args.model_surgery = None # replace activations with PAct2 activation module. Helpful in quantized training.
139 args.bitwidth_weights = 8 # bitwidth for weights
140 args.bitwidth_activations = 8 # bitwidth for activations
141 args.histogram_range = True # histogram range for calibration
142 args.per_channel_q = False # apply separate quantizion factor for each channel in depthwise or not
143 args.bias_calibration = False # apply bias correction during quantized inference calibration
145 args.frame_IOU = False # Print mIOU for each frame
146 args.make_score_zero_mean = False #to make score and desc zero mean
147 args.learn_scaled_values_interest_pt = True
148 args.save_mod_files = False # saves modified files after last commit. Also stores commit id.
149 args.gpu_mode = True #False will make inference run on CPU
150 args.write_layer_ip_op=False #True will make it tap inputs outputs for layers
151 args.write_layer_ip_op_names=None #name of the layers to write out
152 args.file_format = 'none' #Ip/Op tapped points for each layer: None : it will not be written but print will still appear
153 args.save_onnx = True
154 args.remove_ignore_lbls_in_pred = False #True: if in the pred where GT has ignore label do not visualize for GT visualization
155 args.do_pred_cordi_f2r = False #true: Do f2r operation on detected location for interet point task
156 args.depth_cmap_plasma = False
157 args.visualize_gt = False #to vis pred or GT
158 args.viz_depth_color_type = 'plasma' #color type for dpeth visualization
159 args.depth = [False]
160 args.palette = None
161 args.label_infer = False
162 args.viz_op_type = None
163 args.car_mask = None
164 args.en_accuracy_measurement = True #enabling accuracy measurement makes whole operation sequential and hence slows down inference significantly.
165 args.opset_version = 9 # onnx opset version
166 args.prob_color_to_gray = 0.0 #for color 2 gray augmentation during inference
167 return args
170 # ################################################
171 # to avoid hangs in data loader with multi threads
172 # this was observed after using cv2 image processing functions
173 # https://github.com/pytorch/pytorch/issues/1355
174 cv2.setNumThreads(0)
176 ##################################################
177 np.set_printoptions(precision=3)
181 # ################################################
182 def main(args):
184 assert not hasattr(args, 'model_surgery'), 'the argument model_surgery is deprecated, it is not needed now - remove it'
186 #################################################
187 # global settings. rand seeds for repeatability
188 random.seed(args.rand_seed)
189 np.random.seed(args.rand_seed)
190 torch.manual_seed(args.rand_seed)
191 torch.backends.cudnn.deterministic = True
192 torch.backends.cudnn.benchmark = True
194 ################################
195 # args check and config
196 if args.iter_size != 1 and args.total_batch_size is not None:
197 warnings.warn("only one of --iter_size or --total_batch_size must be set")
198 #
199 if args.total_batch_size is not None:
200 args.iter_size = args.total_batch_size//args.batch_size
201 else:
202 args.total_batch_size = args.batch_size*args.iter_size
203 #
205 assert args.pretrained is not None, 'pretrained path must be provided'
207 # onnx generation is filing for post quantized module
208 # args.save_onnx = False if (args.quantize) else args.save_onnx
209 #################################################
210 # set some global flags and initializations
211 # keep it in args for now - although they don't belong here strictly
212 # using pin_memory is seen to cause issues, especially when when lot of memory is used.
213 args.use_pinned_memory = False
214 args.n_iter = 0
215 args.best_metric = -1
217 #################################################
218 if args.save_path is None:
219 save_path = get_save_path(args)
220 else:
221 save_path = args.save_path
222 #
223 print('=> will save everything to {}'.format(save_path))
224 if not os.path.exists(save_path):
225 os.makedirs(save_path)
227 #################################################
228 if args.logger is None:
229 log_file = os.path.splitext(os.path.basename(__file__))[0] + '.log'
230 args.logger = xnn.utils.TeeLogger(filename=os.path.join(save_path,log_file))
232 ################################
233 # print everything for log
234 print('=> args: ', args)
236 if args.save_mod_files:
237 #store all the files after the last commit.
238 mod_files_path = save_path+'/mod_files'
239 os.makedirs(mod_files_path)
241 cmd = "git ls-files --modified | xargs -i cp {} {}".format("{}", mod_files_path)
242 print("cmd:", cmd)
243 os.system(cmd)
245 #stoe last commit id.
246 cmd = "git log -n 1 >> {}".format(mod_files_path + '/commit_id.txt')
247 print("cmd:", cmd)
248 os.system(cmd)
250 transforms = get_transforms(args) if args.transforms is None else args.transforms
252 print("=> fetching img pairs in '{}'".format(args.data_path))
253 split_arg = args.split_file if args.split_file else (args.split_files if args.split_files else args.split_value)
255 if args.dataset is not None:
256 dataset = args.dataset
257 else:
258 dataset = xvision.datasets.pixel2pixel.__dict__[args.dataset_name](args.dataset_config, args.data_path, split=split_arg, transforms=transforms)
259 #
261 # if a pair is given, take the second one
262 val_dataset = (dataset[1] if (isinstance(dataset, (list, tuple)) and len(dataset) == 2) else dataset)
264 print('=> {} val samples found'.format(len(val_dataset)))
265 val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=args.batch_size,
266 num_workers=args.workers, pin_memory=args.use_pinned_memory, shuffle=args.shuffle)
268 #################################################
269 if (args.model_config.input_channels is None):
270 args.model_config.input_channels = (3,)
271 print("=> input channels is not given - setting to {}".format(args.model_config.input_channels))
273 if (args.model_config.output_channels is None):
274 if ('num_classes' in dir(val_dataset)):
275 args.model_config.output_channels = val_dataset.num_classes()
276 else:
277 args.model_config.output_channels = (2 if args.model_config.output_type == 'flow' else args.model_config.output_channels)
278 xnn.utils.print_yellow("=> output channels is not given - setting to {} - not sure to work".format(args.model_config.output_channels))
279 #
280 if not isinstance(args.model_config.output_channels,(list,tuple)):
281 args.model_config.output_channels = [args.model_config.output_channels]
283 #################################################
284 pretrained_data = None
285 model_surgery_quantize = False
286 if args.pretrained and args.pretrained != "None":
287 if isinstance(args.pretrained, dict):
288 pretrained_data = args.pretrained
289 else:
290 if args.pretrained.startswith('http://') or args.pretrained.startswith('https://'):
291 pretrained_file = xvision.datasets.utils.download_url(args.pretrained, './data/downloads')
292 else:
293 pretrained_file = args.pretrained
294 #
295 print(f'=> using pre-trained weights from: {args.pretrained}')
296 pretrained_data = torch.load(pretrained_file)
297 #
298 model_surgery_quantize = pretrained_data['quantize'] if 'quantize' in pretrained_data else False
299 #
301 #################################################
302 if args.model is not None:
303 model, change_names_dict = args.model if isinstance(args.model, (list, tuple)) else (args.model, None)
304 assert isinstance(model, torch.nn.Module), 'args.model, if provided must be a valid torch.nn.Module'
305 else:
306 model = xvision.models.pixel2pixel.__dict__[args.model_name](args.model_config)
307 # check if we got the model as well as parameters to change the names in pretrained
308 model, change_names_dict = model if isinstance(model, (list,tuple)) else (model,None)
309 #
311 #################################################
312 if args.quantize:
313 # dummy input is used by quantized models to analyze graph
314 is_cuda = next(model.parameters()).is_cuda
315 dummy_input = create_rand_inputs(args, is_cuda=is_cuda)
316 # Note: bias_calibration is not enabled in test
317 model = xnn.quantize.QuantTestModule(model, per_channel_q=args.per_channel_q,
318 bitwidth_weights=args.bitwidth_weights, bitwidth_activations=args.bitwidth_activations,
319 histogram_range=args.histogram_range, dummy_input=dummy_input,
320 model_surgery_quantize=model_surgery_quantize)
322 # load pretrained weights
323 xnn.utils.load_weights(get_model_orig(model), pretrained=pretrained_data, change_names_dict=change_names_dict)
325 if args.save_onnx:
326 write_onnx_model(args, model, save_path, name='model_best.onnx')
327 #################################################
328 # multi gpu mode is not yet supported with quantization in evaluate
329 if args.gpu_mode and ('training' in args.phase):
330 model = torch.nn.DataParallel(model)
332 #################################################
333 model = model.cuda()
335 #################################################
336 assign_write_layer_ip_op_hook(model=model, save_path=save_path, args=args, file_format='npy')
338 args.loss_modules = copy.deepcopy(args.losses)
339 for task_dx, task_losses in enumerate(args.losses):
340 for loss_idx, loss_fn in enumerate(task_losses):
341 kw_args = {}
342 loss_args = xvision.losses.__dict__[loss_fn].args()
343 for arg in loss_args:
344 #if arg == 'weight':
345 # kw_args.update({arg:args.class_weights[task_dx]})
346 if arg == 'num_classes':
347 kw_args.update({arg:args.model_config.output_channels[task_dx]})
348 elif arg == 'sparse':
349 kw_args.update({arg:args.sparse})
350 #
351 #
352 loss_fn = xvision.losses.__dict__[loss_fn](**kw_args)
353 loss_fn = loss_fn.cuda()
354 args.loss_modules[task_dx][loss_idx] = loss_fn
356 args.metric_modules = copy.deepcopy(args.metrics)
357 for task_dx, task_metrics in enumerate(args.metrics):
358 for midx, metric_fn in enumerate(task_metrics):
359 kw_args = {}
360 loss_args = xvision.losses.__dict__[metric_fn].args()
361 for arg in loss_args:
362 if arg == 'weight':
363 kw_args.update({arg:args.class_weights[task_dx]})
364 elif arg == 'num_classes':
365 kw_args.update({arg:args.model_config.output_channels[task_dx]})
366 elif arg == 'sparse':
367 kw_args.update({arg:args.sparse})
368 #
369 #
370 metric_fn = xvision.losses.__dict__[metric_fn](**kw_args)
371 metric_fn = metric_fn.cuda()
372 args.metric_modules[task_dx][midx] = metric_fn
374 #################################################
375 if args.palette:
376 print('Creating palette')
377 args.palette = val_dataset.create_palette()
378 for i, p in enumerate(args.palette):
379 args.palette[i] = np.array(p, dtype = np.uint8)
380 args.palette[i] = args.palette[i][..., ::-1] # RGB->BGR, since palette is expected to be given in RGB format
382 infer_path = []
383 for i, p in enumerate(args.model_config.output_channels):
384 infer_path.append(os.path.join(save_path, 'Task{}'.format(i)))
385 if not os.path.exists(infer_path[i]):
386 os.makedirs(infer_path[i])
388 #################################################
389 with torch.no_grad():
390 validate(args, val_dataset, val_loader, model, 0, infer_path)
392 if args.create_video:
393 create_video(args, infer_path=infer_path)
396 def validate(args, val_dataset, val_loader, model, epoch, infer_path):
397 data_time = xnn.utils.AverageMeter()
398 avg_metric = [xnn.utils.AverageMeter(print_avg=(not task_metric[0].info()['is_avg'])) for task_metric in args.metric_modules]
400 # switch to evaluate mode
401 model.eval()
402 metric_name = "Metric"
403 end_time = time.time()
404 writer_idx = 0
405 last_update_iter = -1
406 metric_ctx = [None] * len(args.metric_modules)
408 confusion_matrix = []
409 for n_cls in args.model_config.output_channels:
410 confusion_matrix.append(np.zeros((n_cls, n_cls+1)))
411 metric_txt = []
412 ard_err = None
413 for iter, (input_list, target_list, input_path, target_path) in enumerate(val_loader):
414 file_name = input_path[-1][0]
415 print("started inference of file_name:", file_name)
416 data_time.update(time.time() - end_time)
417 if args.gpu_mode:
418 input_list = [img.cuda() for img in input_list]
420 outputs = model(input_list)
421 outputs = outputs if isinstance(outputs,(list,tuple)) else [outputs]
423 if args.output_size is not None and target_list:
424 target_sizes = [tgt.shape for tgt in target_list]
425 outputs = upsample_tensors(outputs, target_sizes, args.upsample_mode)
426 elif args.output_size is not None and not target_list:
427 target_sizes = [args.output_size for _ in range(len(outputs))]
428 outputs = upsample_tensors(outputs, target_sizes, args.upsample_mode)
429 outputs = [out.cpu() for out in outputs]
431 for task_index in range(len(outputs)):
432 output = outputs[task_index]
433 gt_target = target_list[task_index] if target_list else None
434 if args.visualize_gt and target_list:
435 if args.model_config.output_type[task_index] is 'depth':
436 output = gt_target
437 else:
438 output = gt_target.to(dtype=torch.int8)
440 if args.remove_ignore_lbls_in_pred and not (args.model_config.output_type[task_index] is 'depth') and target_list :
441 output[gt_target == 255] = args.palette[task_index-1].shape[0]-1
442 for index in range(output.shape[0]):
443 if args.frame_IOU:
444 confusion_matrix[task_index] = np.zeros((args.model_config.output_channels[task_index], args.model_config.output_channels[task_index] + 1))
445 prediction = np.array(output[index])
446 if len(prediction.shape)>2 and prediction.shape[0]>1:
447 prediction = np.argmax(prediction, axis=0)
448 #
449 prediction = np.squeeze(prediction)
451 if target_list:
452 label = np.squeeze(np.array(target_list[task_index][index]))
453 if not args.model_config.output_type[task_index] is 'depth':
454 if args.en_accuracy_measurement:
455 confusion_matrix[task_index] = eval_output(args, prediction, label, confusion_matrix[task_index], args.model_config.output_channels[task_index])
456 accuracy, mean_iou, iou, f1_score= compute_accuracy(args, confusion_matrix[task_index], args.model_config.output_channels[task_index])
457 temp_txt = []
458 temp_txt.append(input_path[-1][index])
459 temp_txt.extend(iou)
460 metric_txt.append(temp_txt)
461 print('{}/{} Inferred Frame {} mean_iou={},'.format((args.batch_size*iter+index+1), len(val_dataset), input_path[-1][index], mean_iou))
462 if index == output.shape[0]-1:
463 print('Task={},\npixel_accuracy={},\nmean_iou={},\niou={},\nf1_score = {}'.format(task_index, accuracy, mean_iou, iou, f1_score))
464 sys.stdout.flush()
465 elif args.model_config.output_type[task_index] is 'depth':
466 valid = (label != 0)
467 gt = torch.tensor(label[valid]).float()
468 inference = torch.tensor(prediction[valid]).float()
469 if len(gt) > 2:
470 if ard_err is None:
471 ard_err = [absreldiff_rng3to80(inference, gt).mean()]
472 else:
473 ard_err.append(absreldiff_rng3to80(inference, gt).mean())
474 elif len(gt) < 2:
475 if ard_err is None:
476 ard_err = [0.0]
477 else:
478 ard_err.append(0.0)
480 print('{}/{} ARD: {}'.format((args.batch_size * iter + index), len(val_dataset),torch.tensor(ard_err).mean()))
482 seq = input_path[-1][index].split('/')[-4]
483 base_file = os.path.basename(input_path[-1][index])
485 if args.label_infer:
486 output_image = prediction
487 output_name = os.path.join(infer_path[task_index], seq + '_' + input_path[-1][index].split('/')[-3] + '_' + base_file)
488 cv2.imwrite(output_name, output_image)
489 print('{}/{}'.format((args.batch_size*iter+index), len(val_dataset)))
491 if hasattr(args, 'interest_pt') and args.interest_pt[task_index]:
492 print('{}/{}'.format((args.batch_size * iter + index), len(val_dataset)))
493 output_name = os.path.join(infer_path[task_index], seq + '_' + input_path[-1][index].split('/')[-3] + '_' + base_file)
494 output_name_short = os.path.join(infer_path[task_index], os.path.basename(input_path[-1][index]))
495 wrapper_write_desc(args=args, target_list=target_list, task_index=task_index, outputs=outputs, index=index, output_name=output_name, output_name_short=output_name_short)
497 if args.model_config.output_type[task_index] is 'depth':
498 output_name = os.path.join(infer_path[task_index], seq + '_' + input_path[-1][index].split('/')[-3] + '_' + base_file)
499 viz_depth(prediction = prediction, args=args, output_name = output_name, input_name=input_path[-1][task_index])
500 print('{}/{}'.format((args.batch_size * iter + index), len(val_dataset)))
502 if args.viz_op_type is not None and args.viz_op_type[task_index] == 'blend':
503 prediction_size = (prediction.shape[0], prediction.shape[1], 3)
504 output_image = args.palette[task_index-1][prediction.ravel()].reshape(prediction_size)
505 input_bgr = cv2.imread(input_path[-1][index]) #Read the actual RGB image
506 if args.img_border_crop is not None:
507 t, l, h, w = args.img_border_crop
508 input_bgr = input_bgr[t:t+h, l:l+w]
509 input_bgr = cv2.resize(input_bgr, dsize=(prediction.shape[1],prediction.shape[0]))
510 output_image = xnn.utils.chroma_blend(input_bgr, output_image)
511 output_name = os.path.join(infer_path[task_index], input_path[-1][index].split('/')[-4] + '_' + input_path[-1][index].split('/')[-3] + '_' +os.path.basename(input_path[-1][index]))
512 cv2.imwrite(output_name, output_image)
513 print('{}/{}'.format((args.batch_size*iter+index), len(val_dataset)))
514 elif args.viz_op_type is not None and args.viz_op_type[task_index] == 'color':
515 prediction_size = (prediction.shape[0], prediction.shape[1], 3)
516 output_image = args.palette[task_index-1][prediction.ravel()].reshape(prediction_size)
517 output_name = os.path.join(infer_path[task_index], input_path[-1][index].split('/')[-4] + '_' + input_path[-1][index].split('/')[-3] + '_' +os.path.basename(input_path[-1][index]))
518 cv2.imwrite(output_name, output_image)
519 print('{}/{}'.format((args.batch_size*iter+index), len(val_dataset)))
520 #
521 if args.car_mask: # generating car_mask (required for localization)
522 car_mask = np.logical_or(prediction == 13, prediction == 14, prediction == 16, prediction == 17)
523 prediction[car_mask] = 255
524 prediction[np.invert(car_mask)] = 0
525 output_image = prediction
526 output_name = os.path.join(infer_path[task_index], os.path.basename(input_path[-1][index]))
527 cv2.imwrite(output_name, output_image)
528 np.savetxt('metric.txt', metric_txt, fmt='%s')
533 ###############################################################
534 def get_save_path(args, phase=None):
535 date = args.date if args.date else datetime.datetime.now().strftime("%Y-%m-%d-%H-%M-%S")
536 save_path = os.path.join('./data/checkpoints', args.dataset_name, date + '_' + args.dataset_name + '_' + args.model_name)
537 save_path += '_resize{}x{}'.format(args.img_resize[1], args.img_resize[0])
538 if args.rand_crop:
539 save_path += '_crop{}x{}'.format(args.rand_crop[1], args.rand_crop[0])
540 #
541 phase = phase if (phase is not None) else args.phase
542 save_path = os.path.join(save_path, phase)
543 return save_path
546 def get_model_orig(model):
547 is_parallel_model = isinstance(model, (torch.nn.DataParallel, torch.nn.parallel.DistributedDataParallel))
548 model_orig = (model.module if is_parallel_model else model)
549 model_orig = (model_orig.module if isinstance(model_orig, (xnn.quantize.QuantBaseModule)) else model_orig)
550 return model_orig
553 def create_rand_inputs(args, is_cuda):
554 dummy_input = []
555 for i_ch in args.model_config.input_channels:
556 x = torch.rand((1, i_ch, args.img_resize[0], args.img_resize[1]))
557 x = x.cuda() if is_cuda else x
558 dummy_input.append(x)
559 #
560 return dummy_input
563 # FIX_ME:SN move to utils
564 def store_desc(args=[], output_name=[], write_dense=False, desc_tensor=[], prediction=[],
565 scale_to_write_kp_loc_to_orig_res=[1.0, 1.0],
566 learn_scaled_values=True):
567 sys.path.insert(0, './scripts/')
568 import write_desc as write_desc
570 if args.write_desc_type != 'NONE':
571 txt_file_name = output_name.replace(".png", ".txt")
572 if write_dense:
573 # write desc
574 desc_tensor = desc_tensor.astype(np.int16)
575 print("writing dense desc(64 ch) op: {} : {} : {} : {}".format(desc_tensor.shape, desc_tensor.dtype,
576 desc_tensor.min(), desc_tensor.max()))
577 desc_tensor_name = output_name.replace(".png", "_desc.npy")
578 np.save(desc_tensor_name, desc_tensor)
580 # utils_hist.comp_hist_tensor3d(x=desc_tensor, name='desc_64ch', en=True, dir='desc_64ch', log=True, ch_dim=0)
582 # write score channel
583 prediction = prediction.astype(np.int16)
585 print("writing dense score ch op: {} : {} : {} : {}".format(prediction.shape, prediction.dtype,
586 prediction.min(),
587 prediction.max()))
588 score_tensor_name = output_name.replace(".png", "_score.npy")
589 np.save(score_tensor_name, prediction)
591 # utils_hist.hist_tensor2D(x_ch = prediction, dir='score', name='score', en=True, log=True)
592 else:
593 prediction[prediction < 0.0] = 0.0
595 if learn_scaled_values:
596 img_interest_pt_cur = prediction.astype(np.uint16)
597 score_th = 127
598 else:
599 img_interest_pt_cur = prediction
600 score_th = 0.001
602 # at boundary pixles score/desc will be wrong so do not write pred qty near the borader.
603 guard_band = 32 if args.write_desc_type == 'PRED' else 0
605 write_desc.write_score_desc_as_text(desc_tensor_cur=desc_tensor, img_interest_pt_cur=img_interest_pt_cur,
606 txt_file_name=txt_file_name, score_th=score_th,
607 skip_fac_for_reading_desc=1, en_nms=args.en_nms,
608 scale_to_write_kp_loc_to_orig_res=scale_to_write_kp_loc_to_orig_res,
609 recursive_nms=True, learn_scaled_values=learn_scaled_values,
610 guard_band=guard_band, true_nms=True, f2r=args.do_pred_cordi_f2r)
613 #utils_hist.hist_tensor2D(x_ch = prediction, dir='score', name='score', en=True, log=True)
614 else:
615 prediction[prediction < 0.0] = 0.0
617 if learn_scaled_values:
618 img_interest_pt_cur = prediction.astype(np.uint16)
619 score_th = 127
620 else:
621 img_interest_pt_cur = prediction
622 score_th = 0.001
624 # at boundary pixles score/desc will be wrong so do not write pred qty near the borader.
625 guard_band = 32 if args.write_desc_type == 'PRED' else 0
627 write_desc.write_score_desc_as_text(desc_tensor_cur = desc_tensor, img_interest_pt_cur = img_interest_pt_cur,
628 txt_file_name = txt_file_name, score_th = score_th, skip_fac_for_reading_desc = 1, en_nms=args.en_nms,
629 scale_to_write_kp_loc_to_orig_res = scale_to_write_kp_loc_to_orig_res,
630 recursive_nms=True, learn_scaled_values=learn_scaled_values, guard_band = guard_band, true_nms=True, f2r=args.do_pred_cordi_f2r)
632 def viz_depth(prediction = [], args=[], output_name=[], input_name=[]):
633 max_value_depth = args.max_depth
634 output_image = torch.tensor(prediction)
635 if args.viz_depth_color_type == 'rainbow':
636 not_valid_indices = output_image == 0
637 output_image = 255*xnn.utils.tensor2array(output_image, max_value=max_value_depth, colormap='rainbow')[:,:,::-1]
638 output_image[not_valid_indices] = 0
639 elif args.viz_depth_color_type == 'rainbow_blend':
640 print(max_value_depth)
641 #scale_mul = 1 if args.visualize_gt else 255
642 print(output_image.min())
643 print(output_image.max())
644 not_valid_indices = output_image == 0
645 output_image = 255*xnn.utils.tensor2array(output_image, max_value=max_value_depth, colormap='rainbow')[:,:,::-1]
646 print(output_image.max())
647 #output_image[label == 1] = 0
648 input_bgr = cv2.imread(input_name) # Read the actual RGB image
649 if args.img_border_crop is not None:
650 t, l, h, w = args.img_border_crop
651 input_bgr = input_bgr[t:t+h, l:l+w]
652 input_bgr = cv2.resize(input_bgr, dsize=(prediction.shape[1], prediction.shape[0]))
653 if args.sky_dir:
654 label_file = os.path.join(args.sky_dir, seq, seq + '_image_00_' + base_file)
655 label = cv2.imread(label_file)
656 label = cv2.resize(label, dsize=(prediction.shape[1], prediction.shape[0]),
657 interpolation=cv2.INTER_NEAREST)
658 output_image[label == 1] = 0
659 output_image[not_valid_indices] = 0
660 output_image = xnn.utils.chroma_blend(input_bgr, output_image) # chroma_blend(input_bgr, output_image)
662 elif args.viz_depth_color_type == 'bone':
663 output_image = 255 * xnn.utils.tensor2array(output_image, max_value=max_value_depth, colormap='bone')
664 elif args.viz_depth_color_type == 'raw_depth':
665 output_image = np.array(output_image)
666 output_image[output_image > max_value_depth] = max_value_depth
667 output_image[output_image < 0] = 0
668 scale = 2.0**16 - 1.0 #255
669 output_image = (output_image / max_value_depth) * scale
670 output_image = output_image.astype(np.uint16)
671 # output_image[(label[:,:,0]==1)|(label[:,:,0]==4)]=255
672 elif args.viz_depth_color_type == 'plasma':
673 plt.imsave(output_name, output_image, cmap='plasma', vmin=0, vmax=max_value_depth)
674 elif args.viz_depth_color_type == 'log_greys':
675 plt.imsave(output_name, np.log10(output_image), cmap='Greys', vmin=0, vmax=np.log10(max_value_depth))
676 #plt.imsave(output_name, output_image, cmap='Greys', vmin=0, vmax=max_value_depth)
677 else:
678 print("undefined color type for visualization")
679 exit(0)
681 if args.viz_depth_color_type != 'plasma':
682 # plasma type will be handled by imsave
683 cv2.imwrite(output_name, output_image)
686 def wrapper_write_desc(args=[], target_list=None, task_index=0, outputs=[], index=0, output_name=[], output_name_short=[]):
687 if args.write_desc_type == 'GT':
688 # write GT desc
689 tensor_to_write = target_list[task_index]
690 elif args.write_desc_type == 'PRED':
691 # write predicted desc
692 tensor_to_write = outputs[task_index]
694 interest_pt_score = np.array(tensor_to_write[index, 0, ...])
696 if args.make_score_zero_mean:
697 # visulization code assumes range [0,255]. Add 128 to make range the same in case of zero mean too.
698 interest_pt_score += 128.0
700 if args.write_desc_type == 'NONE':
701 # scale + clip score between 0-255 and convert score_array to image
702 # scale_range = 127.0/0.005
703 # scale_range = 255.0/np.max(interest_pt_score)
704 scale_range = 1.0
705 interest_pt_score = np.clip(interest_pt_score * scale_range, 0.0, 255.0)
706 interest_pt_score = np.asarray(interest_pt_score, 'uint8')
708 interest_pt_descriptor = np.array(tensor_to_write[index, 1:, ...])
710 # output_name = os.path.join(infer_path[task_index], seq + '_' + input_path[-1][index].split('/')[-3] + '_' + base_file)
711 cv2.imwrite(output_name, interest_pt_score)
713 # output_name_short = os.path.join(infer_path[task_index], os.path.basename(input_path[-1][index]))
715 scale_to_write_kp_loc_to_orig_res = args.scale_to_write_kp_loc_to_orig_res
716 if args.scale_to_write_kp_loc_to_orig_res[0] == -1:
717 scale_to_write_kp_loc_to_orig_res[0] = input_list[task_index].shape[2] / target_list[task_index].shape[2]
718 scale_to_write_kp_loc_to_orig_res[1] = scale_to_write_kp_loc_to_orig_res[0]
720 print("scale_to_write_kp_loc_to_orig_res: ", scale_to_write_kp_loc_to_orig_res)
721 store_desc(args=args, output_name=output_name_short, desc_tensor=interest_pt_descriptor,
722 prediction=interest_pt_score,
723 scale_to_write_kp_loc_to_orig_res=scale_to_write_kp_loc_to_orig_res,
724 learn_scaled_values=args.learn_scaled_values_interest_pt,
725 write_dense=False)
728 def get_transforms(args):
729 # image normalization can be at the beginning of transforms or at the end
730 args.image_mean = np.array(args.image_mean, dtype=np.float32)
731 args.image_scale = np.array(args.image_scale, dtype=np.float32)
732 image_prenorm = xvision.transforms.image_transforms.NormalizeMeanScale(mean=args.image_mean, scale=args.image_scale) if args.image_prenorm else None
733 image_postnorm = xvision.transforms.image_transforms.NormalizeMeanScale(mean=args.image_mean, scale=args.image_scale) if (not image_prenorm) else None
734 color_2_gray = xvision.transforms.image_transforms.RandomColor2Gray(is_flow=args.is_flow, random_threshold= args.prob_color_to_gray[1]) if args.prob_color_to_gray != 0.0 else None
736 #target size must be according to output_size. prediction will be resized to output_size before evaluation.
737 test_transform = xvision.transforms.image_transforms.Compose([
738 image_prenorm,
739 xvision.transforms.image_transforms.AlignImages(),
740 xvision.transforms.image_transforms.MaskTarget(args.target_mask, 0),
741 xvision.transforms.image_transforms.CropRect(args.img_border_crop),
742 xvision.transforms.image_transforms.Scale(args.img_resize, target_size=args.output_size, is_flow=args.is_flow),
743 color_2_gray,
744 image_postnorm,
745 xvision.transforms.image_transforms.ConvertToTensor()
746 ])
748 return test_transform
751 def _upsample_impl(tensor, output_size, upsample_mode):
752 # upsample of long tensor is not supported currently. covert to float, just to avoid error.
753 # we can do thsi only in the case of nearest mode, otherwise output will have invalid values.
754 convert_to_float = False
755 if isinstance(tensor, (torch.LongTensor,torch.cuda.LongTensor)):
756 convert_to_float = True
757 tensor = tensor.float()
758 upsample_mode = 'nearest'
759 #
761 dim_added = False
762 if len(tensor.shape) < 4:
763 tensor = tensor[np.newaxis,...]
764 dim_added = True
765 #
766 if (tensor.size()[-2:] != output_size):
767 tensor = torch.nn.functional.interpolate(tensor, output_size, mode=upsample_mode)
768 # --
769 if dim_added:
770 tensor = tensor[0,...]
771 #
773 if convert_to_float:
774 tensor = tensor.long()
775 #
776 return tensor
778 def upsample_tensors(tensors, output_sizes, upsample_mode):
779 if not output_sizes:
780 return tensors
781 #
782 if isinstance(tensors, (list,tuple)):
783 for tidx, tensor in enumerate(tensors):
784 tensors[tidx] = _upsample_impl(tensor, output_sizes[tidx][-2:], upsample_mode)
785 #
786 else:
787 tensors = _upsample_impl(tensors, output_sizes[0][-2:], upsample_mode)
788 return tensors
793 def eval_output(args, output, label, confusion_matrix, n_classes):
794 if len(label.shape)>2:
795 label = label[:,:,0]
796 gt_labels = label.ravel()
797 det_labels = output.ravel().clip(0,n_classes)
798 gt_labels_valid_ind = np.where(gt_labels != 255)
799 gt_labels_valid = gt_labels[gt_labels_valid_ind]
800 det_labels_valid = det_labels[gt_labels_valid_ind]
801 for r in range(confusion_matrix.shape[0]):
802 for c in range(confusion_matrix.shape[1]):
803 confusion_matrix[r,c] += np.sum((gt_labels_valid==r) & (det_labels_valid==c))
805 return confusion_matrix
807 def compute_accuracy(args, confusion_matrix, n_classes):
808 num_selected_classes = n_classes
809 tp = np.zeros(n_classes)
810 population = np.zeros(n_classes)
811 det = np.zeros(n_classes)
812 iou = np.zeros(n_classes)
814 for r in range(n_classes):
815 for c in range(n_classes):
816 population[r] += confusion_matrix[r][c]
817 det[c] += confusion_matrix[r][c]
818 if r == c:
819 tp[r] += confusion_matrix[r][c]
821 for cls in range(num_selected_classes):
822 intersection = tp[cls]
823 union = population[cls] + det[cls] - tp[cls]
824 iou[cls] = (intersection / union) if union else 0 # For caffe jacinto script
825 #iou[cls] = (intersection / (union + np.finfo(np.float32).eps)) # For pytorch-jacinto script
827 num_nonempty_classes = 0
828 for pop in population:
829 if pop>0:
830 num_nonempty_classes += 1
832 mean_iou = np.sum(iou) / num_nonempty_classes if num_nonempty_classes else 0
833 accuracy = np.sum(tp) / np.sum(population) if np.sum(population) else 0
835 #F1 score calculation
836 fp = np.zeros(n_classes)
837 fn = np.zeros(n_classes)
838 precision = np.zeros(n_classes)
839 recall = np.zeros(n_classes)
840 f1_score = np.zeros(n_classes)
842 for cls in range(num_selected_classes):
843 fp[cls] = det[cls] - tp[cls]
844 fn[cls] = population[cls] - tp[cls]
845 precision[cls] = tp[cls] / (det[cls] + 1e-10)
846 recall[cls] = tp[cls] / (tp[cls] + fn[cls] + 1e-10)
847 f1_score[cls] = 2 * precision[cls]*recall[cls] / (precision[cls] + recall[cls] + 1e-10)
849 return accuracy, mean_iou, iou, f1_score
852 def infer_video(args, net):
853 videoIpHandle = imageio.get_reader(args.input, 'ffmpeg')
854 fps = math.ceil(videoIpHandle.get_meta_data()['fps'])
855 print(videoIpHandle.get_meta_data())
856 numFrames = min(len(videoIpHandle), args.num_images)
857 videoOpHandle = imageio.get_writer(args.output,'ffmpeg', fps=fps)
858 for num in range(numFrames):
859 print(num, end=' ')
860 sys.stdout.flush()
861 input_blob = videoIpHandle.get_data(num)
862 input_blob = input_blob[...,::-1] #RGB->BGR
863 output_blob = infer_blob(args, net, input_blob)
864 output_blob = output_blob[...,::-1] #BGR->RGB
865 videoOpHandle.append_data(output_blob)
866 videoOpHandle.close()
867 return
870 def absreldiff(x, y, eps = 0.0, max_val=None):
871 assert x.size() == y.size(), 'tensor dimension mismatch'
872 if max_val is not None:
873 x = torch.clamp(x, -max_val, max_val)
874 y = torch.clamp(y, -max_val, max_val)
875 #
877 diff = torch.abs(x - y)
878 y = torch.abs(y)
880 den_valid = (y == 0).float()
881 eps_arr = (den_valid * (1e-6)) # Just to avoid divide by zero
883 large_arr = (y > eps).float() # ARD is not a good measure for small ref values. Avoid them.
884 out = (diff / (y + eps_arr)) * large_arr
885 return out
888 def absreldiff_rng3to80(x, y):
889 return absreldiff(x, y, eps = 3.0, max_val=80.0)
893 def create_video(args, infer_path):
894 op_file_name = args.data_path.split('/')[-1]
895 os.system(' ffmpeg -framerate 30 -pattern_type glob -i "{}/*.png" -c:v libx264 -vb 50000k -qmax 20 -r 10 -vf scale=1024:512 -pix_fmt yuv420p {}.MP4'.format(infer_path, op_file_name))
897 def write_onnx_model(args, model, save_path, name='checkpoint.onnx'):
898 is_cuda = next(model.parameters()).is_cuda
899 input_list = create_rand_inputs(args, is_cuda=is_cuda)
900 #
901 model.eval()
902 torch.onnx.export(model, input_list, os.path.join(save_path, name), export_params=True, verbose=False,
903 do_constant_folding=True, opset_version=args.opset_version)
904 # torch onnx export does not update names. Do it using onnx.save
906 def assign_write_layer_ip_op_hook(model=None, save_path=None, args=None, file_format='bin'):
907 if args.write_layer_ip_op:
908 def write_tensor_hook_function_save_path(m, inp, out):
909 write_tensor_hook_function(m, inp, out, save_path=save_path, file_format=file_format)
911 # for dumping module outputs
912 for name, module in model.named_modules():
913 module.name = name
914 print(name)
916 en_write_layer = False
917 if args.write_layer_ip_op_names == None:
918 #write all layers
919 en_write_layer = True
920 else:
921 for layer_name_to_write in args.write_layer_ip_op_names:
922 if layer_name_to_write in name:
923 en_write_layer = True
924 break
926 if en_write_layer:
927 module.register_forward_hook(write_tensor_hook_function_save_path)
928 print('{:7} {:33} {:12} {:8} {:6} {:30} : {:17} : {:4} : {:11} : {:7} : {:7}'.format("type", "name", "layer", "min", "max", "tensor_shape", "dtype", "scale", "dtype", "min", "max"))
932 if __name__ == '__main__':
933 train_args = get_config()
934 main(train_args)