]> Gitweb @ Texas Instruments - Open Source Git Repositories - git.TI.com/gitweb - jacinto-ai/pytorch-jacinto-ai-devkit.git/blob - modules/pytorch_jacinto_ai/engine/test_pixel2pixel_onnx.py
release commit
[jacinto-ai/pytorch-jacinto-ai-devkit.git] / modules / pytorch_jacinto_ai / engine / test_pixel2pixel_onnx.py
1 import os
2 import time
3 import sys
4 import warnings
6 import torch
7 import torch.nn.parallel
8 import torch.optim
9 import torch.utils.data
10 import datetime
11 import numpy as np
12 import random
13 import cv2
14 import PIL
15 import PIL.Image
17 import onnx
18 import caffe2
19 import caffe2.python.onnx.backend
21 from .. import xnn
22 from .. import vision
26 # ################################################
27 def get_config():
28     args = xnn.utils.ConfigNode()
29     args.model_config = xnn.utils.ConfigNode()
30     args.dataset_config = xnn.utils.ConfigNode()
32     args.dataset_name = 'flying_chairs'              # dataset type
33     args.model_name = 'flownets'                # model architecture, overwritten if pretrained is specified: '
35     args.data_path = './data/datasets'                       # path to dataset
36     args.save_path = None            # checkpoints save path
37     args.pretrained = None
39     args.model_config.output_type = ['flow']                # the network is used to predict flow or depth or sceneflow')
40     args.model_config.output_channels = None                 # number of output channels
41     args.model_config.input_channels = None                  # number of input channels
42     args.n_classes = None                       # number of classes (for segmentation)
44     args.logger = None                          # logger stream to output into
46     args.prediction_type = 'flow'               # the network is used to predict flow or depth or sceneflow
47     args.split_file = None                      # train_val split file
48     args.split_files = None                     # split list files. eg: train.txt val.txt
49     args.split_value = 0.8                      # test_val split proportion (between 0 (only test) and 1 (only train))
51     args.workers = 8                            # number of data loading workers
53     args.epoch_size = 0                         # manual epoch size (will match dataset size if not specified)
54     args.epoch_size_val = 0                     # manual epoch size (will match dataset size if not specified)
55     args.batch_size = 8                         # mini_batch_size
56     args.total_batch_size = None                # accumulated batch size. total_batch_size = batch_size*iter_size
57     args.iter_size = 1                          # iteration size. total_batch_size = batch_size*iter_size
59     args.tensorboard_num_imgs = 5               # number of imgs to display in tensorboard
60     args.phase = 'validation'                   # evaluate model on validation set
61     args.pretrained = None                      # path to pre_trained model
62     args.date = None                            # don\'t append date timestamp to folder
63     args.print_freq = 10                        # print frequency (default: 100)
65     args.div_flow = 1.0                         # value by which flow will be divided. Original value is 20 but 1 with batchNorm gives good results
66     args.milestones = [100,150,200]             # epochs at which learning rate is divided by 2
67     args.losses = ['supervised_loss']           # loss functions to minimize
68     args.metrics = ['supervised_error']         # metric/measurement/error functions for train/validation
69     args.class_weights = None                   # class weights
71     args.multistep_gamma = 0.5                  # steps for step scheduler
72     args.polystep_power = 1.0                   # power for polynomial scheduler
73     args.train_fwbw = False                     # do forward backward step while training
75     args.rand_seed = 1                          # random seed
76     args.img_border_crop = None                 # image border crop rectangle. can be relative or absolute
77     args.target_mask = None                      # mask rectangle. can be relative or absolute. last value is the mask value
78     args.img_resize = None                      # image size to be resized to
79     args.rand_scale = (1,1.25)                  # random scale range for training
80     args.rand_crop = None                       # image size to be cropped to')
81     args.output_size = None                     # target output size to be resized to')
83     args.count_flops = True                     # count flops and report
85     args.shuffle = True                         # shuffle or not
86     args.is_flow = None                         # whether entries in images and targets lists are optical flow or not
88     args.multi_decoder = True                   # whether to use multiple decoders or unified decoder
90     args.create_video = False                   # whether to create video out of the inferred images
92     args.input_tensor_name = ['0']              # list of input tensore names
94     args.upsample_mode = 'nearest'              # upsample mode to use., choices=['nearest','bilinear']
96     args.image_prenorm = True                   # whether normalization is done before all other the transforms
97     args.image_mean = [128.0]                   # image mean for input image normalization
98     args.image_scale = [1.0/(0.25*256)]         # image scaling/mult for input iamge normalization
99     return args
102 # ################################################
103 # to avoid hangs in data loader with multi threads
104 # this was observed after using cv2 image processing functions
105 # https://github.com/pytorch/pytorch/issues/1355
106 cv2.setNumThreads(0)
109 # ################################################
110 def main(args):
112     assert not hasattr(args, 'model_surgery'), 'the argument model_surgery is deprecated, it is not needed now - remove it'
114     #################################################
115     # global settings. rand seeds for repeatability
116     random.seed(args.rand_seed)
117     np.random.seed(args.rand_seed)
118     torch.manual_seed(args.rand_seed)
120     ################################
121     # print everything for log
122     print('=> args: ', args)
124     ################################
125     # args check and config
126     if args.iter_size != 1 and args.total_batch_size is not None:
127         warnings.warn("only one of --iter_size or --total_batch_size must be set")
128     #
129     if args.total_batch_size is not None:
130         args.iter_size = args.total_batch_size//args.batch_size
131     else:
132         args.total_batch_size = args.batch_size*args.iter_size
133     #
135     assert args.pretrained is not None, 'pretrained onnx model path should be provided'
137     #################################################
138     # set some global flags and initializations
139     # keep it in args for now - although they don't belong here strictly
140     # using pin_memory is seen to cause issues, especially when when lot of memory is used.
141     args.use_pinned_memory = False
142     args.n_iter = 0
143     args.best_metric = -1
145     #################################################
146     if args.save_path is None:
147         save_path = get_save_path(args)
148     else:
149         save_path = args.save_path
150     #
152     print('=> will save everything to {}'.format(save_path))
153     
154     if not os.path.exists(save_path):
155         os.makedirs(save_path)
157     #################################################
158     if args.logger is None:
159         log_file = os.path.splitext(os.path.basename(__file__))[0] + '.log'
160         args.logger = xnn.utils.TeeLogger(filename=os.path.join(save_path,log_file))
161     transforms = get_transforms(args)
163     print("=> fetching img pairs in '{}'".format(args.data_path))
164     split_arg = args.split_file if args.split_file else (args.split_files if args.split_files else args.split_value)
165     val_dataset = vision.datasets.pixel2pixel.__dict__[args.dataset_name](args.dataset_config, args.data_path, split=split_arg, transforms=transforms)
167     print('=> {} val samples found'.format(len(val_dataset)))
168     
169     val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=args.batch_size,
170         num_workers=args.workers, pin_memory=args.use_pinned_memory, shuffle=args.shuffle)
171     #
173     #################################################
174     args.model_config.output_channels = val_dataset.num_classes() if (args.model_config.output_channels == None and 'num_classes' in dir(val_dataset)) else None
175     args.n_classes = args.model_config.output_channels[0]
177     #################################################
178     # create model
179     print("=> creating model '{}'".format(args.model_name))
181     model = onnx.load(args.pretrained)
182     # Run the ONNX model with Caffe2
183     onnx.checker.check_model(model)
184     model = caffe2.python.onnx.backend.prepare(model)
187     #################################################
188     if args.palette:
189         print('Creating palette')
190         eval_string = args.palette
191         palette = eval(eval_string)
192         args.palette = np.zeros((256,3))
193         for i, p in enumerate(palette):
194             args.palette[i,0] = p[0]
195             args.palette[i,1] = p[1]
196             args.palette[i,2] = p[2]
197         args.palette = args.palette[...,::-1] #RGB->BGR, since palette is expected to be given in RGB format
199     infer_path = os.path.join(save_path, 'inference')
200     if not os.path.exists(infer_path):
201         os.makedirs(infer_path)
203     #################################################
204     with torch.no_grad():
205         validate(args, val_dataset, val_loader, model, 0, infer_path)
207     if args.create_video:
208         create_video(args, infer_path=infer_path)
210 def validate(args, val_dataset, val_loader, model, epoch, infer_path):
211     data_time = xnn.utils.AverageMeter()
212     avg_metric = xnn.utils.AverageMeter()
214     # switch to evaluate mode
215     #model.eval()
216     metric_name = "Metric"
217     end_time = time.time()
218     writer_idx = 0
219     last_update_iter = -1
220     metric_ctx = [None] * len(args.metric_modules)
222     if args.label:
223         confusion_matrix = np.zeros((args.n_classes, args.n_classes+1))
224         for iter, (input_list, target_list, input_path, target_path) in enumerate(val_loader):
225             data_time.update(time.time() - end_time)
226             target_sizes = [tgt.shape for tgt in target_list]
228             input_dict = {args.input_tensor_name[idx]: input_list[idx].numpy() for idx in range(len(input_list))}
229             output = model.run(input_dict)[0]
231             list_output = type(output) in (list, tuple)
232             output_pred = output[0] if list_output else output
234             if args.output_size is not None:
235                 output_pred = upsample_tensors(output_pred, target_sizes, args.upsample_mode)
236             #
237             if args.blend:
238                 for index in range(output_pred.shape[0]):
239                     prediction = np.squeeze(np.array(output_pred[index]))
240                     #prediction = np.argmax(prediction, axis = 0)
241                     prediction_size = (prediction.shape[0], prediction.shape[1], 3)
242                     output_image = args.palette[prediction.ravel()].reshape(prediction_size)
243                     input_bgr = cv2.imread(input_path[0][index]) #Read the actual RGB image
244                     input_bgr = cv2.resize(input_bgr, dsize=(prediction.shape[1],prediction.shape[0]))
246                     output_image = chroma_blend(input_bgr, output_image)
247                     output_name = os.path.join(infer_path, os.path.basename(input_path[0][index]))
248                     cv2.imwrite(output_name, output_image)
250             if args.label:
251                 for index in range(output_pred.shape[0]): 
252                     prediction = np.array(output_pred[index])
253                     #prediction = np.argmax(prediction, axis = 0)
254                     label = np.squeeze(np.array(target_list[0][index]))
255                     confusion_matrix = eval_output(args, prediction, label, confusion_matrix)
256                     accuracy, mean_iou, iou, f1_score= compute_accuracy(args, confusion_matrix)
257                 print('pixel_accuracy={}, mean_iou={}, iou={}, f1_score = {}'.format(accuracy, mean_iou, iou, f1_score))
258                 sys.stdout.flush()
259     else:
260         for iter, (input_list, _ , input_path, _) in enumerate(val_loader):
261             data_time.update(time.time() - end_time)
263             input_dict = {args.input_tensor_name[idx]: input_list[idx].numpy() for idx in range(len(input_list))}
264             output = model.run(input_dict)[0]
266             list_output = type(output) in (list, tuple)
267             output_pred = output[0] if list_output else output
268             input_path = input_path[0]
269             
270             if args.output_size is not None:
271                 target_sizes = [args.output_size]
272                 output_pred = upsample_tensors(output_pred, target_sizes, args.upsample_mode)
273             #
274             if args.blend:
275                 for index in range(output_pred.shape[0]):
276                     prediction = np.squeeze(np.array(output_pred[index])) #np.squeeze(np.array(output_pred[index].cpu()))
277                     prediction_size = (prediction.shape[0], prediction.shape[1], 3)
278                     output_image = args.palette[prediction.ravel()].reshape(prediction_size)
279                     input_bgr = cv2.imread(input_path[index]) #Read the actual RGB image
280                     input_bgr = cv2.resize(input_bgr, (args.img_resize[1], args.img_resize[0]), interpolation=cv2.INTER_LINEAR)
281                     output_image = chroma_blend(input_bgr, output_image)
282                     output_name = os.path.join(infer_path, os.path.basename(input_path[index]))
283                     cv2.imwrite(output_name, output_image)
284                     print('Inferred image {}'.format(input_path[index]))
285             if args.car_mask:   #generating car_mask (required for localization)
286                 for index in range(output_pred.shape[0]):
287                     prediction = np.array(output_pred[index])
288                     prediction = np.argmax(prediction, axis = 0)
289                     car_mask = np.logical_or(prediction == 13, prediction == 14, prediction == 16, prediction, prediction == 17)
290                     prediction[car_mask] = 255
291                     prediction[np.invert(car_mask)] = 0
292                     output_image = prediction
293                     output_name = os.path.join(infer_path, os.path.basename(input_path[index]))
294                     cv2.imwrite(output_name, output_image)
297 def get_save_path(args, phase=None):
298     date = args.date if args.date else datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
299     save_path = os.path.join('./data/checkpoints', args.dataset_name, '{}_{}_'.format(date, args.model_name))
300     save_path += 'b{}'.format(args.batch_size)
301     phase = phase if (phase is not None) else args.phase
302     save_path = os.path.join(save_path, phase)
303     return save_path
306 def get_transforms(args):
307     # image normalization can be at the beginning of transforms or at the end
308     args.image_mean = np.array(args.image_mean, dtype=np.float32)
309     args.image_scale = np.array(args.image_scale, dtype=np.float32)
310     image_prenorm = vision.transforms.image_transforms.NormalizeMeanScale(mean=args.image_mean, scale=args.image_scale) if args.image_prenorm else None
311     image_postnorm = vision.transforms.image_transforms.NormalizeMeanScale(mean=args.image_mean, scale=args.image_scale) if (not image_prenorm) else None
313     #target size must be according to output_size. prediction will be resized to output_size before evaluation.
314     test_transform = vision.transforms.image_transforms.Compose([
315         image_prenorm,
316         vision.transforms.image_transforms.AlignImages(),
317         vision.transforms.image_transforms.MaskTarget(args.target_mask, 0),
318         vision.transforms.image_transforms.CropRect(args.img_border_crop),
319         vision.transforms.image_transforms.Scale(args.img_resize, target_size=args.output_size, is_flow=args.is_flow),
320         image_postnorm,
321         vision.transforms.image_transforms.ConvertToTensor()
322         ])
324     return test_transform
327 def _upsample_impl(tensor, output_size, upsample_mode):
328     # upsample of long tensor is not supported currently. covert to float, just to avoid error.
329     # we can do this only in the case of nearest mode, otherwise output will have invalid values.
330     convert_tensor_to_float = False
331     convert_np_to_float = False
332     if isinstance(tensor, (torch.LongTensor,torch.cuda.LongTensor)):
333         convert_tensor_to_float = True
334         original_dtype = tensor.dtype
335         tensor = tensor.float()
336     elif isinstance(tensor, np.ndarray) and (np.dtype != np.float32):
337         convert_np_to_float = True
338         original_dtype = tensor.dtype
339         tensor = tensor.astype(np.float32)
340     #
342     dim_added = False
343     if len(tensor.shape) < 4:
344         tensor = tensor[np.newaxis,...]
345         dim_added = True
346     #
347     if (tensor.shape[-2:] != output_size):
348         assert tensor.shape[1] == 1, 'TODO: add code for multi channel resizing'
349         out_tensor = np.zeros((tensor.shape[0],tensor.shape[1],output_size[0],output_size[1]),dtype=np.float32)
350         for b_idx in range(tensor.shape[0]):
351             b_tensor = PIL.Image.fromarray(tensor[b_idx,0])
352             b_tensor = b_tensor.resize((output_size[1],output_size[0]), PIL.Image.NEAREST)
353             out_tensor[b_idx,0,...] = np.asarray(b_tensor)
354         #
355         tensor = out_tensor
356     #
357     if dim_added:
358         tensor = tensor[0]
359     #
361     if convert_tensor_to_float:
362         tensor = tensor.long()
363     elif convert_np_to_float:
364         tensor = tensor.astype(original_dtype)
365     #
366     return tensor
368 def upsample_tensors(tensors, output_sizes, upsample_mode):
369     if isinstance(tensors, (list,tuple)):
370         for tidx, tensor in enumerate(tensors):
371             tensors[tidx] = _upsample_impl(tensor, output_sizes[tidx][-2:], upsample_mode)
372         #
373     else:
374         tensors = _upsample_impl(tensors, output_sizes[0][-2:], upsample_mode)
375     return tensors
378 def chroma_blend(image, color):
379     image_yuv = cv2.cvtColor(image.astype(np.uint8), cv2.COLOR_BGR2YUV)
380     image_y,image_u,image_v = cv2.split(image_yuv)
381     color_yuv = cv2.cvtColor(color.astype(np.uint8), cv2.COLOR_BGR2YUV)
382     color_y,color_u,color_v = cv2.split(color_yuv)
383     image_y = np.uint8(image_y)
384     color_u = np.uint8(color_u)
385     color_v = np.uint8(color_v)
386     image_yuv = cv2.merge((image_y,color_u,color_v))
387     image = cv2.cvtColor(image_yuv.astype(np.uint8), cv2.COLOR_YUV2BGR)
388     return image    
392 def eval_output(args, output, label, confusion_matrix):
394     if len(label.shape)>2:
395         label = label[:,:,0]
396     gt_labels = label.ravel()
397     det_labels = output.ravel().clip(0,args.n_classes)
398     gt_labels_valid_ind = np.where(gt_labels != 255)
399     gt_labels_valid = gt_labels[gt_labels_valid_ind]
400     det_labels_valid = det_labels[gt_labels_valid_ind]
401     for r in range(confusion_matrix.shape[0]):
402         for c in range(confusion_matrix.shape[1]):
403             confusion_matrix[r,c] += np.sum((gt_labels_valid==r) & (det_labels_valid==c))
405     return confusion_matrix
406     
407 def compute_accuracy(args, confusion_matrix):
409     #pdb.set_trace()
410     num_selected_classes = args.n_classes
411     tp = np.zeros(args.n_classes)
412     population = np.zeros(args.n_classes)
413     det = np.zeros(args.n_classes)
414     iou = np.zeros(args.n_classes)
415     
416     for r in range(args.n_classes):
417       for c in range(args.n_classes):   
418         population[r] += confusion_matrix[r][c]
419         det[c] += confusion_matrix[r][c]   
420         if r == c:
421           tp[r] += confusion_matrix[r][c]
423     for cls in range(num_selected_classes):
424       intersection = tp[cls]
425       union = population[cls] + det[cls] - tp[cls]
426       iou[cls] = (intersection / union) if union else 0     # For caffe jacinto script
427       #iou[cls] = (intersection / (union + np.finfo(np.float32).eps))  # For pytorch-jacinto script
429     num_nonempty_classes = 0
430     for pop in population:
431       if pop>0:
432         num_nonempty_classes += 1
433           
434     mean_iou = np.sum(iou) / num_nonempty_classes
435     accuracy = np.sum(tp) / np.sum(population)
436     
437     #F1 score calculation
438     fp = np.zeros(args.n_classes)
439     fn = np.zeros(args.n_classes)
440     precision = np.zeros(args.n_classes)
441     recall = np.zeros(args.n_classes)
442     f1_score = np.zeros(args.n_classes)
444     for cls in range(num_selected_classes):
445         fp[cls] = det[cls] - tp[cls]
446         fn[cls] = population[cls] - tp[cls]
447         precision[cls] = tp[cls] / (det[cls] + 1e-10)
448         recall[cls] = tp[cls] / (tp[cls] + fn[cls] + 1e-10)        
449         f1_score[cls] = 2 * precision[cls]*recall[cls] / (precision[cls] + recall[cls] + 1e-10)
451     return accuracy, mean_iou, iou, f1_score
452     
453         
454 def infer_video(args, net):
455     videoIpHandle = imageio.get_reader(args.input, 'ffmpeg')
456     fps = math.ceil(videoIpHandle.get_meta_data()['fps'])
457     print(videoIpHandle.get_meta_data())
458     numFrames = min(len(videoIpHandle), args.num_images)
459     videoOpHandle = imageio.get_writer(args.output,'ffmpeg', fps=fps)
460     for num in range(numFrames):
461         print(num, end=' ')
462         sys.stdout.flush()
463         input_blob = videoIpHandle.get_data(num)
464         input_blob = input_blob[...,::-1]    #RGB->BGR
465         output_blob = infer_blob(args, net, input_blob)     
466         output_blob = output_blob[...,::-1]  #BGR->RGB            
467         videoOpHandle.append_data(output_blob)
468     videoOpHandle.close()
469     return
471 def create_video(args, infer_path):
472     op_file_name = args.data_path.split('/')[-1]
473     os.system(' ffmpeg -framerate 30 -pattern_type glob -i "{}/*.png" -c:v libx264 -vb 50000k -qmax 20 -r 10 -vf \
474                  scale=1024:512  -pix_fmt yuv420p {}.MP4'.format(infer_path, op_file_name))
476 if __name__ == '__main__':
477     train_args = get_config()
478     train_args = parser.parse_args()
479     main(train_args)