[jacinto-ai/pytorch-jacinto-ai-devkit.git] / modules / pytorch_jacinto_ai / xnn / layers / multi_task.py
diff --git a/modules/pytorch_jacinto_ai/xnn/layers/multi_task.py b/modules/pytorch_jacinto_ai/xnn/layers/multi_task.py
index 7c9f8e63ac43bcc45a9f17c0bab71090fb101b2e..d3c1975fd3a90af81559dd6ede450968dac5e435 100644 (file)
'''
multi_task_type: "grad_norm" "pseudo_grad_norm" "naive_grad_norm" "dwa" "dwa_gradnorm" "uncertainty"
'''
'''
multi_task_type: "grad_norm" "pseudo_grad_norm" "naive_grad_norm" "dwa" "dwa_gradnorm" "uncertainty"
'''
- def __init__(self, num_splits = 1, multi_task_type=None, output_type=None):
+ def __init__(self, num_splits = 1, multi_task_type=None, output_type=None, multi_task_factors = None):
super().__init__()
################################
super().__init__()
################################
self.num_splits = num_splits
self.losses_short = None
self.losses_long = None
self.num_splits = num_splits
self.losses_short = None
self.losses_long = None
+
# self.loss_scales = torch.nn.Parameter(torch.ones(num_splits, device='cuda:0'))
# self.loss_scales = torch.ones(num_splits, device='cuda:0', dtype=torch.float32) #requires_grad=True
# self.loss_scales = torch.nn.Parameter(torch.ones(num_splits, device='cuda:0'))
# self.loss_scales = torch.ones(num_splits, device='cuda:0', dtype=torch.float32) #requires_grad=True
- self.loss_scales = torch.tensor([0.2, 0.2, 2.6], device='cuda:0', dtype=torch.float32)
+ self.loss_scales = torch.ones(num_splits, device='cuda:0', dtype=torch.float32) if multi_task_factors is None else \
+ torch.tensor(multi_task_factors, device='cuda:0', dtype=torch.float32)
self.loss_offsets = torch.zeros(num_splits, device='cuda:0', dtype=torch.float32) #None
self.dy_norms_smooth = None
self.register_backward_hook(self.backward_hook)
self.loss_offsets = torch.zeros(num_splits, device='cuda:0', dtype=torch.float32) #None
self.dy_norms_smooth = None
self.register_backward_hook(self.backward_hook)