]> Gitweb @ Texas Instruments - Open Source Git Repositories - git.TI.com/gitweb - jacinto-ai/pytorch-jacinto-ai-devkit.git/blob - modules/pytorch_jacinto_ai/xvision/datasets/pixel2pixel/calculate_class_weights.py
updated python package requirements (don't need tensorflow for tensorboard). not...
[jacinto-ai/pytorch-jacinto-ai-devkit.git] / modules / pytorch_jacinto_ai / xvision / datasets / pixel2pixel / calculate_class_weights.py
1 import numpy as np
2 import os
3 import scipy.misc as misc
4 import sys
6 from .cityscapes_plus import CityscapesBaseSegmentationLoader, CityscapesBaseMotionLoader
9 def calc_median_frequency(classes, present_num):
10     """
11     Class balancing by median frequency balancing method.
12     Reference: https://arxiv.org/pdf/1411.4734.pdf
13        'a = median_freq / freq(c) where freq(c) is the number of pixels
14         of class c divided by the total number of pixels in images where
15         c is present, and median_freq is the median of these frequencies.'
16     """
17     class_freq = classes / present_num
18     median_freq = np.median(class_freq)
19     return median_freq / class_freq
22 def calc_log_frequency(classes, value=1.02):
23     """Class balancing by ERFNet method.
24        prob = each_sum_pixel / each_sum_pixel.max()
25        a = 1 / (log(1.02 + prob)).
26     """
27     class_freq = classes / classes.sum()  # ERFNet is max, but ERFNet is sum
28     # print(class_freq)
29     # print(np.log(value + class_freq))
30     return 1 / np.log(value + class_freq)
33 def calc_weights():
34     method = "median"
35     result_path = "/afs/cg.cs.tu-bs.de/home/zhang/SEDPShuffleNet/datasets"
37     traval = "gtFine"
38     imgs_path = "./data/tiad/data/leftImg8bit/train"    #"./data/cityscapes/data/leftImg8bit/train"   #"./data/TIAD/data/leftImg8bit/train"
39     lbls_path = "./data/tiad/data/gtFine/train"         #"./data/cityscapes/data/gtFine/train"   # "./data/tiad/data/gtFine/train"  #"./data/cityscapes_frame_pair/data/gtFine/train"
40     labels = xnn.utils.recursive_glob(rootdir=lbls_path, suffix='labelTrainIds_motion.png')  #'labelTrainIds_motion.png'  #'labelTrainIds.png'
42     num_classes = 2       #5  #2
44     local_path = "./data/checkpoints"
45     dst = CityscapesBaseMotionLoader() #TiadBaseSegmentationLoader()  #CityscapesBaseSegmentationLoader()  #CityscapesBaseMotionLoader()
47     classes, present_num = ([0 for i in range(num_classes)] for i in range(2))
49     for idx, lbl_path in enumerate(labels):
50         lbl = misc.imread(lbl_path)
51         lbl = dst.encode_segmap(np.array(lbl, dtype=np.uint8))
53         for nc in range(num_classes):
54             num_pixel = (lbl == nc).sum()
55             if num_pixel:
56                 classes[nc] += num_pixel
57                 present_num[nc] += 1
59     if 0 in classes:
60         raise Exception("Some classes are not found")
62     classes = np.array(classes, dtype="f")
63     presetn_num = np.array(classes, dtype="f")
64     if method == "median":
65         class_weight = calc_median_frequency(classes, present_num)
66     elif method == "log":
67         class_weight = calc_log_frequency(classes)
68     else:
69         raise Exception("Please assign method to 'mean' or 'log'")
71     print("class weight", class_weight)
72     print("Done!")
75 if __name__ == '__main__':
76     calc_weights()