summary | shortlog | log | commit | commitdiff | tree
raw | patch | inline | side by side (parent: f23f684)
raw | patch | inline | side by side (parent: f23f684)
author | Manu Mathew <a0393608@ti.com> | |
Sat, 25 Jan 2020 04:01:32 +0000 (09:31 +0530) | ||
committer | Manu Mathew <a0393608@ti.com> | |
Sat, 25 Jan 2020 04:01:32 +0000 (09:31 +0530) |
diff --git a/modules/pytorch_jacinto_ai/engine/train_pixel2pixel.py b/modules/pytorch_jacinto_ai/engine/train_pixel2pixel.py
index 237ac7520118e3c4fb63336c2a6c5c4e2e7530a2..2eec922c978a3c207144224b3da0afb8658289e2 100644 (file)
args.model_config.freeze_encoder = False # do not update encoder weights
args.model_config.freeze_decoder = False # do not update decoder weights
args.model_config.multi_task_type = 'learned' # find out loss multiplier by learning, choices=[None, 'learned', 'uncertainty', 'grad_norm', 'dwa_grad_norm']
args.model_config.freeze_encoder = False # do not update encoder weights
args.model_config.freeze_decoder = False # do not update decoder weights
args.model_config.multi_task_type = 'learned' # find out loss multiplier by learning, choices=[None, 'learned', 'uncertainty', 'grad_norm', 'dwa_grad_norm']
-
+
+ args.model = None # the model itself can be given from ouside
args.model_name = 'deeplabv2lite_mobilenetv2' # model architecture, overwritten if pretrained is specified
args.dataset_name = 'cityscapes_segmentation' # dataset type
args.data_path = './data/cityscapes' # 'path to dataset'
args.model_name = 'deeplabv2lite_mobilenetv2' # model architecture, overwritten if pretrained is specified
args.dataset_name = 'cityscapes_segmentation' # dataset type
args.data_path = './data/cityscapes' # 'path to dataset'
pretrained_data = None
model_surgery_quantize = False
if args.pretrained and args.pretrained != "None":
pretrained_data = None
model_surgery_quantize = False
if args.pretrained and args.pretrained != "None":
- if args.pretrained.startswith('http://') or args.pretrained.startswith('https://'):
- pretrained_file = vision.datasets.utils.download_url(args.pretrained, './data/downloads')
+ if isinstance(args.pretrained, dict):
+ pretrained_data = args.pretrained
else:
else:
- pretrained_file = args.pretrained
+ if args.pretrained.startswith('http://') or args.pretrained.startswith('https://'):
+ pretrained_file = vision.datasets.utils.download_url(args.pretrained, './data/downloads')
+ else:
+ pretrained_file = args.pretrained
+ #
+ print(f'=> using pre-trained weights from: {args.pretrained}')
+ pretrained_data = torch.load(pretrained_file)
#
#
- print(f'=> using pre-trained weights from: {args.pretrained}')
- pretrained_data = torch.load(pretrained_file)
model_surgery_quantize = pretrained_data['quantize'] if 'quantize' in pretrained_data else False
#
#################################################
# create model
model_surgery_quantize = pretrained_data['quantize'] if 'quantize' in pretrained_data else False
#
#################################################
# create model
- xnn.utils.print_yellow("=> creating model '{}'".format(args.model_name))
- model = vision.models.pixel2pixel.__dict__[args.model_name](args.model_config)
+ if args.model is not None:
+ model, change_names_dict = model if isinstance(args.model, (list, tuple)) else (args.model, None)
+ assert isinstance(model, torch.nn.Module), 'args.model, if provided must be a valid torch.nn.Module'
+ else:
+ xnn.utils.print_yellow("=> creating model '{}'".format(args.model_name))
+ model = vision.models.pixel2pixel.__dict__[args.model_name](args.model_config)
+ #
# check if we got the model as well as parameters to change the names in pretrained
model, change_names_dict = model if isinstance(model, (list,tuple)) else (model,None)
# check if we got the model as well as parameters to change the names in pretrained
model, change_names_dict = model if isinstance(model, (list,tuple)) else (model,None)
@@ -609,6 +619,8 @@ def train(args, train_dataset, train_loader, model, optimizer, epoch, train_writ
##########################
# compute output
task_outputs = model(input_list)
##########################
# compute output
task_outputs = model(input_list)
+
+ task_outputs = task_outputs if isinstance(task_outputs,(list,tuple)) else [task_outputs]
# upsample output to target resolution
task_outputs = upsample_tensors(task_outputs, target_sizes, args.upsample_mode)
# upsample output to target resolution
task_outputs = upsample_tensors(task_outputs, target_sizes, args.upsample_mode)
# compute output
task_outputs = model(input_list)
# compute output
task_outputs = model(input_list)
+ task_outputs = task_outputs if isinstance(task_outputs,(list,tuple)) else [task_outputs]
task_outputs = upsample_tensors(task_outputs, target_sizes, args.upsample_mode)
metric_total, metric_list, metric_names, metric_types, _ = \
task_outputs = upsample_tensors(task_outputs, target_sizes, args.upsample_mode)
metric_total, metric_list, metric_names, metric_types, _ = \
diff --git a/modules/pytorch_jacinto_ai/vision/datasets/pixel2pixel/__init__.py b/modules/pytorch_jacinto_ai/vision/datasets/pixel2pixel/__init__.py
index 5dd9de560706d88138562a9cea7662a84b431f72..7679729f0bb0ac374c945f27c6e25c33177fcdf5 100644 (file)
try: from .multi_dataset_internal import *
except: pass
try: from .multi_dataset_internal import *
except: pass
+from .calculate_class_weights import *
+
diff --git a/modules/pytorch_jacinto_ai/vision/datasets/pixel2pixel/calculate_class_weights.py b/modules/pytorch_jacinto_ai/vision/datasets/pixel2pixel/calculate_class_weights.py
--- /dev/null
@@ -0,0 +1,77 @@
+import numpy as np
+import os
+import scipy.misc as misc
+import sys
+
+from .... import xnn
+from .cityscapes_plus import CityscapesBaseSegmentationLoader, CityscapesBaseMotionLoader
+
+
+def calc_median_frequency(classes, present_num):
+ """
+ Class balancing by median frequency balancing method.
+ Reference: https://arxiv.org/pdf/1411.4734.pdf
+ 'a = median_freq / freq(c) where freq(c) is the number of pixels
+ of class c divided by the total number of pixels in images where
+ c is present, and median_freq is the median of these frequencies.'
+ """
+ class_freq = classes / present_num
+ median_freq = np.median(class_freq)
+ return median_freq / class_freq
+
+
+def calc_log_frequency(classes, value=1.02):
+ """Class balancing by ERFNet method.
+ prob = each_sum_pixel / each_sum_pixel.max()
+ a = 1 / (log(1.02 + prob)).
+ """
+ class_freq = classes / classes.sum() # ERFNet is max, but ERFNet is sum
+ # print(class_freq)
+ # print(np.log(value + class_freq))
+ return 1 / np.log(value + class_freq)
+
+
+def calc_weights():
+ method = "median"
+ result_path = "/afs/cg.cs.tu-bs.de/home/zhang/SEDPShuffleNet/datasets"
+
+ traval = "gtFine"
+ imgs_path = "./data/tiad/data/leftImg8bit/train" #"./data/cityscapes/data/leftImg8bit/train" #"./data/TIAD/data/leftImg8bit/train"
+ lbls_path = "./data/tiad/data/gtFine/train" #"./data/cityscapes/data/gtFine/train" # "./data/tiad/data/gtFine/train" #"./data/cityscapes_frame_pair/data/gtFine/train"
+ labels = xnn.utils.recursive_glob(rootdir=lbls_path, suffix='labelTrainIds_motion.png') #'labelTrainIds_motion.png' #'labelTrainIds.png'
+
+ num_classes = 2 #5 #2
+
+ local_path = "./data/checkpoints"
+ dst = CityscapesBaseMotionLoader() #TiadBaseSegmentationLoader() #CityscapesBaseSegmentationLoader() #CityscapesBaseMotionLoader()
+
+ classes, present_num = ([0 for i in range(num_classes)] for i in range(2))
+
+ for idx, lbl_path in enumerate(labels):
+ lbl = misc.imread(lbl_path)
+ lbl = dst.encode_segmap(np.array(lbl, dtype=np.uint8))
+
+ for nc in range(num_classes):
+ num_pixel = (lbl == nc).sum()
+ if num_pixel:
+ classes[nc] += num_pixel
+ present_num[nc] += 1
+
+ if 0 in classes:
+ raise Exception("Some classes are not found")
+
+ classes = np.array(classes, dtype="f")
+ presetn_num = np.array(classes, dtype="f")
+ if method == "median":
+ class_weight = calc_median_frequency(classes, present_num)
+ elif method == "log":
+ class_weight = calc_log_frequency(classes)
+ else:
+ raise Exception("Please assign method to 'mean' or 'log'")
+
+ print("class weight", class_weight)
+ print("Done!")
+
+
+if __name__ == '__main__':
+ calc_weights()
\ No newline at end of file