]> Gitweb @ Texas Instruments - Open Source Git Repositories - git.TI.com/gitweb - jacinto-ai/pytorch-jacinto-ai-devkit.git/blob - modules/pytorch_jacinto_ai/xvision/losses/basic_loss.py
renamed pytorch_jacinto_ai.vision to pytorch_jacinto_ai.xvision
[jacinto-ai/pytorch-jacinto-ai-devkit.git] / modules / pytorch_jacinto_ai / xvision / losses / basic_loss.py
1 import torch
2 from .loss_utils import *
4 __all__ = [
5     'BasicElementwiseLossModule', 'BasicNormLossModule', 'BasicSplitLossModule', 'MeanLossModule'
6 ]
9 class BasicElementwiseLossModule(torch.nn.Module):
10     '''
11     This is the basic class used for elementwise functions.
12     Unlike Norm functions these don't reduce the chennels dimension to size 1
13     The handling of sparse is different from that of Norm Loss
14     '''
15     def __init__(self, sparse=False, error_fn=None, error_name=None, is_avg=False, min_elements=0, square_root=False):
16         super().__init__()
17         self.sparse = sparse
18         self.error_fn = error_fn
19         self.error_name = error_name
20         self.is_avg = is_avg
21         self.min_elements = min_elements
22         self.square_root = square_root
24     def forward(self, input_img, input_flow, target_flow):
25         # input_flow, target_flow = utils.crop_alike(input_flow, target_flow)
26         # invalid flow is defined with both flow coordinates to be exactly 0
27         if self.sparse:
28             target_abs_sum = torch.sum(torch.abs(target_flow), dim=1, keepdim=True)
29             valid = (target_abs_sum != 0).expand_as(target_flow)
30             input_flow = input_flow[valid]
31             target_flow = target_flow[valid]
33         error_flow = self.error_fn(input_flow, target_flow)
34         error_val = error_flow.mean()
35         # sqrt as in RMSE operation
36         error_val = torch.sqrt(error_val) if self.square_root else error_val
37         if error_flow.dim() > 0 and len(error_flow) < self.min_elements:
38             # nan_tensor
39             error_val = torch.tensor(1.0, device=input_flow.device) * float('nan')
40         #
41         return (error_val)
42     def clear(self):
43         return
44     def info(self):
45         return {'value':'error', 'name':self.error_name, 'is_avg':self.is_avg}
46     @classmethod
47     def args(cls):
48         return ['sparse']
51 class BasicNormLossModule(torch.nn.Module):
52     '''
53     This is the basic class used for norm functions.
54     Norm functions usually reduce the chennels dimension to size 1
55         The handling of sparse is different from that of Elementiwise Loss
56     '''
57     def __init__(self, sparse=False, error_fn=None, error_name=None, is_avg=False):
58         super().__init__()
59         self.sparse = sparse
60         self.error_fn = error_fn
61         self.error_name = error_name
62         self.is_avg = is_avg
64     def forward(self, input_img, input_flow, target_flow):
65         # input_flow, target_flow = utils.crop_alike(input_flow, target_flow)
66         # invalid flow is defined with both flow coordinates to be exactly 0
67         error_flow = self.error_fn(input_flow, target_flow)
69         if self.sparse:
70             target_abs_sum = torch.sum(torch.abs(target_flow), dim=1, keepdim=True)
71             valid = (target_abs_sum != 0)
72             error_flow = error_flow[valid]
74         error_val = error_flow.mean()
75         return (error_val)
76     def clear(self):
77         return
78     def info(self):
79         return {'value':'error', 'name':self.error_name, 'is_avg':self.is_avg}
80     @classmethod
81     def args(cls):
82         return ['sparse']
85 class BasicSplitLossModule(torch.nn.Module):
86     def __init__(self, sparse=False, error_fn=None, error_name=None, is_avg=False, channels=None, losses=None, weights=None):
87         super().__init__()
88         self.sparse = sparse
89         self.error_fn = error_fn
90         self.error_name = error_name
91         self.is_avg = is_avg
92         self.channels = channels
93         self.losses = torch.nn.ModuleList(losses)
94         self.weights = weights
96     def forward(self, input_img, input_flow, target_flow):
97         ch_start = 0
98         for idx, (ch, wt) in enumerate(zip(self.channels, self.weights)):
99             ch_end = ch_start + ch
100             input_flow_split = input_flow[:,ch_start:ch_end,...]
101             target_flow_split = target_flow[:,ch_start:ch_end,...]
102             ch_start = ch_end
103             if idx == 0:
104                 total_loss = self.losses[idx](input_img, input_flow_split, target_flow_split)*wt
105             else:
106                 total_loss = total_loss + self.losses[idx](input_img, input_flow_split, target_flow_split)*wt
108         return total_loss
110     def clear(self):
111         return
112     def info(self):
113         return {'value':'error', 'name':self.error_name, 'is_avg':self.is_avg}
114     @classmethod
115     def args(cls):
116         return ['sparse']
119 #loss computed on the mean
120 class MeanLossModule(torch.nn.Module):
121     def __init__(self, sparse=False, error_fn=None, error_name=None, is_avg=False):
122         super().__init__()
123         self.sparse = sparse
124         self.error_fn = error_fn
125         self.error_name = error_name
126         self.is_avg = is_avg
128     def forward(self, input_img, input_flow, target_flow):
129         #input_flow, target_flow = utils.crop_alike(input_flow, target_flow)
130         # invalid flow is defined with both flow coordinates to be exactly 0
131         if self.sparse:
132             #mask = (target_flow == 0)
133             mask = (torch.sum(torch.abs(target_flow),dim=1,keepdim=True) == 0)
134             mask = mask.expand_as(target_flow)
136             valid = (mask == False)
137             input_flow = input_flow[valid]
138             target_flow = target_flow[valid]
139         #
140         input_mean = input_flow.mean()
141         target_mean = target_flow.mean()
142         error_val = self.error_fn(input_mean, target_mean)
143         return (error_val)
144     def clear(self):
145         return
146     def info(self):
147         return {'value':'error', 'name':self.error_name, 'is_avg':self.is_avg}
148     @classmethod
149     def args(cls):
150         return ['sparse']