release commit
[jacinto-ai/pytorch-jacinto-ai-devkit.git] / modules / pytorch_jacinto_ai / xnn / layers / multi_task.py
1 import torch
2 import torch.optim as optim
3 import numpy as np
4 import sys
5 import torch.nn.functional as F
7 #####################################################
8 class MultiTask(torch.nn.Module):
9     '''
10     multi_task_type: "grad_norm" "pseudo_grad_norm" "naive_grad_norm" "dwa" "dwa_gradnorm" "uncertainty"
11     '''
12     def __init__(self, num_splits = 1, multi_task_type=None, output_type=None, multi_task_factors = None):
13         super().__init__()
15         ################################
16         # check args
17         assert multi_task_type in (None, 'learned', 'uncertainty', 'grad_norm', 'pseudo_grad_norm','dwa_grad_norm'), 'invalid value for multi_task_type'
19         self.num_splits = num_splits
20         self.losses_short = None
21         self.losses_long = None
23         # self.loss_scales = torch.nn.Parameter(torch.ones(num_splits, device='cuda:0'))
24         # self.loss_scales = torch.ones(num_splits, device='cuda:0', dtype=torch.float32) #requires_grad=True
25         self.loss_scales = torch.ones(num_splits, device='cuda:0', dtype=torch.float32) if multi_task_factors is None else \
26                             torch.tensor(multi_task_factors, device='cuda:0', dtype=torch.float32)
27         self.loss_offsets = torch.zeros(num_splits, device='cuda:0', dtype=torch.float32) #None
28         self.dy_norms_smooth = None
29         self.register_backward_hook(self.backward_hook)
30         self.alpha = 0.75  #0.75 #0.5 #0.12  #1.5
31         self.lr = 1e-4  #1e-2 #1e-3 #1e-4 #1e-5
32         self.momentum = 0.9
33         self.beta = 0.999
34         self.multi_task_type = ('uncertainty' if multi_task_type == 'learned' else multi_task_type)
35         self.output_type = output_type
36         self.long_smooth = 1e-6
37         self.short_smooth = 1e-3
38         self.eps = 1e-6
39         # self.grad_loss = torch.nn.L1Loss()
40         self.temperature = 1.0 #2.0
41         self.dy_norms = None
42         if self.multi_task_type == 'uncertainty':
43             self.sigma_factor = torch.zeros(num_splits, device='cuda:0', dtype=torch.float32)  # requires_grad=True
44             self.uncertainty_factors = self.loss_scales
45             for task_idx in enumerate(num_splits):
46                 output_type = self.output_type[task_idx]
47                 discrete_loss = (output_type in ('segmentation', 'classification'))
48                 self.sigma_factor[task_idx] = (1 if discrete_loss else 2)
49                 self.loss_scales[task_idx] = (-0.5)*torch.log(self.uncertainty_factors[task_idx]*self.sigma_factor[task_idx])
51         if self.multi_task_type in ["grad_norm", "uncertainty"]:
52             if self.multi_task_type is 'grad_norm':
53                 param_groups = [{'params':self.loss_scales}]
54             elif self.multi_task_type == 'uncertainty':
55                 param_groups = [{'params': self.uncertainty_factors}]
56             self.gradnorm_solver = 'sgd' #'adam' #'sgd'
57             if self.gradnorm_solver == 'adam':
58                 self.optimizer = torch.optim.Adam(param_groups, self.lr, betas=(self.momentum, self.beta))
59             elif self.gradnorm_solver == 'sgd':
60                 self.optimizer = torch.optim.SGD(param_groups, self.lr, momentum=self.momentum)
62     def forward(self, x):
63         return torch.stack([x for split in range(self.num_splits)])
65     def backward_hook(self, module, dy_list, dx):
66         self.dy_norms = torch.stack([torch.norm(dy, p=2) for dy in dy_list]).to('cuda:0')
67         if self.dy_norms_smooth is not None:
68             self.dy_norms_smooth = self.dy_norms_smooth*(1-self.short_smooth) + self.dy_norms*self.short_smooth
69         else:
70             self.dy_norms_smooth = self.dy_norms
71         # dy_norms_smooth_mean = self.dy_norms_smooth.mean()
72         self.update_loss_scale()
73         del dx, module, dy_list
75     def set_losses(self, losses):
76         if self.losses_short is not None:
77             self.losses_short = self.losses_short*(1-self.short_smooth) + torch.stack([loss.detach() for loss in losses])*self.short_smooth
78             self.losses_long = self.losses_long*(1-self.long_smooth) + torch.stack([loss.detach() for loss in losses])*self.long_smooth
79         else:
80             self.losses_short = torch.stack([loss.detach() for loss in losses])
81             self.losses_long = torch.stack([loss.detach() for loss in losses])
83     def get_loss_scales(self):
84         return self.loss_scales, self.loss_offsets
86     def update_loss_scale(self, model=None, loss_list=None):
87         # wc  = model.module.encoder.features.stream0._modules['17'].conv._modules['6'].weight #final common layer
88         # dy_list = [torch.autograd.grad(loss_list[index], wc, retain_graph=True)[0] #, create_graph=True
89         #                 for index in range(len(loss_list))]
90         # dy_norms = torch.stack([torch.norm(dy, p=2) for dy in dy_list])
91         #
92         # if self.dy_norms_smooth is not None:
93         #     self.dy_norms_smooth = self.dy_norms_smooth*(1-self.short_smooth) + dy_norms*self.short_smooth
94         # else:
95         #     self.dy_norms_smooth = dy_norms
96         #
97         dy_norms_mean = self.dy_norms.mean().detach()
99         dy_norms_smooth_mean = self.dy_norms_smooth.mean()
100         inverse_training_rate = self.losses_short / (self.losses_long + self.eps)
102         if self.multi_task_type is "grad_norm" :#2nd order update rakes load of time to update
103             rel_inverse_training_rate = inverse_training_rate/ inverse_training_rate.mean()
104             target_dy_norm = dy_norms_mean * rel_inverse_training_rate**self.alpha
105             self.optimizer.zero_grad()
106             dy_norms_loss = self.grad_loss(self.dy_norms, target_dy_norm)
107             # dy_norms_loss.backward()
108             self.loss_scales.grad = torch.autograd.grad(dy_norms_loss, self.loss_scales)[0]
109             self.optimizer.step()
110             #initializing the optimizer and the loss_scales once again
111             self.loss_scales = torch.nn.Parameter(3.0 * self.loss_scales / self.loss_scales.sum())
112             param_groups = [{'params':self.loss_scales}]
113             # self.optimizer = torch.optim.Adam(param_groups, self.lr, betas=(self.momentum, self.beta))
114             self.optimizer = torch.optim.SGD(param_groups, self.lr, momentum=self.momentum)
115             del inverse_training_rate, rel_inverse_training_rate, target_dy_norm, dy_norms_loss
116             torch.cuda.empty_cache()
117             return self.loss_scales
119         elif self.multi_task_type is "naive_grad_norm": # special case of pseudo_grad_norm with aplha=1.0
120             update_factor =  ((dy_norms_smooth_mean / (self.dy_norms_smooth + self.eps)) * (self.losses_short / (self.losses_long+self.eps)))
121             self.loss_scales = self.loss_scales + self.lr*(self.loss_scales* (update_factor-1))
122             self.loss_scales = 3.0*self.loss_scales/self.loss_scales.sum()
123             del update_factor
125         elif self.multi_task_type is "pseudo_grad_norm": #works reasonably well
126             rel_inverse_training_rate = inverse_training_rate/ inverse_training_rate.sum()
127             target_dy_norm = dy_norms_smooth_mean * rel_inverse_training_rate**self.alpha
128             update_factor = (target_dy_norm/(self.dy_norms_smooth + self.eps))
129             self.loss_scales = self.loss_scales + self.lr*(self.loss_scales* (update_factor-1))
130             self.loss_scales = 3.0*self.loss_scales/self.loss_scales.sum()
131             del inverse_training_rate, rel_inverse_training_rate, target_dy_norm, update_factor
133         elif self.multi_task_type is "dwa": #update using dynamic weight averaging, doesn't work well because of the drastic updates of weights
134             self.loss_scales = 3.0 * F.softmax(inverse_training_rate/self.temperature, dim=0)
136         elif self.multi_task_type is "dwa_gradnorm": #update using dynamic weight averaging along with gradient information, have the best results until now
137             inverse_training_rate = F.softmax(inverse_training_rate/self.temperature, dim=0)
138             target_dy_norm = dy_norms_smooth_mean * inverse_training_rate**self.alpha         #Try to tune the smoothness parameter of the gradient
139             update_factor = (target_dy_norm/(self.dy_norms_smooth + self.eps))
140             self.loss_scales = self.loss_scales + self.lr*(self.loss_scales* (update_factor-1))
141             self.loss_scales = 3.0*self.loss_scales/self.loss_scales.sum()
142             # print(self.loss_scales)
143         elif self.multi_task_type == "uncertainty":  # Uncertainty based weight update
144             self.optimizer.step()
145             for task_idx in enumerate(self.num_splits):
146                 clip_scale = True
147                 loss_scale = torch.exp(self.uncertainty_factors[task_idx]*(-2))/self.sigma_factor
148                 self.loss_scales[task_idx] = torch.nn.functional.tanh(loss_scale) if clip_scale else loss_scale
149             #
150             self.loss_offset = self.uncertainty_factors
151         #
152         del dy_norms_smooth_mean, dy_norms_mean
155     def find_last_common_weight(self):
156         """
157         :return: Given a model, we must return the last common layer from the encoder. This is not required for the current implementation. However, may be needed in future.
158         """
159         pass
161 #####################################################
162 def set_losses(module, losses):
163     def set_losses_func(op):
164         if isinstance(op, MultiTask):
165             op.set_losses(losses)
166     #--
167     module.apply(set_losses_func)
170 loss_scales, loss_offsets = None, None
171 def get_loss_scales(module):
172     def get_loss_scales_func(op):
173         global loss_scales, loss_offsets
174         if isinstance(op, MultiTask):
175             loss_scales, loss_offsets = op.get_loss_scales()
176     #--
177     module.apply(get_loss_scales_func)
178     return loss_scales, loss_offsets