]> Gitweb @ Texas Instruments - Open Source Git Repositories - git.TI.com/gitweb - jacinto-ai/pytorch-jacinto-ai-devkit.git/blobdiff - modules/pytorch_jacinto_ai/xvision/ops/misc.py
added mobilenetv3 from torchvision and also mobilenetv3_lite models, updated docs
[jacinto-ai/pytorch-jacinto-ai-devkit.git] / modules / pytorch_jacinto_ai / xvision / ops / misc.py
index 17ae5bcfc8952aee699939b2974300e01c9bf4a1..17e69c506d8f2e626c9885633771707c47b1cd27 100644 (file)
@@ -1,5 +1,3 @@
-from __future__ import division
-
 """
 helper class that supports empty tensors on some nn functions.
 
@@ -10,149 +8,92 @@ This can be removed once https://github.com/pytorch/pytorch/issues/12013
 is implemented
 """
 
-import math
+import warnings
 import torch
-from torch.nn.modules.utils import _ntuple
-
-
-class _NewEmptyTensorOp(torch.autograd.Function):
-    @staticmethod
-    def forward(ctx, x, new_shape):
-        ctx.shape = x.shape
-        return x.new_empty(new_shape)
-
-    @staticmethod
-    def backward(ctx, grad):
-        shape = ctx.shape
-        return _NewEmptyTensorOp.apply(grad, shape), None
+from torch import Tensor, Size
+from torch.jit.annotations import List, Optional, Tuple
 
 
 class Conv2d(torch.nn.Conv2d):
-    """
-    Equivalent to nn.Conv2d, but with support for empty batch sizes.
-    This will eventually be supported natively by PyTorch, and this
-    class can go away.
-    """
-    def forward(self, x):
-        if x.numel() > 0:
-            return super(Conv2d, self).forward(x)
-        # get output shape
-
-        output_shape = [
-            (i + 2 * p - (di * (k - 1) + 1)) // d + 1
-            for i, p, di, k, d in zip(
-                x.shape[-2:], self.padding, self.dilation, self.kernel_size, self.stride
-            )
-        ]
-        output_shape = [x.shape[0], self.weight.shape[0]] + output_shape
-        return _NewEmptyTensorOp.apply(x, output_shape)
+    def __init__(self, *args, **kwargs):
+        super().__init__(*args, **kwargs)
+        warnings.warn(
+            "torchvision.ops.misc.Conv2d is deprecated and will be "
+            "removed in future versions, use torch.nn.Conv2d instead.", FutureWarning)
 
 
 class ConvTranspose2d(torch.nn.ConvTranspose2d):
-    """
-    Equivalent to nn.ConvTranspose2d, but with support for empty batch sizes.
-    This will eventually be supported natively by PyTorch, and this
-    class can go away.
-    """
-    def forward(self, x):
-        if x.numel() > 0:
-            return super(ConvTranspose2d, self).forward(x)
-        # get output shape
-
-        output_shape = [
-            (i - 1) * d - 2 * p + (di * (k - 1) + 1) + op
-            for i, p, di, k, d, op in zip(
-                x.shape[-2:],
-                self.padding,
-                self.dilation,
-                self.kernel_size,
-                self.stride,
-                self.output_padding,
-            )
-        ]
-        output_shape = [x.shape[0], self.bias.shape[0]] + output_shape
-        return _NewEmptyTensorOp.apply(x, output_shape)
+    def __init__(self, *args, **kwargs):
+        super().__init__(*args, **kwargs)
+        warnings.warn(
+            "torchvision.ops.misc.ConvTranspose2d is deprecated and will be "
+            "removed in future versions, use torch.nn.ConvTranspose2d instead.", FutureWarning)
 
 
 class BatchNorm2d(torch.nn.BatchNorm2d):
-    """
-    Equivalent to nn.BatchNorm2d, but with support for empty batch sizes.
-    This will eventually be supported natively by PyTorch, and this
-    class can go away.
-    """
-    def forward(self, x):
-        if x.numel() > 0:
-            return super(BatchNorm2d, self).forward(x)
-        # get output shape
-        output_shape = x.shape
-        return _NewEmptyTensorOp.apply(x, output_shape)
+    def __init__(self, *args, **kwargs):
+        super().__init__(*args, **kwargs)
+        warnings.warn(
+            "torchvision.ops.misc.BatchNorm2d is deprecated and will be "
+            "removed in future versions, use torch.nn.BatchNorm2d instead.", FutureWarning)
 
 
-def interpolate(
-    input, size=None, scale_factor=None, mode="nearest", align_corners=None
-):
-    """
-    Equivalent to nn.functional.interpolate, but with support for empty batch sizes.
-    This will eventually be supported natively by PyTorch, and this
-    class can go away.
-    """
-    if input.numel() > 0:
-        return torch.nn.functional.interpolate(
-            input, size, scale_factor, mode, align_corners
-        )
-
-    def _check_size_scale_factor(dim):
-        if size is None and scale_factor is None:
-            raise ValueError("either size or scale_factor should be defined")
-        if size is not None and scale_factor is not None:
-            raise ValueError("only one of size or scale_factor should be defined")
-        if (
-            scale_factor is not None and
-            isinstance(scale_factor, tuple) and
-            len(scale_factor) != dim
-        ):
-            raise ValueError(
-                "scale_factor shape must match input shape. "
-                "Input is {}D, scale_factor size is {}".format(dim, len(scale_factor))
-            )
-
-    def _output_size(dim):
-        _check_size_scale_factor(dim)
-        if size is not None:
-            return size
-        scale_factors = _ntuple(dim)(scale_factor)
-        # math.floor might return float in py2.7
-        return [
-            int(math.floor(input.size(i + 2) * scale_factors[i])) for i in range(dim)
-        ]
-
-    output_shape = tuple(_output_size(2))
-    output_shape = input.shape[:-2] + output_shape
-    return _NewEmptyTensorOp.apply(input, output_shape)
+interpolate = torch.nn.functional.interpolate
 
 
 # This is not in nn
-class FrozenBatchNorm2d(torch.jit.ScriptModule):
+class FrozenBatchNorm2d(torch.nn.Module):
     """
     BatchNorm2d where the batch statistics and the affine parameters
     are fixed
     """
 
-    def __init__(self, n):
+    def __init__(
+        self,
+        num_features: int,
+        eps: float = 0.,
+        n: Optional[int] = None,
+    ):
+        # n=None for backward-compatibility
+        if n is not None:
+            warnings.warn("`n` argument is deprecated and has been renamed `num_features`",
+                          DeprecationWarning)
+            num_features = n
         super(FrozenBatchNorm2d, self).__init__()
-        self.register_buffer("weight", torch.ones(n))
-        self.register_buffer("bias", torch.zeros(n))
-        self.register_buffer("running_mean", torch.zeros(n))
-        self.register_buffer("running_var", torch.ones(n))
-
-    @torch.jit.script_method
-    def forward(self, x):
+        self.eps = eps
+        self.register_buffer("weight", torch.ones(num_features))
+        self.register_buffer("bias", torch.zeros(num_features))
+        self.register_buffer("running_mean", torch.zeros(num_features))
+        self.register_buffer("running_var", torch.ones(num_features))
+
+    def _load_from_state_dict(
+        self,
+        state_dict: dict,
+        prefix: str,
+        local_metadata: dict,
+        strict: bool,
+        missing_keys: List[str],
+        unexpected_keys: List[str],
+        error_msgs: List[str],
+    ):
+        num_batches_tracked_key = prefix + 'num_batches_tracked'
+        if num_batches_tracked_key in state_dict:
+            del state_dict[num_batches_tracked_key]
+
+        super(FrozenBatchNorm2d, self)._load_from_state_dict(
+            state_dict, prefix, local_metadata, strict,
+            missing_keys, unexpected_keys, error_msgs)
+
+    def forward(self, x: Tensor) -> Tensor:
         # move reshapes to the beginning
         # to make it fuser-friendly
         w = self.weight.reshape(1, -1, 1, 1)
         b = self.bias.reshape(1, -1, 1, 1)
         rv = self.running_var.reshape(1, -1, 1, 1)
         rm = self.running_mean.reshape(1, -1, 1, 1)
-        scale = w * rv.rsqrt()
+        scale = w * (rv + self.eps).rsqrt()
         bias = b - rm * scale
         return x * scale + bias
+
+    def __repr__(self) -> str:
+        return f"{self.__class__.__name__}({self.weight.shape[0]})"