]> Gitweb @ Texas Instruments - Open Source Git Repositories - git.TI.com/gitweb - jacinto-ai/pytorch-jacinto-ai-devkit.git/blob - modules/pytorch_jacinto_ai/engine/test_pixel2pixel_onnx.py
support Hardtanh activation function also in quantization aware training
[jacinto-ai/pytorch-jacinto-ai-devkit.git] / modules / pytorch_jacinto_ai / engine / test_pixel2pixel_onnx.py
1 import os
2 import time
3 import sys
4 import warnings
6 import torch
7 import torch.nn.parallel
8 import torch.optim
9 import torch.utils.data
10 import datetime
11 import numpy as np
12 import random
13 import cv2
14 import PIL
15 import PIL.Image
17 import onnx
18 import caffe2
19 import caffe2.python.onnx.backend
21 from .. import xnn
22 from .. import vision
26 # ################################################
27 def get_config():
28     args = xnn.utils.ConfigNode()
29     args.model_config = xnn.utils.ConfigNode()
30     args.dataset_config = xnn.utils.ConfigNode()
32     args.dataset_name = 'flying_chairs'              # dataset type
33     args.model_name = 'flownets'                # model architecture, overwritten if pretrained is specified: '
35     args.data_path = './data/datasets'                       # path to dataset
36     args.save_path = None            # checkpoints save path
37     args.pretrained = None
39     args.model_config.output_type = ['flow']                # the network is used to predict flow or depth or sceneflow')
40     args.model_config.output_channels = None                 # number of output channels
41     args.model_config.input_channels = None                  # number of input channels
42     args.n_classes = None                       # number of classes (for segmentation)
44     args.logger = None                          # logger stream to output into
46     args.prediction_type = 'flow'               # the network is used to predict flow or depth or sceneflow
47     args.split_file = None                      # train_val split file
48     args.split_files = None                     # split list files. eg: train.txt val.txt
49     args.split_value = 0.8                      # test_val split proportion (between 0 (only test) and 1 (only train))
51     args.workers = 8                            # number of data loading workers
53     args.epoch_size = 0                         # manual epoch size (will match dataset size if not specified)
54     args.epoch_size_val = 0                     # manual epoch size (will match dataset size if not specified)
55     args.batch_size = 8                         # mini_batch_size
56     args.total_batch_size = None                # accumulated batch size. total_batch_size = batch_size*iter_size
57     args.iter_size = 1                          # iteration size. total_batch_size = batch_size*iter_size
59     args.tensorboard_num_imgs = 5               # number of imgs to display in tensorboard
60     args.phase = 'validation'                   # evaluate model on validation set
61     args.pretrained = None                      # path to pre_trained model
62     args.date = None                            # don\'t append date timestamp to folder
63     args.print_freq = 10                        # print frequency (default: 100)
65     args.div_flow = 1.0                         # value by which flow will be divided. Original value is 20 but 1 with batchNorm gives good results
66     args.milestones = [100,150,200]             # epochs at which learning rate is divided by 2
67     args.losses = ['supervised_loss']           # loss functions to minimize
68     args.metrics = ['supervised_error']         # metric/measurement/error functions for train/validation
69     args.class_weights = None                   # class weights
71     args.multistep_gamma = 0.5                  # steps for step scheduler
72     args.polystep_power = 1.0                   # power for polynomial scheduler
74     args.rand_seed = 1                          # random seed
75     args.img_border_crop = None                 # image border crop rectangle. can be relative or absolute
76     args.target_mask = None                      # mask rectangle. can be relative or absolute. last value is the mask value
77     args.img_resize = None                      # image size to be resized to
78     args.rand_scale = (1,1.25)                  # random scale range for training
79     args.rand_crop = None                       # image size to be cropped to')
80     args.output_size = None                     # target output size to be resized to')
82     args.count_flops = True                     # count flops and report
84     args.shuffle = True                         # shuffle or not
85     args.is_flow = None                         # whether entries in images and targets lists are optical flow or not
87     args.multi_decoder = True                   # whether to use multiple decoders or unified decoder
89     args.create_video = False                   # whether to create video out of the inferred images
91     args.input_tensor_name = ['0']              # list of input tensore names
93     args.upsample_mode = 'nearest'              # upsample mode to use., choices=['nearest','bilinear']
95     args.image_prenorm = True                   # whether normalization is done before all other the transforms
96     args.image_mean = [128.0]                   # image mean for input image normalization
97     args.image_scale = [1.0/(0.25*256)]         # image scaling/mult for input iamge normalization
98     return args
101 # ################################################
102 # to avoid hangs in data loader with multi threads
103 # this was observed after using cv2 image processing functions
104 # https://github.com/pytorch/pytorch/issues/1355
105 cv2.setNumThreads(0)
108 # ################################################
109 def main(args):
111     assert not hasattr(args, 'model_surgery'), 'the argument model_surgery is deprecated, it is not needed now - remove it'
113     #################################################
114     # global settings. rand seeds for repeatability
115     random.seed(args.rand_seed)
116     np.random.seed(args.rand_seed)
117     torch.manual_seed(args.rand_seed)
119     ################################
120     # print everything for log
121     print('=> args: ', args)
123     ################################
124     # args check and config
125     if args.iter_size != 1 and args.total_batch_size is not None:
126         warnings.warn("only one of --iter_size or --total_batch_size must be set")
127     #
128     if args.total_batch_size is not None:
129         args.iter_size = args.total_batch_size//args.batch_size
130     else:
131         args.total_batch_size = args.batch_size*args.iter_size
132     #
134     assert args.pretrained is not None, 'pretrained onnx model path should be provided'
136     #################################################
137     # set some global flags and initializations
138     # keep it in args for now - although they don't belong here strictly
139     # using pin_memory is seen to cause issues, especially when when lot of memory is used.
140     args.use_pinned_memory = False
141     args.n_iter = 0
142     args.best_metric = -1
144     #################################################
145     if args.save_path is None:
146         save_path = get_save_path(args)
147     else:
148         save_path = args.save_path
149     #
151     print('=> will save everything to {}'.format(save_path))
152     
153     if not os.path.exists(save_path):
154         os.makedirs(save_path)
156     #################################################
157     if args.logger is None:
158         log_file = os.path.splitext(os.path.basename(__file__))[0] + '.log'
159         args.logger = xnn.utils.TeeLogger(filename=os.path.join(save_path,log_file))
160     transforms = get_transforms(args)
162     print("=> fetching img pairs in '{}'".format(args.data_path))
163     split_arg = args.split_file if args.split_file else (args.split_files if args.split_files else args.split_value)
164     val_dataset = vision.datasets.pixel2pixel.__dict__[args.dataset_name](args.dataset_config, args.data_path, split=split_arg, transforms=transforms)
166     print('=> {} val samples found'.format(len(val_dataset)))
167     
168     val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=args.batch_size,
169         num_workers=args.workers, pin_memory=args.use_pinned_memory, shuffle=args.shuffle)
170     #
172     #################################################
173     args.model_config.output_channels = val_dataset.num_classes() if (args.model_config.output_channels == None and 'num_classes' in dir(val_dataset)) else None
174     args.n_classes = args.model_config.output_channels[0]
176     #################################################
177     # create model
178     print("=> creating model '{}'".format(args.model_name))
180     model = onnx.load(args.pretrained)
181     # Run the ONNX model with Caffe2
182     onnx.checker.check_model(model)
183     model = caffe2.python.onnx.backend.prepare(model)
186     #################################################
187     if args.palette:
188         print('Creating palette')
189         eval_string = args.palette
190         palette = eval(eval_string)
191         args.palette = np.zeros((256,3))
192         for i, p in enumerate(palette):
193             args.palette[i,0] = p[0]
194             args.palette[i,1] = p[1]
195             args.palette[i,2] = p[2]
196         args.palette = args.palette[...,::-1] #RGB->BGR, since palette is expected to be given in RGB format
198     infer_path = os.path.join(save_path, 'inference')
199     if not os.path.exists(infer_path):
200         os.makedirs(infer_path)
202     #################################################
203     with torch.no_grad():
204         validate(args, val_dataset, val_loader, model, 0, infer_path)
206     if args.create_video:
207         create_video(args, infer_path=infer_path)
209 def validate(args, val_dataset, val_loader, model, epoch, infer_path):
210     data_time = xnn.utils.AverageMeter()
211     avg_metric = xnn.utils.AverageMeter()
213     # switch to evaluate mode
214     #model.eval()
215     metric_name = "Metric"
216     end_time = time.time()
217     writer_idx = 0
218     last_update_iter = -1
219     metric_ctx = [None] * len(args.metric_modules)
221     if args.label:
222         confusion_matrix = np.zeros((args.n_classes, args.n_classes+1))
223         for iter, (input_list, target_list, input_path, target_path) in enumerate(val_loader):
224             data_time.update(time.time() - end_time)
225             target_sizes = [tgt.shape for tgt in target_list]
227             input_dict = {args.input_tensor_name[idx]: input_list[idx].numpy() for idx in range(len(input_list))}
228             output = model.run(input_dict)[0]
230             list_output = type(output) in (list, tuple)
231             output_pred = output[0] if list_output else output
233             if args.output_size is not None:
234                 output_pred = upsample_tensors(output_pred, target_sizes, args.upsample_mode)
235             #
236             if args.blend:
237                 for index in range(output_pred.shape[0]):
238                     prediction = np.squeeze(np.array(output_pred[index]))
239                     #prediction = np.argmax(prediction, axis = 0)
240                     prediction_size = (prediction.shape[0], prediction.shape[1], 3)
241                     output_image = args.palette[prediction.ravel()].reshape(prediction_size)
242                     input_bgr = cv2.imread(input_path[0][index]) #Read the actual RGB image
243                     input_bgr = cv2.resize(input_bgr, dsize=(prediction.shape[1],prediction.shape[0]))
245                     output_image = chroma_blend(input_bgr, output_image)
246                     output_name = os.path.join(infer_path, os.path.basename(input_path[0][index]))
247                     cv2.imwrite(output_name, output_image)
249             if args.label:
250                 for index in range(output_pred.shape[0]): 
251                     prediction = np.array(output_pred[index])
252                     #prediction = np.argmax(prediction, axis = 0)
253                     label = np.squeeze(np.array(target_list[0][index]))
254                     confusion_matrix = eval_output(args, prediction, label, confusion_matrix)
255                     accuracy, mean_iou, iou, f1_score= compute_accuracy(args, confusion_matrix)
256                 print('pixel_accuracy={}, mean_iou={}, iou={}, f1_score = {}'.format(accuracy, mean_iou, iou, f1_score))
257                 sys.stdout.flush()
258     else:
259         for iter, (input_list, _ , input_path, _) in enumerate(val_loader):
260             data_time.update(time.time() - end_time)
262             input_dict = {args.input_tensor_name[idx]: input_list[idx].numpy() for idx in range(len(input_list))}
263             output = model.run(input_dict)[0]
265             list_output = type(output) in (list, tuple)
266             output_pred = output[0] if list_output else output
267             input_path = input_path[0]
268             
269             if args.output_size is not None:
270                 target_sizes = [args.output_size]
271                 output_pred = upsample_tensors(output_pred, target_sizes, args.upsample_mode)
272             #
273             if args.blend:
274                 for index in range(output_pred.shape[0]):
275                     prediction = np.squeeze(np.array(output_pred[index])) #np.squeeze(np.array(output_pred[index].cpu()))
276                     prediction_size = (prediction.shape[0], prediction.shape[1], 3)
277                     output_image = args.palette[prediction.ravel()].reshape(prediction_size)
278                     input_bgr = cv2.imread(input_path[index]) #Read the actual RGB image
279                     input_bgr = cv2.resize(input_bgr, (args.img_resize[1], args.img_resize[0]), interpolation=cv2.INTER_LINEAR)
280                     output_image = chroma_blend(input_bgr, output_image)
281                     output_name = os.path.join(infer_path, os.path.basename(input_path[index]))
282                     cv2.imwrite(output_name, output_image)
283                     print('Inferred image {}'.format(input_path[index]))
284             if args.car_mask:   #generating car_mask (required for localization)
285                 for index in range(output_pred.shape[0]):
286                     prediction = np.array(output_pred[index])
287                     prediction = np.argmax(prediction, axis = 0)
288                     car_mask = np.logical_or(prediction == 13, prediction == 14, prediction == 16, prediction, prediction == 17)
289                     prediction[car_mask] = 255
290                     prediction[np.invert(car_mask)] = 0
291                     output_image = prediction
292                     output_name = os.path.join(infer_path, os.path.basename(input_path[index]))
293                     cv2.imwrite(output_name, output_image)
296 def get_save_path(args, phase=None):
297     date = args.date if args.date else datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
298     save_path = os.path.join('./data/checkpoints', args.dataset_name, '{}_{}_'.format(date, args.model_name))
299     save_path += 'b{}'.format(args.batch_size)
300     phase = phase if (phase is not None) else args.phase
301     save_path = os.path.join(save_path, phase)
302     return save_path
305 def get_transforms(args):
306     # image normalization can be at the beginning of transforms or at the end
307     args.image_mean = np.array(args.image_mean, dtype=np.float32)
308     args.image_scale = np.array(args.image_scale, dtype=np.float32)
309     image_prenorm = vision.transforms.image_transforms.NormalizeMeanScale(mean=args.image_mean, scale=args.image_scale) if args.image_prenorm else None
310     image_postnorm = vision.transforms.image_transforms.NormalizeMeanScale(mean=args.image_mean, scale=args.image_scale) if (not image_prenorm) else None
312     #target size must be according to output_size. prediction will be resized to output_size before evaluation.
313     test_transform = vision.transforms.image_transforms.Compose([
314         image_prenorm,
315         vision.transforms.image_transforms.AlignImages(),
316         vision.transforms.image_transforms.MaskTarget(args.target_mask, 0),
317         vision.transforms.image_transforms.CropRect(args.img_border_crop),
318         vision.transforms.image_transforms.Scale(args.img_resize, target_size=args.output_size, is_flow=args.is_flow),
319         image_postnorm,
320         vision.transforms.image_transforms.ConvertToTensor()
321         ])
323     return test_transform
326 def _upsample_impl(tensor, output_size, upsample_mode):
327     # upsample of long tensor is not supported currently. covert to float, just to avoid error.
328     # we can do this only in the case of nearest mode, otherwise output will have invalid values.
329     convert_tensor_to_float = False
330     convert_np_to_float = False
331     if isinstance(tensor, (torch.LongTensor,torch.cuda.LongTensor)):
332         convert_tensor_to_float = True
333         original_dtype = tensor.dtype
334         tensor = tensor.float()
335     elif isinstance(tensor, np.ndarray) and (np.dtype != np.float32):
336         convert_np_to_float = True
337         original_dtype = tensor.dtype
338         tensor = tensor.astype(np.float32)
339     #
341     dim_added = False
342     if len(tensor.shape) < 4:
343         tensor = tensor[np.newaxis,...]
344         dim_added = True
345     #
346     if (tensor.shape[-2:] != output_size):
347         assert tensor.shape[1] == 1, 'TODO: add code for multi channel resizing'
348         out_tensor = np.zeros((tensor.shape[0],tensor.shape[1],output_size[0],output_size[1]),dtype=np.float32)
349         for b_idx in range(tensor.shape[0]):
350             b_tensor = PIL.Image.fromarray(tensor[b_idx,0])
351             b_tensor = b_tensor.resize((output_size[1],output_size[0]), PIL.Image.NEAREST)
352             out_tensor[b_idx,0,...] = np.asarray(b_tensor)
353         #
354         tensor = out_tensor
355     #
356     if dim_added:
357         tensor = tensor[0]
358     #
360     if convert_tensor_to_float:
361         tensor = tensor.long()
362     elif convert_np_to_float:
363         tensor = tensor.astype(original_dtype)
364     #
365     return tensor
367 def upsample_tensors(tensors, output_sizes, upsample_mode):
368     if isinstance(tensors, (list,tuple)):
369         for tidx, tensor in enumerate(tensors):
370             tensors[tidx] = _upsample_impl(tensor, output_sizes[tidx][-2:], upsample_mode)
371         #
372     else:
373         tensors = _upsample_impl(tensors, output_sizes[0][-2:], upsample_mode)
374     return tensors
377 def chroma_blend(image, color):
378     image_yuv = cv2.cvtColor(image.astype(np.uint8), cv2.COLOR_BGR2YUV)
379     image_y,image_u,image_v = cv2.split(image_yuv)
380     color_yuv = cv2.cvtColor(color.astype(np.uint8), cv2.COLOR_BGR2YUV)
381     color_y,color_u,color_v = cv2.split(color_yuv)
382     image_y = np.uint8(image_y)
383     color_u = np.uint8(color_u)
384     color_v = np.uint8(color_v)
385     image_yuv = cv2.merge((image_y,color_u,color_v))
386     image = cv2.cvtColor(image_yuv.astype(np.uint8), cv2.COLOR_YUV2BGR)
387     return image    
391 def eval_output(args, output, label, confusion_matrix):
393     if len(label.shape)>2:
394         label = label[:,:,0]
395     gt_labels = label.ravel()
396     det_labels = output.ravel().clip(0,args.n_classes)
397     gt_labels_valid_ind = np.where(gt_labels != 255)
398     gt_labels_valid = gt_labels[gt_labels_valid_ind]
399     det_labels_valid = det_labels[gt_labels_valid_ind]
400     for r in range(confusion_matrix.shape[0]):
401         for c in range(confusion_matrix.shape[1]):
402             confusion_matrix[r,c] += np.sum((gt_labels_valid==r) & (det_labels_valid==c))
404     return confusion_matrix
405     
406 def compute_accuracy(args, confusion_matrix):
408     #pdb.set_trace()
409     num_selected_classes = args.n_classes
410     tp = np.zeros(args.n_classes)
411     population = np.zeros(args.n_classes)
412     det = np.zeros(args.n_classes)
413     iou = np.zeros(args.n_classes)
414     
415     for r in range(args.n_classes):
416       for c in range(args.n_classes):   
417         population[r] += confusion_matrix[r][c]
418         det[c] += confusion_matrix[r][c]   
419         if r == c:
420           tp[r] += confusion_matrix[r][c]
422     for cls in range(num_selected_classes):
423       intersection = tp[cls]
424       union = population[cls] + det[cls] - tp[cls]
425       iou[cls] = (intersection / union) if union else 0     # For caffe jacinto script
426       #iou[cls] = (intersection / (union + np.finfo(np.float32).eps))  # For pytorch-jacinto script
428     num_nonempty_classes = 0
429     for pop in population:
430       if pop>0:
431         num_nonempty_classes += 1
432           
433     mean_iou = np.sum(iou) / num_nonempty_classes
434     accuracy = np.sum(tp) / np.sum(population)
435     
436     #F1 score calculation
437     fp = np.zeros(args.n_classes)
438     fn = np.zeros(args.n_classes)
439     precision = np.zeros(args.n_classes)
440     recall = np.zeros(args.n_classes)
441     f1_score = np.zeros(args.n_classes)
443     for cls in range(num_selected_classes):
444         fp[cls] = det[cls] - tp[cls]
445         fn[cls] = population[cls] - tp[cls]
446         precision[cls] = tp[cls] / (det[cls] + 1e-10)
447         recall[cls] = tp[cls] / (tp[cls] + fn[cls] + 1e-10)        
448         f1_score[cls] = 2 * precision[cls]*recall[cls] / (precision[cls] + recall[cls] + 1e-10)
450     return accuracy, mean_iou, iou, f1_score
451     
452         
453 def infer_video(args, net):
454     videoIpHandle = imageio.get_reader(args.input, 'ffmpeg')
455     fps = math.ceil(videoIpHandle.get_meta_data()['fps'])
456     print(videoIpHandle.get_meta_data())
457     numFrames = min(len(videoIpHandle), args.num_images)
458     videoOpHandle = imageio.get_writer(args.output,'ffmpeg', fps=fps)
459     for num in range(numFrames):
460         print(num, end=' ')
461         sys.stdout.flush()
462         input_blob = videoIpHandle.get_data(num)
463         input_blob = input_blob[...,::-1]    #RGB->BGR
464         output_blob = infer_blob(args, net, input_blob)     
465         output_blob = output_blob[...,::-1]  #BGR->RGB            
466         videoOpHandle.append_data(output_blob)
467     videoOpHandle.close()
468     return
470 def create_video(args, infer_path):
471     op_file_name = args.data_path.split('/')[-1]
472     os.system(' ffmpeg -framerate 30 -pattern_type glob -i "{}/*.png" -c:v libx264 -vb 50000k -qmax 20 -r 10 -vf \
473                  scale=1024:512  -pix_fmt yuv420p {}.MP4'.format(infer_path, op_file_name))
475 if __name__ == '__main__':
476     train_args = get_config()
477     train_args = parser.parse_args()
478     main(train_args)