class weights were not being used in segmentation loss due to a bug. fixed it.
authorManu Mathew <a0393608@ti.com>
Wed, 13 May 2020 11:36:04 +0000 (17:06 +0530)
committerManu Mathew <a0393608@ti.com>
Thu, 14 May 2020 07:56:30 +0000 (13:26 +0530)
modules/pytorch_jacinto_ai/vision/losses/segmentation_loss.py

index 90b715c1c90221e8f947c479cd0ea0d67ff2dcb9..96be9fc4cd348cd61a196330504ae28d461e4921 100755 (executable)
@@ -106,7 +106,11 @@ class SegmentationMetricsCalc(object):
 class SegmentationLoss(torch.nn.Module):
     def __init__(self, *args, ignore_index = 255, weight=None, **kwargs):
         super().__init__()
-        self.weight = None if weight is None else self.register_buffer('weight', torch.FloatTensor(weight))
+        if weight is None:
+            self.weight = None
+        else:
+            self.register_buffer('weight', torch.FloatTensor(weight))
+        #
         self.ignore_index = ignore_index
         self.is_avg = False
     #