[jacinto-ai/pytorch-jacinto-ai-devkit.git] / modules / pytorch_jacinto_ai / engine / test_pixel2pixel_onnx.py
1 import os
2 import time
3 import sys
4 import warnings
6 import torch
7 import torch.nn.parallel
8 import torch.optim
9 import torch.utils.data
10 import datetime
11 import numpy as np
12 import random
13 import cv2
14 import PIL
15 import PIL.Image
17 import onnx
18 import caffe2
19 import caffe2.python.onnx.backend
21 from .. import xnn
22 from .. import vision
26 # ################################################
27 def get_config():
28 args = xnn.utils.ConfigNode()
29 args.model_config = xnn.utils.ConfigNode()
30 args.dataset_config = xnn.utils.ConfigNode()
32 args.dataset_name = 'flying_chairs' # dataset type
33 args.model_name = 'flownets' # model architecture, overwritten if pretrained is specified: '
35 args.data_path = './data/datasets' # path to dataset
36 args.save_path = None # checkpoints save path
37 args.pretrained = None
39 args.model_config.output_type = ['flow'] # the network is used to predict flow or depth or sceneflow')
40 args.model_config.output_channels = None # number of output channels
41 args.model_config.input_channels = None # number of input channels
42 args.n_classes = None # number of classes (for segmentation)
44 args.logger = None # logger stream to output into
46 args.prediction_type = 'flow' # the network is used to predict flow or depth or sceneflow
47 args.split_file = None # train_val split file
48 args.split_files = None # split list files. eg: train.txt val.txt
49 args.split_value = 0.8 # test_val split proportion (between 0 (only test) and 1 (only train))
51 args.workers = 8 # number of data loading workers
53 args.epoch_size = 0 # manual epoch size (will match dataset size if not specified)
54 args.epoch_size_val = 0 # manual epoch size (will match dataset size if not specified)
55 args.batch_size = 8 # mini_batch_size
56 args.total_batch_size = None # accumulated batch size. total_batch_size = batch_size*iter_size
57 args.iter_size = 1 # iteration size. total_batch_size = batch_size*iter_size
59 args.tensorboard_num_imgs = 5 # number of imgs to display in tensorboard
60 args.phase = 'validation' # evaluate model on validation set
61 args.pretrained = None # path to pre_trained model
62 args.date = None # don\'t append date timestamp to folder
63 args.print_freq = 10 # print frequency (default: 100)
65 args.div_flow = 1.0 # value by which flow will be divided. Original value is 20 but 1 with batchNorm gives good results
66 args.milestones = [100,150,200] # epochs at which learning rate is divided by 2
67 args.losses = ['supervised_loss'] # loss functions to minimize
68 args.metrics = ['supervised_error'] # metric/measurement/error functions for train/validation
69 args.class_weights = None # class weights
71 args.multistep_gamma = 0.5 # steps for step scheduler
72 args.polystep_power = 1.0 # power for polynomial scheduler
74 args.rand_seed = 1 # random seed
75 args.img_border_crop = None # image border crop rectangle. can be relative or absolute
76 args.target_mask = None # mask rectangle. can be relative or absolute. last value is the mask value
77 args.img_resize = None # image size to be resized to
78 args.rand_scale = (1,1.25) # random scale range for training
79 args.rand_crop = None # image size to be cropped to')
80 args.output_size = None # target output size to be resized to')
82 args.count_flops = True # count flops and report
84 args.shuffle = True # shuffle or not
85 args.is_flow = None # whether entries in images and targets lists are optical flow or not
87 args.multi_decoder = True # whether to use multiple decoders or unified decoder
89 args.create_video = False # whether to create video out of the inferred images
91 args.input_tensor_name = ['0'] # list of input tensore names
93 args.upsample_mode = 'nearest' # upsample mode to use., choices=['nearest','bilinear']
95 args.image_prenorm = True # whether normalization is done before all other the transforms
96 args.image_mean = [128.0] # image mean for input image normalization
97 args.image_scale = [1.0/(0.25*256)] # image scaling/mult for input iamge normalization
98 return args
101 # ################################################
102 # to avoid hangs in data loader with multi threads
103 # this was observed after using cv2 image processing functions
104 # https://github.com/pytorch/pytorch/issues/1355
105 cv2.setNumThreads(0)
108 # ################################################
109 def main(args):
111 assert not hasattr(args, 'model_surgery'), 'the argument model_surgery is deprecated, it is not needed now - remove it'
113 #################################################
114 # global settings. rand seeds for repeatability
115 random.seed(args.rand_seed)
116 np.random.seed(args.rand_seed)
117 torch.manual_seed(args.rand_seed)
119 ################################
120 # print everything for log
121 print('=> args: ', args)
123 ################################
124 # args check and config
125 if args.iter_size != 1 and args.total_batch_size is not None:
126 warnings.warn("only one of --iter_size or --total_batch_size must be set")
127 #
128 if args.total_batch_size is not None:
129 args.iter_size = args.total_batch_size//args.batch_size
130 else:
131 args.total_batch_size = args.batch_size*args.iter_size
132 #
134 assert args.pretrained is not None, 'pretrained onnx model path should be provided'
136 #################################################
137 # set some global flags and initializations
138 # keep it in args for now - although they don't belong here strictly
139 # using pin_memory is seen to cause issues, especially when when lot of memory is used.
140 args.use_pinned_memory = False
141 args.n_iter = 0
142 args.best_metric = -1
144 #################################################
145 if args.save_path is None:
146 save_path = get_save_path(args)
147 else:
148 save_path = args.save_path
149 #
151 print('=> will save everything to {}'.format(save_path))
153 if not os.path.exists(save_path):
154 os.makedirs(save_path)
156 #################################################
157 if args.logger is None:
158 log_file = os.path.splitext(os.path.basename(__file__))[0] + '.log'
159 args.logger = xnn.utils.TeeLogger(filename=os.path.join(save_path,log_file))
160 transforms = get_transforms(args)
162 print("=> fetching img pairs in '{}'".format(args.data_path))
163 split_arg = args.split_file if args.split_file else (args.split_files if args.split_files else args.split_value)
164 val_dataset = vision.datasets.pixel2pixel.__dict__[args.dataset_name](args.dataset_config, args.data_path, split=split_arg, transforms=transforms)
166 print('=> {} val samples found'.format(len(val_dataset)))
168 val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=args.batch_size,
169 num_workers=args.workers, pin_memory=args.use_pinned_memory, shuffle=args.shuffle)
170 #
172 #################################################
173 args.model_config.output_channels = val_dataset.num_classes() if (args.model_config.output_channels == None and 'num_classes' in dir(val_dataset)) else None
174 args.n_classes = args.model_config.output_channels[0]
176 #################################################
177 # create model
178 print("=> creating model '{}'".format(args.model_name))
180 model = onnx.load(args.pretrained)
181 # Run the ONNX model with Caffe2
182 onnx.checker.check_model(model)
183 model = caffe2.python.onnx.backend.prepare(model)
186 #################################################
187 if args.palette:
188 print('Creating palette')
189 eval_string = args.palette
190 palette = eval(eval_string)
191 args.palette = np.zeros((256,3))
192 for i, p in enumerate(palette):
193 args.palette[i,0] = p[0]
194 args.palette[i,1] = p[1]
195 args.palette[i,2] = p[2]
196 args.palette = args.palette[...,::-1] #RGB->BGR, since palette is expected to be given in RGB format
198 infer_path = os.path.join(save_path, 'inference')
199 if not os.path.exists(infer_path):
200 os.makedirs(infer_path)
202 #################################################
203 with torch.no_grad():
204 validate(args, val_dataset, val_loader, model, 0, infer_path)
206 if args.create_video:
207 create_video(args, infer_path=infer_path)
209 def validate(args, val_dataset, val_loader, model, epoch, infer_path):
210 data_time = xnn.utils.AverageMeter()
211 avg_metric = xnn.utils.AverageMeter()
213 # switch to evaluate mode
214 #model.eval()
215 metric_name = "Metric"
216 end_time = time.time()
217 writer_idx = 0
218 last_update_iter = -1
219 metric_ctx = [None] * len(args.metric_modules)
221 if args.label:
222 confusion_matrix = np.zeros((args.n_classes, args.n_classes+1))
223 for iter, (input_list, target_list, input_path, target_path) in enumerate(val_loader):
224 data_time.update(time.time() - end_time)
225 target_sizes = [tgt.shape for tgt in target_list]
227 input_dict = {args.input_tensor_name[idx]: input_list[idx].numpy() for idx in range(len(input_list))}
228 output = model.run(input_dict)[0]
230 list_output = type(output) in (list, tuple)
231 output_pred = output[0] if list_output else output
233 if args.output_size is not None:
234 output_pred = upsample_tensors(output_pred, target_sizes, args.upsample_mode)
235 #
236 if args.blend:
237 for index in range(output_pred.shape[0]):
238 prediction = np.squeeze(np.array(output_pred[index]))
239 #prediction = np.argmax(prediction, axis = 0)
240 prediction_size = (prediction.shape[0], prediction.shape[1], 3)
241 output_image = args.palette[prediction.ravel()].reshape(prediction_size)
242 input_bgr = cv2.imread(input_path[0][index]) #Read the actual RGB image
243 input_bgr = cv2.resize(input_bgr, dsize=(prediction.shape[1],prediction.shape[0]))
245 output_image = chroma_blend(input_bgr, output_image)
246 output_name = os.path.join(infer_path, os.path.basename(input_path[0][index]))
247 cv2.imwrite(output_name, output_image)
249 if args.label:
250 for index in range(output_pred.shape[0]):
251 prediction = np.array(output_pred[index])
252 #prediction = np.argmax(prediction, axis = 0)
253 label = np.squeeze(np.array(target_list[0][index]))
254 confusion_matrix = eval_output(args, prediction, label, confusion_matrix)
255 accuracy, mean_iou, iou, f1_score= compute_accuracy(args, confusion_matrix)
256 print('pixel_accuracy={}, mean_iou={}, iou={}, f1_score = {}'.format(accuracy, mean_iou, iou, f1_score))
257 sys.stdout.flush()
258 else:
259 for iter, (input_list, _ , input_path, _) in enumerate(val_loader):
260 data_time.update(time.time() - end_time)
262 input_dict = {args.input_tensor_name[idx]: input_list[idx].numpy() for idx in range(len(input_list))}
263 output = model.run(input_dict)[0]
265 list_output = type(output) in (list, tuple)
266 output_pred = output[0] if list_output else output
267 input_path = input_path[0]
269 if args.output_size is not None:
270 target_sizes = [args.output_size]
271 output_pred = upsample_tensors(output_pred, target_sizes, args.upsample_mode)
272 #
273 if args.blend:
274 for index in range(output_pred.shape[0]):
275 prediction = np.squeeze(np.array(output_pred[index])) #np.squeeze(np.array(output_pred[index].cpu()))
276 prediction_size = (prediction.shape[0], prediction.shape[1], 3)
277 output_image = args.palette[prediction.ravel()].reshape(prediction_size)
278 input_bgr = cv2.imread(input_path[index]) #Read the actual RGB image
279 input_bgr = cv2.resize(input_bgr, (args.img_resize[1], args.img_resize[0]), interpolation=cv2.INTER_LINEAR)
280 output_image = chroma_blend(input_bgr, output_image)
281 output_name = os.path.join(infer_path, os.path.basename(input_path[index]))
282 cv2.imwrite(output_name, output_image)
283 print('Inferred image {}'.format(input_path[index]))
284 if args.car_mask: #generating car_mask (required for localization)
285 for index in range(output_pred.shape[0]):
286 prediction = np.array(output_pred[index])
287 prediction = np.argmax(prediction, axis = 0)
288 car_mask = np.logical_or(prediction == 13, prediction == 14, prediction == 16, prediction, prediction == 17)
289 prediction[car_mask] = 255
290 prediction[np.invert(car_mask)] = 0
291 output_image = prediction
292 output_name = os.path.join(infer_path, os.path.basename(input_path[index]))
293 cv2.imwrite(output_name, output_image)
296 def get_save_path(args, phase=None):
297 date = args.date if args.date else datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
298 save_path = os.path.join('./data/checkpoints', args.dataset_name, '{}_{}_'.format(date, args.model_name))
299 save_path += 'b{}'.format(args.batch_size)
300 phase = phase if (phase is not None) else args.phase
301 save_path = os.path.join(save_path, phase)
302 return save_path
305 def get_transforms(args):
306 # image normalization can be at the beginning of transforms or at the end
307 args.image_mean = np.array(args.image_mean, dtype=np.float32)
308 args.image_scale = np.array(args.image_scale, dtype=np.float32)
309 image_prenorm = vision.transforms.image_transforms.NormalizeMeanScale(mean=args.image_mean, scale=args.image_scale) if args.image_prenorm else None
310 image_postnorm = vision.transforms.image_transforms.NormalizeMeanScale(mean=args.image_mean, scale=args.image_scale) if (not image_prenorm) else None
312 #target size must be according to output_size. prediction will be resized to output_size before evaluation.
313 test_transform = vision.transforms.image_transforms.Compose([
314 image_prenorm,
315 vision.transforms.image_transforms.AlignImages(),
316 vision.transforms.image_transforms.MaskTarget(args.target_mask, 0),
317 vision.transforms.image_transforms.CropRect(args.img_border_crop),
318 vision.transforms.image_transforms.Scale(args.img_resize, target_size=args.output_size, is_flow=args.is_flow),
319 image_postnorm,
320 vision.transforms.image_transforms.ConvertToTensor()
321 ])
323 return test_transform
326 def _upsample_impl(tensor, output_size, upsample_mode):
327 # upsample of long tensor is not supported currently. covert to float, just to avoid error.
328 # we can do this only in the case of nearest mode, otherwise output will have invalid values.
329 convert_tensor_to_float = False
330 convert_np_to_float = False
331 if isinstance(tensor, (torch.LongTensor,torch.cuda.LongTensor)):
332 convert_tensor_to_float = True
333 original_dtype = tensor.dtype
334 tensor = tensor.float()
335 elif isinstance(tensor, np.ndarray) and (np.dtype != np.float32):
336 convert_np_to_float = True
337 original_dtype = tensor.dtype
338 tensor = tensor.astype(np.float32)
339 #
341 dim_added = False
342 if len(tensor.shape) < 4:
343 tensor = tensor[np.newaxis,...]
344 dim_added = True
345 #
346 if (tensor.shape[-2:] != output_size):
347 assert tensor.shape[1] == 1, 'TODO: add code for multi channel resizing'
348 out_tensor = np.zeros((tensor.shape[0],tensor.shape[1],output_size[0],output_size[1]),dtype=np.float32)
349 for b_idx in range(tensor.shape[0]):
350 b_tensor = PIL.Image.fromarray(tensor[b_idx,0])
351 b_tensor = b_tensor.resize((output_size[1],output_size[0]), PIL.Image.NEAREST)
352 out_tensor[b_idx,0,...] = np.asarray(b_tensor)
353 #
354 tensor = out_tensor
355 #
356 if dim_added:
357 tensor = tensor[0]
358 #
360 if convert_tensor_to_float:
361 tensor = tensor.long()
362 elif convert_np_to_float:
363 tensor = tensor.astype(original_dtype)
364 #
365 return tensor
367 def upsample_tensors(tensors, output_sizes, upsample_mode):
368 if isinstance(tensors, (list,tuple)):
369 for tidx, tensor in enumerate(tensors):
370 tensors[tidx] = _upsample_impl(tensor, output_sizes[tidx][-2:], upsample_mode)
371 #
372 else:
373 tensors = _upsample_impl(tensors, output_sizes[0][-2:], upsample_mode)
374 return tensors
377 def chroma_blend(image, color):
378 image_yuv = cv2.cvtColor(image.astype(np.uint8), cv2.COLOR_BGR2YUV)
379 image_y,image_u,image_v = cv2.split(image_yuv)
380 color_yuv = cv2.cvtColor(color.astype(np.uint8), cv2.COLOR_BGR2YUV)
381 color_y,color_u,color_v = cv2.split(color_yuv)
382 image_y = np.uint8(image_y)
383 color_u = np.uint8(color_u)
384 color_v = np.uint8(color_v)
385 image_yuv = cv2.merge((image_y,color_u,color_v))
386 image = cv2.cvtColor(image_yuv.astype(np.uint8), cv2.COLOR_YUV2BGR)
387 return image
391 def eval_output(args, output, label, confusion_matrix):
393 if len(label.shape)>2:
394 label = label[:,:,0]
395 gt_labels = label.ravel()
396 det_labels = output.ravel().clip(0,args.n_classes)
397 gt_labels_valid_ind = np.where(gt_labels != 255)
398 gt_labels_valid = gt_labels[gt_labels_valid_ind]
399 det_labels_valid = det_labels[gt_labels_valid_ind]
400 for r in range(confusion_matrix.shape[0]):
401 for c in range(confusion_matrix.shape[1]):
402 confusion_matrix[r,c] += np.sum((gt_labels_valid==r) & (det_labels_valid==c))
404 return confusion_matrix
406 def compute_accuracy(args, confusion_matrix):
408 #pdb.set_trace()
409 num_selected_classes = args.n_classes
410 tp = np.zeros(args.n_classes)
411 population = np.zeros(args.n_classes)
412 det = np.zeros(args.n_classes)
413 iou = np.zeros(args.n_classes)
415 for r in range(args.n_classes):
416 for c in range(args.n_classes):
417 population[r] += confusion_matrix[r][c]
418 det[c] += confusion_matrix[r][c]
419 if r == c:
420 tp[r] += confusion_matrix[r][c]
422 for cls in range(num_selected_classes):
423 intersection = tp[cls]
424 union = population[cls] + det[cls] - tp[cls]
425 iou[cls] = (intersection / union) if union else 0 # For caffe jacinto script
426 #iou[cls] = (intersection / (union + np.finfo(np.float32).eps)) # For pytorch-jacinto script
428 num_nonempty_classes = 0
429 for pop in population:
430 if pop>0:
431 num_nonempty_classes += 1
433 mean_iou = np.sum(iou) / num_nonempty_classes
434 accuracy = np.sum(tp) / np.sum(population)
436 #F1 score calculation
437 fp = np.zeros(args.n_classes)
438 fn = np.zeros(args.n_classes)
439 precision = np.zeros(args.n_classes)
440 recall = np.zeros(args.n_classes)
441 f1_score = np.zeros(args.n_classes)
443 for cls in range(num_selected_classes):
444 fp[cls] = det[cls] - tp[cls]
445 fn[cls] = population[cls] - tp[cls]
446 precision[cls] = tp[cls] / (det[cls] + 1e-10)
447 recall[cls] = tp[cls] / (tp[cls] + fn[cls] + 1e-10)
448 f1_score[cls] = 2 * precision[cls]*recall[cls] / (precision[cls] + recall[cls] + 1e-10)
450 return accuracy, mean_iou, iou, f1_score
453 def infer_video(args, net):
454 videoIpHandle = imageio.get_reader(args.input, 'ffmpeg')
455 fps = math.ceil(videoIpHandle.get_meta_data()['fps'])
456 print(videoIpHandle.get_meta_data())
457 numFrames = min(len(videoIpHandle), args.num_images)
458 videoOpHandle = imageio.get_writer(args.output,'ffmpeg', fps=fps)
459 for num in range(numFrames):
460 print(num, end=' ')
461 sys.stdout.flush()
462 input_blob = videoIpHandle.get_data(num)
463 input_blob = input_blob[...,::-1] #RGB->BGR
464 output_blob = infer_blob(args, net, input_blob)
465 output_blob = output_blob[...,::-1] #BGR->RGB
466 videoOpHandle.append_data(output_blob)
467 videoOpHandle.close()
468 return
470 def create_video(args, infer_path):
471 op_file_name = args.data_path.split('/')[-1]
472 os.system(' ffmpeg -framerate 30 -pattern_type glob -i "{}/*.png" -c:v libx264 -vb 50000k -qmax 20 -r 10 -vf \
473 scale=1024:512 -pix_fmt yuv420p {}.MP4'.format(infer_path, op_file_name))
475 if __name__ == '__main__':
476 train_args = get_config()
477 train_args = parser.parse_args()
478 main(train_args)