release commit
authorManu Mathew <a0393608@ti.com>
Mon, 17 Feb 2020 09:36:53 +0000 (15:06 +0530)
committerManu Mathew <a0393608@ti.com>
Mon, 17 Feb 2020 09:36:53 +0000 (15:06 +0530)
17 files changed:
.gitignore
docs/Semantic_Segmentation.md
modules/pytorch_jacinto_ai/engine/infer_classification_onnx_rt.py [new file with mode: 0644]
modules/pytorch_jacinto_ai/engine/infer_pixel2pixel.py
modules/pytorch_jacinto_ai/engine/infer_pixel2pixel_onnx_rt.py [new file with mode: 0644]
modules/pytorch_jacinto_ai/engine/train_classification.py
modules/pytorch_jacinto_ai/engine/train_pixel2pixel.py
modules/pytorch_jacinto_ai/vision/datasets/pixel2pixel/a2d2.py
modules/pytorch_jacinto_ai/vision/datasets/pixel2pixel/calculate_class_weights.py
modules/pytorch_jacinto_ai/vision/losses/segmentation_loss.py
modules/pytorch_jacinto_ai/vision/models/__init__.py
modules/pytorch_jacinto_ai/vision/models/classification/__init__.py
modules/pytorch_jacinto_ai/vision/models/mobilenetv1.py
modules/pytorch_jacinto_ai/vision/transforms/functional.py
modules/pytorch_jacinto_ai/xnn/utils/__init__.py
scripts/infer_classification_onnx_rt_main.py [new file with mode: 0755]
scripts/infer_segmentation_main.py

index 6b0f32adc791a095f1bd17bd758fc3fe533d505d..6b0362c2e03bf225332ec8eac4e30e4df3a3cf4c 100644 (file)
@@ -14,4 +14,6 @@ data/datasets/*
 !data/datasets/readme.txt
 data/downloads/*
 !data/downloads/readme.txt
-
+checkpoints/*
+scripts_internal/data/checkpoints/*
+venv/*
index 7cf06964f1515cfd3fdc87692d8bcb48d589e744..5d9d048c6f1bf8f6e93b74f9ecd554f314913d50 100644 (file)
@@ -84,6 +84,7 @@ Inference can be done as follows (fill in the path to the pretrained model):<br>
 |Dataset    |Mode Architecture         |Backbone Model |Backbone Stride|Resolution |Complexity (GigaMACS)|MeanIoU%  |Model Configuration Name                  |
 |---------  |----------                |-----------    |-------------- |-----------|--------             |----------|----------------------------------------  |
 |Cityscapes |FPNPixel2Pixel with DWASPP|FD-MobileNetV2 |64             |768x384    |0.99                 |62.43     |fpn_pixel2pixel_aspp_mobilenetv2_tv_fd    |
+|Cityscapes |UNet with DWASPP          |MobileNetV2    |32             |768x384    |**2.20**             |**68.94** |**unet_pixel2pixel_aspp_mobilenetv2_tv**  |
 |Cityscapes |DeepLabV3Lite with DWASPP |MobileNetV2    |16             |768x384    |**3.54**             |**69.13** |**deeplabv3lite_mobilenetv2_tv**          |
 |Cityscapes |FPNPixel2Pixel            |MobileNetV2    |32(\*2\*2)     |768x384    |3.66                 |70.30     |fpn_pixel2pixel_mobilenetv2_tv            |
 |Cityscapes |FPNPixel2Pixel with DWASPP|MobileNetV2    |32             |768x384    |3.84                 |70.39     |fpn_pixel2pixel_aspp_mobilenetv2_tv       |
diff --git a/modules/pytorch_jacinto_ai/engine/infer_classification_onnx_rt.py b/modules/pytorch_jacinto_ai/engine/infer_classification_onnx_rt.py
new file mode 100644 (file)
index 0000000..f1129cf
--- /dev/null
@@ -0,0 +1,447 @@
+import os
+import sys
+import shutil
+import time
+import datetime
+
+import random
+import numpy as np
+from colorama import Fore
+import random
+import progiter
+import warnings
+
+import torch
+import torch.nn.parallel
+import torch.backends.cudnn as cudnn
+import torch.optim
+import torch.utils.data
+import torch.utils.data.distributed
+
+import onnx
+import onnxruntime
+
+from .. import xnn
+from .. import vision
+
+
+# ################################################
+def get_config():
+    args = xnn.utils.ConfigNode()
+    args.model_config = xnn.utils.ConfigNode()
+    args.dataset_config = xnn.utils.ConfigNode()
+
+    args.model_name = 'mobilenet_v2_classification'     # model architecture'
+    args.dataset_name = 'imagenet_classification'       # image folder classification
+
+    args.data_path = './data/datasets/ilsvrc'           # path to dataset
+    args.save_path = None                               # checkpoints save path
+    args.pretrained = './data/modelzoo/pretrained/pytorch/imagenet_classification/ericsun99/MobileNet-V2-Pytorch/mobilenetv2_Top1_71.806_Top2_90.410.pth.tar' # path to pre_trained model
+
+    args.workers = 8                                    # number of data loading workers (default: 4)
+    args.batch_size = 256                               # mini_batch size (default: 256)
+    args.print_freq = 100                               # print frequency (default: 100)
+
+    args.img_resize = 256                               # image resize
+    args.img_crop = 224                                 # image crop
+
+    args.image_mean = (123.675, 116.28, 103.53)         # image mean for input image normalization')
+    args.image_scale = (0.017125, 0.017507, 0.017429)   # image scaling/mult for input iamge normalization')
+
+    args.logger = None                                  # logger stream to output into
+
+    args.data_augument = 'inception'                    # data augumentation method, choices=['inception','resize','adaptive_resize']
+    args.dataset_format = 'folder'                      # dataset format, choices=['folder','lmdb']
+    args.count_flops = True                             # count flops and report
+
+    args.lr_calib = 0.1                                 # lr for bias calibration
+
+    args.rand_seed = 1                                  # random seed
+    args.generate_onnx = False                          # apply quantized inference or not
+    args.print_model = False                            # print the model to text
+    args.run_soon = True                                # Set to false if only cfs files/onnx  modelsneeded but no training
+    args.parallel_model = True                          # parallel or not
+    args.shuffle = True                                 # shuffle or not
+    args.epoch_size = 0                                 # epoch size
+    args.rand_seed = 1                                  # random seed
+    args.date = None                                    # date to add to save path. if this is None, current date will be added.
+    args.write_layer_ip_op = False
+    args.gpu_mode = True                                # False will make inference run on CPU
+
+    args.quantize = False                               # apply quantized inference or not
+    #args.model_surgery = None                           # replace activations with PAct2 activation module. Helpful in quantized training.
+    args.bitwidth_weights = 8                           # bitwidth for weights
+    args.bitwidth_activations = 8                       # bitwidth for activations
+    args.histogram_range = True                         # histogram range for calibration
+    args.per_channel_q = False                          # apply separate quantizion factor for each channel in depthwise or not
+    args.bias_calibration = False                        # apply bias correction during quantized inference calibration
+    return args
+
+
+def main(args):
+    assert not hasattr(args, 'model_surgery'), 'the argument model_surgery is deprecated, it is not needed now - remove it'
+
+    if (args.phase == 'validation' and args.bias_calibration):
+        args.bias_calibration = False
+        warnings.warn('switching off bias calibration in validation')
+    #
+
+    if args.save_path is None:
+        save_path = get_save_path(args)
+    else:
+        save_path = args.save_path
+    #
+
+    if not os.path.exists(save_path):
+        os.makedirs(save_path)
+    #
+
+    #################################################
+    if args.logger is None:
+        log_file = os.path.splitext(os.path.basename(__file__))[0] + '.log'
+        args.logger = xnn.utils.TeeLogger(filename=os.path.join(save_path,log_file))
+
+    #################################################
+    # global settings. rand seeds for repeatability
+    random.seed(args.rand_seed)
+    np.random.seed(args.rand_seed)
+    torch.manual_seed(args.rand_seed)
+    torch.backends.cudnn.deterministic = True
+    torch.backends.cudnn.benchmark = True
+
+    ################################
+    # print everything for log
+    # reset character color, in case it is different
+    print('{}'.format(Fore.RESET))
+    print("=> args: ", args)
+    print("=> resize resolution: {}".format(args.img_resize))
+    print("=> crop resolution  : {}".format(args.img_crop))
+    sys.stdout.flush()
+
+
+    #################################################
+    # define loss function (criterion) and optimizer
+    criterion = torch.nn.CrossEntropyLoss().cuda()
+
+    model = onnx.load(args.pretrained)
+    # Run the ONNX model with Caffe2
+    onnx.checker.check_model(model)
+
+    val_loader = get_data_loaders(args)
+    validate(args, val_loader, model, criterion)
+
+
+def validate(args, val_loader, model, criterion):
+
+    # change color to green
+    print('{}'.format(Fore.GREEN), end='')
+
+    session = onnxruntime.InferenceSession(args.pretrained, None)
+    input_name = session.get_inputs()[0].name
+    input_details = session.get_inputs()
+    output_details = session.get_outputs()
+
+    batch_time = AverageMeter()
+    losses = AverageMeter()
+    top1 = AverageMeter()
+    top5 = AverageMeter()
+    use_progressbar = True
+    epoch_size = get_epoch_size(args, val_loader, args.epoch_size)
+
+    if use_progressbar:
+        progress_bar = progiter.ProgIter(np.arange(epoch_size), chunksize=1)
+        last_update_iter = -1
+
+    end = time.time()
+    for iteration, (input, target) in enumerate(val_loader):
+        if args.gpu_mode:
+            input_list = [img.cuda() for img in input_list]
+            target = target.cuda(non_blocking=True)
+            input = torch.cat([j.cuda() for j in input], dim=1) if (type(input) in (list,tuple)) else input.cuda()
+        # compute output
+        output = session.run([], {input_name: np.asarray(input)})
+        output = [torch.tensor(output[index]) for index in range(len(output))]
+        if type(output) in (list, tuple):
+            output = output[0]
+        #
+
+        loss = criterion(output, target)
+
+        # measure accuracy and record loss
+        prec1, prec5 = accuracy(output, target, topk=(1, 5))
+        losses.update(loss.item(), input.size(0))
+        top1.update(prec1[0], input.size(0))
+        top5.update(prec5[0], input.size(0))
+
+        # measure elapsed time
+        batch_time.update(time.time() - end)
+        end = time.time()
+        final_iter = (iteration >= (epoch_size-1))
+
+        if ((iteration % args.print_freq) == 0) or final_iter:
+            status_str = 'Time {batch_time.val:.2f}({batch_time.avg:.2f}) LR {cur_lr:.4f} ' \
+                         'Loss {loss.val:.2f}({loss.avg:.2f}) Prec@1 {top1.val:.2f}({top1.avg:.2f}) Prec@5 {top5.val:.2f}({top5.avg:.2f})' \
+                         .format(batch_time=batch_time, cur_lr=0.0, loss=losses, top1=top1, top5=top5)
+            #
+            prefix = '**' if final_iter else '=>'
+            if use_progressbar:
+                progress_bar.set_description('{} validation'.format(prefix))
+                progress_bar.set_postfix(Epoch='{}'.format(status_str))
+                progress_bar.update(iteration - last_update_iter)
+                last_update_iter = iteration
+            else:
+                iter_str = '{:6}/{:6}    : '.format(iteration+1, len(val_loader))
+                status_str = prefix + ' ' + iter_str + status_str
+                if final_iter:
+                    xnn.utils.print_color(status_str, color=Fore.GREEN)
+                else:
+                    xnn.utils.print_color(status_str)
+
+        if final_iter:
+            break
+
+        if use_progressbar:
+            progress_bar.close()
+
+        # to print a new line - do not provide end=''
+        print('{}'.format(Fore.RESET), end='')
+
+    return top1.avg
+
+
+#######################################################################
+def get_save_path(args, phase=None):
+    date = args.date if args.date else datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
+    save_path_base = os.path.join('./data/checkpoints', args.dataset_name, date + '_' + args.dataset_name + '_' + args.model_name)
+    save_path = save_path_base + '_resize{}_crop{}'.format(args.img_resize, args.img_crop)
+    phase = phase if (phase is not None) else args.phase
+    save_path = os.path.join(save_path, phase)
+    return save_path
+
+
+def get_model_orig(model):
+    is_parallel_model = isinstance(model, (torch.nn.DataParallel, torch.nn.parallel.DistributedDataParallel))
+    model_orig = (model.module if is_parallel_model else model)
+    model_orig = (model_orig.module if isinstance(model_orig, (xnn.quantize.QuantBaseModule)) else model_orig)
+    return model_orig
+
+
+def create_rand_inputs(args, is_cuda=True):
+    x = torch.rand((1, args.model_config.input_channels, args.img_crop, args.img_crop))
+    x = x.cuda() if is_cuda else x
+    return x
+
+
+def count_flops(args, model):
+    is_cuda = next(model.parameters()).is_cuda
+    input_list = create_rand_inputs(args, is_cuda)
+    model.eval()
+    flops = xnn.utils.forward_count_flops(model, input_list)
+    gflops = flops/1e9
+    print('=> Resize = {}, Crop = {}, GFLOPs = {}, GMACs = {}'.format(args.img_resize, args.img_crop, gflops, gflops/2))
+
+
+def write_onnx_model(args, model, save_path, name='checkpoint.onnx'):
+    is_cuda = next(model.parameters()).is_cuda
+    dummy_input = create_rand_inputs(args, is_cuda)
+    #
+    model.eval()
+    torch.onnx.export(model, dummy_input, os.path.join(save_path,name), export_params=True, verbose=False)
+
+def save_checkpoint(state, is_best, filename='checkpoint.pth.tar'):
+    torch.save(state, filename)
+    if is_best:
+        shutil.copyfile(filename, 'model_best.pth.tar')
+
+
+class AverageMeter(object):
+    """Computes and stores the average and current value"""
+    def __init__(self):
+        self.reset()
+
+    def reset(self):
+        self.val = 0
+        self.avg = 0
+        self.sum = 0
+        self.count = 0
+
+    def update(self, val, n=1):
+        self.val = val
+        self.sum += val * n
+        self.count += n
+        self.avg = self.sum / self.count
+
+
+def accuracy(output, target, topk=(1,)):
+    """Computes the precision@k for the specified values of k"""
+    with torch.no_grad():
+        maxk = max(topk)
+        batch_size = target.size(0)
+
+        _, pred = output.topk(maxk, 1, True, True)
+        pred = pred.t()
+        correct = pred.eq(target.view(1, -1).expand_as(pred))
+
+        res = []
+        for k in topk:
+            correct_k = correct[:k].view(-1).float().sum(0, keepdim=True)
+            res.append(correct_k.mul_(100.0 / batch_size))
+        return res
+
+
+def get_epoch_size(args, loader, args_epoch_size):
+    if args_epoch_size == 0:
+        epoch_size = len(loader)
+    elif args_epoch_size < 1:
+        epoch_size = int(len(loader) * args_epoch_size)
+    else:
+        epoch_size = min(len(loader), int(args_epoch_size))
+    return epoch_size
+
+
+def get_data_loaders(args):
+    # Data loading code
+    normalize = vision.transforms.NormalizeMeanScale(mean=args.image_mean, scale=args.image_scale) \
+                        if (args.image_mean is not None and args.image_scale is not None) else None
+
+    # pass tuple to Resize() to resize to exact size without respecting aspect ratio (typical caffe style)
+    val_transform = vision.transforms.Compose([vision.transforms.Resize(size=args.img_resize),
+                                               vision.transforms.CenterCrop(size=args.img_crop),
+                                               vision.transforms.ToFloat(),
+                                               vision.transforms.ToTensor(),
+                                               normalize])
+
+    val_dataset = vision.datasets.classification.__dict__[args.dataset_name](args.dataset_config, args.data_path, transforms=val_transform)
+
+    val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=args.batch_size, shuffle=args.shuffle, num_workers=args.workers,
+                                             pin_memory=True, drop_last=False)
+
+    return val_loader
+
+
+#################################################
+def shape_as_string(shape=[]):
+    shape_str = ''
+    for dim in shape:
+        shape_str += '_' + str(dim)
+    return shape_str
+
+
+def write_tensor_int(m=[], tensor=[], suffix='op', bitwidth=8, power2_scaling=True, file_format='bin',
+                     rnd_type='rnd_sym'):
+    mn = tensor.min()
+    mx = tensor.max()
+
+    print(
+        '{:6}, {:32}, {:10}, {:7.2f}, {:7.2f}'.format(suffix, m.name, m.__class__.__name__, tensor.min(), tensor.max()),
+        end=" ")
+
+    [tensor_scale, clamp_limits] = xnn.quantize.compute_tensor_scale(tensor, mn, mx, bitwidth, power2_scaling)
+    print("{:30} : {:15} : {:8.2f}".format(str(tensor.shape), str(tensor.dtype), tensor_scale), end=" ")
+
+    print_weight_bias = False
+    if rnd_type == 'rnd_sym':
+        # use best rounding for offline quantities
+        if suffix == 'weight' and print_weight_bias:
+            no_idx = 0
+            torch.set_printoptions(precision=32)
+            print("tensor_scale: ", tensor_scale)
+            print(tensor[no_idx])
+        if tensor.dtype != torch.int64:
+            tensor = xnn.quantize.symmetric_round_tensor(tensor * tensor_scale)
+        if suffix == 'weight' and print_weight_bias:
+            print(tensor[no_idx])
+    else:
+        # for activation use HW friendly rounding
+        if tensor.dtype != torch.int64:
+            tensor = xnn.quantize.upward_round_tensor(tensor * tensor_scale)
+    tensor = tensor.clamp(clamp_limits[0], clamp_limits[1]).float()
+
+    if bitwidth == 8:
+        data_type = np.int8
+    elif bitwidth == 16:
+        data_type = np.int16
+    elif bitwidth == 32:
+        data_type = np.int32
+    else:
+        exit("Bit width other 8,16,32 not supported for writing layer level op")
+
+    tensor = tensor.cpu().numpy().astype(data_type)
+
+    print("{:7} : {:7d} : {:7d}".format(str(tensor.dtype), tensor.min(), tensor.max()))
+
+    tensor_dir = './data/checkpoints/debug/test_vecs/' + '{}_{}_{}_scale_{:010.4f}'.format(m.name,
+                                                                                            m.__class__.__name__,
+                                                                                            suffix, tensor_scale)
+
+    if not os.path.exists(tensor_dir):
+        os.makedirs(tensor_dir)
+
+    if file_format == 'bin':
+        tensor_name = tensor_dir + "/{}_shape{}.bin".format(m.name, shape_as_string(shape=tensor.shape))
+        tensor.tofile(tensor_name)
+    elif file_format == 'npy':
+        tensor_name = tensor_dir + "/{}_shape{}.npy".format(m.name, shape_as_string(shape=tensor.shape))
+        np.save(tensor_name, tensor)
+    else:
+        warnings.warn('unknown file_format for write_tensor - no file written')
+    #
+
+    # utils_hist.comp_hist_tensor3d(x=tensor, name=m.name, en=True, dir=m.name, log=True, ch_dim=0)
+
+
+def write_tensor_float(m=[], tensor=[], suffix='op'):
+    mn = tensor.min()
+    mx = tensor.max()
+
+    print(
+        '{:6}, {:32}, {:10}, {:7.2f}, {:7.2f}'.format(suffix, m.name, m.__class__.__name__, tensor.min(), tensor.max()))
+    root = os.getcwd()
+    tensor_dir = root + '/checkpoints/debug/test_vecs/' + '{}_{}_{}'.format(m.name, m.__class__.__name__, suffix)
+
+    if not os.path.exists(tensor_dir):
+        os.makedirs(tensor_dir)
+
+    tensor_name = tensor_dir + "/{}_shape{}.npy".format(m.name, shape_as_string(shape=tensor.shape))
+    np.save(tensor_name, tensor.data)
+
+
+def write_tensor(data_type='int', m=[], tensor=[], suffix='op', bitwidth=8, power2_scaling=True, file_format='bin',
+                 rnd_type='rnd_sym'):
+    if data_type == 'int':
+        write_tensor_int(m=m, tensor=tensor, suffix=suffix, rnd_type=rnd_type, file_format=file_format)
+    elif data_type == 'float':
+        write_tensor_float(m=m, tensor=tensor, suffix=suffix)
+
+
+enable_hook_function = True
+def write_tensor_hook_function(m, inp, out, file_format='bin'):
+    if not enable_hook_function:
+        return
+
+    #Output
+    if isinstance(out, (torch.Tensor)):
+        write_tensor(m=m, tensor=out, suffix='op', rnd_type ='rnd_up', file_format=file_format)
+
+    #Input(s)
+    if type(inp) is tuple:
+        #if there are more than 1 inputs
+        for index, sub_ip in enumerate(inp[0]):
+            if isinstance(sub_ip, (torch.Tensor)):
+                write_tensor(m=m, tensor=sub_ip, suffix='ip_{}'.format(index), rnd_type ='rnd_up', file_format=file_format)
+    elif isinstance(inp, (torch.Tensor)):
+         write_tensor(m=m, tensor=inp, suffix='ip', rnd_type ='rnd_up', file_format=file_format)
+
+    #weights
+    if hasattr(m, 'weight'):
+        if isinstance(m.weight,torch.Tensor):
+            write_tensor(m=m, tensor=m.weight, suffix='weight', rnd_type ='rnd_sym', file_format=file_format)
+
+    #bias
+    if hasattr(m, 'bias'):
+        if m.bias is not None:
+            write_tensor(m=m, tensor=m.bias, suffix='bias', rnd_type ='rnd_sym', file_format=file_format)
+
+
+if __name__ == '__main__':
+    main()
index 27213f2662b7e2cde3063f8251d0367dc3b44c97..b7e45d91b8c9a0f65917c33d1edfd6e30de5daf9 100644 (file)
@@ -124,13 +124,14 @@ def get_config():
     args.do_pred_cordi_f2r = False              #true: Do f2r operation on detected location for interet point task
     args.depth_cmap_plasma = False      
     args.visualize_gt = False                   #to vis pred or GT
-    args.viz_depth_color_type = 'plasma'       #color type for dpeth visualization
+    args.viz_depth_color_type = 'plasma'        #color type for dpeth visualization
     args.depth = [False]
 
     args.palette = None
     args.label_infer = False
     args.viz_op_type = None
     args.car_mask = None
+    args.en_accuracy_measurement = True         #enabling accuracy measurement makes whole operation sequential and hence slows down inference significantly.
     return args
 
 
@@ -426,16 +427,17 @@ def validate(args, val_dataset, val_loader, model, epoch, infer_path):
                 if target_list:
                     label = np.squeeze(np.array(target_list[task_index][index]))
                     if not args.model_config.output_type[task_index] is 'depth':
-                        confusion_matrix[task_index] = eval_output(args, prediction, label, confusion_matrix[task_index], args.model_config.output_channels[task_index])
-                        accuracy, mean_iou, iou, f1_score= compute_accuracy(args, confusion_matrix[task_index], args.model_config.output_channels[task_index])
-                        temp_txt = []
-                        temp_txt.append(input_path[-1][index])
-                        temp_txt.extend(iou)
-                        metric_txt.append(temp_txt)
-                        print('{}/{} Inferred Frame {} mean_iou={},'.format((args.batch_size*iter+index+1), len(val_dataset), input_path[-1][index], mean_iou))
-                        if index == output.shape[0]-1:
-                            print('Task={},\npixel_accuracy={},\nmean_iou={},\niou={},\nf1_score = {}'.format(task_index, accuracy, mean_iou, iou, f1_score))
-                            sys.stdout.flush()
+                        if args.en_accuracy_measurement:
+                            confusion_matrix[task_index] = eval_output(args, prediction, label, confusion_matrix[task_index], args.model_config.output_channels[task_index])
+                            accuracy, mean_iou, iou, f1_score= compute_accuracy(args, confusion_matrix[task_index], args.model_config.output_channels[task_index])
+                            temp_txt = []
+                            temp_txt.append(input_path[-1][index])
+                            temp_txt.extend(iou)
+                            metric_txt.append(temp_txt)
+                            print('{}/{} Inferred Frame {} mean_iou={},'.format((args.batch_size*iter+index+1), len(val_dataset), input_path[-1][index], mean_iou))
+                            if index == output.shape[0]-1:
+                                print('Task={},\npixel_accuracy={},\nmean_iou={},\niou={},\nf1_score = {}'.format(task_index, accuracy, mean_iou, iou, f1_score))
+                                sys.stdout.flush()
                     elif args.model_config.output_type[task_index] is 'depth':
                         valid = (label != 0)
                         gt = torch.tensor(label[valid]).float()
diff --git a/modules/pytorch_jacinto_ai/engine/infer_pixel2pixel_onnx_rt.py b/modules/pytorch_jacinto_ai/engine/infer_pixel2pixel_onnx_rt.py
new file mode 100644 (file)
index 0000000..e243d9d
--- /dev/null
@@ -0,0 +1,836 @@
+import os
+import time
+import sys
+import math
+import copy
+import warnings
+
+import torch
+import torch.nn.parallel
+import torch.backends.cudnn as cudnn
+import torch.optim
+import torch.utils.data
+import datetime
+import numpy as np
+import random
+import cv2
+import matplotlib.pyplot as plt
+
+import onnx
+import onnxruntime
+
+from onnx import helper
+
+
+from .. import xnn
+from .. import vision
+
+#sys.path.insert(0, '../devkit-datasets/TI/')
+#from fisheye_calib import r_fish_to_theta_rect
+
+# ################################################
+def get_config():
+    args = xnn.utils.ConfigNode()
+
+    args.dataset_config = xnn.utils.ConfigNode()
+    args.dataset_config.split_name = 'val'
+    args.dataset_config.max_depth_bfr_scaling = 80
+    args.dataset_config.depth_scale = 1
+    args.dataset_config.train_depth_log = 1
+    args.use_semseg_for_depth = False
+
+    args.model_config = xnn.utils.ConfigNode()
+    args.model_name = 'deeplabv2lite_mobilenetv2'       # model architecture, overwritten if pretrained is specified
+    args.dataset_name = 'flying_chairs'              # dataset type
+
+    args.data_path = './data/datasets'                       # path to dataset
+    args.save_path = None            # checkpoints save path
+    args.pretrained = None
+
+    args.model_config.output_type = ['flow']                # the network is used to predict flow or depth or sceneflow')
+    args.model_config.output_channels = None                 # number of output channels
+    args.model_config.input_channels = None                  # number of input channels
+    args.model_config.num_classes = None                       # number of classes (for segmentation)
+    args.model_config.output_range = None  # max range of output
+
+    args.model_config.num_decoders = None               # number of decoders to use. [options: 0, 1, None]
+    args.sky_dir = False
+
+    args.logger = None                          # logger stream to output into
+
+    args.split_file = None                      # train_val split file
+    args.split_files = None                     # split list files. eg: train.txt val.txt
+    args.split_value = 0.8                      # test_val split proportion (between 0 (only test) and 1 (only train))
+
+    args.workers = 8                            # number of data loading workers
+
+    args.epoch_size = 0                         # manual epoch size (will match dataset size if not specified)
+    args.epoch_size_val = 0                     # manual epoch size (will match dataset size if not specified)
+    args.batch_size = 8                         # mini_batch_size
+    args.total_batch_size = None                # accumulated batch size. total_batch_size = batch_size*iter_size
+    args.iter_size = 1                          # iteration size. total_batch_size = batch_size*iter_size
+
+    args.tensorboard_num_imgs = 5               # number of imgs to display in tensorboard
+    args.phase = 'validation'                        # evaluate model on validation set
+    args.pretrained = None                      # path to pre_trained model
+    args.date = None                            # don\'t append date timestamp to folder
+    args.print_freq = 10                        # print frequency (default: 100)
+
+    args.div_flow = 1.0                         # value by which flow will be divided. Original value is 20 but 1 with batchNorm gives good results
+    args.losses = ['supervised_loss']           # loss functions to minimize
+    args.metrics = ['supervised_error']         # metric/measurement/error functions for train/validation
+    args.class_weights = None                   # class weights
+
+    args.multistep_gamma = 0.5                  # steps for step scheduler
+    args.polystep_power = 1.0                   # power for polynomial scheduler
+    args.train_fwbw = False                     # do forward backward step while training
+
+    args.rand_seed = 1                          # random seed
+    args.img_border_crop = None                 # image border crop rectangle. can be relative or absolute
+    args.target_mask = None                      # mask rectangle. can be relative or absolute. last value is the mask value
+    args.img_resize = None                      # image size to be resized to
+    args.rand_scale = (1,1.25)                  # random scale range for training
+    args.rand_crop = None                       # image size to be cropped to')
+    args.output_size = None                     # target output size to be resized to')
+
+    args.count_flops = True                     # count flops and report
+
+    args.shuffle = True                         # shuffle or not
+    args.is_flow = None                         # whether entries in images and targets lists are optical flow or not
+
+    args.multi_decoder = True                   # whether to use multiple decoders or unified decoder
+
+    args.create_video = False                   # whether to create video out of the inferred images
+
+    args.input_tensor_name = ['0']              # list of input tensore names
+
+    args.upsample_mode = 'nearest'              # upsample mode to use., choices=['nearest','bilinear']
+
+    args.image_prenorm = True                   # whether normalization is done before all other the transforms
+    args.image_mean = [128.0]                   # image mean for input image normalization
+    args.image_scale = [1.0/(0.25*256)]         # image scaling/mult for input iamge normalization
+    args.quantize = False                       # apply quantized inference or not
+    #args.model_surgery = None                   # replace activations with PAct2 activation module. Helpful in quantized training.
+    args.bitwidth_weights = 8                   # bitwidth for weights
+    args.bitwidth_activations = 8               # bitwidth for activations
+    args.histogram_range = True                 # histogram range for calibration
+    args.per_channel_q = False                  # apply separate quantizion factor for each channel in depthwise or not
+    args.bias_calibration = False                # apply bias correction during quantized inference calibration
+
+    args.frame_IOU = False                      # Print mIOU for each frame
+    args.make_score_zero_mean = False           #to make score and desc zero mean
+    args.learn_scaled_values_interest_pt = True
+    args.save_mod_files = False                 # saves modified files after last commit. Also  stores commit id.
+    args.gpu_mode = True                        #False will make inference run on CPU
+    args.write_layer_ip_op= False               #True will make it tap inputs outputs for layers
+    args.file_format = 'none'                   #Ip/Op tapped points for each layer: None : it will not be written but print will still appear
+    args.generate_onnx = True
+    args.remove_ignore_lbls_in_pred = False     #True: if in the pred where GT has ignore label do not visualize for GT visualization
+    args.do_pred_cordi_f2r = False              #true: Do f2r operation on detected location for interet point task
+    args.depth_cmap_plasma = False      
+    args.visualize_gt = False                   #to vis pred or GT
+    args.viz_depth_color_type = 'plasma'       #color type for dpeth visualization
+    args.depth = [False]
+    args.dump_layers = True
+    return args
+
+
+# ################################################
+# to avoid hangs in data loader with multi threads
+# this was observed after using cv2 image processing functions
+# https://github.com/pytorch/pytorch/issues/1355
+cv2.setNumThreads(0)
+
+##################################################
+np.set_printoptions(precision=3)
+
+# ################################################
+def main(args):
+
+    assert not hasattr(args, 'model_surgery'), 'the argument model_surgery is deprecated, it is not needed now - remove it'
+
+    #################################################
+    # global settings. rand seeds for repeatability
+    random.seed(args.rand_seed)
+    np.random.seed(args.rand_seed)
+    torch.manual_seed(args.rand_seed)
+    torch.backends.cudnn.deterministic = True
+    torch.backends.cudnn.benchmark = True
+
+    ################################
+    # args check and config
+    if args.iter_size != 1 and args.total_batch_size is not None:
+        warnings.warn("only one of --iter_size or --total_batch_size must be set")
+    #
+    if args.total_batch_size is not None:
+        args.iter_size = args.total_batch_size//args.batch_size
+    else:
+        args.total_batch_size = args.batch_size*args.iter_size
+    #
+
+    assert args.pretrained is not None, 'pretrained path must be provided'
+
+    # onnx generation is filing for post quantized module
+    # args.generate_onnx = False if (args.quantize) else args.generate_onnx
+    #################################################
+    # set some global flags and initializations
+    # keep it in args for now - although they don't belong here strictly
+    # using pin_memory is seen to cause issues, especially when when lot of memory is used.
+    args.use_pinned_memory = False
+    args.n_iter = 0
+    args.best_metric = -1
+
+    #################################################
+    if args.save_path is None:
+        save_path = get_save_path(args)
+    else:
+        save_path = args.save_path
+    #
+    print('=> will save everything to {}'.format(save_path))
+    if not os.path.exists(save_path):
+        os.makedirs(save_path)
+
+    #################################################
+    if args.logger is None:
+        log_file = os.path.splitext(os.path.basename(__file__))[0] + '.log'
+        args.logger = xnn.utils.TeeLogger(filename=os.path.join(save_path,log_file))
+
+    ################################
+    # print everything for log
+    print('=> args: ', args)
+
+    if args.save_mod_files:
+        #store all the files after the last commit.
+        mod_files_path = save_path+'/mod_files'
+        os.makedirs(mod_files_path)
+        
+        cmd = "git ls-files --modified | xargs -i cp {} {}".format("{}", mod_files_path)
+        print("cmd:", cmd)    
+        os.system(cmd)
+
+        #stoe last commit id. 
+        cmd = "git log -n 1  >> {}".format(mod_files_path + '/commit_id.txt')
+        print("cmd:", cmd)    
+        os.system(cmd)
+
+    transforms = get_transforms(args)
+
+    print("=> fetching img pairs in '{}'".format(args.data_path))
+    split_arg = args.split_file if args.split_file else (args.split_files if args.split_files else args.split_value)
+
+    val_dataset = vision.datasets.pixel2pixel.__dict__[args.dataset_name](args.dataset_config, args.data_path, split=split_arg, transforms=transforms)
+
+    print('=> {} val samples found'.format(len(val_dataset)))
+    val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=args.batch_size,
+        num_workers=args.workers, pin_memory=args.use_pinned_memory, shuffle=args.shuffle)
+
+    #################################################
+    if (args.model_config.input_channels is None):
+        args.model_config.input_channels = (3,)
+        print("=> input channels is not given - setting to {}".format(args.model_config.input_channels))
+
+    if (args.model_config.output_channels is None):
+        if ('num_classes' in dir(val_dataset)):
+            args.model_config.output_channels = val_dataset.num_classes()
+        else:
+            args.model_config.output_channels = (2 if args.model_config.output_type == 'flow' else args.model_config.output_channels)
+            xnn.utils.print_yellow("=> output channels is not given - setting to {} - not sure to work".format(args.model_config.output_channels))
+        #
+        if not isinstance(args.model_config.output_channels,(list,tuple)):
+            args.model_config.output_channels = [args.model_config.output_channels]
+
+    #################################################
+    #Load the ONNX model
+    onnx_model = onnx.load(args.pretrained)
+    try:
+        #Check that the IR is well formed
+        onnx.checker.check_model(model)
+    except:
+        print("ONNX model check failed: IR(Intermediate Representation) is not well formed")
+
+    if args.dump_layers: #add intermediate outputs to the onnx model
+        intermediate_layer_value_info = helper.ValueInfoProto()
+        intermediate_layer_value_info.name = ''
+        for i in range(len(onnx_model.graph.node)):
+            for j in range(len(onnx_model.graph.node[i].output)):
+                print('-' * 60)
+                print("Node:", i, "output_node:", j, onnx_model.graph.node[i].op_type, onnx_model.graph.node[i].output)
+                # add each intermediate layer one by one
+                if (onnx_model.graph.node[i].op_type == 'Relu') | (onnx_model.graph.node[i].op_type == 'Add') | \
+                (onnx_model.graph.node[i].op_type == 'Concat') | (onnx_model.graph.node[i].op_type == 'Resize') | \
+                (onnx_model.graph.node[i].op_type == 'Upsample'):
+                    intermediate_layer_value_info.name = onnx_model.graph.node[i].output[0]
+                    onnx_model.graph.output.append(intermediate_layer_value_info)
+        onnx.save(onnx_model, os.path.join(save_path, 'model_mod.onnx'))
+        args.pretrained = os.path.join(save_path, 'model_mod.onnx')
+
+    #################################################
+    args.loss_modules = copy.deepcopy(args.losses)
+    for task_dx, task_losses in enumerate(args.losses):
+        for loss_idx, loss_fn in enumerate(task_losses):
+            kw_args = {}
+            loss_args = vision.losses.__dict__[loss_fn].args()
+            for arg in loss_args:
+                #if arg == 'weight':
+                #    kw_args.update({arg:args.class_weights[task_dx]})
+                if arg == 'num_classes':
+                    kw_args.update({arg:args.model_config.output_channels[task_dx]})
+                elif arg == 'sparse':
+                    kw_args.update({arg:args.sparse})
+                #
+            #
+            loss_fn = vision.losses.__dict__[loss_fn](**kw_args)
+            loss_fn = loss_fn.cuda()
+            args.loss_modules[task_dx][loss_idx] = loss_fn
+
+    args.metric_modules = copy.deepcopy(args.metrics)
+    for task_dx, task_metrics in enumerate(args.metrics):
+        for midx, metric_fn in enumerate(task_metrics):
+            kw_args = {}
+            loss_args = vision.losses.__dict__[metric_fn].args()
+            for arg in loss_args:
+                if arg == 'weight':
+                    kw_args.update({arg:args.class_weights[task_dx]})
+                elif arg == 'num_classes':
+                    kw_args.update({arg:args.model_config.output_channels[task_dx]})
+                elif arg == 'sparse':
+                    kw_args.update({arg:args.sparse})
+                #
+            #
+            metric_fn = vision.losses.__dict__[metric_fn](**kw_args)
+            metric_fn = metric_fn.cuda()
+            args.metric_modules[task_dx][midx] = metric_fn
+
+    #################################################
+    if args.palette:
+        print('Creating palette')
+        args.palette = val_dataset.create_palette()
+        for i, p in enumerate(args.palette):
+            args.palette[i] = np.array(p, dtype = np.uint8)
+            args.palette[i] = args.palette[i][..., ::-1]  # RGB->BGR, since palette is expected to be given in RGB format
+
+    infer_path = []
+    for i, p in enumerate(args.model_config.output_channels):
+        infer_path.append(os.path.join(save_path, 'Task{}'.format(i)))
+        if not os.path.exists(infer_path[i]):
+            os.makedirs(infer_path[i])
+
+    #################################################
+    with torch.no_grad():
+        validate(args, val_dataset, val_loader, onnx_model, 0, infer_path)
+
+    if args.create_video:
+        create_video(args, infer_path=infer_path)
+
+
+def validate(args, val_dataset, val_loader, model, epoch, infer_path):
+    data_time = xnn.utils.AverageMeter()
+    avg_metric = [xnn.utils.AverageMeter(print_avg=(not task_metric[0].info()['is_avg'])) for task_metric in args.metric_modules]
+
+    # switch to evaluate mode
+    # model.eval()
+
+    session = onnxruntime.InferenceSession(args.pretrained, None)
+    input_name = session.get_inputs()[0].name
+    input_details = session.get_inputs()
+    output_details = session.get_outputs()
+
+    metric_name = "Metric"
+    end_time = time.time()
+    writer_idx = 0
+    last_update_iter = -1
+    metric_ctx = [None] * len(args.metric_modules)
+
+    confusion_matrix = []
+    for n_cls in args.model_config.output_channels:
+        confusion_matrix.append(np.zeros((n_cls, n_cls+1)))
+    metric_txt = []
+    ard_err = None
+    for iter, (input_list, target_list, input_path, target_path) in enumerate(val_loader):
+        file_name =  input_path[-1][0]
+        print("started inference of file_name:", file_name)
+        data_time.update(time.time() - end_time)
+
+        outputs = session.run([], {input_name: input_list[0].cpu().numpy()})
+        if args.dump_layers:
+            dst_dir = os.path.join(*infer_path[0].split('/')[:-1],'layers_dump' ,"{:04d}".format(args.batch_size*iter))
+            if not os.path.exists(dst_dir):
+                os.makedirs(dst_dir)
+            for idx, output in enumerate(outputs):
+                onnx_idx = model.graph.output[idx].name
+                dst_file = os.path.join(dst_dir,  "{:04d}".format(int(onnx_idx)) + '.bin')
+                output.tofile(dst_file)
+            outputs = [outputs[0]]  # First element from the model is the final output
+
+        outputs = [torch.tensor(outputs[index]) for index in range(len(outputs))]
+        if args.output_size is not None and target_list:
+           target_sizes = [tgt.shape for tgt in target_list]
+           outputs = upsample_tensors(outputs, target_sizes, args.upsample_mode)
+        elif args.output_size is not None and not target_list:
+           target_sizes = [args.output_size for _ in range(len(outputs))]
+           outputs = upsample_tensors(outputs, target_sizes, args.upsample_mode)
+        outputs = [out.cpu() for out in outputs]
+
+        for task_index in range(len(outputs)):
+            output = outputs[task_index]
+            gt_target = target_list[task_index] if target_list else None
+            if args.visualize_gt and target_list:
+                if args.model_config.output_type[task_index] is 'depth':
+                    output = gt_target
+                else:
+                    output = gt_target.to(dtype=torch.int8)
+                
+            if args.remove_ignore_lbls_in_pred and not (args.model_config.output_type[task_index] is 'depth') and target_list :
+                output[gt_target == 255] = args.palette[task_index-1].shape[0]-1
+            for index in range(output.shape[0]):
+                if args.frame_IOU:
+                    confusion_matrix[task_index] = np.zeros((args.model_config.output_channels[task_index], args.model_config.output_channels[task_index] + 1))
+                prediction = np.array(output[index])
+                if output.shape[1]>1:
+                    prediction = np.argmax(prediction, axis=0)
+                #
+                prediction = np.squeeze(prediction)
+
+                if target_list:
+                    label = np.squeeze(np.array(target_list[task_index][index]))
+                    if not args.model_config.output_type[task_index] is 'depth':
+                        confusion_matrix[task_index] = eval_output(args, prediction, label, confusion_matrix[task_index], args.model_config.output_channels[task_index])
+                        accuracy, mean_iou, iou, f1_score= compute_accuracy(args, confusion_matrix[task_index], args.model_config.output_channels[task_index])
+                        temp_txt = []
+                        temp_txt.append(input_path[-1][index])
+                        temp_txt.extend(iou)
+                        metric_txt.append(temp_txt)
+                        print('{}/{} Inferred Frame {} mean_iou={},'.format((args.batch_size*iter+index+1), len(val_dataset), input_path[-1][index], mean_iou))
+                        if index == output.shape[0]-1:
+                            print('Task={},\npixel_accuracy={},\nmean_iou={},\niou={},\nf1_score = {}'.format(task_index, accuracy, mean_iou, iou, f1_score))
+                            sys.stdout.flush()
+                    elif args.model_config.output_type[task_index] is 'depth':
+                        valid = (label != 0)
+                        gt = torch.tensor(label[valid]).float()
+                        inference = torch.tensor(prediction[valid]).float()
+                        if len(gt) > 2:
+                            if ard_err is None:
+                                ard_err = [absreldiff_rng3to80(inference, gt).mean()]
+                            else:
+                                ard_err.append(absreldiff_rng3to80(inference, gt).mean())
+                        elif len(gt) < 2:
+                            if ard_err is None:
+                                ard_err = [0.0]
+                            else:
+                                ard_err.append(0.0)
+
+                        print('{}/{} ARD: {}'.format((args.batch_size * iter + index), len(val_dataset),torch.tensor(ard_err).mean()))
+
+                seq = input_path[-1][index].split('/')[-4]
+                base_file = os.path.basename(input_path[-1][index])
+
+                if args.label_infer:
+                    output_image = prediction
+                    output_name = os.path.join(infer_path[task_index], seq + '_' + input_path[-1][index].split('/')[-3] + '_' + base_file)
+                    cv2.imwrite(output_name, output_image)
+                    print('{}/{}'.format((args.batch_size*iter+index), len(val_dataset)))
+
+                if hasattr(args, 'interest_pt') and args.interest_pt[task_index]:
+                    print('{}/{}'.format((args.batch_size * iter + index), len(val_dataset)))
+                    output_name = os.path.join(infer_path[task_index], seq + '_' + input_path[-1][index].split('/')[-3] + '_' + base_file)
+                    output_name_short = os.path.join(infer_path[task_index], os.path.basename(input_path[-1][index]))
+                    wrapper_write_desc(args=args, task_index=task_index, outputs=outputs, index=index, output_name=output_name, output_name_short=output_name_short)
+                    
+                if args.model_config.output_type[task_index] is 'depth':
+                    output_name = os.path.join(infer_path[task_index], seq + '_' + input_path[-1][index].split('/')[-3] + '_' + base_file)
+                    viz_depth(prediction = prediction, args=args, output_name = output_name, input_name=input_path[-1][task_index])
+                    print('{}/{}'.format((args.batch_size * iter + index), len(val_dataset)))
+
+                if args.blend[task_index]:
+                    prediction_size = (prediction.shape[0], prediction.shape[1], 3)
+                    output_image = args.palette[task_index-1][prediction.ravel()].reshape(prediction_size)
+                    input_bgr = cv2.imread(input_path[-1][index]) #Read the actual RGB image
+                    input_bgr = cv2.resize(input_bgr, dsize=(prediction.shape[1],prediction.shape[0]))
+                    output_image = xnn.utils.chroma_blend(input_bgr, output_image)
+                    output_name = os.path.join(infer_path[task_index], input_path[-1][index].split('/')[-4] + '_' + input_path[-1][index].split('/')[-3] + '_' +os.path.basename(input_path[-1][index]))
+                    cv2.imwrite(output_name, output_image)
+                    print('{}/{}'.format((args.batch_size*iter+index), len(val_dataset)))
+                #
+
+                if args.car_mask:   # generating car_mask (required for localization)
+                    car_mask = np.logical_or(prediction == 13, prediction == 14, prediction == 16, prediction == 17)
+                    prediction[car_mask] = 255
+                    prediction[np.invert(car_mask)] = 0
+                    output_image = prediction
+                    output_name = os.path.join(infer_path[task_index], os.path.basename(input_path[-1][index]))
+                    cv2.imwrite(output_name, output_image)
+    np.savetxt('metric.txt', metric_txt, fmt='%s')
+
+
+###############################################################
+def get_save_path(args, phase=None):
+    date = args.date if args.date else datetime.datetime.now().strftime("%Y-%m-%d-%H-%M-%S")
+    save_path = os.path.join('./data/checkpoints', args.dataset_name, date + '_' + args.dataset_name + '_' + args.model_name)
+    save_path += '_resize{}x{}'.format(args.img_resize[1], args.img_resize[0])
+    if args.rand_crop:
+        save_path += '_crop{}x{}'.format(args.rand_crop[1], args.rand_crop[0])
+    #
+    phase = phase if (phase is not None) else args.phase
+    save_path = os.path.join(save_path, phase)
+    return save_path
+
+
+def get_model_orig(model):
+    is_parallel_model = isinstance(model, (torch.nn.DataParallel, torch.nn.parallel.DistributedDataParallel))
+    model_orig = (model.module if is_parallel_model else model)
+    model_orig = (model_orig.module if isinstance(model_orig, (xnn.quantize.QuantBaseModule)) else model_orig)
+    return model_orig
+
+
+def create_rand_inputs(args, is_cuda):
+    dummy_input = []
+    for i_ch in args.model_config.input_channels:
+        x = torch.rand((1, i_ch, args.img_resize[0], args.img_resize[1]))
+        x = x.cuda() if is_cuda else x
+        dummy_input.append(x)
+    #
+    return dummy_input
+
+
+# FIX_ME:SN move to utils
+def store_desc(args=[], output_name=[], write_dense=False, desc_tensor=[], prediction=[],
+               scale_to_write_kp_loc_to_orig_res=[1.0, 1.0],
+               learn_scaled_values=True):
+    sys.path.insert(0, './scripts/')
+    import write_desc as write_desc
+
+    if args.write_desc_type != 'NONE':
+        txt_file_name = output_name.replace(".png", ".txt")
+        if write_dense:
+            # write desc
+            desc_tensor = desc_tensor.astype(np.int16)
+            print("writing dense desc(64 ch) op: {} : {} : {} : {}".format(desc_tensor.shape, desc_tensor.dtype,
+                                                                           desc_tensor.min(), desc_tensor.max()))
+            desc_tensor_name = output_name.replace(".png", "_desc.npy")
+            np.save(desc_tensor_name, desc_tensor)
+
+            # utils_hist.comp_hist_tensor3d(x=desc_tensor, name='desc_64ch', en=True, dir='desc_64ch', log=True, ch_dim=0)
+
+            # write score channel
+            prediction = prediction.astype(np.int16)
+
+            print("writing dense score ch op: {} : {} : {} : {}".format(prediction.shape, prediction.dtype,
+                                                                        prediction.min(),
+                                                                        prediction.max()))
+            score_tensor_name = output_name.replace(".png", "_score.npy")
+            np.save(score_tensor_name, prediction)
+
+            # utils_hist.hist_tensor2D(x_ch = prediction, dir='score', name='score', en=True, log=True)
+        else:
+            prediction[prediction < 0.0] = 0.0
+
+            if learn_scaled_values:
+                img_interest_pt_cur = prediction.astype(np.uint16)
+                score_th = 127
+            else:
+                img_interest_pt_cur = prediction
+                score_th = 0.001
+
+            # at boundary pixles score/desc will be wrong so do not write pred qty near the borader.
+            guard_band = 32 if args.write_desc_type == 'PRED' else 0
+
+            write_desc.write_score_desc_as_text(desc_tensor_cur=desc_tensor, img_interest_pt_cur=img_interest_pt_cur,
+                                                txt_file_name=txt_file_name, score_th=score_th,
+                                                skip_fac_for_reading_desc=1, en_nms=args.en_nms,
+                                                scale_to_write_kp_loc_to_orig_res=scale_to_write_kp_loc_to_orig_res,
+                                                recursive_nms=True, learn_scaled_values=learn_scaled_values,
+                                                guard_band=guard_band, true_nms=True, f2r=args.do_pred_cordi_f2r)
+
+
+      #utils_hist.hist_tensor2D(x_ch = prediction, dir='score', name='score', en=True, log=True)
+    else:  
+      prediction[prediction < 0.0] = 0.0
+      
+      if learn_scaled_values:
+        img_interest_pt_cur = prediction.astype(np.uint16)
+        score_th = 127
+      else:  
+        img_interest_pt_cur = prediction
+        score_th = 0.001
+
+      # at boundary pixles score/desc will be wrong so do not write pred qty near the borader.
+      guard_band = 32 if args.write_desc_type == 'PRED' else 0
+
+      write_desc.write_score_desc_as_text(desc_tensor_cur = desc_tensor, img_interest_pt_cur = img_interest_pt_cur,
+        txt_file_name = txt_file_name, score_th = score_th, skip_fac_for_reading_desc = 1, en_nms=args.en_nms,
+        scale_to_write_kp_loc_to_orig_res = scale_to_write_kp_loc_to_orig_res,
+        recursive_nms=True, learn_scaled_values=learn_scaled_values, guard_band = guard_band, true_nms=True, f2r=args.do_pred_cordi_f2r)
+
+def viz_depth(prediction = [], args=[], output_name=[], input_name=[]):
+    max_value_depth = args.max_depth
+    output_image = torch.tensor(prediction)
+    if args.viz_depth_color_type == 'rainbow':
+        not_valid_indices = output_image == 0
+        output_image = 255*xnn.utils.tensor2array(output_image, max_value=max_value_depth, colormap='rainbow')[:,:,::-1]
+        output_image[not_valid_indices] = 0
+    elif args.viz_depth_color_type == 'rainbow_blend':
+        print(max_value_depth)
+        #scale_mul = 1 if args.visualize_gt else 255
+        print(output_image.min())
+        print(output_image.max())
+        not_valid_indices = output_image == 0
+        output_image = 255*xnn.utils.tensor2array(output_image, max_value=max_value_depth, colormap='rainbow')[:,:,::-1]
+        print(output_image.max())
+        #output_image[label == 1] = 0
+        input_bgr = cv2.imread(input_name)  # Read the actual RGB image
+        input_bgr = cv2.resize(input_bgr, dsize=(prediction.shape[1], prediction.shape[0]))
+        if args.sky_dir:
+            label_file = os.path.join(args.sky_dir, seq, seq + '_image_00_' + base_file)
+            label = cv2.imread(label_file)
+            label = cv2.resize(label, dsize=(prediction.shape[1], prediction.shape[0]),
+                                interpolation=cv2.INTER_NEAREST)
+            output_image[label == 1] = 0
+        output_image[not_valid_indices] = 0
+        output_image = xnn.utils.chroma_blend(input_bgr, output_image)  # chroma_blend(input_bgr, output_image)
+
+    elif args.viz_depth_color_type == 'bone':
+        output_image = 255 * xnn.utils.tensor2array(output_image, max_value=max_value_depth, colormap='bone')
+    elif args.viz_depth_color_type == 'raw_depth':
+        output_image = np.array(output_image)
+        output_image[output_image > max_value_depth] = max_value_depth
+        output_image[output_image < 0] = 0
+        scale = 2.0**16 - 1.0 #255
+        output_image = (output_image / max_value_depth) * scale
+        output_image = output_image.astype(np.uint16)
+        # output_image[(label[:,:,0]==1)|(label[:,:,0]==4)]=255
+    elif args.viz_depth_color_type == 'plasma':
+        plt.imsave(output_name, output_image, cmap='plasma', vmin=0, vmax=max_value_depth)
+    elif args.viz_depth_color_type == 'log_greys':        
+        plt.imsave(output_name, np.log10(output_image), cmap='Greys', vmin=0, vmax=np.log10(max_value_depth))
+        #plt.imsave(output_name, output_image, cmap='Greys', vmin=0, vmax=max_value_depth)
+    else:
+        print("undefined color type for visualization")
+        exit(0)
+
+    if args.viz_depth_color_type != 'plasma':
+        # plasma type will be handled by imsave
+        cv2.imwrite(output_name, output_image)
+
+
+def wrapper_write_desc(args=[], task_index=0, outputs=[], index=0, output_name=[], output_name_short=[]):
+    if args.write_desc_type == 'GT':
+        # write GT desc
+        tensor_to_write = target_list[task_index]
+    elif args.write_desc_type == 'PRED':
+        # write predicted desc
+        tensor_to_write = outputs[task_index]
+
+    interest_pt_score = np.array(tensor_to_write[index, 0, ...])
+
+    if args.make_score_zero_mean:
+        # visulization code assumes range [0,255]. Add 128 to make range the same in case of zero mean too.
+        interest_pt_score += 128.0
+
+    if args.write_desc_type == 'NONE':
+        # scale + clip score between 0-255 and convert score_array to image
+        # scale_range = 127.0/0.005
+        # scale_range = 255.0/np.max(interest_pt_score)
+        scale_range = 1.0
+        interest_pt_score = np.clip(interest_pt_score * scale_range, 0.0, 255.0)
+        interest_pt_score = np.asarray(interest_pt_score, 'uint8')
+
+    interest_pt_descriptor = np.array(tensor_to_write[index, 1:, ...])
+
+    # output_name = os.path.join(infer_path[task_index], seq + '_' + input_path[-1][index].split('/')[-3] + '_' + base_file)
+    cv2.imwrite(output_name, interest_pt_score)
+
+    # output_name_short = os.path.join(infer_path[task_index], os.path.basename(input_path[-1][index]))
+
+    scale_to_write_kp_loc_to_orig_res = args.scale_to_write_kp_loc_to_orig_res
+    if args.scale_to_write_kp_loc_to_orig_res[0] == -1:
+        scale_to_write_kp_loc_to_orig_res[0] = input_list[task_index].shape[2] / target_list[task_index].shape[2]
+        scale_to_write_kp_loc_to_orig_res[1] = scale_to_write_kp_loc_to_orig_res[0]
+
+    print("scale_to_write_kp_loc_to_orig_res: ", scale_to_write_kp_loc_to_orig_res)
+    store_desc(args=args, output_name=output_name_short, desc_tensor=interest_pt_descriptor,
+               prediction=interest_pt_score,
+               scale_to_write_kp_loc_to_orig_res=scale_to_write_kp_loc_to_orig_res,
+               learn_scaled_values=args.learn_scaled_values_interest_pt,
+               write_dense=False)
+
+
+def get_transforms(args):
+    # image normalization can be at the beginning of transforms or at the end
+    args.image_mean = np.array(args.image_mean, dtype=np.float32)
+    args.image_scale = np.array(args.image_scale, dtype=np.float32)
+    image_prenorm = vision.transforms.image_transforms.NormalizeMeanScale(mean=args.image_mean, scale=args.image_scale) if args.image_prenorm else None
+    image_postnorm = vision.transforms.image_transforms.NormalizeMeanScale(mean=args.image_mean, scale=args.image_scale) if (not image_prenorm) else None
+
+    #target size must be according to output_size. prediction will be resized to output_size before evaluation.
+    test_transform = vision.transforms.image_transforms.Compose([
+        image_prenorm,
+        vision.transforms.image_transforms.AlignImages(),
+        vision.transforms.image_transforms.MaskTarget(args.target_mask, 0),
+        vision.transforms.image_transforms.CropRect(args.img_border_crop),
+        vision.transforms.image_transforms.Scale(args.img_resize, target_size=args.output_size, is_flow=args.is_flow),
+        image_postnorm,
+        vision.transforms.image_transforms.ConvertToTensor()
+        ])
+
+    return test_transform
+
+
+def _upsample_impl(tensor, output_size, upsample_mode):
+    # upsample of long tensor is not supported currently. covert to float, just to avoid error.
+    # we can do thsi only in the case of nearest mode, otherwise output will have invalid values.
+    convert_to_float = False
+    if isinstance(tensor, (torch.LongTensor,torch.cuda.LongTensor)):
+        convert_to_float = True
+        tensor = tensor.float()
+        upsample_mode = 'nearest'
+    #
+
+    dim_added = False
+    if len(tensor.shape) < 4:
+        tensor = tensor[np.newaxis,...]
+        dim_added = True
+    #
+    if (tensor.size()[-2:] != output_size):
+        tensor = torch.nn.functional.interpolate(tensor, output_size, mode=upsample_mode)
+    # --
+    if dim_added:
+        tensor = tensor[0,...]
+    #
+
+    if convert_to_float:
+        tensor = tensor.long()
+    #
+    return tensor
+
+def upsample_tensors(tensors, output_sizes, upsample_mode):
+    if not output_sizes:
+        return tensors
+    #
+    if isinstance(tensors, (list,tuple)):
+        for tidx, tensor in enumerate(tensors):
+            tensors[tidx] = _upsample_impl(tensor, output_sizes[tidx][-2:], upsample_mode)
+        #
+    else:
+        tensors = _upsample_impl(tensors, output_sizes[0][-2:], upsample_mode)
+    return tensors
+
+
+
+
+def eval_output(args, output, label, confusion_matrix, n_classes):
+    if len(label.shape)>2:
+        label = label[:,:,0]
+    gt_labels = label.ravel()
+    det_labels = output.ravel().clip(0,n_classes)
+    gt_labels_valid_ind = np.where(gt_labels != 255)
+    gt_labels_valid = gt_labels[gt_labels_valid_ind]
+    det_labels_valid = det_labels[gt_labels_valid_ind]
+    for r in range(confusion_matrix.shape[0]):
+        for c in range(confusion_matrix.shape[1]):
+            confusion_matrix[r,c] += np.sum((gt_labels_valid==r) & (det_labels_valid==c))
+
+    return confusion_matrix
+    
+def compute_accuracy(args, confusion_matrix, n_classes):
+    num_selected_classes = n_classes
+    tp = np.zeros(n_classes)
+    population = np.zeros(n_classes)
+    det = np.zeros(n_classes)
+    iou = np.zeros(n_classes)
+    
+    for r in range(n_classes):
+      for c in range(n_classes):
+        population[r] += confusion_matrix[r][c]
+        det[c] += confusion_matrix[r][c]   
+        if r == c:
+          tp[r] += confusion_matrix[r][c]
+
+    for cls in range(num_selected_classes):
+      intersection = tp[cls]
+      union = population[cls] + det[cls] - tp[cls]
+      iou[cls] = (intersection / union) if union else 0     # For caffe jacinto script
+      #iou[cls] = (intersection / (union + np.finfo(np.float32).eps))  # For pytorch-jacinto script
+
+    num_nonempty_classes = 0
+    for pop in population:
+      if pop>0:
+        num_nonempty_classes += 1
+          
+    mean_iou = np.sum(iou) / num_nonempty_classes if num_nonempty_classes else 0
+    accuracy = np.sum(tp) / np.sum(population) if np.sum(population) else 0
+    
+    #F1 score calculation
+    fp = np.zeros(n_classes)
+    fn = np.zeros(n_classes)
+    precision = np.zeros(n_classes)
+    recall = np.zeros(n_classes)
+    f1_score = np.zeros(n_classes)
+
+    for cls in range(num_selected_classes):
+        fp[cls] = det[cls] - tp[cls]
+        fn[cls] = population[cls] - tp[cls]
+        precision[cls] = tp[cls] / (det[cls] + 1e-10)
+        recall[cls] = tp[cls] / (tp[cls] + fn[cls] + 1e-10)        
+        f1_score[cls] = 2 * precision[cls]*recall[cls] / (precision[cls] + recall[cls] + 1e-10)
+
+    return accuracy, mean_iou, iou, f1_score
+    
+        
+def infer_video(args, net):
+    videoIpHandle = imageio.get_reader(args.input, 'ffmpeg')
+    fps = math.ceil(videoIpHandle.get_meta_data()['fps'])
+    print(videoIpHandle.get_meta_data())
+    numFrames = min(len(videoIpHandle), args.num_images)
+    videoOpHandle = imageio.get_writer(args.output,'ffmpeg', fps=fps)
+    for num in range(numFrames):
+        print(num, end=' ')
+        sys.stdout.flush()
+        input_blob = videoIpHandle.get_data(num)
+        input_blob = input_blob[...,::-1]    #RGB->BGR
+        output_blob = infer_blob(args, net, input_blob)     
+        output_blob = output_blob[...,::-1]  #BGR->RGB            
+        videoOpHandle.append_data(output_blob)
+    videoOpHandle.close()
+    return
+
+
+def absreldiff(x, y, eps = 0.0, max_val=None):
+    assert x.size() == y.size(), 'tensor dimension mismatch'
+    if max_val is not None:
+        x = torch.clamp(x, -max_val, max_val)
+        y = torch.clamp(y, -max_val, max_val)
+    #
+
+    diff = torch.abs(x - y)
+    y = torch.abs(y)
+
+    den_valid = (y == 0).float()
+    eps_arr = (den_valid * (1e-6))   # Just to avoid divide by zero
+
+    large_arr = (y > eps).float()    # ARD is not a good measure for small ref values. Avoid them.
+    out = (diff / (y + eps_arr)) * large_arr
+    return out
+
+
+def absreldiff_rng3to80(x, y):
+    return absreldiff(x, y, eps = 3.0, max_val=80.0)
+
+
+
+def create_video(args, infer_path):
+    op_file_name = args.data_path.split('/')[-1]
+    os.system(' ffmpeg -framerate 30 -pattern_type glob -i "{}/*.png" -c:v libx264 -vb 50000k -qmax 20 -r 10 -vf scale=1024:512  -pix_fmt yuv420p {}.MP4'.format(infer_path, op_file_name))
+
+def write_onnx_model(args, model, save_path, name='checkpoint.onnx'):
+    is_cuda = next(model.parameters()).is_cuda
+    input_list = create_rand_inputs(args, is_cuda=is_cuda)
+    #
+    model.eval()
+    torch.onnx.export(get_model_orig(model), input_list, os.path.join(save_path, name), export_params=True, verbose=False)
+    # torch onnx export does not update names. Do it using onnx.save
+
+
+
+if __name__ == '__main__':
+    train_args = get_config()
+    main(train_args)
index d1d60995f9fb8f94ab19bd3f8b6225d098995563..4c62bac0d2aadfe753ff08ebe901ee9154f07da5 100644 (file)
@@ -29,6 +29,8 @@ def get_config():
     args = xnn.utils.ConfigNode()
     args.model_config = xnn.utils.ConfigNode()
     args.dataset_config = xnn.utils.ConfigNode()
+    args.model_config.num_tiles_x = int(1)
+    args.model_config.num_tiles_y = int(1)
 
     args.model_config.input_channels = 3                # num input channels
 
@@ -398,7 +400,8 @@ def get_model_orig(model):
 
 
 def create_rand_inputs(args, is_cuda):
-    dummy_input = torch.rand((1, args.model_config.input_channels, args.img_crop, args.img_crop))
+    dummy_input = torch.rand((1, args.model_config.input_channels, args.img_crop*args.model_config.num_tiles_y,
+      args.img_crop*args.model_config.num_tiles_x))
     dummy_input = dummy_input.cuda() if is_cuda else dummy_input
     return dummy_input
 
@@ -453,9 +456,22 @@ def train(args, train_loader, model, criterion, optimizer, epoch):
 
         data_time.update(time.time() - end)
 
+        # preprocess to make tiles
+        if args.model_config.num_tiles_y>1 or args.model_config.num_tiles_x>1:
+            input = xnn.utils.reshape_input_4d(input, args.model_config.num_tiles_y, args.model_config.num_tiles_x)
+        #
+
         # compute output
         output = model(input)
 
+        if args.model_config.num_tiles_y>1 or args.model_config.num_tiles_x>1:
+            # [1,n_class,n_tiles_y, n_tiles_x] to [1,n_tiles_y, n_tiles_x, n_class]
+            # e.g. [1,10,4,5] to [1,4,5,10]
+            output = output.permute(0, 2, 3, 1)
+            #change shape from [1,n_tiles_y, n_tiles_x, n_class] to [1*n_tiles_y*n_tiles_x, n_class]
+            output = torch.reshape(output, (-1, output.shape[-1]))
+        #
+
         # compute loss
         loss = criterion(output, target) / args.iter_size
 
@@ -531,8 +547,22 @@ def validate(args, val_loader, model, criterion, epoch):
             input_size = input[0].size() if xnn.utils.is_list(input) else input.size()
             target = target.cuda(non_blocking=True)
 
+            # preprocess to make tiles
+            if args.model_config.num_tiles_y > 1 or args.model_config.num_tiles_x > 1:
+                input = xnn.utils.reshape_input_4d(input, args.model_config.num_tiles_y, args.model_config.num_tiles_x)
+            #
+
             # compute output
             output = model(input)
+
+            if args.model_config.num_tiles_y > 1 or args.model_config.num_tiles_x > 1:
+                # [1,n_class,n_tiles_y, n_tiles_x] to [1,n_tiles_y, n_tiles_x, n_class] 
+                # e.g. [1,10,4,5] to [1,4,5,10]
+                output = output.permute(0,2,3,1)
+                #change shape from [1,n_tiles_y, n_tiles_x, n_class] to [1*n_tiles_y*n_tiles_x, n_class]
+                output = torch.reshape(output, (-1, output.shape[-1]))
+            #
+
             loss = criterion(output, target)
 
             # measure accuracy and record loss
index 1732e993cdab120e8f54d66e4ba3446b664afa07..2662799f5e76577cac426aae4a83ea9d11e8b0a7 100644 (file)
@@ -24,6 +24,7 @@ import warnings
 
 from .. import xnn
 from .. import vision
+from . infer_pixel2pixel import compute_accuracy
 
 
 ##################################################
@@ -165,6 +166,9 @@ def get_config():
     args.viz_colormap = 'rainbow'                       # colormap for tensorboard: 'rainbow', 'plasma', 'magma', 'bone'
 
     args.freeze_bn = False                              # freeze the statistics of bn
+    args.tensorboard_enable = True                      # en/disable of TB writing
+    args.print_train_class_iou = False
+    args.print_val_class_iou = False
 
     return args
 
@@ -259,8 +263,8 @@ def main(args):
     print('=> will save everything to {}'.format(save_path))
 
     #################################################
-    train_writer = SummaryWriter(os.path.join(save_path,'train'))
-    val_writer = SummaryWriter(os.path.join(save_path,'val'))
+    train_writer = SummaryWriter(os.path.join(save_path,'train')) if args.tensorboard_enable else None
+    val_writer = SummaryWriter(os.path.join(save_path,'val')) if args.tensorboard_enable else None
     transforms = get_transforms(args) if args.transforms is None else args.transforms
     assert isinstance(transforms, (list,tuple)) and len(transforms) == 2, 'incorrect transforms were given'
 
@@ -551,9 +555,10 @@ def main(args):
                             'quantize' : args.quantize}
 
         save_checkpoint(args, save_path, get_model_orig(model), checkpoint_dict, is_best)
-        
-        train_writer.file_writer.flush()
-        val_writer.file_writer.flush()
+
+        if args.tensorboard_enable:
+            train_writer.file_writer.flush()
+            val_writer.file_writer.flush()
 
         # adjust the learning rate using lr scheduler
         if 'training' in args.phase:
@@ -638,8 +643,14 @@ def train(args, train_dataset, train_loader, model, optimizer, epoch, train_writ
                          task_mults=args.multi_task_factors, task_offsets=args.multi_task_offsets,
                          loss_mult_factors=args.loss_mult_factors)
 
-        metric_total, metric_list, metric_names, metric_types, _ = \
-            compute_task_objectives(args, args.metric_modules, input_list, task_outputs, target_list)
+        if args.print_train_class_iou:
+            metric_total, metric_list, metric_names, metric_types, _, confusion_matrix = \
+                compute_task_objectives(args, args.metric_modules, input_list, task_outputs, target_list, 
+                get_confusion_matrix=args.print_train_class_iou)
+        else:        
+            metric_total, metric_list, metric_names, metric_types, _ = \
+                compute_task_objectives(args, args.metric_modules, input_list, task_outputs, target_list, 
+                get_confusion_matrix=args.print_train_class_iou)
 
         if args.model_config.multi_task_type is not None and len(args.model_config.output_channels) > 1:
             xnn.layers.set_losses(model, loss_list_orig)
@@ -660,16 +671,19 @@ def train(args, train_dataset, train_loader, model, optimizer, epoch, train_writ
         for task_idx, task_losses in enumerate(args.loss_modules):
             avg_loss[task_idx].update(float(loss_list[task_idx].cpu()), batch_size_cur)
             avg_loss_orig[task_idx].update(float(loss_list_orig[task_idx].cpu()), batch_size_cur)
-            train_writer.add_scalar('Training/Task{}_{}_Loss_Iter'.format(task_idx,loss_names[task_idx]), float(loss_list[task_idx]), args.n_iter)
-            if args.model_config.multi_task_type is not None and len(args.model_config.output_channels) > 1:
-                train_writer.add_scalar('Training/multi_task_Factor_Task{}_{}'.format(task_idx,loss_names[task_idx]), float(args.multi_task_factors[task_idx]), args.n_iter)
+            if args.tensorboard_enable:
+                train_writer.add_scalar('Training/Task{}_{}_Loss_Iter'.format(task_idx,loss_names[task_idx]), float(loss_list[task_idx]), args.n_iter)
+                if args.model_config.multi_task_type is not None and len(args.model_config.output_channels) > 1:
+                    train_writer.add_scalar('Training/multi_task_Factor_Task{}_{}'.format(task_idx,loss_names[task_idx]), float(args.multi_task_factors[task_idx]), args.n_iter)
 
         # record error/accuracy.
         for task_idx, task_metrics in enumerate(args.metric_modules):
             avg_metric[task_idx].update(float(metric_list[task_idx].cpu()), batch_size_cur)
 
         ##########################
-        write_output(args, 'Training_', epoch_size, iter, epoch, train_dataset, train_writer, input_list, task_outputs, target_list, metric_name, writer_idx)
+        if args.tensorboard_enable:
+            write_output(args, 'Training_', epoch_size, iter, epoch, train_dataset, train_writer, input_list, task_outputs, target_list, metric_name, writer_idx)
+
         if ((iter % args.print_freq) == 0) or (iter == (epoch_size-1)):
             output_string = ''
             for task_idx, task_metrics in enumerate(args.metric_modules):
@@ -692,22 +706,26 @@ def train(args, train_dataset, train_loader, model, optimizer, epoch, train_writ
         #if epoch == 0 and iter == 0:
         #    input_zero = torch.zeros(input_var.shape)
         #    train_writer.add_graph(model, input_zero)
-
+        #This cache operation slows down tranining  
         #torch.cuda.empty_cache()
 
         if iter >= epoch_size:
             break
 
+    if args.print_train_class_iou:
+        print_class_iou(args=args, confusion_matrix=confusion_matrix, task_idx=task_idx)
+        
     progress_bar.close()
 
     # to print a new line - do not provide end=''
     print('{}'.format(Fore.RESET), end='')
 
-    for task_idx, task_losses in enumerate(args.loss_modules):
-        train_writer.add_scalar('Training/Task{}_{}_Loss_Epoch'.format(task_idx,loss_names[task_idx]), float(avg_loss[task_idx]), epoch)
+    if args.tensorboard_enable:
+        for task_idx, task_losses in enumerate(args.loss_modules):
+            train_writer.add_scalar('Training/Task{}_{}_Loss_Epoch'.format(task_idx,loss_names[task_idx]), float(avg_loss[task_idx]), epoch)
 
-    for task_idx, task_metrics in enumerate(args.metric_modules):
-        train_writer.add_scalar('Training/Task{}_{}_Metric_Epoch'.format(task_idx,metric_names[task_idx]), float(avg_metric[task_idx]), epoch)
+        for task_idx, task_metrics in enumerate(args.metric_modules):
+            train_writer.add_scalar('Training/Task{}_{}_Metric_Epoch'.format(task_idx,metric_names[task_idx]), float(avg_metric[task_idx]), epoch)
 
     output_name = metric_names[args.pivot_task_idx]
     output_metric = float(avg_metric[args.pivot_task_idx])
@@ -764,15 +782,22 @@ def validate(args, val_dataset, val_loader, model, epoch, val_writer):
 
         task_outputs = task_outputs if isinstance(task_outputs,(list,tuple)) else [task_outputs]
         task_outputs = upsample_tensors(task_outputs, target_sizes, args.upsample_mode)
-
-        metric_total, metric_list, metric_names, metric_types, _ = \
-            compute_task_objectives(args, args.metric_modules, input_list, task_outputs, target_list)
+        
+        if args.print_val_class_iou:
+            metric_total, metric_list, metric_names, metric_types, _, confusion_matrix = \
+                compute_task_objectives(args, args.metric_modules, input_list, task_outputs, target_list, 
+                get_confusion_matrix = args.print_val_class_iou)
+        else:        
+            metric_total, metric_list, metric_names, metric_types, _ = \
+                compute_task_objectives(args, args.metric_modules, input_list, task_outputs, target_list, 
+                get_confusion_matrix = args.print_val_class_iou)
 
         # record error/accuracy.
         for task_idx, task_metrics in enumerate(args.metric_modules):
             avg_metric[task_idx].update(float(metric_list[task_idx].cpu()), batch_size_cur)
 
-        write_output(args, 'Validation_', epoch_size, iter, epoch, val_dataset, val_writer, input_list, task_outputs, target_list, metric_names, writer_idx)
+        if args.tensorboard_enable:
+            write_output(args, 'Validation_', epoch_size, iter, epoch, val_dataset, val_writer, input_list, task_outputs, target_list, metric_names, writer_idx)
 
         if ((iter % args.print_freq) == 0) or (iter == (epoch_size-1)):
             output_string = ''
@@ -791,13 +816,18 @@ def validate(args, val_dataset, val_loader, model, epoch, val_writer):
         if iter >= epoch_size:
             break
 
+    if args.print_val_class_iou:
+        print_class_iou(args = args, confusion_matrix = confusion_matrix, task_idx=task_idx)
+    
+    #print_conf_matrix(conf_matrix=conf_matrix, en=False)
     progress_bar.close()
 
     # to print a new line - do not provide end=''
     print('{}'.format(Fore.RESET), end='')
 
-    for task_idx, task_metrics in enumerate(args.metric_modules):
-        val_writer.add_scalar('Validation/Task{}_{}_Metric_Epoch'.format(task_idx,metric_names[task_idx]), float(avg_metric[task_idx]), epoch)
+    if args.tensorboard_enable:
+        for task_idx, task_metrics in enumerate(args.metric_modules):
+            val_writer.add_scalar('Validation/Task{}_{}_Metric_Epoch'.format(task_idx,metric_names[task_idx]), float(avg_metric[task_idx]), epoch)
 
     output_name = metric_names[args.pivot_task_idx]
     output_metric = float(avg_metric[args.pivot_task_idx])
@@ -955,8 +985,23 @@ def write_output(args, prefix, val_epoch_size, iter, epoch, dataset, output_writ
             output_writer.add_image(prefix+'Task{}_{}_Output_Bone_Visualization/{}'.format(task_idx,output_type,writer_idx), xnn.utils.tensor2array(output_score_to_write, max_value=max_value_score, colormap='bone').transpose(2,0,1), epoch)
         #
 
-
-def compute_task_objectives(args, objective_fns, input_var, task_outputs, task_targets, task_mults=None, task_offsets=None, loss_mult_factors=None):
+def print_conf_matrix(conf_matrix = [], en = False):
+    if not en:
+        return
+    num_rows = conf_matrix.shape[0]
+    num_cols = conf_matrix.shape[1]
+    print("-"*64)
+    num_ele = 1
+    for r_idx in range(num_rows):
+        print("\n")
+        for c_idx in range(0,num_cols,num_ele):
+            print(conf_matrix[r_idx][c_idx:c_idx+num_ele], end="")
+    print("\n")
+    print("-" * 64)
+
+def compute_task_objectives(args, objective_fns, input_var, task_outputs, task_targets, task_mults=None, 
+  task_offsets=None, loss_mult_factors=None, get_confusion_matrix = False):
+  
     ##########################
     objective_total = torch.zeros_like(task_outputs[0].view(-1)[0])
     objective_list = []
@@ -977,6 +1022,9 @@ def compute_task_objectives(args, objective_fns, input_var, task_outputs, task_t
             objective_batch = objective_batch.mean() if isinstance(objective_fn, torch.nn.DataParallel) else objective_batch
             objective_name = objective_fn.info()['name']
             objective_type = objective_fn.info()['is_avg']
+            if get_confusion_matrix:
+                confusion_matrix = objective_fn.info()['confusion_matrix']
+
             loss_mult = loss_mult_factors[task_idx][oidx] if (loss_mult_factors is not None) else 1.0
             # --
             objective_batch_not_nan = (objective_batch if not torch.isnan(objective_batch) else 0.0)
@@ -992,8 +1040,11 @@ def compute_task_objectives(args, objective_fns, input_var, task_outputs, task_t
 
         objective_total = objective_sum_value*task_mult + task_offset + objective_total
 
-    return objective_total, objective_list, objective_names, objective_types, objective_list_orig
+    return_list = [objective_total, objective_list, objective_names, objective_types, objective_list_orig]
+    if get_confusion_matrix:
+        return_list.append(confusion_matrix)
 
+    return return_list 
 
 
 def save_checkpoint(args, save_path, model, checkpoint_dict, is_best, filename='checkpoint.pth.tar'):
@@ -1110,6 +1161,14 @@ def upsample_tensors(tensors, output_sizes, upsample_mode):
         tensors = _upsample_impl(tensors, output_sizes[0][-2:], upsample_mode)
     return tensors
 
+#print IoU for each class
+def print_class_iou(args = None, confusion_matrix = None, task_idx = 0):    
+    n_classes = args.model_config.output_channels[task_idx]
+    [accuracy, mean_iou, iou, f1_score] = compute_accuracy(args, confusion_matrix, n_classes)
+    print("\n Class IoU: [", end = "")
+    for class_iou in iou:
+        print("{:0.3f}".format(class_iou), end=",")
+    print("]")    
 
 if __name__ == '__main__':
     train_args = get_config()
index 2c55bc904bd89dc172b0db32a98293de45dc891a..62a2377aa5b03ffbcae97a1922e7654c6d0fd88d 100644 (file)
@@ -24,6 +24,7 @@ def get_config():
     dataset_config.input_offsets = None
     dataset_config.load_segmentation = True
     dataset_config.load_segmentation_five_class = False
+    dataset_config.split = 'val'
     return dataset_config
 
 
@@ -32,31 +33,50 @@ class A2D2BaseSegmentationLoader():
     """A2D2Loader: Data is derived from A2D2, and can be downloaded from here: https://www.A2D2-dataset.com/downloads/
     Many Thanks to @fvisin for the loader repo: https://github.com/fvisin/dataset_loaders/blob/master/dataset_loaders/images/A2D2.py"""
     
+    train_only_difficult_classes = False
+    class_weights_ = None
+
     colors = [
         [255,0,0],[182,89,6],[204,153,255],[255,128,0],[0,255,0],[0,128,255],[0,255,255],[255,255,0],[233,100,0],[110,110,0],[128,128,0],[255,193,37],[64,0,64],[185,122,87],[0,0,100],[139,99,108],[210,50,115],[255,0,128],[255,246,143],[150,0,150],[204,255,153],[238,162,173],[33,44,177],[180,50,180],[255,70,185],[238,233,191],[147,253,194],[150,150,200],[180,150,200],[72,209,204],[200,125,210],[159,121,238],[128,0,255],[255,0,255],[135,206,255],[241,230,255],[96,69,143],[53,46,82], [0, 0, 0]]
 
     num_classes_ = 38
     label_colours = dict(zip(range(num_classes_), colors))
     
-    void_classes = []
-    
-    valid_classes = range(0,num_classes_)
+    #Difficult classes less than < 0.4 mAP
+    if train_only_difficult_classes:
+        valid_classes = [8, 9, 11, 13, 14, 16, 17, 18, 19, 20, 22, 24, 25, 36, 37]
+        void_classes = []
+        #could not get pythonoic way working !!
+        #void_classes = [x for x in range(num_classes_) if x in valid_classes]
+        for idx in range(num_classes_):
+            if idx not in valid_classes:
+                void_classes.append(idx)
+        
+        print("void_classes: ", void_classes)
+        colors_valid_class = []
+        for idx, valid_class in enumerate(valid_classes):
+            colors_valid_class.append(colors[valid_class]) 
+        colors = colors_valid_class
+    else:        
+        valid_classes = range(0,num_classes_)
+        void_classes = []
+        # class_weights_ = np.ones(num_classes_)
+        # #set high freq category weights to low to not over power other categorie
+        # # Nature object 26
+        # # RD normal street 33
+        # # Sky 34
+        # # Buildings 35
+
+        # cat_with_high_freq = [26, 33, 34, 35]
+        # for cat_idx in cat_with_high_freq:
+        #     class_weights_[cat_idx] = 0.05
+
+    num_valid_classes = len(valid_classes)
     class_names = ['Car  0','Bicycle  1','Pedestrian  2','Truck  3','Small vehicles  4','Traffic signal  5','Traffic sign  6','Utility vehicle  7','Sidebars 8','Speed bumper 9','Curbstone 10','Solid line 11','Irrelevant signs 12','Road blocks 13','Tractor 14','Non-drivable street 15','Zebra crossing 16','Obstacles / trash 17','Poles 18','RD restricted area 19','Animals 20','Grid structure 21','Signal corpus 22','Drivable cobbleston 23','Electronic traffic 24','Slow drive area 25','Nature object 26','Parking area 27','Sidewalk 28','Ego car 29','Painted driv. instr. 30','Traffic guide obj. 31','Dashed line 32','RD normal street 33','Sky 34','Buildings 35','Blurred area 36','Rain dirt 37']
 
     ignore_index = 255
     class_map = dict(zip(valid_classes, range(num_classes_)))
 
-    class_weights_ = np.ones(num_classes_)
-    #set high freq category weights to low to not over power other categorie
-    # Nature object 26
-    # RD normal street 33
-    # Sky 34
-    # Buildings 35
-
-    cat_with_high_freq = [26, 33, 34, 35]
-    for cat_idx in cat_with_high_freq:
-        class_weights_[cat_idx] = 0.05
-
     @classmethod
     def decode_segmap(cls, temp):
         r = temp.copy()
@@ -141,7 +161,7 @@ class A2D2DataLoader(data.Dataset):
                  search_images=False, load_segmentation=True, load_depth=False, load_motion=False, load_flow=False,
                  load_segmentation_five_class=False, inference=False, additional_info=False, input_offsets=None):
         super().__init__()
-        if split not in ['train', 'val', 'test']:
+        if split not in ['train', 'val', 'test', 'test_val']:
             warnings.warn(f'unknown split specified: {split}')
         #
         self.root = root
@@ -376,18 +396,17 @@ def a2d2_segmentation(dataset_config, root, split=None, transforms=None):
     dataset_config = get_config().merge_from(dataset_config)
     gt = "gtFine"
     train_split = val_split = None
-    split = ['train', 'val']
+    if split is None:
+        split = ['train', 'val']
     for split_name in split:
         if split_name == 'train':
             train_split = A2D2DataLoader(dataset_config, root, split_name, gt, transforms=transforms[0],
                                             load_segmentation=dataset_config.load_segmentation,
                                             load_segmentation_five_class=dataset_config.load_segmentation_five_class)
-        elif split_name == 'val':
+        else:
             val_split = A2D2DataLoader(dataset_config, root, split_name, gt, transforms=transforms[1],
                                             load_segmentation=dataset_config.load_segmentation,
                                             load_segmentation_five_class=dataset_config.load_segmentation_five_class)
-        else:
-            pass
     #
     return train_split, val_split
 
@@ -420,7 +439,7 @@ def a2d2_depth(dataset_config, root, split=None, transforms=None):
 def a2d2_segmentation_infer(dataset_config, root, split=None, transforms=None):
     dataset_config = get_config().merge_from(dataset_config)
     gt = "gtFine"
-    split_name = 'val'
+    split_name = dataset_config.split #'val'
     infer_split = A2D2DataLoader(dataset_config, root, split_name, gt, transforms=transforms, image_folders=dataset_config.image_folders,
                                        load_segmentation=dataset_config.load_segmentation,
                                        load_segmentation_five_class=dataset_config.load_segmentation_five_class,
@@ -431,7 +450,7 @@ def a2d2_segmentation_infer(dataset_config, root, split=None, transforms=None):
 def a2d2_segmentation_measure(dataset_config, root, split=None, transforms=None):
     dataset_config = get_config().merge_from(dataset_config)
     gt = "gtFine"
-    split_name = 'val'
+    split_name = dataset_config.split #'val'
     infer_split = A2D2DataLoader(dataset_config, root, split_name, gt, transforms=transforms, image_folders=dataset_config.image_folders,
                                        load_segmentation=dataset_config.load_segmentation,
                                        load_segmentation_five_class=dataset_config.load_segmentation_five_class,
index db2adecc5f27ac3c36e43bb35f9334a7f122a6b5..11b6dd9d6007e9781a1f3efc342b46f9ca997bee 100644 (file)
@@ -2,12 +2,9 @@ import numpy as np
 import os
 import scipy.misc as misc
 import sys
-import cv2
-#__package__ = "pytorch_jacinto_ai.vision.datasets.pixel2pixel"
-from .... import xnn
 
 from .cityscapes_plus import CityscapesBaseSegmentationLoader, CityscapesBaseMotionLoader
-from .a2d2 import A2D2BaseSegmentationLoader, A2D2BaseMotionLoader
+
 
 def calc_median_frequency(classes, present_num):
     """
@@ -18,9 +15,10 @@ def calc_median_frequency(classes, present_num):
         c is present, and median_freq is the median of these frequencies.'
     """
     class_freq = classes / present_num
-    median_freq = np.median(class_freq[classes != 1.0])
+    median_freq = np.median(class_freq)
     return median_freq / class_freq
 
+
 def calc_log_frequency(classes, value=1.02):
     """Class balancing by ERFNet method.
        prob = each_sum_pixel / each_sum_pixel.max()
@@ -31,41 +29,25 @@ def calc_log_frequency(classes, value=1.02):
     # print(np.log(value + class_freq))
     return 1 / np.log(value + class_freq)
 
-def print_stats(classes = [], class_weight = []):
-    print("class_freq \n","-"*32)
-    for idx, class_freq in enumerate(classes):
-        print("{} : {:.0f}".format(idx, class_freq))
-    print("-"*32)
-    print("class_freq in % \n","-"*32)
-    for idx, class_freq in enumerate(classes):
-        print("{} : {:.2f}".format(idx, class_freq*100.0/np.sum(classes)))
-    print("-"*32)
-
-    print("-"*32)
-    print("class weights \n","-"*32)
-    for idx, class_wt in enumerate(class_weight):
-        print("{} : {:.08}".format(idx, class_wt))
-
-def calc_weights():    
+
+def calc_weights():
     method = "median"
-    result_path = "/data/ssd/datasets/a2d2_v2/info/"
+    result_path = "/afs/cg.cs.tu-bs.de/home/zhang/SEDPShuffleNet/datasets"
 
     traval = "gtFine"
-    #imgs_path = "/data/ssd/datasets/a2d2_v1_full/leftImg8bit/train"    #"./data/cityscapes/data/leftImg8bit/train"   #"./data/TIAD/data/leftImg8bit/train"
-    lbls_path = "/data/ssd/datasets/a2d2_v2/gtFine/train/"         #"./data/cityscapes/data/gtFine/train"   # "./data/tiad/data/gtFine/train"  #"./data/cityscapes_frame_pair/data/gtFine/train"
-    #labels = xnn.utils.recursive_glob(rootdir=lbls_path, suffix='labelTrainIds_motion.png')  #'labelTrainIds_motion.png'  #'labelTrainIds.png'
-    labels = xnn.utils.recursive_glob(rootdir=lbls_path, suffix='.png')  #'labelTrainIds_motion.png'  #'labelTrainIds.png'
+    imgs_path = "./data/tiad/data/leftImg8bit/train"    #"./data/cityscapes/data/leftImg8bit/train"   #"./data/TIAD/data/leftImg8bit/train"
+    lbls_path = "./data/tiad/data/gtFine/train"         #"./data/cityscapes/data/gtFine/train"   # "./data/tiad/data/gtFine/train"  #"./data/cityscapes_frame_pair/data/gtFine/train"
+    labels = xnn.utils.recursive_glob(rootdir=lbls_path, suffix='labelTrainIds_motion.png')  #'labelTrainIds_motion.png'  #'labelTrainIds.png'
 
-    num_classes = 38       #5  #2
+    num_classes = 2       #5  #2
 
     local_path = "./data/checkpoints"
-    dst = A2D2BaseSegmentationLoader() #TiadBaseSegmentationLoader()  #CityscapesBaseSegmentationLoader()  #CityscapesBaseMotionLoader(), #A2D2BaseSegmentationLoader()
+    dst = CityscapesBaseMotionLoader() #TiadBaseSegmentationLoader()  #CityscapesBaseSegmentationLoader()  #CityscapesBaseMotionLoader()
 
     classes, present_num = ([0 for i in range(num_classes)] for i in range(2))
 
     for idx, lbl_path in enumerate(labels):
-        print("lbl_path: ", lbl_path)
-        lbl = cv2.imread(lbl_path, 0)
+        lbl = misc.imread(lbl_path)
         lbl = dst.encode_segmap(np.array(lbl, dtype=np.uint8))
 
         for nc in range(num_classes):
@@ -74,11 +56,10 @@ def calc_weights():
                 classes[nc] += num_pixel
                 present_num[nc] += 1
 
-    classes = np.array(classes, dtype="f")
-    
-    #if any class had 0 occurnace then set to 1 to avoid div by 0 kind of error
-    classes[classes==0] = 1
+    if 0 in classes:
+        raise Exception("Some classes are not found")
 
+    classes = np.array(classes, dtype="f")
     presetn_num = np.array(classes, dtype="f")
     if method == "median":
         class_weight = calc_median_frequency(classes, present_num)
@@ -86,9 +67,8 @@ def calc_weights():
         class_weight = calc_log_frequency(classes)
     else:
         raise Exception("Please assign method to 'mean' or 'log'")
-    
-    print_stats(classes = classes, class_weight = class_weight)
 
+    print("class weight", class_weight)
     print("Done!")
 
 
index 22e7daf72423e80cf25fb508ec3ba907da29cd0a..9e8c76fcb0f3cfb2b3a3219b5dc586ab26656291 100755 (executable)
@@ -146,7 +146,7 @@ class SegmentationMetrics(torch.nn.Module):
         self.metrics_calc.clear()
         return
     def info(self):
-        return {'value':'accuracy', 'name':'MeanIoU', 'is_avg':self.is_avg}
+        return {'value':'accuracy', 'name':'MeanIoU', 'is_avg':self.is_avg, 'confusion_matrix':self.metrics_calc.confusion_matrix}
     @classmethod
     def args(cls):
         return ['num_classes']
index d9397105a18ba886c9720a670d308fd9daddf635..e7329cfcb65cd77f91099379af9b8774c97fff72 100644 (file)
@@ -34,6 +34,7 @@ except: pass
 try: from .flownetbase_internal import *
 except: pass
 
+
 @property
 def name():
     return 'pytorch_jacinto_ai.vision.models'
index 0950ce29d1727e06060713d2d4bb23b79428e62a..e9e90b9b04cb3b3660ef89e6a92560c027d3b648 100644 (file)
@@ -14,13 +14,17 @@ except: pass
 try: from .. import flownetbase_internal
 except: pass
 
+try: from .. import mobilenetv1_internal
+except: pass
+
+
 from .... import xnn
 
 __all__ = ['mobilenetv1_x1', 'mobilenetv2_tv_x1', 'mobilenetv2_x1', 'mobilenetv2_tv_x2_t2',
            'resnet50_x1', 'resnet50_xp5',
            # experimental
            'mobilenetv2_ericsun_x1', 'mobilenetv2_shicai_x1',
-           'mobilenetv2_tv_gws_x1', 'flownetslite_base_x1']
+           'mobilenetv2_tv_gws_x1', 'flownetslite_base_x1', 'mobilenetv1_multi_label_x1']
 
 
 #####################################################################
@@ -50,6 +54,13 @@ def mobilenetv1_x1(model_config, pretrained=None):
         model = xnn.utils.load_weights(model, pretrained)
     return model
 
+def mobilenetv1_multi_label_x1(model_config, pretrained=None):
+    model_config = mobilenetv1.get_config().merge_from(model_config)
+    model = mobilenetv1_internal.MobileNetV1MultiLabel(model_config=model_config)
+    if pretrained:
+        model = xnn.utils.load_weights(model, pretrained)
+    return model
+
 
 #####################################################################
 def mobilenetv2_tv_x1(model_config, pretrained=None):
index c0701a1837b1f17eb66433d11334cb9098c5e436..0f13921bddcb43226dd9b1f111ba1c2c18434db2 100644 (file)
@@ -20,6 +20,7 @@ def get_config():
     model_config.dropout = False
     model_config.linear_dw = False
     model_config.layer_setting = None
+    model_config.classifier_type = torch.nn.Linear
     return model_config
 
 model_urls = {
@@ -84,7 +85,7 @@ class MobileNetV1Base(torch.nn.Module):
         if self.model_config.num_classes != None:
             self.classifier = torch.nn.Sequential(
                 torch.nn.Dropout(0.2) if self.model_config.dropout else xnn.layers.BypassBlock(),
-                torch.nn.Linear(channels, self.num_classes),
+                model_config.classifier_type(channels, self.num_classes),
             )
         #
 
@@ -100,8 +101,11 @@ class MobileNetV1Base(torch.nn.Module):
         if self.num_classes is not None:
             xnn.utils.print_once('=> feature size is: ', x.size())
             x = torch.nn.functional.adaptive_avg_pool2d(x, (1,1))
+            #xnn.utils.print_once('=> size after pool2d: ', x.size())
             x = torch.flatten(x, 1)
+            #xnn.utils.print_once('=> size after flatten: ', x.size())
             x = self.classifier(x)
+            #xnn.utils.print_once('=> size after classifier: ', x.size())
         #
         return x
 
@@ -128,3 +132,4 @@ def mobilenet_v1(pretrained=False, progress=True, **kwargs):
         state_dict = load_state_dict_from_url(model_urls['mobilenet_v1'], progress=progress)
         model.load_state_dict(state_dict)
     return model
+
index 3bbf54f6754420ebe7d498dd8bcad1d20c0df70c..469f458ca29635a86381f29c73feb60833dc5264 100644 (file)
@@ -2,7 +2,7 @@ from __future__ import division
 import torch
 import sys
 import math
-from PIL import Image, ImageOps, ImageEnhance, PILLOW_VERSION
+from PIL import Image, ImageOps, ImageEnhance, __version__ as PILLOW_VERSION
 try:
     import accimage
 except ImportError:
index d327afed6f92ed9f55c67593001daac7c020b8ee..2d49146f048c933726ecce45049452cb0df41f2d 100644 (file)
@@ -3,6 +3,7 @@ from .util_functions import *
 from .utils_data import *
 from .load_weights import *
 from .tensor_utils import *
+#from .tensor_utils_internal import *
 from .logger import *
 from .utils_hist import *
 from .attr_dict import *
@@ -11,3 +12,5 @@ from .image_utils import *
 from .module_utils import *
 from .count_flops import forward_count_flops
 from .bn_utils import *
+try: from .tensor_utils_internal import *
+except: pass
diff --git a/scripts/infer_classification_onnx_rt_main.py b/scripts/infer_classification_onnx_rt_main.py
new file mode 100755 (executable)
index 0000000..0e77238
--- /dev/null
@@ -0,0 +1,121 @@
+#!/usr/bin/env python
+
+import sys
+import cv2
+import os
+import datetime
+import argparse
+
+
+################################
+#sys.path.insert(0, os.path.abspath('./modules'))
+
+
+################################
+from pytorch_jacinto_ai.xnn.utils import str2bool
+parser = argparse.ArgumentParser()
+parser.add_argument('--save_path', type=str, default=None, help='checkpoint save folder')
+parser.add_argument('--gpus', type=int, nargs='*', default=None, help='Base learning rate')
+parser.add_argument('--batch_size', type=int, default=None, help='Batch size')
+# parser.add_argument('--lr', type=float, default=None, help='Base learning rate')
+# parser.add_argument('--lr_clips', type=float, default=None, help='Learning rate for clips in PAct2')
+parser.add_argument('--lr_calib', type=float, default=None, help='Learning rate for calibration')
+parser.add_argument('--model_name', type=str, default=None, help='model name')
+parser.add_argument('--dataset_name', type=str, default=None, help='dataset name')
+parser.add_argument('--data_path', type=str, default=None, help='data path')
+parser.add_argument('--epoch_size', type=float, default=None, help='epoch size. using a fraction will reduce the data used for one epoch')
+# parser.add_argument('--epochs', type=int, default=None, help='number of epochs')
+# parser.add_argument('--milestones', type=int, nargs=2, default=None, help='change lr at these milestones')
+parser.add_argument('--img_resize', type=int, default=None, help='images will be first resized to this size during training and validation')
+# parser.add_argument('--rand_scale', type=float, nargs=2, default=None, help='during training (only) fraction of the image to crop (this will then be resized to img_crop)')
+parser.add_argument('--img_crop', type=int, default=None, help='the cropped portion (validation), cropped pertion will be resized to this size (training)')
+parser.add_argument('--quantize', type=str2bool, default=None, help='Quantize the model')
+#parser.add_argument('--model_surgery', type=str, default=None, choices=[None, 'pact2'], help='whether to transform the model after defining')
+parser.add_argument('--pretrained', type=str, default=None, help='pretrained model')
+# parser.add_argument('--resume', type=str, default=None, help='resume an unfinished training from this model')
+parser.add_argument('--bitwidth_weights', type=int, default=None, help='bitwidth for weight quantization')
+parser.add_argument('--bitwidth_activations', type=int, default=None, help='bitwidth for activation quantization')
+cmds = parser.parse_args()
+
+################################
+# taken care first, since this has to be done before importing pytorch
+if 'gpus' in vars(cmds):
+    value = getattr(cmds, 'gpus')
+    if value is not None:
+        os.environ["CUDA_VISIBLE_DEVICES"] = ','.join([str(v) for v in value])
+#
+
+# to avoid hangs in data loader with multi threads
+# this was observed after using cv2 image processing functions
+# https://github.com/pytorch/pytorch/issues/1355
+cv2.setNumThreads(0)
+
+################################
+from pytorch_jacinto_ai.engine import infer_classification_onnx_rt
+
+#Create the parse and set default arguments
+args = infer_classification_onnx_rt.get_config()
+
+################################
+date = datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
+
+################################
+#Set arguments
+# args.model_name = 'mobilenetv2_tv_x1' #'resnet50_x1', 'mobilenetv2_tv_x1', 'mobilenetv2_ericsun_x1', 'mobilenetv2_shicai_x1'
+
+args.dataset_name = 'image_folder_classification_validation' # 'image_folder_classification', 'imagenet_classification', 'cifar10_classification', 'cifar100_classification'
+
+#args.save_path = './data/checkpoints'
+
+args.data_path = f'./data/datasets/{args.dataset_name}'
+args.gpu_mode = False
+args.pretrained = '/data/adas_vision_data1/users/manu/expt/modelzoo/image_classification/imagenet/mobilenetv1/imagenet_mobilenetv1_2019-09-06_17-15-44.onnx'
+                #'https://download.pytorch.org/models/mobilenet_v2-b0353104.pth'
+                #'./data/modelzoo/pretrained/pytorch/imagenet_classification/torchvision/resnet50-19c8e357.pth'
+                #'./data/modelzoo/pretrained/pytorch/imagenet_classification/ericsun99/MobileNet-V2-Pytorch/mobilenetv2_Top1_71.806_Top2_90.410.pth.tar'
+                #'./data/modelzoo/pretrained/pytorch/imagenet_classification/shicai/MobileNet-Caffe/mobilenetv2_shicai_rgb.tar'
+
+args.model_config.input_channels = 3
+args.model_config.output_type = 'classification'
+args.model_config.output_channels = None
+
+args.batch_size = 1 #1 #256                   #16 #32 #64
+args.workers    = 8 #1
+#args.shuffle = True
+#args.epoch_size = 0
+args.count_flops = True
+
+# args.quantize = True
+# args.write_layer_ip_op = True
+
+
+# args.histogram_range = True
+# args.bias_calibration = True
+# args.per_channel_q = False
+
+args.phase = 'validation'
+args.print_freq = 10 #100
+
+################################
+for key in vars(cmds):
+    if key == 'gpus':
+        pass # already taken care above, since this has to be done before importing pytorch
+    elif hasattr(args, key):
+        value = getattr(cmds, key)
+        if value != 'None' and value is not None:
+            setattr(args, key, value)
+    else:
+        assert False, f'invalid argument {key}'
+#
+
+################################
+# these dependent on the dataset chosen
+args.img_resize = (args.img_resize if args.img_resize else 256)
+args.img_crop = (args.img_crop if args.img_crop else 224)
+args.model_config.num_classes = (100 if 'cifar100' in args.dataset_name else (10  if 'cifar10' in args.dataset_name else 1000))
+args.model_config.strides = (1,1,1,2,2) if args.img_crop<64 else ((1,1,2,2,2) if args.img_crop<128 else (2,2,2,2,2))
+
+
+################################
+#Run the training
+infer_classification_onnx_rt.main(args)
\ No newline at end of file
index b4a368a14ada57805338711328c4e898f1a271b6..71fc384d2256b1d7d0972094b765d9d7be8ac076 100755 (executable)
@@ -56,13 +56,14 @@ args = infer_pixel2pixel.get_config()
 args.model_name = "deeplabv3lite_mobilenetv2_tv" #"deeplabv3lite_mobilenetv2_relu" #"deeplabv3lite_mobilenetv2_relu_x1p5" #"deeplabv3plus"
 
 args.dataset_name = 'a2d2_segmentation_measure' #'tiad_segmentation_infer'   #'cityscapes_segmentation_infer' #'tiad_segmentation'  #'cityscapes_segmentation_measure'
+args.dataset_config.split = 'val'
 
 #args.save_path = './data/checkpoints'
 args.data_path = '/data/ssd/datasets/a2d2_v2/' #'./data/datasets/cityscapes/data'   #'/data/hdd/datasets/cityscapes_leftImg8bit_sequence_trainvaltest/' #'./data/datasets/cityscapes/data'  #'./data/tiad/data/demoVideo/sequence0021'  #'./data/tiad/data/demoVideo/sequence0025'   #'./data/tiad/data/demoVideo/sequence0001_2017'
 #args.pretrained = './data/modelzoo/semantic_segmentation/cityscapes/deeplabv3lite-mobilenetv2/cityscapes_segmentation_deeplabv3lite-mobilenetv2_2019-06-26-08-59-32.pth'
 #args.pretrained = './data/checkpoints/tiad_segmentation/2019-10-18_00-50-03_tiad_segmentation_deeplabv3lite_mobilenetv2_ericsun_resize768x384_traincrop768x384_float/checkpoint.pth.tar'
 
-args.pretrained = '/data/files/work/bitbucket_TI/pytorch-jacinto-models/data/checkpoints/a2d2_segmentation/2020-01-25_13-06-18_a2d2_segmentation_deeplabv3lite_mobilenetv2_tv_resize768x384_traincrop768x384/training/model_best_ep172.pth.tar'
+args.pretrained = '/data/files/work/bitbucket_TI/pytorch-jacinto-models/data/checkpoints/a2d2_segmentation/2020-01-25_13-06-18_a2d2_segmentation_deeplabv3lite_mobilenetv2_tv_resize768x384_traincrop768x384_v2_val41.83_train56.02/training/model_best.pth.tar'
 
 args.model_config.input_channels = (3,)
 args.model_config.output_type = ['segmentation']
@@ -73,10 +74,10 @@ args.metrics = [['segmentation_metrics']]
 args.frame_IOU =  False # Print mIOU for each frame
 args.shuffle = False
 
-args.num_images = 30000   # Max number of images to run inference on
-#args.blend = [False]
-#'color', 'blend'
-args.viz_op_type = ['color']
+args.num_images = 50000   # Max number of images to run inference on
+
+#['color'], ['blend'], ['']
+args.viz_op_type = ['blend']
 args.visualize_gt = False
 args.car_mask = False  # False   #True
 args.label = [True]    # False   #True
@@ -90,14 +91,12 @@ args.depth = [False]
 args.epoch_size = 0                     #0 #0.5
 args.iter_size = 1                      #2
 
-args.batch_size = 1 #80                  #12 #16 #32 #64
+args.batch_size = 32 #80                  #12 #16 #32 #64
 args.img_resize = (384, 768)         #(256,512) #(512,512) # #(1024, 2048) #(512,1024)  #(720, 1280)
 
 args.output_size = (1208, 1920)          #(1024, 2048)
 #args.rand_scale = (1.0, 2.0)            #(1.0,2.0) #(1.0,1.5) #(1.0,1.25)
 
-args.quantize = True
-
 args.depth = [False]
 args.quantize = False
 args.histogram_range = True