release commit
[jacinto-ai/pytorch-jacinto-ai-devkit.git] / modules / pytorch_jacinto_ai / vision / models / pixel2pixel / pixel2pixelnet.py
index b420347db92c1e26823beb2b25725356483efa86..fbec903270e03dc9223f46c43250ef8b2749ca17 100644 (file)
@@ -47,7 +47,8 @@ class Pixel2PixelNet(torch.nn.Module):
         self.output_channels = model_config.output_channels
         self.num_decoders = len(model_config.output_channels) if (model_config.num_decoders is None) else model_config.num_decoders
         self.split_outputs = model_config.split_outputs
-        self.multi_task = xnn.layers.MultiTask(self.num_decoders, model_config.multi_task_type, model_config.output_type) if model_config.multi_task else None
+        self.multi_task = xnn.layers.MultiTask(num_splits=self.num_decoders, multi_task_type=model_config.multi_task_type, output_type=model_config.output_type,
+                                               multi_task_factors=model_config.multi_task_factors) if model_config.multi_task else None
 
         #if model_config.freeze_encoder:
             #xnn.utils.freeze_bn(self.encoder)