]> Gitweb @ Texas Instruments - Open Source Git Repositories - git.TI.com/gitweb - jacinto-ai/pytorch-jacinto-ai-devkit.git/blob - modules/pytorch_jacinto_ai/xvision/losses/scale_loss.py
renamed pytorch_jacinto_ai.vision to pytorch_jacinto_ai.xvision
[jacinto-ai/pytorch-jacinto-ai-devkit.git] / modules / pytorch_jacinto_ai / xvision / losses / scale_loss.py
1 import torch
2 from .basic_loss import *
3 from .loss_utils import *
6 class ScaleLoss(MeanLossModule):
7     def __init__(self, sparse=False, error_fn=charbonnier, error_name='ScaleLoss'):
8         super().__init__(sparse=sparse, error_fn=error_fn, error_name=error_name)
9 charbonnier_scale_loss = ScaleLoss
11 class ScaleDiff(MeanLossModule):
12     def __init__(self, sparse=False, error_fn=abs_diff, error_name='ScaleDiff'):
13         super().__init__(sparse=sparse, error_fn=error_fn, error_name=error_name)
14 scale_abs_loss = ScaleDiff
16 class SmoothL1ScaleLoss(MeanLossModule):
17     def __init__(self, sparse=False, error_fn=smooth_l1_loss, error_name='SmoothL1ScaleLoss'):
18         super().__init__(sparse=sparse, error_fn=error_fn, error_name=error_name)
19 scale_loss = smooth_l1_norm_scale_loss = ScaleLoss