release commit
[jacinto-ai/pytorch-jacinto-ai-devkit.git] / modules / pytorch_jacinto_ai / xnn / layers / multi_task.py
index 7c9f8e63ac43bcc45a9f17c0bab71090fb101b2e..d3c1975fd3a90af81559dd6ede450968dac5e435 100644 (file)
@@ -9,7 +9,7 @@ class MultiTask(torch.nn.Module):
     '''
     multi_task_type: "grad_norm" "pseudo_grad_norm" "naive_grad_norm" "dwa" "dwa_gradnorm" "uncertainty"
     '''
-    def __init__(self, num_splits = 1, multi_task_type=None, output_type=None):
+    def __init__(self, num_splits = 1, multi_task_type=None, output_type=None, multi_task_factors = None):
         super().__init__()
 
         ################################
@@ -19,9 +19,11 @@ class MultiTask(torch.nn.Module):
         self.num_splits = num_splits
         self.losses_short = None
         self.losses_long = None
+
         # self.loss_scales = torch.nn.Parameter(torch.ones(num_splits, device='cuda:0'))
         # self.loss_scales = torch.ones(num_splits, device='cuda:0', dtype=torch.float32) #requires_grad=True
-        self.loss_scales = torch.tensor([0.2, 0.2, 2.6], device='cuda:0', dtype=torch.float32)
+        self.loss_scales = torch.ones(num_splits, device='cuda:0', dtype=torch.float32) if multi_task_factors is None else \
+                            torch.tensor(multi_task_factors, device='cuda:0', dtype=torch.float32)
         self.loss_offsets = torch.zeros(num_splits, device='cuda:0', dtype=torch.float32) #None
         self.dy_norms_smooth = None
         self.register_backward_hook(self.backward_hook)