[jacinto-ai/pytorch-jacinto-ai-devkit.git] / modules / pytorch_jacinto_ai / vision / datasets / pixel2pixel / calculate_class_weight.py
1 import numpy as np
2 import os
3 import scipy.misc as misc
4 from .... import xnn
6 from .cityscapes_plus import CityscapesBaseSegmentationLoader, CityscapesBaseMotionLoader
9 def calc_median_frequency(classes, present_num):
10 """
11 Class balancing by median frequency balancing method.
12 Reference: https://arxiv.org/pdf/1411.4734.pdf
13 'a = median_freq / freq(c) where freq(c) is the number of pixels
14 of class c divided by the total number of pixels in images where
15 c is present, and median_freq is the median of these frequencies.'
16 """
17 class_freq = classes / present_num
18 median_freq = np.median(class_freq)
19 return median_freq / class_freq
22 def calc_log_frequency(classes, value=1.02):
23 """Class balancing by ERFNet method.
24 prob = each_sum_pixel / each_sum_pixel.max()
25 a = 1 / (log(1.02 + prob)).
26 """
27 class_freq = classes / classes.sum() # ERFNet is max, but ERFNet is sum
28 # print(class_freq)
29 # print(np.log(value + class_freq))
30 return 1 / np.log(value + class_freq)
33 if __name__ == '__main__':
35 method = "median"
36 result_path = "/afs/cg.cs.tu-bs.de/home/zhang/SEDPShuffleNet/datasets"
38 traval = "gtFine"
39 imgs_path = "./data/tiad/data/leftImg8bit/train" #"./data/cityscapes/data/leftImg8bit/train" #"./data/TIAD/data/leftImg8bit/train"
40 lbls_path = "./data/tiad/data/gtFine/train" #"./data/cityscapes/data/gtFine/train" # "./data/tiad/data/gtFine/train" #"./data/cityscapes_frame_pair/data/gtFine/train"
41 labels = xnn.utils.recursive_glob(rootdir=lbls_path, suffix='labelTrainIds_motion.png') #'labelTrainIds_motion.png' #'labelTrainIds.png'
43 num_classes = 2 #5 #2
45 local_path = "./data/checkpoints"
46 dst = CityscapesBaseMotionLoader() #TiadBaseSegmentationLoader() #CityscapesBaseSegmentationLoader() #CityscapesBaseMotionLoader()
48 classes, present_num = ([0 for i in range(num_classes)] for i in range(2))
50 for idx, lbl_path in enumerate(labels):
51 lbl = misc.imread(lbl_path)
52 lbl = dst.encode_segmap(np.array(lbl, dtype=np.uint8))
54 for nc in range(num_classes):
55 num_pixel = (lbl == nc).sum()
56 if num_pixel:
57 classes[nc] += num_pixel
58 present_num[nc] += 1
60 if 0 in classes:
61 raise Exception("Some classes are not found")
63 classes = np.array(classes, dtype="f")
64 presetn_num = np.array(classes, dtype="f")
65 if method == "median":
66 class_weight = calc_median_frequency(classes, present_num)
67 elif method == "log":
68 class_weight = calc_log_frequency(classes)
69 else:
70 raise Exception("Please assign method to 'mean' or 'log'")
72 print("class weight", class_weight)
73 print("Done!")