[jacinto-ai/pytorch-jacinto-ai-devkit.git] / modules / pytorch_jacinto_ai / xvision / losses / segmentation_loss.py
1 from torch.autograd import Variable
2 import torch.nn.functional as F
4 import numpy as np
5 import torch
6 from .loss_utils import *
8 __all__ = ['segmentation_loss', 'segmentation_metrics', 'SegmentationMetricsCalc']
11 def cross_entropy2d(input, target, weight=None, ignore_index=None, size_average=True):
12 #nll_loss expects long tensor target
13 target = target.long()
15 # 1. input: (n, c, h, w), target: (n, h, w)
16 n, c, h, w = input.size()
18 # 2. log_p: (n, c, h, w)
19 log_p = F.log_softmax(input, dim=1)
21 # 3. log_p: (n*h*w, c) - contiguous() required if transpose() is used before view().
22 log_p = log_p.transpose(1, 2).transpose(2, 3).contiguous().view(-1, c)
23 log_p = log_p[target.view(n * h * w, 1).repeat(1, c) >= 0]
24 log_p = log_p.view(-1, c)
26 #assert torch.all((target>=0) & ((target<19) | (target == ignore_index))), 'target range problem'
28 # 4. target: (n*h*w,)
29 mask = target >= 0
30 target = target[mask]
32 loss = F.nll_loss(log_p, target, ignore_index=ignore_index, weight=weight, size_average=False)
33 if size_average:
34 # loss /= mask.sum().data[0]
35 sum_val = mask.data.sum().float() if mask.numel()>0 else 0
36 loss = (loss/sum_val) if (sum_val!=0) else (loss*np.float32(0.0))
37 return loss
40 class SegmentationMetricsCalc(object):
41 def __init__(self, n_classes):
42 super().__init__()
43 self.n_classes = n_classes
44 self.confusion_matrix = np.zeros((n_classes, n_classes))
46 def __call__(self, label_preds, label_trues):
47 return self._update(label_preds, label_trues)
49 @staticmethod
50 def _fast_hist(label_pred, label_true, n_class):
51 mask = (label_true >= 0) & (label_true < n_class)
52 hist = np.bincount(n_class * label_true[mask].astype(int) + label_pred[mask],
53 minlength=n_class**2).reshape(n_class, n_class)
54 return hist
56 def _update(self, label_preds, label_trues):
57 if type(label_trues) == torch.Tensor:
58 label_trues = label_trues.cpu().long().numpy()
59 if label_preds.shape[1] > 1:
60 label_preds = label_preds.max(1)[1].cpu().numpy()
61 else:
62 label_preds = label_preds.cpu().numpy()
64 for lt, lp in zip(label_trues, label_preds):
65 self.confusion_matrix += self._fast_hist(lp.flatten(), lt.flatten(), self.n_classes)
67 return self._get_scores()
69 def _get_scores(self):
70 """Returns accuracy score evaluation result.
71 - overall accuracy
72 - mean accuracy
73 - mean IU
74 - fwavacc
75 """
76 hist = self.confusion_matrix
77 tp = np.diag(hist)
78 sum_a1 = hist.sum(axis=1)
80 acc = tp.sum() / (hist.sum() + np.finfo(np.float32).eps)
82 acc_cls = tp / (sum_a1 + np.finfo(np.float32).eps)
83 acc_cls = np.nanmean(acc_cls)
85 iou = tp / (sum_a1 + hist.sum(axis=0) - tp + np.finfo(np.float32).eps)
86 mean_iou = np.nanmean(iou)
88 freq = sum_a1 / (hist.sum() + np.finfo(np.float32).eps)
89 fwavacc = (freq[freq > 0] * iou[freq > 0]).sum()
91 iou = iou*100
92 mean_iou = mean_iou*100
93 acc = acc*100
94 acc_cls = acc_cls*100
95 fwavacc = fwavacc*100
97 cls_iou = dict(zip(range(self.n_classes), iou))
98 return {'MeanIoU': mean_iou, 'OverallAcc': acc,
99 'MeanAcc': acc_cls, 'FreqWtAcc': fwavacc,
100 'ClsIoU':cls_iou}
102 def clear(self):
103 self.confusion_matrix = np.zeros((self.n_classes, self.n_classes))
106 class SegmentationLoss(torch.nn.Module):
107 def __init__(self, *args, ignore_index = 255, weight=None, **kwargs):
108 super().__init__()
109 if weight is None:
110 self.weight = None
111 else:
112 self.register_buffer('weight', torch.FloatTensor(weight))
113 #
114 self.ignore_index = ignore_index
115 self.is_avg = False
116 #
117 def forward(self, input_img, input, target):
118 weight = self.weight if (self.weight is not None and np.random.random() < 0.20) else None
119 loss = cross_entropy2d(input, target, weight, ignore_index=self.ignore_index)
120 return loss
121 def info(self):
122 return {'value':'loss', 'name':'CrossEntropyLoss', 'is_avg':self.is_avg}
123 def clear(self):
124 return
125 @classmethod
126 def args(cls):
127 return ['weight']
128 #
129 segmentation_loss = SegmentationLoss
132 class SegmentationMetrics(torch.nn.Module):
133 def __init__(self, *args, num_classes=None, **kwargs):
134 super().__init__()
135 self.metrics_calc = SegmentationMetricsCalc(num_classes)
136 # the output is an using the confusion matrix accumulated so far, after clear() was called.
137 self.is_avg = True
138 #
139 def forward(self, input_img, input, target):
140 is_cuda = input.is_cuda or target.is_cuda
141 metrics = self.metrics_calc(input, target)
142 metrics = metrics['MeanIoU']
143 metrics = torch.FloatTensor([metrics])
144 metrics = metrics.cuda() if is_cuda else metrics
145 return metrics
146 def forward_all(self, input, target):
147 metrics = self.metrics_calc(input, target)
148 return metrics
149 def clear(self):
150 self.metrics_calc.clear()
151 return
152 def info(self):
153 return {'value':'accuracy', 'name':'MeanIoU', 'is_avg':self.is_avg, 'confusion_matrix':self.metrics_calc.confusion_matrix}
154 @classmethod
155 def args(cls):
156 return ['num_classes']
157 #
158 segmentation_metrics = SegmentationMetrics