]> Gitweb @ Texas Instruments - Open Source Git Repositories - git.TI.com/gitweb - jacinto-ai/pytorch-jacinto-ai-devkit.git/blob - modules/pytorch_jacinto_ai/xvision/losses/segmentation_loss.py
updated python package requirements (don't need tensorflow for tensorboard). not...
[jacinto-ai/pytorch-jacinto-ai-devkit.git] / modules / pytorch_jacinto_ai / xvision / losses / segmentation_loss.py
1 from torch.autograd import Variable
2 import torch.nn.functional as F
4 import numpy as np
5 import torch
6 from .loss_utils import *
8 __all__ = ['segmentation_loss', 'segmentation_metrics', 'SegmentationMetricsCalc']
11 def cross_entropy2d(input, target, weight=None, ignore_index=None, size_average=True):
12     #nll_loss expects long tensor target
13     target = target.long()
15     # 1. input: (n, c, h, w), target: (n, h, w)
16     n, c, h, w = input.size()
18     # 2. log_p: (n, c, h, w)
19     log_p = F.log_softmax(input, dim=1)
21     # 3. log_p: (n*h*w, c) - contiguous() required if transpose() is used before view().
22     log_p = log_p.transpose(1, 2).transpose(2, 3).contiguous().view(-1, c)
23     log_p = log_p[target.view(n * h * w, 1).repeat(1, c) >= 0]
24     log_p = log_p.view(-1, c)
26     #assert torch.all((target>=0) & ((target<19) | (target == ignore_index))), 'target range problem'
28     # 4. target: (n*h*w,)
29     mask = target >= 0
30     target = target[mask]
32     loss = F.nll_loss(log_p, target, ignore_index=ignore_index, weight=weight, size_average=False)
33     if size_average:
34         # loss /= mask.sum().data[0]
35         sum_val = mask.data.sum().float() if mask.numel()>0 else 0
36         loss = (loss/sum_val) if (sum_val!=0) else (loss*np.float32(0.0))
37     return loss
40 class SegmentationMetricsCalc(object):
41     def __init__(self, n_classes):
42         super().__init__()
43         self.n_classes = n_classes
44         self.confusion_matrix = np.zeros((n_classes, n_classes))
46     def __call__(self, label_preds, label_trues):
47         return self._update(label_preds, label_trues)
49     @staticmethod
50     def _fast_hist(label_pred, label_true, n_class):
51         mask = (label_true >= 0) & (label_true < n_class)
52         hist = np.bincount(n_class * label_true[mask].astype(int) + label_pred[mask],
53                            minlength=n_class**2).reshape(n_class, n_class)
54         return hist
56     def _update(self, label_preds, label_trues):
57         if type(label_trues) == torch.Tensor:
58             label_trues = label_trues.cpu().long().numpy()
59             if label_preds.shape[1] > 1:
60                 label_preds = label_preds.max(1)[1].cpu().numpy()
61             else:
62                 label_preds = label_preds.cpu().numpy()
64         for lt, lp in zip(label_trues, label_preds):
65             self.confusion_matrix += self._fast_hist(lp.flatten(), lt.flatten(), self.n_classes)
67         return self._get_scores()
69     def _get_scores(self):
70         """Returns accuracy score evaluation result.
71             - overall accuracy
72             - mean accuracy
73             - mean IU
74             - fwavacc
75         """
76         hist = self.confusion_matrix
77         tp = np.diag(hist)
78         sum_a1 = hist.sum(axis=1)
80         acc = tp.sum() / (hist.sum() + np.finfo(np.float32).eps)
82         acc_cls = tp / (sum_a1 + np.finfo(np.float32).eps)
83         acc_cls = np.nanmean(acc_cls)
85         iou = tp / (sum_a1 + hist.sum(axis=0) - tp + np.finfo(np.float32).eps)
86         mean_iou = np.nanmean(iou)
88         freq = sum_a1 / (hist.sum() + np.finfo(np.float32).eps)
89         fwavacc = (freq[freq > 0] * iou[freq > 0]).sum()
91         iou = iou*100
92         mean_iou = mean_iou*100
93         acc = acc*100
94         acc_cls = acc_cls*100
95         fwavacc = fwavacc*100
97         cls_iou = dict(zip(range(self.n_classes), iou))
98         return {'MeanIoU': mean_iou, 'OverallAcc': acc,
99                 'MeanAcc': acc_cls, 'FreqWtAcc': fwavacc,
100                 'ClsIoU':cls_iou}
102     def clear(self):
103         self.confusion_matrix = np.zeros((self.n_classes, self.n_classes))
106 class SegmentationLoss(torch.nn.Module):
107     def __init__(self, *args, ignore_index = 255, weight=None, **kwargs):
108         super().__init__()
109         if weight is None:
110             self.weight = None
111         else:
112             self.register_buffer('weight', torch.FloatTensor(weight))
113         #
114         self.ignore_index = ignore_index
115         self.is_avg = False
116     #
117     def forward(self, input_img, input, target):
118         weight = self.weight if (self.weight is not None and np.random.random() < 0.20) else None
119         loss = cross_entropy2d(input, target, weight, ignore_index=self.ignore_index)
120         return loss
121     def info(self):
122         return {'value':'loss', 'name':'CrossEntropyLoss', 'is_avg':self.is_avg}
123     def clear(self):
124         return
125     @classmethod
126     def args(cls):
127         return ['weight']
129 segmentation_loss = SegmentationLoss
132 class SegmentationMetrics(torch.nn.Module):
133     def __init__(self, *args, num_classes=None, **kwargs):
134         super().__init__()
135         self.metrics_calc = SegmentationMetricsCalc(num_classes)
136         # the output is an using the confusion matrix accumulated so far, after clear() was called.
137         self.is_avg = True
138     #
139     def forward(self, input_img, input, target):
140         is_cuda = input.is_cuda or target.is_cuda
141         metrics = self.metrics_calc(input, target)
142         metrics = metrics['MeanIoU']
143         metrics = torch.FloatTensor([metrics])
144         metrics = metrics.cuda() if is_cuda else metrics
145         return metrics
146     def forward_all(self, input, target):
147         metrics = self.metrics_calc(input, target)
148         return metrics
149     def clear(self):
150         self.metrics_calc.clear()
151         return
152     def info(self):
153         return {'value':'accuracy', 'name':'MeanIoU', 'is_avg':self.is_avg, 'confusion_matrix':self.metrics_calc.confusion_matrix}
154     @classmethod
155     def args(cls):
156         return ['num_classes']
158 segmentation_metrics = SegmentationMetrics