[jacinto-ai/pytorch-jacinto-ai-devkit.git] / modules / pytorch_jacinto_ai / vision / losses / scale_loss.py
1 import torch
2 from .basic_loss import *
3 from .loss_utils import *
6 class ScaleLoss(MeanLossModule):
7 def __init__(self, sparse=False, error_fn=charbonnier, error_name='ScaleLoss'):
8 super().__init__(sparse=sparse, error_fn=error_fn, error_name=error_name)
9 charbonnier_scale_loss = ScaleLoss
11 class ScaleDiff(MeanLossModule):
12 def __init__(self, sparse=False, error_fn=abs_diff, error_name='ScaleDiff'):
13 super().__init__(sparse=sparse, error_fn=error_fn, error_name=error_name)
14 scale_abs_loss = ScaleDiff
16 class SmoothL1ScaleLoss(MeanLossModule):
17 def __init__(self, sparse=False, error_fn=smooth_l1_loss, error_name='SmoothL1ScaleLoss'):
18 super().__init__(sparse=sparse, error_fn=error_fn, error_name=error_name)
19 scale_loss = smooth_l1_norm_scale_loss = ScaleLoss