release commit
authorManu Mathew <a0393608@ti.com>
Sat, 25 Jan 2020 04:01:32 +0000 (09:31 +0530)
committerManu Mathew <a0393608@ti.com>
Sat, 25 Jan 2020 04:01:32 +0000 (09:31 +0530)
modules/pytorch_jacinto_ai/engine/train_pixel2pixel.py
modules/pytorch_jacinto_ai/vision/datasets/pixel2pixel/__init__.py
modules/pytorch_jacinto_ai/vision/datasets/pixel2pixel/calculate_class_weights.py [new file with mode: 0644]

index 237ac7520118e3c4fb63336c2a6c5c4e2e7530a2..2eec922c978a3c207144224b3da0afb8658289e2 100644 (file)
@@ -50,7 +50,8 @@ def get_config():
     args.model_config.freeze_encoder = False            # do not update encoder weights
     args.model_config.freeze_decoder = False            # do not update decoder weights
     args.model_config.multi_task_type = 'learned'       # find out loss multiplier by learning, choices=[None, 'learned', 'uncertainty', 'grad_norm', 'dwa_grad_norm']
-    
+
+    args.model = None                                   # the model itself can be given from ouside
     args.model_name = 'deeplabv2lite_mobilenetv2'       # model architecture, overwritten if pretrained is specified
     args.dataset_name = 'cityscapes_segmentation'       # dataset type
     args.data_path = './data/cityscapes'                # 'path to dataset'
@@ -300,20 +301,29 @@ def main(args):
     pretrained_data = None
     model_surgery_quantize = False
     if args.pretrained and args.pretrained != "None":
-        if args.pretrained.startswith('http://') or args.pretrained.startswith('https://'):
-            pretrained_file = vision.datasets.utils.download_url(args.pretrained, './data/downloads')
+        if isinstance(args.pretrained, dict):
+            pretrained_data = args.pretrained
         else:
-            pretrained_file = args.pretrained
+            if args.pretrained.startswith('http://') or args.pretrained.startswith('https://'):
+                pretrained_file = vision.datasets.utils.download_url(args.pretrained, './data/downloads')
+            else:
+                pretrained_file = args.pretrained
+            #
+            print(f'=> using pre-trained weights from: {args.pretrained}')
+            pretrained_data = torch.load(pretrained_file)
         #
-        print(f'=> using pre-trained weights from: {args.pretrained}')
-        pretrained_data = torch.load(pretrained_file)
         model_surgery_quantize = pretrained_data['quantize'] if 'quantize' in pretrained_data else False
     #
 
     #################################################
     # create model
-    xnn.utils.print_yellow("=> creating model '{}'".format(args.model_name))
-    model = vision.models.pixel2pixel.__dict__[args.model_name](args.model_config)
+    if args.model is not None:
+        model, change_names_dict = model if isinstance(args.model, (list, tuple)) else (args.model, None)
+        assert isinstance(model, torch.nn.Module), 'args.model, if provided must be a valid torch.nn.Module'
+    else:
+        xnn.utils.print_yellow("=> creating model '{}'".format(args.model_name))
+        model = vision.models.pixel2pixel.__dict__[args.model_name](args.model_config)
+    #
 
     # check if we got the model as well as parameters to change the names in pretrained
     model, change_names_dict = model if isinstance(model, (list,tuple)) else (model,None)
@@ -609,6 +619,8 @@ def train(args, train_dataset, train_loader, model, optimizer, epoch, train_writ
         ##########################
         # compute output
         task_outputs = model(input_list)
+
+        task_outputs = task_outputs if isinstance(task_outputs,(list,tuple)) else [task_outputs]
         # upsample output to target resolution
         task_outputs = upsample_tensors(task_outputs, target_sizes, args.upsample_mode)
 
@@ -747,6 +759,7 @@ def validate(args, val_dataset, val_loader, model, epoch, val_writer):
         # compute output
         task_outputs = model(input_list)
 
+        task_outputs = task_outputs if isinstance(task_outputs,(list,tuple)) else [task_outputs]
         task_outputs = upsample_tensors(task_outputs, target_sizes, args.upsample_mode)
 
         metric_total, metric_list, metric_names, metric_types, _ = \
index 5dd9de560706d88138562a9cea7662a84b431f72..7679729f0bb0ac374c945f27c6e25c33177fcdf5 100644 (file)
@@ -12,4 +12,6 @@ except: pass
 try: from .multi_dataset_internal import *
 except: pass
 
+from .calculate_class_weights import *
+
 
diff --git a/modules/pytorch_jacinto_ai/vision/datasets/pixel2pixel/calculate_class_weights.py b/modules/pytorch_jacinto_ai/vision/datasets/pixel2pixel/calculate_class_weights.py
new file mode 100644 (file)
index 0000000..c22ffd6
--- /dev/null
@@ -0,0 +1,77 @@
+import numpy as np
+import os
+import scipy.misc as misc
+import sys
+
+from  .... import xnn
+from .cityscapes_plus import CityscapesBaseSegmentationLoader, CityscapesBaseMotionLoader
+
+
+def calc_median_frequency(classes, present_num):
+    """
+    Class balancing by median frequency balancing method.
+    Reference: https://arxiv.org/pdf/1411.4734.pdf
+       'a = median_freq / freq(c) where freq(c) is the number of pixels
+        of class c divided by the total number of pixels in images where
+        c is present, and median_freq is the median of these frequencies.'
+    """
+    class_freq = classes / present_num
+    median_freq = np.median(class_freq)
+    return median_freq / class_freq
+
+
+def calc_log_frequency(classes, value=1.02):
+    """Class balancing by ERFNet method.
+       prob = each_sum_pixel / each_sum_pixel.max()
+       a = 1 / (log(1.02 + prob)).
+    """
+    class_freq = classes / classes.sum()  # ERFNet is max, but ERFNet is sum
+    # print(class_freq)
+    # print(np.log(value + class_freq))
+    return 1 / np.log(value + class_freq)
+
+
+def calc_weights():
+    method = "median"
+    result_path = "/afs/cg.cs.tu-bs.de/home/zhang/SEDPShuffleNet/datasets"
+
+    traval = "gtFine"
+    imgs_path = "./data/tiad/data/leftImg8bit/train"    #"./data/cityscapes/data/leftImg8bit/train"   #"./data/TIAD/data/leftImg8bit/train"
+    lbls_path = "./data/tiad/data/gtFine/train"         #"./data/cityscapes/data/gtFine/train"   # "./data/tiad/data/gtFine/train"  #"./data/cityscapes_frame_pair/data/gtFine/train"
+    labels = xnn.utils.recursive_glob(rootdir=lbls_path, suffix='labelTrainIds_motion.png')  #'labelTrainIds_motion.png'  #'labelTrainIds.png'
+
+    num_classes = 2       #5  #2
+
+    local_path = "./data/checkpoints"
+    dst = CityscapesBaseMotionLoader() #TiadBaseSegmentationLoader()  #CityscapesBaseSegmentationLoader()  #CityscapesBaseMotionLoader()
+
+    classes, present_num = ([0 for i in range(num_classes)] for i in range(2))
+
+    for idx, lbl_path in enumerate(labels):
+        lbl = misc.imread(lbl_path)
+        lbl = dst.encode_segmap(np.array(lbl, dtype=np.uint8))
+
+        for nc in range(num_classes):
+            num_pixel = (lbl == nc).sum()
+            if num_pixel:
+                classes[nc] += num_pixel
+                present_num[nc] += 1
+
+    if 0 in classes:
+        raise Exception("Some classes are not found")
+
+    classes = np.array(classes, dtype="f")
+    presetn_num = np.array(classes, dtype="f")
+    if method == "median":
+        class_weight = calc_median_frequency(classes, present_num)
+    elif method == "log":
+        class_weight = calc_log_frequency(classes)
+    else:
+        raise Exception("Please assign method to 'mean' or 'log'")
+
+    print("class weight", class_weight)
+    print("Done!")
+
+
+if __name__ == '__main__':
+    calc_weights()
\ No newline at end of file