improved speed in training pixel2pixel models, added unet, other fixes
authorManu Mathew <a0393608@ti.com>
Tue, 4 Feb 2020 03:29:14 +0000 (08:59 +0530)
committerManu Mathew <a0393608@ti.com>
Tue, 4 Feb 2020 04:42:06 +0000 (10:12 +0530)
21 files changed:
docs/Quantization.md
modules/pytorch_jacinto_ai/engine/engine_utils.py [new file with mode: 0644]
modules/pytorch_jacinto_ai/engine/infer_pixel2pixel.py
modules/pytorch_jacinto_ai/engine/train_pixel2pixel.py
modules/pytorch_jacinto_ai/vision/datasets/pixel2pixel/cityscapes_plus.py
modules/pytorch_jacinto_ai/vision/models/pixel2pixel/__init__.py
modules/pytorch_jacinto_ai/vision/models/pixel2pixel/deeplabv3lite.py
modules/pytorch_jacinto_ai/vision/models/pixel2pixel/fpn_pixel2pixel.py
modules/pytorch_jacinto_ai/vision/models/pixel2pixel/unet_pixel2pixel.py [new file with mode: 0644]
modules/pytorch_jacinto_ai/xnn/layers/__init__.py
modules/pytorch_jacinto_ai/xnn/layers/upsample_blocks.py [new file with mode: 0644]
run_depth.sh
run_segmentation.sh
scripts/infer_segmentation_main.py
scripts/infer_segmentation_onnx_main.py
scripts/test_classification_main.py
scripts/train_classification_main.py
scripts/train_depth_main.py
scripts/train_motion_segmentation_main.py
scripts/train_pixel2pixel_multitask_main.py
scripts/train_segmentation_main.py

index fe927afe810858dea113f187edfa16b795e6a56b..e2cb297c588f6633ab22e02af9c20edf610d9d0f 100644 (file)
@@ -99,11 +99,11 @@ The table below shows the Quantized Accuracy with various Calibration and method
 
 ###### Dataset: ImageNet Classification (Image Classification)
 
-|Mode Name          |Backbone   |Stride|Resolution|Float Acc%|Simple Calib Acc%|Adv Calib Acc%|Adv DW Calib Acc%|QAT Acc% |Acc Drop-Adv DW Calib|Acc Drop - QAT|
-|----------         |-----------|------|----------|--------- |---              |---           |---              |---      |---                   |---          |
-|ResNet50(TV)       |ResNet50   |32    |224x224   |**76.15** |75.56            |**75.56**     |75.56            |**76.05**|-0.59                 |-0.10        |
-|MobileNetV2(TV)    |MobileNetV2|32    |224x224   |**71.89** |67.77            |**68.39**     |69.34            |**70.74**|-2.55                 |-1.34        |
-|MobileNetV2(Shicai)|MobileNetV2|32    |224x224   |**71.44** |0.0              |**68.81**     |70.65            |**70.54**|-0.79                 |-0.9         |
+|Mode Name          |Backbone   |Stride|Resolution|Float Acc%|Simple Calib Acc%|Adv Calib Acc%|Adv DW Calib Acc%|QAT Acc% |Acc Drop-Adv Calib|Acc Drop - QAT|
+|----------         |-----------|------|----------|--------- |---              |---           |---              |---      |---               |---          |
+|ResNet50(TV)       |ResNet50   |32    |224x224   |**76.15** |75.56            |**75.56**     |75.56            |**76.05**|-0.59             |-0.10        |
+|MobileNetV2(TV)    |MobileNetV2|32    |224x224   |**71.89** |67.77            |**68.39**     |69.34            |**70.74**|-3.50             |-1.34        |
+|MobileNetV2(Shicai)|MobileNetV2|32    |224x224   |**71.44** |0.0              |**68.81**     |70.65            |**70.54**|-2.63             |-0.9         |
 
 Notes:
 - For Image Classification, the accuracy measure used is % Top-1 Classification Accuracy. 'Top-1 Classification Accuracy' is abbreviated by Acc in the above table.<br>
@@ -112,9 +112,9 @@ Notes:
 
 ###### Dataset: Cityscapes Segmentation (Semantic Segmentation)
 
-|Mode Name    |Backbone   |Stride|Resolution|Float Acc%|Simple Calib Acc%|Adv Calib Acc%|Adv DW Calib Acc%|QAT Acc% |Acc Drop-Advanced DW Calib|Acc Drop - QAT|
-|----------   |-----------|------|----------|----------|---              |---           |---              |---      |---                       |---           |
-|DeepLabV3Lite|MobileNetV2|16    |768x384   |**69.13** |61.71            |**67.95**     |68.47            |**68.44**|-0.66                     |-0.69         |
+|Mode Name    |Backbone   |Stride|Resolution|Float Acc%|Simple Calib Acc%|Adv Calib Acc%|Adv DW Calib Acc%|QAT Acc% |Acc Drop-Advanced Calib|Acc Drop - QAT|
+|----------   |-----------|------|----------|----------|---              |---           |---              |---      |---                    |---           |
+|DeepLabV3Lite|MobileNetV2|16    |768x384   |**69.13** |61.71            |**67.95**     |68.47            |**68.44**|-1.18                  |-0.69         |
 
 Note: For Semantic Segmentation, the accuracy measure used in MeanIoU Accuracy. 'MeanIoU Accuracy' is abbreviated by Acc in the above table.
 
diff --git a/modules/pytorch_jacinto_ai/engine/engine_utils.py b/modules/pytorch_jacinto_ai/engine/engine_utils.py
new file mode 100644 (file)
index 0000000..e9b727d
--- /dev/null
@@ -0,0 +1,129 @@
+import os
+import numpy as np
+import torch
+from .. import xnn
+
+
+#################################################
+def shape_as_string(shape=[]):
+    shape_str = ''
+    for dim in shape:
+        shape_str += '_' + str(dim)
+    return shape_str
+
+
+def write_tensor_int(m=[], tensor=[], suffix='op', bitwidth=8, power2_scaling=True, file_format='bin',
+                     rnd_type='rnd_sym'):
+    mn = tensor.min()
+    mx = tensor.max()
+
+    print(
+        '{:6}, {:32}, {:10}, {:7.2f}, {:7.2f}'.format(suffix, m.name, m.__class__.__name__, tensor.min(), tensor.max()),
+        end=" ")
+
+    [tensor_scale, clamp_limits] = xnn.quantize.compute_tensor_scale(tensor, mn, mx, bitwidth, power2_scaling)
+    print("{:30} : {:15} : {:8.2f}".format(str(tensor.shape), str(tensor.dtype), tensor_scale), end=" ")
+
+    print_weight_bias = False
+    if rnd_type == 'rnd_sym':
+        # use best rounding for offline quantities
+        if suffix == 'weight' and print_weight_bias:
+            no_idx = 0
+            torch.set_printoptions(precision=32)
+            print("tensor_scale: ", tensor_scale)
+            print(tensor[no_idx])
+        if tensor.dtype != torch.int64:
+            tensor = xnn.quantize.symmetric_round_tensor(tensor * tensor_scale)
+        if suffix == 'weight' and print_weight_bias:
+            print(tensor[no_idx])
+    else:
+        # for activation use HW friendly rounding
+        if tensor.dtype != torch.int64:
+            tensor = xnn.quantize.upward_round_tensor(tensor * tensor_scale)
+    tensor = tensor.clamp(clamp_limits[0], clamp_limits[1]).float()
+
+    if bitwidth == 8:
+        data_type = np.int8
+    elif bitwidth == 16:
+        data_type = np.int16
+    elif bitwidth == 32:
+        data_type = np.int32
+    else:
+        exit("Bit width other 8,16,32 not supported for writing layer level op")
+
+    tensor = tensor.cpu().numpy().astype(data_type)
+
+    print("{:7} : {:7d} : {:7d}".format(str(tensor.dtype), tensor.min(), tensor.max()))
+
+    root = os.getcwd()
+    tensor_dir = root + '/checkpoints/debug/test_vecs/' + '{}_{}_{}_scale_{:010.4f}'.format(m.name,
+                                                                                            m.__class__.__name__,
+                                                                                            suffix, tensor_scale)
+
+    if not os.path.exists(tensor_dir):
+        os.makedirs(tensor_dir)
+
+    if file_format == 'bin':
+        tensor_name = tensor_dir + "/{}_shape{}.bin".format(m.name, shape_as_string(shape=tensor.shape))
+        tensor.tofile(tensor_name)
+    elif file_format == 'npy':
+        tensor_name = tensor_dir + "/{}_shape{}.npy".format(m.name, shape_as_string(shape=tensor.shape))
+        np.save(tensor_name, tensor)
+
+    # utils_hist.comp_hist_tensor3d(x=tensor, name=m.name, en=True, dir=m.name, log=True, ch_dim=0)
+
+
+def write_tensor_float(m=[], tensor=[], suffix='op'):
+    mn = tensor.min()
+    mx = tensor.max()
+
+    print(
+        '{:6}, {:32}, {:10}, {:7.2f}, {:7.2f}'.format(suffix, m.name, m.__class__.__name__, tensor.min(), tensor.max()))
+    root = os.getcwd()
+    tensor_dir = root + '/checkpoints/debug/test_vecs/' + '{}_{}_{}'.format(m.name, m.__class__.__name__, suffix)
+
+    if not os.path.exists(tensor_dir):
+        os.makedirs(tensor_dir)
+
+    tensor_name = tensor_dir + "/{}_shape{}.npy".format(m.name, shape_as_string(shape=tensor.shape))
+    np.save(tensor_name, tensor.data)
+
+
+def write_tensor(data_type='int', m=[], tensor=[], suffix='op', bitwidth=8, power2_scaling=True, file_format='bin',
+                 rnd_type='rnd_sym'):
+    if data_type == 'int':
+        write_tensor_int(m=m, tensor=tensor, suffix=suffix, rnd_type=rnd_type, file_format=file_format)
+    elif data_type == 'float':
+        write_tensor_float(m=m, tensor=tensor, suffix=suffix)
+
+
+enable_hook_function = True
+
+
+def write_tensor_hook_function(m, inp, out, file_format='none'):
+    if not enable_hook_function:
+        return
+
+    # Output
+    if isinstance(out, (torch.Tensor)):
+        write_tensor(m=m, tensor=out, suffix='op', rnd_type='rnd_up', file_format=file_format)
+
+    # Input(s)
+    if type(inp) is tuple:
+        # if there are more than 1 inputs
+        for index, sub_ip in enumerate(inp[0]):
+            if isinstance(sub_ip, (torch.Tensor)):
+                write_tensor(m=m, tensor=sub_ip, suffix='ip_{}'.format(index), rnd_type='rnd_up',
+                             file_format=file_format)
+    elif isinstance(inp, (torch.Tensor)):
+        write_tensor(m=m, tensor=inp, suffix='ip', rnd_type='rnd_up', file_format=file_format)
+
+    # weights
+    if hasattr(m, 'weight'):
+        if isinstance(m.weight, torch.Tensor):
+            write_tensor(m=m, tensor=m.weight, suffix='weight', rnd_type='rnd_sym', file_format=file_format)
+
+    # bias
+    if hasattr(m, 'bias'):
+        if m.bias is not None:
+            write_tensor(m=m, tensor=m.bias, suffix='bias', rnd_type='rnd_sym', file_format=file_format)
index 81e1ec5aff9e5533d3efd50d55b44e08ab1edfee..27213f2662b7e2cde3063f8251d0367dc3b44c97 100644 (file)
@@ -18,14 +18,13 @@ import matplotlib.pyplot as plt
 
 from .. import xnn
 from .. import vision
-
-#sys.path.insert(0, '../devkit-datasets/TI/')
-#from fisheye_calib import r_fish_to_theta_rect
+from .engine_utils import *
 
 # ################################################
 def get_config():
     args = xnn.utils.ConfigNode()
 
+    args.dataset = None
     args.dataset_config = xnn.utils.ConfigNode()
     args.dataset_config.split_name = 'val'
     args.dataset_config.max_depth_bfr_scaling = 80
@@ -33,9 +32,11 @@ def get_config():
     args.dataset_config.train_depth_log = 1
     args.use_semseg_for_depth = False
 
+    args.model = None
     args.model_config = xnn.utils.ConfigNode()
     args.model_name = 'deeplabv2lite_mobilenetv2'       # model architecture, overwritten if pretrained is specified
     args.dataset_name = 'flying_chairs'              # dataset type
+    args.transforms = None
 
     args.data_path = './data/datasets'                       # path to dataset
     args.save_path = None            # checkpoints save path
@@ -89,7 +90,7 @@ def get_config():
 
     args.count_flops = True                     # count flops and report
 
-    args.shuffle = True                         # shuffle or not
+    args.shuffle = False                         # shuffle or not
     args.is_flow = None                         # whether entries in images and targets lists are optical flow or not
 
     args.multi_decoder = True                   # whether to use multiple decoders or unified decoder
@@ -125,6 +126,11 @@ def get_config():
     args.visualize_gt = False                   #to vis pred or GT
     args.viz_depth_color_type = 'plasma'       #color type for dpeth visualization
     args.depth = [False]
+
+    args.palette = None
+    args.label_infer = False
+    args.viz_op_type = None
+    args.car_mask = None
     return args
 
 
@@ -138,119 +144,6 @@ cv2.setNumThreads(0)
 np.set_printoptions(precision=3)
 
 
-#################################################
-def shape_as_string(shape=[]):
-    shape_str = ''
-    for dim in shape:
-        shape_str += '_' + str(dim)
-    return shape_str
-
-def write_tensor_int(m = [], tensor = [], suffix='op', bitwidth = 8, power2_scaling = True, file_format='bin', rnd_type = 'rnd_sym'):
-    mn = tensor.min()
-    mx = tensor.max()
-
-    print('{:6}, {:32}, {:10}, {:7.2f}, {:7.2f}'.format(suffix, m.name, m.__class__.__name__, tensor.min(), tensor.max()),  end =" ")
-
-    [tensor_scale, clamp_limits] = xnn.quantize.compute_tensor_scale(tensor, mn, mx, bitwidth, power2_scaling)
-    print("{:30} : {:15} : {:8.2f}".format(str(tensor.shape), str(tensor.dtype), tensor_scale), end =" ")
-    
-    print_weight_bias = False
-    if rnd_type == 'rnd_sym':
-      #use best rounding for offline quantities
-      if suffix == 'weight' and print_weight_bias:
-          no_idx = 0
-          torch.set_printoptions(precision=32)
-          print("tensor_scale: ", tensor_scale)
-          print(tensor[no_idx])
-      if tensor.dtype != torch.int64:
-          tensor = xnn.quantize.symmetric_round_tensor(tensor * tensor_scale)
-      if suffix == 'weight'  and print_weight_bias:
-          print(tensor[no_idx])
-    else:  
-      #for activation use HW friendly rounding  
-      if tensor.dtype != torch.int64:
-          tensor = xnn.quantize.upward_round_tensor(tensor * tensor_scale)
-    tensor = tensor.clamp(clamp_limits[0], clamp_limits[1]).float()
-
-    if bitwidth == 8:
-      data_type = np.int8
-    elif bitwidth == 16:
-      data_type = np.int16
-    elif bitwidth == 32:
-      data_type = np.int32
-    else:
-       exit("Bit width other 8,16,32 not supported for writing layer level op")
-
-    tensor = tensor.cpu().numpy().astype(data_type)
-
-    print("{:7} : {:7d} : {:7d}".format(str(tensor.dtype), tensor.min(), tensor.max()))
-
-    root = os.getcwd()
-    tensor_dir = root + '/checkpoints/debug/test_vecs/' + '{}_{}_{}_scale_{:010.4f}'.format(m.name, m.__class__.__name__, suffix,  tensor_scale)
-
-    if not os.path.exists(tensor_dir):
-        os.makedirs(tensor_dir)
-
-    if file_format == 'bin':
-        tensor_name = tensor_dir + "/{}_shape{}.bin".format(m.name, shape_as_string(shape=tensor.shape))
-        tensor.tofile(tensor_name)
-    elif file_format == 'npy':
-        tensor_name = tensor_dir + "/{}_shape{}.npy".format(m.name, shape_as_string(shape=tensor.shape))
-        np.save(tensor_name, tensor)
-
-    #utils_hist.comp_hist_tensor3d(x=tensor, name=m.name, en=True, dir=m.name, log=True, ch_dim=0)
-
-
-def write_tensor_float(m = [], tensor = [], suffix='op'):
-    
-    mn = tensor.min()
-    mx = tensor.max()
-
-    print('{:6}, {:32}, {:10}, {:7.2f}, {:7.2f}'.format(suffix, m.name, m.__class__.__name__, tensor.min(), tensor.max()))
-    root = os.getcwd()
-    tensor_dir = root + '/checkpoints/debug/test_vecs/' + '{}_{}_{}'.format(m.name, m.__class__.__name__, suffix)
-
-    if not os.path.exists(tensor_dir):
-        os.makedirs(tensor_dir)
-
-    tensor_name = tensor_dir + "/{}_shape{}.npy".format(m.name, shape_as_string(shape=tensor.shape))
-    np.save(tensor_name, tensor.data)
-
-def write_tensor(data_type = 'int', m = [], tensor = [], suffix='op', bitwidth = 8, power2_scaling = True, file_format='bin', 
-    rnd_type = 'rnd_sym'):
-    
-    if data_type == 'int':
-        write_tensor_int(m=m, tensor=tensor, suffix=suffix, rnd_type =rnd_type, file_format=file_format)
-    elif  data_type == 'float':
-        write_tensor_float(m=m, tensor=tensor, suffix=suffix)
-
-enable_hook_function = True
-def write_tensor_hook_function(m, inp, out, file_format='none'):
-    if not enable_hook_function:
-        return
-
-    #Output
-    if isinstance(out, (torch.Tensor)):
-        write_tensor(m=m, tensor=out, suffix='op', rnd_type ='rnd_up', file_format=file_format)
-
-    #Input(s)
-    if type(inp) is tuple:
-        #if there are more than 1 inputs
-        for index, sub_ip in enumerate(inp[0]):
-            if isinstance(sub_ip, (torch.Tensor)):
-                write_tensor(m=m, tensor=sub_ip, suffix='ip_{}'.format(index), rnd_type ='rnd_up', file_format=file_format)
-    elif isinstance(inp, (torch.Tensor)):
-         write_tensor(m=m, tensor=inp, suffix='ip', rnd_type ='rnd_up', file_format=file_format)
-
-    #weights
-    if hasattr(m, 'weight'):
-        if isinstance(m.weight,torch.Tensor):
-            write_tensor(m=m, tensor=m.weight, suffix='weight', rnd_type ='rnd_sym', file_format=file_format)
-
-    #bias
-    if hasattr(m, 'bias'):
-        if m.bias is not None:
-            write_tensor(m=m, tensor=m.bias, suffix='bias', rnd_type ='rnd_sym', file_format=file_format)
 
 # ################################################
 def main(args):
@@ -321,12 +214,19 @@ def main(args):
         print("cmd:", cmd)    
         os.system(cmd)
 
-    transforms = get_transforms(args)
+    transforms = get_transforms(args) if args.transforms is None else args.transforms
 
     print("=> fetching img pairs in '{}'".format(args.data_path))
     split_arg = args.split_file if args.split_file else (args.split_files if args.split_files else args.split_value)
 
-    val_dataset = vision.datasets.pixel2pixel.__dict__[args.dataset_name](args.dataset_config, args.data_path, split=split_arg, transforms=transforms)
+    if args.dataset is not None:
+        dataset = args.dataset
+    else:
+        dataset = vision.datasets.pixel2pixel.__dict__[args.dataset_name](args.dataset_config, args.data_path, split=split_arg, transforms=transforms)
+    #
+
+    # if a pair is given, take the second one
+    val_dataset = (dataset[1] if (isinstance(dataset, (list, tuple)) and len(dataset) == 2) else dataset)
 
     print('=> {} val samples found'.format(len(val_dataset)))
     val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=args.batch_size,
@@ -351,23 +251,29 @@ def main(args):
     pretrained_data = None
     model_surgery_quantize = False
     if args.pretrained and args.pretrained != "None":
-        if args.pretrained.startswith('http://') or args.pretrained.startswith('https://'):
-            pretrained_file = vision.datasets.utils.download_url(args.pretrained, './data/downloads')
+        if isinstance(args.pretrained, dict):
+            pretrained_data = args.pretrained
         else:
-            pretrained_file = args.pretrained
+            if args.pretrained.startswith('http://') or args.pretrained.startswith('https://'):
+                pretrained_file = vision.datasets.utils.download_url(args.pretrained, './data/downloads')
+            else:
+                pretrained_file = args.pretrained
+            #
+            print(f'=> using pre-trained weights from: {args.pretrained}')
+            pretrained_data = torch.load(pretrained_file)
         #
-        print(f'=> using pre-trained weights from: {args.pretrained}')
-        pretrained_data = torch.load(pretrained_file)
         model_surgery_quantize = pretrained_data['quantize'] if 'quantize' in pretrained_data else False
     #
 
     #################################################
-    # the portion before comma is used as the model name
-    # string after comma (if present is used as decoder names) in the decoder ModuleDict()
-    model = vision.models.pixel2pixel.__dict__[args.model_name](args.model_config)
-
-    # check if we got the model as well as parameters to change the names in pretrained
-    model, change_names_dict = model if isinstance(model, (list,tuple)) else (model,None)
+    if args.model is not None:
+        model, change_names_dict = args.model if isinstance(args.model, (list, tuple)) else (args.model, None)
+        assert isinstance(model, torch.nn.Module), 'args.model, if provided must be a valid torch.nn.Module'
+    else:
+        model = vision.models.pixel2pixel.__dict__[args.model_name](args.model_config)
+        # check if we got the model as well as parameters to change the names in pretrained
+        model, change_names_dict = model if isinstance(model, (list,tuple)) else (model,None)
+    #
 
     #################################################
     if args.quantize:
@@ -485,7 +391,10 @@ def validate(args, val_dataset, val_loader, model, epoch, infer_path):
         data_time.update(time.time() - end_time)
         if args.gpu_mode:
             input_list = [img.cuda() for img in input_list]
+
         outputs = model(input_list)
+        outputs = outputs if isinstance(outputs,(list,tuple)) else [outputs]
+
         if args.output_size is not None and target_list:
            target_sizes = [tgt.shape for tgt in target_list]
            outputs = upsample_tensors(outputs, target_sizes, args.upsample_mode)
@@ -509,7 +418,7 @@ def validate(args, val_dataset, val_loader, model, epoch, infer_path):
                 if args.frame_IOU:
                     confusion_matrix[task_index] = np.zeros((args.model_config.output_channels[task_index], args.model_config.output_channels[task_index] + 1))
                 prediction = np.array(output[index])
-                if output.shape[1]>1:
+                if len(prediction.shape)>2 and prediction.shape[0]>1:
                     prediction = np.argmax(prediction, axis=0)
                 #
                 prediction = np.squeeze(prediction)
@@ -564,7 +473,7 @@ def validate(args, val_dataset, val_loader, model, epoch, infer_path):
                     viz_depth(prediction = prediction, args=args, output_name = output_name, input_name=input_path[-1][task_index])
                     print('{}/{}'.format((args.batch_size * iter + index), len(val_dataset)))
 
-                if args.viz_op_type[task_index] == 'blend':
+                if args.viz_op_type is not None and args.viz_op_type[task_index] == 'blend':
                     prediction_size = (prediction.shape[0], prediction.shape[1], 3)
                     output_image = args.palette[task_index-1][prediction.ravel()].reshape(prediction_size)
                     input_bgr = cv2.imread(input_path[-1][index]) #Read the actual RGB image
@@ -573,7 +482,7 @@ def validate(args, val_dataset, val_loader, model, epoch, infer_path):
                     output_name = os.path.join(infer_path[task_index], input_path[-1][index].split('/')[-4] + '_' + input_path[-1][index].split('/')[-3] + '_' +os.path.basename(input_path[-1][index]))
                     cv2.imwrite(output_name, output_image)
                     print('{}/{}'.format((args.batch_size*iter+index), len(val_dataset)))
-                elif args.viz_op_type[task_index] == 'color':
+                elif args.viz_op_type is not None and args.viz_op_type[task_index] == 'color':
                     prediction_size = (prediction.shape[0], prediction.shape[1], 3)
                     output_image = args.palette[task_index-1][prediction.ravel()].reshape(prediction_size)
                     output_name = os.path.join(infer_path[task_index], input_path[-1][index].split('/')[-4] + '_' + input_path[-1][index].split('/')[-3] + '_' +os.path.basename(input_path[-1][index]))
index 25d2f42827ef5dd1b9c27503e06e332eb6c0dc7b..1732e993cdab120e8f54d66e4ba3446b664afa07 100644 (file)
@@ -127,6 +127,7 @@ def get_config():
     args.count_flops = True                             # count flops and report
 
     args.shuffle = True                                 # shuffle or not
+    args.shuffle_val = False                            # shuffle val dataset or not
 
     args.transform_rotation = 0.                        # apply rotation augumentation. value is rotation in degrees. 0 indicates no rotation
     args.is_flow = None                                 # whether entries in images and targets lists are optical flow or not
@@ -276,7 +277,7 @@ def main(args):
         num_workers=args.workers, pin_memory=args.use_pinned_memory, sampler=train_sampler, shuffle=args.shuffle)
 
     val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=args.batch_size,
-        num_workers=args.workers, pin_memory=args.use_pinned_memory, sampler=val_sampler, shuffle=args.shuffle)
+        num_workers=args.workers, pin_memory=args.use_pinned_memory, sampler=val_sampler, shuffle=args.shuffle_val)
 
     #################################################
     if (args.model_config.input_channels is None):
@@ -321,16 +322,15 @@ def main(args):
     #################################################
     # create model
     if args.model is not None:
-        model, change_names_dict = model if isinstance(args.model, (list, tuple)) else (args.model, None)
+        model, change_names_dict = args.model if isinstance(args.model, (list, tuple)) else (args.model, None)
         assert isinstance(model, torch.nn.Module), 'args.model, if provided must be a valid torch.nn.Module'
     else:
         xnn.utils.print_yellow("=> creating model '{}'".format(args.model_name))
         model = vision.models.pixel2pixel.__dict__[args.model_name](args.model_config)
+        # check if we got the model as well as parameters to change the names in pretrained
+        model, change_names_dict = model if isinstance(model, (list,tuple)) else (model,None)
     #
 
-    # check if we got the model as well as parameters to change the names in pretrained
-    model, change_names_dict = model if isinstance(model, (list,tuple)) else (model,None)
-
     if args.quantize:
         # dummy input is used by quantized models to analyze graph
         is_cuda = next(model.parameters()).is_cuda
@@ -693,7 +693,7 @@ def train(args, train_dataset, train_loader, model, optimizer, epoch, train_writ
         #    input_zero = torch.zeros(input_var.shape)
         #    train_writer.add_graph(model, input_zero)
 
-        torch.cuda.empty_cache()
+        #torch.cuda.empty_cache()
 
         if iter >= epoch_size:
             break
index f0967395c804de7403b461db4bccc7b8e795fc76..0eae1e7fd418a62a4a7642b05bc26a1313d198b7 100755 (executable)
@@ -418,7 +418,7 @@ def cityscapes_segmentation_train(dataset_config, root, split=None, transforms=N
     return train_split
 
 
-def cityscapes_segmentation(dataset_config, root, split=None, transforms=None):
+def cityscapes_segmentation(dataset_config, root, split=None, transforms=None, *args, **kwargs):
     dataset_config = get_config().merge_from(dataset_config)
     gt = "gtFine"
     train_split = val_split = None
@@ -427,17 +427,23 @@ def cityscapes_segmentation(dataset_config, root, split=None, transforms=None):
         if split_name == 'train':
             train_split = CityscapesDataLoader(dataset_config, root, split_name, gt, transforms=transforms[0],
                                             load_segmentation=dataset_config.load_segmentation,
-                                            load_segmentation_five_class=dataset_config.load_segmentation_five_class)
+                                            load_segmentation_five_class=dataset_config.load_segmentation_five_class,
+                                            *args, **kwargs)
         elif split_name == 'val':
             val_split = CityscapesDataLoader(dataset_config, root, split_name, gt, transforms=transforms[1],
                                             load_segmentation=dataset_config.load_segmentation,
-                                            load_segmentation_five_class=dataset_config.load_segmentation_five_class)
+                                            load_segmentation_five_class=dataset_config.load_segmentation_five_class,
+                                            *args, **kwargs)
         else:
             pass
     #
     return train_split, val_split
 
 
+def cityscapes_segmentation_with_additional_info(dataset_config, root, split=None, transforms=None, *args, **kwargs):
+    return cityscapes_segmentation(dataset_config, root, split=None, transforms=transforms, additional_info=True, *args, **kwargs)
+
+
 def cityscapes_depth_train(dataset_config, root, split=None, transforms=None):
     dataset_config = get_config().merge_from(dataset_config)
     gt = "gtFine"
index 1802dd5e2e60b2479d7fe199be16c4fb12b59d47..a0aa684634795373bbe2d23d41b2e0c247a0cb80 100644 (file)
@@ -1,5 +1,6 @@
 from .deeplabv3lite import *
 from .fpn_pixel2pixel import *
+from .unet_pixel2pixel import *
 
 try: from .deeplabv3lite_internal import *
 except: pass
index ea43dfc2f1696101c24caeff799c0ae67a37f82d..4bf0c0b23b6e76908e77bd51c7c5cc149218334a 100644 (file)
@@ -42,7 +42,9 @@ class DeepLabV3LiteDecoder(torch.nn.Module):
         self.decoder_channels = merged_channels = (current_channels+model_config.shortcut_out)
 
         upstride1 = model_config.shortcut_strides[-1]//model_config.shortcut_strides[0]
-        self.upsample1 = xnn.layers.UpsampleGenericTo(decoder_channels, decoder_channels, upstride1, model_config.interpolation_type, model_config.interpolation_mode)
+        # use UpsampleGenericTo() instead of UpsampleTo() to break down large upsampling factors to multiples of 4 and 2 -
+        # useful if upsampling factors other than 4 and 2 are not supported.
+        self.upsample1 = xnn.layers.UpsampleTo(decoder_channels, decoder_channels, upstride1, model_config.interpolation_type, model_config.interpolation_mode)
 
         self.cat = xnn.layers.CatBlock()
 
@@ -53,7 +55,9 @@ class DeepLabV3LiteDecoder(torch.nn.Module):
             self.pred = ConvXWSepBlock(merged_channels, model_config.output_channels, kernel_size=3, normalization=((not model_config.linear_dw),False), activation=(False,final_activation), groups=1)
             if self.model_config.final_upsample:
                 upstride2 = model_config.shortcut_strides[0]
-                self.upsample2 = xnn.layers.UpsampleGenericTo(model_config.output_channels, model_config.output_channels, upstride2, model_config.interpolation_type, model_config.interpolation_mode)
+                # use UpsampleGenericTo() instead of UpsampleTo() to break down large upsampling factors to multiples of 4 and 2 -
+                # useful if upsampling factors other than 4 and 2 are not supported.
+                self.upsample2 = xnn.layers.UpsampleTo(model_config.output_channels, model_config.output_channels, upstride2, model_config.interpolation_type, model_config.interpolation_mode)
             #
 
     # the upsampling is using functional form to support size based upsampling for odd sizes
index 4b13855dcc470719f149a5121b3ec9d6618837f9..8bed13818f8c534a06f3ac3e96345f699e840896 100644 (file)
@@ -24,6 +24,7 @@ def get_config_fpnp2p_mnv2():
     model_config.use_extra_strides = False
     model_config.groupwise_sep = False
     model_config.fastdown = False
+    model_config.width_mult = 1.0
 
     model_config.strides = (2,2,2,2,2)
     encoder_stride = np.prod(model_config.strides)
diff --git a/modules/pytorch_jacinto_ai/vision/models/pixel2pixel/unet_pixel2pixel.py b/modules/pytorch_jacinto_ai/vision/models/pixel2pixel/unet_pixel2pixel.py
new file mode 100644 (file)
index 0000000..94b8204
--- /dev/null
@@ -0,0 +1,286 @@
+import torch
+import numpy as np
+from .... import xnn
+
+from .pixel2pixelnet import *
+from ..multi_input_net import MobileNetV2TVMI4, ResNet50MI4
+
+
+__all__ = ['UNetPixel2PixelASPP', 'UNetPixel2PixelDecoder',
+           'unet_pixel2pixel_aspp_mobilenetv2_tv', 'unet_pixel2pixel_aspp_mobilenetv2_tv_fd',
+           'unet_pixel2pixel_aspp_resnet50', 'unet_pixel2pixel_aspp_resnet50_fd',
+           ]
+
+# config settings for mobilenetv2 backbone
+def get_config_unetp2p_mnv2():
+    model_config = xnn.utils.ConfigNode()
+    model_config.num_classes = None
+    model_config.num_decoders = None
+    model_config.intermediate_outputs = True
+    model_config.use_aspp = True
+    model_config.use_extra_strides = False
+    model_config.groupwise_sep = False
+    model_config.fastdown = False
+    model_config.width_mult = 1.0
+
+    model_config.strides = (2,2,2,2,2)
+    encoder_stride = np.prod(model_config.strides)
+    model_config.shortcut_strides = (2,4,8,16,encoder_stride)
+    model_config.shortcut_channels = (16,24,32,96,320) # this is for mobilenetv2 - change for other networks
+    model_config.decoder_chan = 256
+    model_config.aspp_chan = 256
+    model_config.aspp_dil = (6,12,18)
+
+    model_config.kernel_size_smooth = 3
+    model_config.interpolation_type = 'upsample'
+    model_config.interpolation_mode = 'bilinear'
+
+    model_config.final_prediction = True
+    model_config.final_upsample = True
+    model_config.output_range = None
+
+    model_config.normalize_input = False
+    model_config.split_outputs = False
+    model_config.decoder_factor = 1.0
+    model_config.activation = xnn.layers.DefaultAct2d
+    model_config.linear_dw = False
+    model_config.normalize_gradients = False
+    model_config.freeze_encoder = False
+    model_config.freeze_decoder = False
+    model_config.multi_task = False
+    return model_config
+
+
+###########################################
+class UNetPyramid(torch.nn.Module):
+    def __init__(self, current_channels, minimum_channels, shortcut_strides, shortcut_channels, activation, kernel_size_smooth, interpolation_type, interpolation_mode):
+        super().__init__()
+        self.shortcut_strides = shortcut_strides
+        self.shortcut_channels = shortcut_channels
+        self.upsamples = torch.nn.ModuleList()
+        self.concats = torch.nn.ModuleList()
+        self.smooth_convs = torch.nn.ModuleList()
+
+        self.smooth_convs.append(None)
+        self.concats.append(None)
+
+        upstride = 2
+        activation2 = (activation, activation)
+        for idx, (s_stride, feat_chan) in enumerate(zip(shortcut_strides, shortcut_channels)):
+            self.upsamples.append(xnn.layers.UpsampleTo(current_channels, current_channels, upstride, interpolation_type, interpolation_mode))
+            self.concats.append(xnn.layers.CatBlock())
+            smooth_channels = max(minimum_channels, feat_chan)
+            self.smooth_convs.append( xnn.layers.ConvDWSepNormAct2d(current_channels+feat_chan, smooth_channels, kernel_size=kernel_size_smooth, activation=activation2))
+            current_channels = smooth_channels
+        #
+    #
+
+
+    def forward(self, x_input, x_list):
+        in_shape = x_input.shape
+        x = x_list[-1]
+
+        outputs = []
+
+        x = self.smooth_convs[0](x) if (self.smooth_convs[0] is not None) else x
+        outputs.append(x)
+
+        for idx, (concat, smooth_conv, s_stride, short_chan, upsample) in \
+                enumerate(zip(self.concats[1:], self.smooth_convs[1:], self.shortcut_strides, self.shortcut_channels, self.upsamples)):
+            # get the feature of lower stride
+            shape_s = xnn.utils.get_shape_with_stride(in_shape, s_stride)
+            shape_s[1] = short_chan
+            x_s = xnn.utils.get_blob_from_list(x_list, shape_s)
+            # upsample current output and concat to that
+            x = upsample((x,x_s))
+            x = concat((x,x_s)) if (concat is not None) else x
+            # smooth conv
+            x = smooth_conv(x) if (smooth_conv is not None) else x
+            # output
+            outputs.append(x)
+        #
+        return outputs[::-1]
+
+
+###########################################
+class UNetPixel2PixelDecoder(torch.nn.Module):
+    def __init__(self, model_config):
+        super().__init__()
+        self.model_config = model_config
+        activation = self.model_config.activation
+        self.output_type = model_config.output_type
+        self.decoder_channels = decoder_channels = round(self.model_config.decoder_chan*self.model_config.decoder_factor)
+
+        self.rfblock = None
+        if self.model_config.use_aspp:
+            current_channels = self.model_config.shortcut_channels[-1]
+            aspp_channels = round(self.model_config.aspp_chan * self.model_config.decoder_factor)
+            self.rfblock = xnn.layers.DWASPPLiteBlock(current_channels, aspp_channels, decoder_channels, dilation=self.model_config.aspp_dil, avg_pool=False, activation=activation)
+            current_channels = decoder_channels
+        elif self.model_config.use_extra_strides:
+            # a low complexity pyramid
+            current_channels = self.model_config.shortcut_channels[-3]
+            self.rfblock = torch.nn.Sequential(xnn.layers.ConvDWSepNormAct2d(current_channels, current_channels, kernel_size=3, stride=2, activation=(activation, activation)),
+                                               xnn.layers.ConvDWSepNormAct2d(current_channels, decoder_channels, kernel_size=3, stride=2, activation=(activation, activation)))
+            current_channels = decoder_channels
+        else:
+            current_channels = self.model_config.shortcut_channels[-1]
+            self.rfblock = xnn.layers.ConvNormAct2d(current_channels, decoder_channels, kernel_size=1, stride=1)
+            current_channels = decoder_channels
+        #
+
+        minimum_channels = max(self.model_config.output_channels*2, 32)
+        shortcut_strides = self.model_config.shortcut_strides[::-1][1:]
+        shortcut_channels = self.model_config.shortcut_channels[::-1][1:]
+        self.unet = UNetPyramid(current_channels, minimum_channels, shortcut_strides, shortcut_channels, self.model_config.activation, self.model_config.kernel_size_smooth,
+                           self.model_config.interpolation_type, self.model_config.interpolation_mode)
+        current_channels = max(minimum_channels, shortcut_channels[-1])
+
+        # prediction
+        if self.model_config.final_prediction:
+            final_activation = xnn.layers.get_fixed_pact2(output_range = model_config.output_range) if (model_config.output_range is not None) else False
+            self.pred = xnn.layers.ConvDWSepNormAct2d(current_channels, self.model_config.output_channels, kernel_size=3, normalization=(True,False), activation=(False,final_activation))
+
+            if self.model_config.final_upsample:
+                upstride_final = self.model_config.shortcut_strides[0]
+                self.upsample = xnn.layers.UpsampleTo(model_config.output_channels, model_config.output_channels, upstride_final, model_config.interpolation_type, model_config.interpolation_mode)
+            #
+        #
+
+    def forward(self, x_input, x, x_list):
+        assert isinstance(x_input, (list,tuple)) and len(x_input)<=2, 'incorrect input'
+        assert x is x_list[-1], 'the features must the last one in x_list'
+        x_input = x_input[0]
+        in_shape = x_input.shape
+
+        if self.model_config.use_extra_strides:
+            for blk in self.rfblock:
+                x = blk(x)
+                x_list += [x]
+            #
+        elif self.rfblock is not None:
+            x = self.rfblock(x)
+            x_list[-1] = x
+        #
+
+        x_list = self.unet(x_input, x_list)
+        x = x_list[0]
+
+        if self.model_config.final_prediction:
+            # prediction
+            x = self.pred(x)
+
+            # final prediction is the upsampled one
+            if self.model_config.final_upsample:
+                x = self.upsample((x,x_input))
+
+            if (not self.training) and (self.output_type == 'segmentation'):
+                x = torch.argmax(x, dim=1, keepdim=True)
+
+            assert int(in_shape[2]) == int(x.shape[2]) and int(in_shape[3]) == int(x.shape[3]), 'incorrect output shape'
+
+        return x
+
+
+###########################################
+class UNetPixel2PixelASPP(Pixel2PixelNet):
+    def __init__(self, base_model, model_config):
+        super().__init__(base_model, UNetPixel2PixelDecoder, model_config)
+
+
+###########################################
+def unet_pixel2pixel_aspp_mobilenetv2_tv(model_config, pretrained=None):
+    model_config = get_config_unetp2p_mnv2().merge_from(model_config)
+    # encoder setup
+    model_config_e = model_config.clone()
+    base_model = MobileNetV2TVMI4(model_config_e)
+    # decoder setup
+    model = UNetPixel2PixelASPP(base_model, model_config)
+
+    num_inputs = len(model_config.input_channels)
+    num_decoders = len(model_config.output_channels) if (model_config.num_decoders is None) else model_config.num_decoders
+    if num_inputs > 1:
+        change_names_dict = {'^features.': ['encoder.features.stream{}.'.format(stream) for stream in range(num_inputs)],
+                            '^classifier.': 'encoder.classifier.',
+                            '^encoder.features.': ['encoder.features.stream{}.'.format(stream) for stream in range(num_inputs)],
+                            '^decoders.0.': ['decoders.{}.'.format(d) for d in range(num_decoders)]}
+    else:
+        change_names_dict = {'^features.': 'encoder.features.',
+                             '^classifier.': 'encoder.classifier.',
+                             '^decoders.0.': ['decoders.{}.'.format(d) for d in range(num_decoders)]}
+    #
+
+    if pretrained:
+        model = xnn.utils.load_weights(model, pretrained, change_names_dict, ignore_size=True, verbose=True)
+    #
+    return model, change_names_dict
+
+
+# fast down sampling model (encoder stride 64 model)
+def unet_pixel2pixel_aspp_mobilenetv2_tv_fd(model_config, pretrained=None):
+    model_config = get_config_unetp2p_mnv2().merge_from(model_config)
+    model_config.fastdown = True
+    model_config.strides = (2,2,2,2,2)
+    model_config.shortcut_strides = (4,8,16,32,64)
+    model_config.shortcut_channels = (16,24,32,96,320)
+    model_config.decoder_chan = 256
+    model_config.aspp_chan = 256
+    return unet_pixel2pixel_aspp_mobilenetv2_tv(model_config, pretrained=pretrained)
+
+
+###########################################
+def get_config_unetp2p_resnet50():
+    # only the delta compared to the one defined for mobilenetv2
+    model_config = get_config_unetp2p_mnv2()
+    model_config.shortcut_strides = (2,4,8,16,32)
+    model_config.shortcut_channels = (64,256,512,1024,2048)
+    return model_config
+
+
+def unet_pixel2pixel_aspp_resnet50(model_config, pretrained=None):
+    model_config = get_config_unetp2p_resnet50().merge_from(model_config)
+    # encoder setup
+    model_config_e = model_config.clone()
+    base_model = ResNet50MI4(model_config_e)
+    # decoder setup
+    model = UNetPixel2PixelASPP(base_model, model_config)
+
+    # the pretrained model provided by torchvision and what is defined here differs slightly
+    # note: that this change_names_dict  will take effect only if the direct load fails
+    # finally take care of the change for unet (features->encoder.features)
+    num_inputs = len(model_config.input_channels)
+    num_decoders = len(model_config.output_channels) if (model_config.num_decoders is None) else model_config.num_decoders
+    if num_inputs > 1:
+        change_names_dict = {'^conv1.': ['encoder.features.stream{}.conv1.'.format(stream) for stream in range(num_inputs)],
+                            '^bn1.': ['encoder.features.stream{}.bn1.'.format(stream) for stream in range(num_inputs)],
+                            '^relu.': ['encoder.features.stream{}.relu.'.format(stream) for stream in range(num_inputs)],
+                            '^maxpool.': ['encoder.features.stream{}.maxpool.'.format(stream) for stream in range(num_inputs)],
+                            '^layer': ['encoder.features.stream{}.layer'.format(stream) for stream in range(num_inputs)],
+                            '^features.': ['encoder.features.stream{}.'.format(stream) for stream in range(num_inputs)],
+                            '^encoder.features.': ['encoder.features.stream{}.'.format(stream) for stream in range(num_inputs)],
+                            '^decoders.0.': ['decoders.{}.'.format(d) for d in range(num_decoders)]}
+    else:
+        change_names_dict = {'^conv1.': 'encoder.features.conv1.',
+                             '^bn1.': 'encoder.features.bn1.',
+                             '^relu.': 'encoder.features.relu.',
+                             '^maxpool.': 'encoder.features.maxpool.',
+                             '^layer': 'encoder.features.layer',
+                             '^features.': 'encoder.features.',
+                             '^decoders.0.': ['decoders.{}.'.format(d) for d in range(num_decoders)]}
+    #
+
+    if pretrained:
+        model = xnn.utils.load_weights(model, pretrained, change_names_dict, ignore_size=True, verbose=True)
+
+    return model, change_names_dict
+
+
+def unet_pixel2pixel_aspp_resnet50_fd(model_config, pretrained=None):
+    model_config = get_config_unetp2p_resnet50().merge_from(model_config)
+    model_config.fastdown = True
+    model_config.strides = (2,2,2,2,2)
+    model_config.shortcut_strides = (2,4,8,16,32,64) #(4,8,16,32,64)
+    model_config.shortcut_channels = (64,64,256,512,1024,2048) #(64,256,512,1024,2048)
+    model_config.decoder_chan = 256 #128
+    model_config.aspp_chan = 256 #128
+    return unet_pixel2pixel_aspp_resnet50(model_config, pretrained=pretrained)
\ No newline at end of file
index e5267789751557f9f1f19a8c6b2be76133ed199b..b4205f6a2bd15b6062e7e1751d8469074fcf98f7 100644 (file)
@@ -3,7 +3,7 @@ from .activation import *
 from .common_blocks import *
 
 from .conv_blocks import *
-from .deconv_blocks import *
+from .upsample_blocks import *
 from .multi_task import *
 
 from .rf_blocks import *
diff --git a/modules/pytorch_jacinto_ai/xnn/layers/upsample_blocks.py b/modules/pytorch_jacinto_ai/xnn/layers/upsample_blocks.py
new file mode 100644 (file)
index 0000000..20bc2e4
--- /dev/null
@@ -0,0 +1,129 @@
+from .conv_blocks import *
+from .layer_config import *
+from .common_blocks import *
+
+###############################################################
+def UpsampleTo(input_channels, output_channels, upstride, interpolation_type, interpolation_mode):
+    upsample = []
+    if interpolation_type == 'upsample':
+        upsample = [ResizeTo(mode=interpolation_mode)]
+    elif interpolation_type == 'deconv':
+        upsample = [SplitListTakeFirst(),
+                    DeConvDWLayer2d(input_channels, output_channels, kernel_size=upstride * 2, stride=upstride)]
+    elif interpolation_type == 'upsample_conv':
+        upsample = [ResizeTo(mode=interpolation_mode),
+                    ConvDWLayer2d(input_channels, output_channels, kernel_size=int(upstride * 1.5 + 1))]
+    elif interpolation_type == 'subpixel_conv':
+        upsample = [SplitListTakeFirst(),
+                    ConvDWSepNormAct2d(input_channels, output_channels*upstride*upstride, kernel_size=int(upstride + 1), normalization=(True,False), activation=(False,False)),
+                    torch.nn.PixelShuffle(upscale_factor=upstride)]
+    else:
+        assert False, f'invalid interpolation_type: {interpolation_type}'
+    #
+    upsample = torch.nn.Sequential(*upsample)
+    return upsample
+
+
+class UpsampleGenericTo(torch.nn.Module):
+    def __init__(self, input_channels, output_channels, upstride, interpolation_type, interpolation_mode):
+        super().__init__()
+        self.upsample_list = torch.nn.ModuleList()
+        self.upstride_list = []
+        while upstride >= 2:
+            upstride_layer = 4 if upstride > 4 else upstride
+            upsample = UpsampleTo(input_channels, output_channels, upstride_layer, interpolation_type, interpolation_mode)
+            self.upsample_list.append(upsample)
+            self.upstride_list.append(upstride_layer)
+            upstride = upstride//4
+
+    def forward(self, x):
+        assert isinstance(x, (list,tuple)) and len(x)==2, 'input must be a tuple/list of size 2'
+        x, x_target = x
+        xt_shape = x.shape
+        for idx, (upsample, upstride) in enumerate(zip(self.upsample_list,self.upstride_list)):
+            xt_shape = (xt_shape[0], xt_shape[1], xt_shape[2]*upstride, xt_shape[3]*upstride)
+            xt = torch.zeros(xt_shape).to(x.device)
+            x = upsample((x, xt))
+            xt_shape = x.shape
+        #
+        return x
+
+
+############################################################### 
+def DeConvLayer2d(in_planes, out_planes, kernel_size, stride=1, groups=1, dilation=1, padding=None, output_padding=None, bias=False):
+    """convolution with padding"""
+    if (output_padding is None) and (padding is None):
+        if kernel_size % 2 == 0:
+            padding = (kernel_size - stride) // 2
+            output_padding = 0
+        else:
+            padding = (kernel_size - stride + 1) // 2
+            output_padding = 1
+
+    return torch.nn.ConvTranspose2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, dilation=dilation, padding=padding,
+                                    output_padding=output_padding, bias=bias, groups=groups)
+
+
+def DeConvDWLayer2d(in_planes, out_planes, stride=1, dilation=1, kernel_size=None, padding=None, output_padding=None, bias=False):
+    """convolution with padding"""
+    assert in_planes == out_planes, 'in DW layer channels must not change'
+    return DeConvLayer2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, dilation=dilation, groups=in_planes,
+                       padding=padding, output_padding=output_padding, bias=bias)
+    
+
+############################################################### 
+def DeConvBNAct(in_planes, out_planes, kernel_size=None, stride=1, groups=1, dilation=1, padding=None, output_padding=None, bias=False, \
+              normalization=DefaultNorm2d, activation=DefaultAct2d):
+    """convolution with padding, BN, ReLU"""
+    if (output_padding is None) and (padding is None):
+        if kernel_size % 2 == 0:
+            padding = (kernel_size - stride) // 2
+            output_padding = 0
+        else:
+            padding = (kernel_size - stride + 1) // 2
+            output_padding = 1
+
+    if activation is True:
+        activation = DefaultAct2d
+
+    if normalization is True:
+        normalization = DefaultNorm2d
+
+    layers = [torch.nn.ConvTranspose2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, dilation=dilation, padding=padding,
+                                       output_padding=output_padding, bias=bias, groups=groups)]
+    if normalization:
+        layers.append(normalization(out_planes))
+
+    if activation:
+        layers.append(activation(inplace=True))
+    #
+    layers = torch.nn.Sequential(*layers)
+    return layers
+
+    
+def DeConvDWBNAct(in_planes, out_planes, stride=1, kernel_size=None, dilation=1, padding=None, output_padding=None, bias=False,
+                  normalization=DefaultNorm2d, activation=DefaultAct2d):
+    """convolution with padding, BN, ReLU"""
+    assert in_planes == out_planes, 'in DW layer channels must not change'
+    return DeConvBNAct(in_planes, out_planes, kernel_size=kernel_size, stride=stride, dilation=dilation, padding=padding, output_padding=output_padding,
+                       bias=bias, groups=in_planes, normalization=normalization, activation=activation)
+
+
+###########################################################
+def DeConvDWSepBNAct(in_planes, out_planes, stride=1, kernel_size=None, groups=1, dilation=1, bias=False, \
+                   first_1x1=False, normalization=(DefaultNorm2d,DefaultNorm2d), activation=(DefaultAct2d,DefaultAct2d)):
+    if first_1x1:
+        layers = [
+            ConvNormAct2d(in_planes, out_planes, kernel_size=1, groups=groups, bias=bias,
+                      normalization=normalization[0], activation=activation[0]),
+            DeConvDWBNAct(out_planes, out_planes, stride=stride, kernel_size=kernel_size, dilation=dilation, bias=bias,
+                        normalization=normalization[1], activation=activation[1])]
+    else:
+        layers = [DeConvDWBNAct(in_planes, in_planes, stride=stride, kernel_size=kernel_size, dilation=dilation, bias=bias,
+                              normalization=normalization[0], activation=activation[0]),
+                  ConvNormAct2d(in_planes, out_planes, groups=groups, kernel_size=1, bias=bias,
+                            normalization=normalization[1], activation=activation[1])]
+
+    layers = torch.nn.Sequential(*layers)
+    return layers
+
index 861b29496f1e67dc3ff8cc872e354cacc49d554c..f96d1afefc010b755927e6f788793cffa67937ab 100755 (executable)
@@ -5,8 +5,8 @@
 ## Training
 ## =====================================================================================
 #### KITTI Depth (Manual Download) - Training with MobileNetV2+DeeplabV3Lite
-python ./scripts/train_depth_main.py --dataset_name kitti_depth --model_name deeplabv3lite_mobilenetv2_tv --data_path ./data/datasets/kitti/kitti_depth/data --img_resize 384 768 --output_size 374 1242 \
---pretrained https://download.pytorch.org/models/mobilenet_v2-b0353104.pth
+#python ./scripts/train_depth_main.py --dataset_name kitti_depth --model_name deeplabv3lite_mobilenetv2_tv --data_path ./data/datasets/kitti/kitti_depth/data --img_resize 384 768 --output_size 374 1242 \
+#--pretrained https://download.pytorch.org/models/mobilenet_v2-b0353104.pth
 
 #### KITTI Depth (Manual Download) - Training with ResNet50+FPN
 #python ./scripts/train_depth_main.py --dataset_name kitti_depth --model_name fpn_pixel2pixel_aspp_resnet50 --data_path ./data/datasets/kitti/kitti_depth/data --img_resize 384 768 --output_size 374 1242 \
index 8f7f2cf15c074fabcde05f1c3aee3149d29b440f..c1b6d670cfa7199f8744f545b653e042b3766e7b 100755 (executable)
@@ -1,6 +1,21 @@
 # Summary of commands - uncomment one and run this script
 #### Manual Download: It is expected that the dataset is manually downloaded and kept in the folder specified agaianst the --data_path option.
 
+## =====================================================================================
+# Models Supported:
+## =====================================================================================
+# deeplabv3lite_mobilenetv2_tv: deeplabv3lite decoder
+# fpn_pixel2pixel_aspp_mobilenetv2_tv: fpn decoder
+# unet_pixel2pixel_aspp_mobilenetv2_tv: unet decoder
+# deeplabv3lite_mobilenetv2_tv_fd: low complexity model with fast downsampling strategy
+# fpn_pixel2pixel_aspp_mobilenetv2_tv_fd: low complexity model with fast downsampling strategy
+# unet_pixel2pixel_aspp_mobilenetv2_tv_fd: low complexity model with fast downsampling strategy
+#
+# deeplabv3lite_resnet50: uses resnet50 encoder
+# deeplabv3lite_resnet50_p5: low complexity model - uses resnet50 encoder with half the number of channels (1/4 the complexity). note this need specially trained resnet50 pretrained weights
+# fpn_pixel2pixel_aspp_resnet50_fd: low complexity model - with fast downsampling strategy
+
+
 ## =====================================================================================
 ## Training
 ## =====================================================================================
 #python ./scripts/train_segmentation_main.py --dataset_name cityscapes_segmentation --model_name deeplabv3lite_mobilenetv2_tv --data_path ./data/datasets/cityscapes/data --img_resize 384 768 --output_size 1024 2048 --gpus 0 1 \
 #--pretrained https://download.pytorch.org/models/mobilenet_v2-b0353104.pth
 
-#### Cityscapes Semantic Segmentation - Training with MobileNetV2+FPN
-#python ./scripts/train_segmentation_main.py --dataset_name cityscapes_segmentation --model_name fpn_pixel2pixel_aspp_mobilenetv2_tv --data_path ./data/datasets/cityscapes/data --img_resize 384 768 --output_size 1024 2048 --gpus 0 1 \
-#--pretrained https://download.pytorch.org/models/mobilenet_v2-b0353104.pth
-
-#### Cityscapes Semantic Segmentation - MobileNetV2+FPN - no aspp model, stride 64 model - Low Complexity Model
-#python ./scripts/train_segmentation_main.py --dataset_name cityscapes_segmentation --model_name fpn_pixel2pixel_aspp_mobilenetv2_tv_fd --data_path ./data/datasets/cityscapes/data --img_resize 384 768 --output_size 1024 2048 --gpus 0 1 \
-#--pretrained https://download.pytorch.org/models/mobilenet_v2-b0353104.pth
-
 #### Cityscapes Semantic Segmentation - Training with MobileNetV2+DeeplabV3Lite, Higher Resolution
 #python ./scripts/train_segmentation_main.py --dataset_name cityscapes_segmentation --model_name deeplabv3lite_mobilenetv2_tv --data_path ./data/datasets/cityscapes/data --img_resize 768 1536 --rand_crop 512 1024 --output_size 1024 2048 --gpus 0 1 \
 #--pretrained https://download.pytorch.org/models/mobilenet_v2-b0353104.pth
 
 #### Cityscapes Semantic Segmentation - original fpn - no aspp model, stride 64 model, Higher Resolution - Low Complexity Model
-#python ./scripts/train_segmentation_main.py --dataset_name cityscapes_segmentation --model_name fpn_pixel2pixel_aspp_mobilenetv2_tv_fd --data_path ./data/datasets/cityscapes/data --img_resize 768 1536 --rand_crop 512 1024 --output_size 1024 2048 --gpus 0 1 \
+#python ./scripts/train_segmentation_main.py --dataset_name cityscapes_segmentation --model_name deeplabv3lite_mobilenetv2_tv_fd --data_path ./data/datasets/cityscapes/data --img_resize 768 1536 --rand_crop 512 1024 --output_size 1024 2048 --gpus 0 1 \
 #--pretrained https://download.pytorch.org/models/mobilenet_v2-b0353104.pth
 
 
 
+##--ResNet50 + deeplabv3lite
 #### Cityscapes Semantic Segmentation - Training with ResNet50+DeeplabV3Lite
 #python ./scripts/train_segmentation_main.py --dataset_name cityscapes_segmentation --model_name deeplabv3lite_resnet50 --data_path ./data/datasets/cityscapes/data --img_resize 384 768 --output_size 1024 2048 --gpus 0 1 \
 #--pretrained https://download.pytorch.org/models/resnet50-19c8e357.pth
 
-#### Cityscapes Semantic Segmentation - Training with ResNet50+FPN
-#python ./scripts/train_segmentation_main.py --dataset_name cityscapes_segmentation --model_name fpn_pixel2pixel_aspp_resnet50 --data_path ./data/datasets/cityscapes/data --img_resize 384 768 --rand_crop 512 1024 --output_size 1024 2048 --gpus 0 1 \
-#--pretrained https://download.pytorch.org/models/resnet50-19c8e357.pth
-
-#### Cityscapes Semantic Segmentation - Training with FD-ResNet50+FPN - High Resolution
-#python ./scripts/train_segmentation_main.py --dataset_name cityscapes_segmentation --model_name fpn_pixel2pixel_aspp_resnet50 --data_path ./data/datasets/cityscapes/data --img_resize 768 1536 --rand_crop 512 1024 --output_size 1024 2048 --gpus 0 1 \
+#### Cityscapes Semantic Segmentation - Training with FD-ResNet50+FPN - High Resolution - Low Complexity Model
+#python ./scripts/train_segmentation_main.py --dataset_name cityscapes_segmentation --model_name deeplabv3lite_resnet50_fd --data_path ./data/datasets/cityscapes/data --img_resize 768 1536 --rand_crop 512 1024 --output_size 1024 2048 --gpus 0 1 \
 #--pretrained https://download.pytorch.org/models/resnet50-19c8e357.pth
 
+#--ResNet50 encoder with half the channels + deeplabv3lite
 #### Cityscapes Semantic Segmentation - Training with ResNet50_p5+DeeplabV3Lite (ResNet50 encoder with half the channels): deeplabv3lite_resnet50_p5 & deeplabv3lite_resnet50_p5_fd
 #python ./scripts/train_segmentation_main.py --dataset_name cityscapes_segmentation --model_name deeplabv3lite_resnet50_p5 --data_path ./data/datasets/cityscapes/data --img_resize 384 768 --output_size 1024 2048 --gpus 0 1 \
 #--pretrained "./data/modelzoo/pretrained/pytorch/imagenet_classification/resnet50-0.5_b256_lr0.1_step30_inception-aug(0.08-1.0)_epoch(92of100)_1gmac_(72.05%)/model_best.pth.tar"
 
-#### Cityscapes Semantic Segmentation - Training with FD-ResNet50+FPN - High Resolution - Low Complexity Model
-#python ./scripts/train_segmentation_main.py --dataset_name cityscapes_segmentation --model_name fpn_pixel2pixel_aspp_resnet50_fd --data_path ./data/datasets/cityscapes/data --img_resize 768 1536 --rand_crop 512 1024 --output_size 1024 2048 --gpus 0 1 \
-#--pretrained https://download.pytorch.org/models/resnet50-19c8e357.pth
-
-
 
+#-- VOC Segmentation
 #### VOC Segmentation - Training with MobileNetV2+DeeplabV3Lite
 #python ./scripts/train_segmentation_main.py --dataset_name voc_segmentation --model_name deeplabv3lite_mobilenetv2_tv --data_path ./data/datasets/voc --img_resize 512 512 --output_size 512 512 --gpus 0 1 \
 #--pretrained https://download.pytorch.org/models/mobilenet_v2-b0353104.pth
index 0e99ec66b1e690d4066fe87aedd2955a8259fb23..b4a368a14ada57805338711328c4e898f1a271b6 100755 (executable)
@@ -32,8 +32,9 @@ cmds = parser.parse_args()
 # taken care first, since this has to be done before importing pytorch
 if 'gpus' in vars(cmds):
     value = getattr(cmds, 'gpus')
-    if value is not None:
+    if (value is not None) and ("CUDA_VISIBLE_DEVICES" not in os.environ):
         os.environ["CUDA_VISIBLE_DEVICES"] = ','.join([str(v) for v in value])
+    #
 #
 
 # to avoid hangs in data loader with multi threads
index a1c044ef83152b606c65387e2d76de3e07754067..3fa0f156234e2ed5db22ce4b3ede8e40565e5ea1 100755 (executable)
@@ -36,8 +36,9 @@ cmds = parser.parse_args()
 # taken care first, since this has to be done before importing pytorch
 if 'gpus' in vars(cmds):
     value = getattr(cmds, 'gpus')
-    if value is not None:
+    if (value is not None) and ("CUDA_VISIBLE_DEVICES" not in os.environ):
         os.environ["CUDA_VISIBLE_DEVICES"] = ','.join([str(v) for v in value])
+    #
 #
 
 # to avoid hangs in data loader with multi threads
index 6b17e6ae43f0774372ed5402015af84d17a85af5..704f60ecd54db5bfc914fad082c15a8dd4e46c83 100755 (executable)
@@ -37,8 +37,9 @@ cmds = parser.parse_args()
 # taken care first, since this has to be done before importing pytorch
 if 'gpus' in vars(cmds):
     value = getattr(cmds, 'gpus')
-    if value is not None:
+    if (value is not None) and ("CUDA_VISIBLE_DEVICES" not in os.environ):
         os.environ["CUDA_VISIBLE_DEVICES"] = ','.join([str(v) for v in value])
+    #
 #
 
 # to avoid hangs in data loader with multi threads
index d7317b47fec6272e7668d65cedc81aca6f79c1e9..51d619264753cf4994750a8ca769a0de245a4c08 100755 (executable)
@@ -48,8 +48,9 @@ cmds = parser.parse_args()
 # taken care first, since this has to be done before importing pytorch
 if 'gpus' in vars(cmds):
     value = getattr(cmds, 'gpus')
-    if value is not None:
+    if (value is not None) and ("CUDA_VISIBLE_DEVICES" not in os.environ):
         os.environ["CUDA_VISIBLE_DEVICES"] = ','.join([str(v) for v in value])
+    #
 #
 
 # to avoid hangs in data loader with multi threads
index e5db9066de4452c241e1ef8ce62bb2cde7c81fd4..acf1ad5b5ec3446503d450fe6481a020b4b61951 100755 (executable)
@@ -46,8 +46,9 @@ cmds = parser.parse_args()
 # taken care first, since this has to be done before importing pytorch
 if 'gpus' in vars(cmds):
     value = getattr(cmds, 'gpus')
-    if value is not None:
+    if (value is not None) and ("CUDA_VISIBLE_DEVICES" not in os.environ):
         os.environ["CUDA_VISIBLE_DEVICES"] = ','.join([str(v) for v in value])
+    #
 #
 
 ################################
index 89717df7eb2d4513c49a338da43a257019e056d8..9686aa3d1f7eb7f8272aab74c1b0e75658a41bd6 100755 (executable)
@@ -46,10 +46,10 @@ cmds = parser.parse_args()
 # taken care first, since this has to be done before importing pytorch
 if 'gpus' in vars(cmds):
     value = getattr(cmds, 'gpus')
-    if value is not None:
+    if (value is not None) and ("CUDA_VISIBLE_DEVICES" not in os.environ):
         os.environ["CUDA_VISIBLE_DEVICES"] = ','.join([str(v) for v in value])
+    #
 #
-
 ################################
 # to avoid hangs in data loader with multi threads
 # this was observed after using cv2 image processing functions
index 3d31769ef21d1243daacd0d983d24a7c55e7bb7f..8b7e3f3caaa3f5f485b5d79ecd0eb106fb189ebc 100755 (executable)
@@ -44,8 +44,9 @@ cmds = parser.parse_args()
 # taken care first, since this has to be done before importing pytorch
 if 'gpus' in vars(cmds):
     value = getattr(cmds, 'gpus')
-    if value is not None:
+    if (value is not None) and ("CUDA_VISIBLE_DEVICES" not in os.environ):
         os.environ["CUDA_VISIBLE_DEVICES"] = ','.join([str(v) for v in value])
+    #
 #
 
 ################################
index 271d4d40e34c1b7534c0ec568fdb1ff791305a35..8435b7ea5c3505060803f7eb7d58d508d5263b7b 100755 (executable)
@@ -46,8 +46,9 @@ cmds = parser.parse_args()
 # taken care first, since this has to be done before importing pytorch
 if 'gpus' in vars(cmds):
     value = getattr(cmds, 'gpus')
-    if value is not None:
+    if (value is not None) and ("CUDA_VISIBLE_DEVICES" not in os.environ):
         os.environ["CUDA_VISIBLE_DEVICES"] = ','.join([str(v) for v in value])
+    #
 #
 
 ################################