]> Gitweb @ Texas Instruments - Open Source Git Repositories - git.TI.com/gitweb - jacinto-ai/pytorch-jacinto-ai-devkit.git/blobdiff - modules/pytorch_jacinto_ai/xvision/transforms/functional_tensor.py
added mobilenetv3 from torchvision and also mobilenetv3_lite models, updated docs
[jacinto-ai/pytorch-jacinto-ai-devkit.git] / modules / pytorch_jacinto_ai / xvision / transforms / functional_tensor.py
diff --git a/modules/pytorch_jacinto_ai/xvision/transforms/functional_tensor.py b/modules/pytorch_jacinto_ai/xvision/transforms/functional_tensor.py
new file mode 100644 (file)
index 0000000..36a1228
--- /dev/null
@@ -0,0 +1,891 @@
+import warnings
+from typing import Optional, Dict, Tuple
+
+import torch
+from torch import Tensor
+from torch.nn.functional import grid_sample
+from torch.jit.annotations import List, BroadcastingList2
+
+
+def _is_tensor_a_torch_image(x: Tensor) -> bool:
+    return x.ndim >= 2
+
+
+def _get_image_size(img: Tensor) -> List[int]:
+    """Returns (w, h) of tensor image"""
+    if _is_tensor_a_torch_image(img):
+        return [img.shape[-1], img.shape[-2]]
+    raise TypeError("Unexpected type {}".format(type(img)))
+
+
+def vflip(img: Tensor) -> Tensor:
+    """Vertically flip the given the Image Tensor.
+
+    Args:
+        img (Tensor): Image Tensor to be flipped in the form [C, H, W].
+
+    Returns:
+        Tensor:  Vertically flipped image Tensor.
+    """
+    if not _is_tensor_a_torch_image(img):
+        raise TypeError('tensor is not a torch image.')
+
+    return img.flip(-2)
+
+
+def hflip(img: Tensor) -> Tensor:
+    """Horizontally flip the given the Image Tensor.
+
+    Args:
+        img (Tensor): Image Tensor to be flipped in the form [C, H, W].
+
+    Returns:
+        Tensor:  Horizontally flipped image Tensor.
+    """
+    if not _is_tensor_a_torch_image(img):
+        raise TypeError('tensor is not a torch image.')
+
+    return img.flip(-1)
+
+
+def crop(img: Tensor, top: int, left: int, height: int, width: int) -> Tensor:
+    """Crop the given Image Tensor.
+
+    Args:
+        img (Tensor): Image to be cropped in the form [..., H, W]. (0,0) denotes the top left corner of the image.
+        top (int): Vertical component of the top left corner of the crop box.
+        left (int): Horizontal component of the top left corner of the crop box.
+        height (int): Height of the crop box.
+        width (int): Width of the crop box.
+
+    Returns:
+        Tensor: Cropped image.
+    """
+    if not _is_tensor_a_torch_image(img):
+        raise TypeError("tensor is not a torch image.")
+
+    return img[..., top:top + height, left:left + width]
+
+
+def rgb_to_grayscale(img: Tensor) -> Tensor:
+    """Convert the given RGB Image Tensor to Grayscale.
+    For RGB to Grayscale conversion, ITU-R 601-2 luma transform is performed which
+    is L = R * 0.2989 + G * 0.5870 + B * 0.1140
+
+    Args:
+        img (Tensor): Image to be converted to Grayscale in the form [C, H, W].
+
+    Returns:
+        Tensor: Grayscale image.
+
+    """
+    if img.shape[0] != 3:
+        raise TypeError('Input Image does not contain 3 Channels')
+
+    return (0.2989 * img[0] + 0.5870 * img[1] + 0.1140 * img[2]).to(img.dtype)
+
+
+def adjust_brightness(img: Tensor, brightness_factor: float) -> Tensor:
+    """Adjust brightness of an RGB image.
+
+    Args:
+        img (Tensor): Image to be adjusted.
+        brightness_factor (float):  How much to adjust the brightness. Can be
+            any non negative number. 0 gives a black image, 1 gives the
+            original image while 2 increases the brightness by a factor of 2.
+
+    Returns:
+        Tensor: Brightness adjusted image.
+    """
+    if brightness_factor < 0:
+        raise ValueError('brightness_factor ({}) is not non-negative.'.format(brightness_factor))
+
+    if not _is_tensor_a_torch_image(img):
+        raise TypeError('tensor is not a torch image.')
+
+    return _blend(img, torch.zeros_like(img), brightness_factor)
+
+
+def adjust_contrast(img: Tensor, contrast_factor: float) -> Tensor:
+    """Adjust contrast of an RGB image.
+
+    Args:
+        img (Tensor): Image to be adjusted.
+        contrast_factor (float): How much to adjust the contrast. Can be any
+            non negative number. 0 gives a solid gray image, 1 gives the
+            original image while 2 increases the contrast by a factor of 2.
+
+    Returns:
+        Tensor: Contrast adjusted image.
+    """
+    if contrast_factor < 0:
+        raise ValueError('contrast_factor ({}) is not non-negative.'.format(contrast_factor))
+
+    if not _is_tensor_a_torch_image(img):
+        raise TypeError('tensor is not a torch image.')
+
+    mean = torch.mean(rgb_to_grayscale(img).to(torch.float))
+
+    return _blend(img, mean, contrast_factor)
+
+
+def adjust_hue(img, hue_factor):
+    """Adjust hue of an image.
+
+    The image hue is adjusted by converting the image to HSV and
+    cyclically shifting the intensities in the hue channel (H).
+    The image is then converted back to original image mode.
+
+    `hue_factor` is the amount of shift in H channel and must be in the
+    interval `[-0.5, 0.5]`.
+
+    See `Hue`_ for more details.
+
+    .. _Hue: https://en.wikipedia.org/wiki/Hue
+
+    Args:
+        img (Tensor): Image to be adjusted. Image type is either uint8 or float.
+        hue_factor (float):  How much to shift the hue channel. Should be in
+            [-0.5, 0.5]. 0.5 and -0.5 give complete reversal of hue channel in
+            HSV space in positive and negative direction respectively.
+            0 means no shift. Therefore, both -0.5 and 0.5 will give an image
+            with complementary colors while 0 gives the original image.
+
+    Returns:
+         Tensor: Hue adjusted image.
+    """
+    if not (-0.5 <= hue_factor <= 0.5):
+        raise ValueError('hue_factor ({}) is not in [-0.5, 0.5].'.format(hue_factor))
+
+    if not _is_tensor_a_torch_image(img):
+        raise TypeError('tensor is not a torch image.')
+
+    orig_dtype = img.dtype
+    if img.dtype == torch.uint8:
+        img = img.to(dtype=torch.float32) / 255.0
+
+    img = _rgb2hsv(img)
+    h, s, v = img.unbind(0)
+    h += hue_factor
+    h = h % 1.0
+    img = torch.stack((h, s, v))
+    img_hue_adj = _hsv2rgb(img)
+
+    if orig_dtype == torch.uint8:
+        img_hue_adj = (img_hue_adj * 255.0).to(dtype=orig_dtype)
+
+    return img_hue_adj
+
+
+def adjust_saturation(img: Tensor, saturation_factor: float) -> Tensor:
+    """Adjust color saturation of an RGB image.
+
+    Args:
+        img (Tensor): Image to be adjusted.
+        saturation_factor (float):  How much to adjust the saturation. Can be any
+            non negative number. 0 gives a black and white image, 1 gives the
+            original image while 2 enhances the saturation by a factor of 2.
+
+    Returns:
+        Tensor: Saturation adjusted image.
+    """
+    if saturation_factor < 0:
+        raise ValueError('saturation_factor ({}) is not non-negative.'.format(saturation_factor))
+
+    if not _is_tensor_a_torch_image(img):
+        raise TypeError('tensor is not a torch image.')
+
+    return _blend(img, rgb_to_grayscale(img), saturation_factor)
+
+
+def adjust_gamma(img: Tensor, gamma: float, gain: float = 1) -> Tensor:
+    r"""Adjust gamma of an RGB image.
+
+    Also known as Power Law Transform. Intensities in RGB mode are adjusted
+    based on the following equation:
+
+    .. math::
+        `I_{\text{out}} = 255 \times \text{gain} \times \left(\frac{I_{\text{in}}}{255}\right)^{\gamma}`
+
+    See `Gamma Correction`_ for more details.
+
+    .. _Gamma Correction: https://en.wikipedia.org/wiki/Gamma_correction
+
+    Args:
+        img (Tensor): Tensor of RBG values to be adjusted.
+        gamma (float): Non negative real number, same as :math:`\gamma` in the equation.
+            gamma larger than 1 make the shadows darker,
+            while gamma smaller than 1 make dark regions lighter.
+        gain (float): The constant multiplier.
+    """
+
+    if not isinstance(img, torch.Tensor):
+        raise TypeError('img should be a Tensor. Got {}'.format(type(img)))
+
+    if gamma < 0:
+        raise ValueError('Gamma should be a non-negative real number')
+
+    result = img
+    dtype = img.dtype
+    if not torch.is_floating_point(img):
+        result = result / 255.0
+
+    result = (gain * result ** gamma).clamp(0, 1)
+
+    if result.dtype != dtype:
+        eps = 1e-3
+        result = (255 + 1.0 - eps) * result
+    result = result.to(dtype)
+    return result
+
+
+def center_crop(img: Tensor, output_size: BroadcastingList2[int]) -> Tensor:
+    """DEPRECATED. Crop the Image Tensor and resize it to desired size.
+
+    .. warning::
+
+        This method is deprecated and will be removed in future releases.
+        Please, use ``F.center_crop`` instead.
+
+    Args:
+        img (Tensor): Image to be cropped.
+        output_size (sequence or int): (height, width) of the crop box. If int,
+                it is used for both directions
+
+    Returns:
+            Tensor: Cropped image.
+    """
+    warnings.warn(
+        "This method is deprecated and will be removed in future releases. "
+        "Please, use ``F.center_crop`` instead."
+    )
+
+    if not _is_tensor_a_torch_image(img):
+        raise TypeError('tensor is not a torch image.')
+
+    _, image_width, image_height = img.size()
+    crop_height, crop_width = output_size
+    # crop_top = int(round((image_height - crop_height) / 2.))
+    # Result can be different between python func and scripted func
+    # Temporary workaround:
+    crop_top = int((image_height - crop_height + 1) * 0.5)
+    # crop_left = int(round((image_width - crop_width) / 2.))
+    # Result can be different between python func and scripted func
+    # Temporary workaround:
+    crop_left = int((image_width - crop_width + 1) * 0.5)
+
+    return crop(img, crop_top, crop_left, crop_height, crop_width)
+
+
+def five_crop(img: Tensor, size: BroadcastingList2[int]) -> List[Tensor]:
+    """DEPRECATED. Crop the given Image Tensor into four corners and the central crop.
+
+    .. warning::
+
+        This method is deprecated and will be removed in future releases.
+        Please, use ``F.five_crop`` instead.
+
+    .. Note::
+
+        This transform returns a List of Tensors and there may be a
+        mismatch in the number of inputs and targets your ``Dataset`` returns.
+
+    Args:
+        img (Tensor): Image to be cropped.
+        size (sequence or int): Desired output size of the crop. If size is an
+            int instead of sequence like (h, w), a square crop (size, size) is
+            made.
+
+    Returns:
+       List: List (tl, tr, bl, br, center)
+                Corresponding top left, top right, bottom left, bottom right and center crop.
+    """
+    warnings.warn(
+        "This method is deprecated and will be removed in future releases. "
+        "Please, use ``F.five_crop`` instead."
+    )
+
+    if not _is_tensor_a_torch_image(img):
+        raise TypeError('tensor is not a torch image.')
+
+    assert len(size) == 2, "Please provide only two dimensions (h, w) for size."
+
+    _, image_width, image_height = img.size()
+    crop_height, crop_width = size
+    if crop_width > image_width or crop_height > image_height:
+        msg = "Requested crop size {} is bigger than input size {}"
+        raise ValueError(msg.format(size, (image_height, image_width)))
+
+    tl = crop(img, 0, 0, crop_width, crop_height)
+    tr = crop(img, image_width - crop_width, 0, image_width, crop_height)
+    bl = crop(img, 0, image_height - crop_height, crop_width, image_height)
+    br = crop(img, image_width - crop_width, image_height - crop_height, image_width, image_height)
+    center = center_crop(img, (crop_height, crop_width))
+
+    return [tl, tr, bl, br, center]
+
+
+def ten_crop(img: Tensor, size: BroadcastingList2[int], vertical_flip: bool = False) -> List[Tensor]:
+    """DEPRECATED. Crop the given Image Tensor into four corners and the central crop plus the
+        flipped version of these (horizontal flipping is used by default).
+
+    .. warning::
+
+        This method is deprecated and will be removed in future releases.
+        Please, use ``F.ten_crop`` instead.
+
+    .. Note::
+
+        This transform returns a List of images and there may be a
+        mismatch in the number of inputs and targets your ``Dataset`` returns.
+
+    Args:
+        img (Tensor): Image to be cropped.
+        size (sequence or int): Desired output size of the crop. If size is an
+            int instead of sequence like (h, w), a square crop (size, size) is
+            made.
+        vertical_flip (bool): Use vertical flipping instead of horizontal
+
+    Returns:
+       List: List (tl, tr, bl, br, center, tl_flip, tr_flip, bl_flip, br_flip, center_flip)
+                Corresponding top left, top right, bottom left, bottom right and center crop
+                and same for the flipped image's tensor.
+    """
+    warnings.warn(
+        "This method is deprecated and will be removed in future releases. "
+        "Please, use ``F.ten_crop`` instead."
+    )
+
+    if not _is_tensor_a_torch_image(img):
+        raise TypeError('tensor is not a torch image.')
+
+    assert len(size) == 2, "Please provide only two dimensions (h, w) for size."
+    first_five = five_crop(img, size)
+
+    if vertical_flip:
+        img = vflip(img)
+    else:
+        img = hflip(img)
+
+    second_five = five_crop(img, size)
+
+    return first_five + second_five
+
+
+def _blend(img1: Tensor, img2: Tensor, ratio: float) -> Tensor:
+    bound = 1 if img1.dtype in [torch.half, torch.float32, torch.float64] else 255
+    return (ratio * img1 + (1 - ratio) * img2).clamp(0, bound).to(img1.dtype)
+
+
+def _rgb2hsv(img):
+    r, g, b = img.unbind(0)
+
+    maxc = torch.max(img, dim=0).values
+    minc = torch.min(img, dim=0).values
+
+    # The algorithm erases S and H channel where `maxc = minc`. This avoids NaN
+    # from happening in the results, because
+    #   + S channel has division by `maxc`, which is zero only if `maxc = minc`
+    #   + H channel has division by `(maxc - minc)`.
+    #
+    # Instead of overwriting NaN afterwards, we just prevent it from occuring so
+    # we don't need to deal with it in case we save the NaN in a buffer in
+    # backprop, if it is ever supported, but it doesn't hurt to do so.
+    eqc = maxc == minc
+
+    cr = maxc - minc
+    # Since `eqc => cr = 0`, replacing denominator with 1 when `eqc` is fine.
+    s = cr / torch.where(eqc, maxc.new_ones(()), maxc)
+    # Note that `eqc => maxc = minc = r = g = b`. So the following calculation
+    # of `h` would reduce to `bc - gc + 2 + rc - bc + 4 + rc - bc = 6` so it
+    # would not matter what values `rc`, `gc`, and `bc` have here, and thus
+    # replacing denominator with 1 when `eqc` is fine.
+    cr_divisor = torch.where(eqc, maxc.new_ones(()), cr)
+    rc = (maxc - r) / cr_divisor
+    gc = (maxc - g) / cr_divisor
+    bc = (maxc - b) / cr_divisor
+
+    hr = (maxc == r) * (bc - gc)
+    hg = ((maxc == g) & (maxc != r)) * (2.0 + rc - bc)
+    hb = ((maxc != g) & (maxc != r)) * (4.0 + gc - rc)
+    h = (hr + hg + hb)
+    h = torch.fmod((h / 6.0 + 1.0), 1.0)
+    return torch.stack((h, s, maxc))
+
+
+def _hsv2rgb(img):
+    h, s, v = img.unbind(0)
+    i = torch.floor(h * 6.0)
+    f = (h * 6.0) - i
+    i = i.to(dtype=torch.int32)
+
+    p = torch.clamp((v * (1.0 - s)), 0.0, 1.0)
+    q = torch.clamp((v * (1.0 - s * f)), 0.0, 1.0)
+    t = torch.clamp((v * (1.0 - s * (1.0 - f))), 0.0, 1.0)
+    i = i % 6
+
+    mask = i == torch.arange(6)[:, None, None]
+
+    a1 = torch.stack((v, q, p, p, t, v))
+    a2 = torch.stack((t, v, v, q, p, p))
+    a3 = torch.stack((p, p, t, v, v, q))
+    a4 = torch.stack((a1, a2, a3))
+
+    return torch.einsum("ijk, xijk -> xjk", mask.to(dtype=img.dtype), a4)
+
+
+def _pad_symmetric(img: Tensor, padding: List[int]) -> Tensor:
+    # padding is left, right, top, bottom
+    in_sizes = img.size()
+
+    x_indices = [i for i in range(in_sizes[-1])]  # [0, 1, 2, 3, ...]
+    left_indices = [i for i in range(padding[0] - 1, -1, -1)]  # e.g. [3, 2, 1, 0]
+    right_indices = [-(i + 1) for i in range(padding[1])]  # e.g. [-1, -2, -3]
+    x_indices = torch.tensor(left_indices + x_indices + right_indices)
+
+    y_indices = [i for i in range(in_sizes[-2])]
+    top_indices = [i for i in range(padding[2] - 1, -1, -1)]
+    bottom_indices = [-(i + 1) for i in range(padding[3])]
+    y_indices = torch.tensor(top_indices + y_indices + bottom_indices)
+
+    ndim = img.ndim
+    if ndim == 3:
+        return img[:, y_indices[:, None], x_indices[None, :]]
+    elif ndim == 4:
+        return img[:, :, y_indices[:, None], x_indices[None, :]]
+    else:
+        raise RuntimeError("Symmetric padding of N-D tensors are not supported yet")
+
+
+def pad(img: Tensor, padding: List[int], fill: int = 0, padding_mode: str = "constant") -> Tensor:
+    r"""Pad the given Tensor Image on all sides with specified padding mode and fill value.
+
+    Args:
+        img (Tensor): Image to be padded.
+        padding (int or tuple or list): Padding on each border. If a single int is provided this
+            is used to pad all borders. If a tuple or list of length 2 is provided this is the padding
+            on left/right and top/bottom respectively. If a tuple or list of length 4 is provided
+            this is the padding for the left, top, right and bottom borders
+            respectively. In torchscript mode padding as single int is not supported, use a tuple or
+            list of length 1: ``[padding, ]``.
+        fill (int): Pixel fill value for constant fill. Default is 0.
+            This value is only used when the padding_mode is constant
+        padding_mode (str): Type of padding. Should be: constant, edge or reflect. Default is constant.
+            Mode symmetric is not yet supported for Tensor inputs.
+
+            - constant: pads with a constant value, this value is specified with fill
+
+            - edge: pads with the last value on the edge of the image
+
+            - reflect: pads with reflection of image (without repeating the last value on the edge)
+
+                       padding [1, 2, 3, 4] with 2 elements on both sides in reflect mode
+                       will result in [3, 2, 1, 2, 3, 4, 3, 2]
+
+            - symmetric: pads with reflection of image (repeating the last value on the edge)
+
+                         padding [1, 2, 3, 4] with 2 elements on both sides in symmetric mode
+                         will result in [2, 1, 1, 2, 3, 4, 4, 3]
+
+    Returns:
+        Tensor: Padded image.
+    """
+    if not _is_tensor_a_torch_image(img):
+        raise TypeError("tensor is not a torch image.")
+
+    if not isinstance(padding, (int, tuple, list)):
+        raise TypeError("Got inappropriate padding arg")
+    if not isinstance(fill, (int, float)):
+        raise TypeError("Got inappropriate fill arg")
+    if not isinstance(padding_mode, str):
+        raise TypeError("Got inappropriate padding_mode arg")
+
+    if isinstance(padding, tuple):
+        padding = list(padding)
+
+    if isinstance(padding, list) and len(padding) not in [1, 2, 4]:
+        raise ValueError("Padding must be an int or a 1, 2, or 4 element tuple, not a " +
+                         "{} element tuple".format(len(padding)))
+
+    if padding_mode not in ["constant", "edge", "reflect", "symmetric"]:
+        raise ValueError("Padding mode should be either constant, edge, reflect or symmetric")
+
+    if isinstance(padding, int):
+        if torch.jit.is_scripting():
+            # This maybe unreachable
+            raise ValueError("padding can't be an int while torchscripting, set it as a list [value, ]")
+        pad_left = pad_right = pad_top = pad_bottom = padding
+    elif len(padding) == 1:
+        pad_left = pad_right = pad_top = pad_bottom = padding[0]
+    elif len(padding) == 2:
+        pad_left = pad_right = padding[0]
+        pad_top = pad_bottom = padding[1]
+    else:
+        pad_left = padding[0]
+        pad_top = padding[1]
+        pad_right = padding[2]
+        pad_bottom = padding[3]
+
+    p = [pad_left, pad_right, pad_top, pad_bottom]
+
+    if padding_mode == "edge":
+        # remap padding_mode str
+        padding_mode = "replicate"
+    elif padding_mode == "symmetric":
+        # route to another implementation
+        if p[0] < 0 or p[1] < 0 or p[2] < 0 or p[3] < 0:  # no any support for torch script
+            raise ValueError("Padding can not be negative for symmetric padding_mode")
+        return _pad_symmetric(img, p)
+
+    need_squeeze = False
+    if img.ndim < 4:
+        img = img.unsqueeze(dim=0)
+        need_squeeze = True
+
+    out_dtype = img.dtype
+    need_cast = False
+    if (padding_mode != "constant") and img.dtype not in (torch.float32, torch.float64):
+        # Here we temporary cast input tensor to float
+        # until pytorch issue is resolved :
+        # https://github.com/pytorch/pytorch/issues/40763
+        need_cast = True
+        img = img.to(torch.float32)
+
+    img = torch.nn.functional.pad(img, p, mode=padding_mode, value=float(fill))
+
+    if need_squeeze:
+        img = img.squeeze(dim=0)
+
+    if need_cast:
+        img = img.to(out_dtype)
+
+    return img
+
+
+def resize(img: Tensor, size: List[int], interpolation: int = 2) -> Tensor:
+    r"""Resize the input Tensor to the given size.
+
+    Args:
+        img (Tensor): Image to be resized.
+        size (int or tuple or list): Desired output size. If size is a sequence like
+            (h, w), the output size will be matched to this. If size is an int,
+            the smaller edge of the image will be matched to this number maintaining
+            the aspect ratio. i.e, if height > width, then image will be rescaled to
+            :math:`\left(\text{size} \times \frac{\text{height}}{\text{width}}, \text{size}\right)`.
+            In torchscript mode padding as a single int is not supported, use a tuple or
+            list of length 1: ``[size, ]``.
+        interpolation (int, optional): Desired interpolation. Default is bilinear (=2). Other supported values:
+            nearest(=0) and bicubic(=3).
+
+    Returns:
+        Tensor: Resized image.
+    """
+    if not _is_tensor_a_torch_image(img):
+        raise TypeError("tensor is not a torch image.")
+
+    if not isinstance(size, (int, tuple, list)):
+        raise TypeError("Got inappropriate size arg")
+    if not isinstance(interpolation, int):
+        raise TypeError("Got inappropriate interpolation arg")
+
+    _interpolation_modes = {
+        0: "nearest",
+        2: "bilinear",
+        3: "bicubic",
+    }
+
+    if interpolation not in _interpolation_modes:
+        raise ValueError("This interpolation mode is unsupported with Tensor input")
+
+    if isinstance(size, tuple):
+        size = list(size)
+
+    if isinstance(size, list) and len(size) not in [1, 2]:
+        raise ValueError("Size must be an int or a 1 or 2 element tuple/list, not a "
+                         "{} element tuple/list".format(len(size)))
+
+    w, h = _get_image_size(img)
+
+    if isinstance(size, int):
+        size_w, size_h = size, size
+    elif len(size) < 2:
+        size_w, size_h = size[0], size[0]
+    else:
+        size_w, size_h = size[1], size[0]  # Convention (h, w)
+
+    if isinstance(size, int) or len(size) < 2:
+        if w < h:
+            size_h = int(size_w * h / w)
+        else:
+            size_w = int(size_h * w / h)
+
+        if (w <= h and w == size_w) or (h <= w and h == size_h):
+            return img
+
+    # make image NCHW
+    need_squeeze = False
+    if img.ndim < 4:
+        img = img.unsqueeze(dim=0)
+        need_squeeze = True
+
+    mode = _interpolation_modes[interpolation]
+
+    out_dtype = img.dtype
+    need_cast = False
+    if img.dtype not in (torch.float32, torch.float64):
+        need_cast = True
+        img = img.to(torch.float32)
+
+    # Define align_corners to avoid warnings
+    align_corners = False if mode in ["bilinear", "bicubic"] else None
+
+    img = torch.nn.functional.interpolate(img, size=(size_h, size_w), mode=mode, align_corners=align_corners)
+
+    if need_squeeze:
+        img = img.squeeze(dim=0)
+
+    if need_cast:
+        if mode == "bicubic":
+            img = img.clamp(min=0, max=255)
+        img = img.to(out_dtype)
+
+    return img
+
+
+def _assert_grid_transform_inputs(
+        img: Tensor,
+        matrix: Optional[List[float]],
+        resample: int,
+        fillcolor: Optional[int],
+        _interpolation_modes: Dict[int, str],
+        coeffs: Optional[List[float]] = None,
+):
+    if not (isinstance(img, torch.Tensor) and _is_tensor_a_torch_image(img)):
+        raise TypeError("img should be Tensor Image. Got {}".format(type(img)))
+
+    if matrix is not None and not isinstance(matrix, list):
+        raise TypeError("Argument matrix should be a list. Got {}".format(type(matrix)))
+
+    if matrix is not None and len(matrix) != 6:
+        raise ValueError("Argument matrix should have 6 float values")
+
+    if coeffs is not None and len(coeffs) != 8:
+        raise ValueError("Argument coeffs should have 8 float values")
+
+    if fillcolor is not None:
+        warnings.warn("Argument fill/fillcolor is not supported for Tensor input. Fill value is zero")
+
+    if resample not in _interpolation_modes:
+        raise ValueError("Resampling mode '{}' is unsupported with Tensor input".format(resample))
+
+
+def _apply_grid_transform(img: Tensor, grid: Tensor, mode: str) -> Tensor:
+    # make image NCHW
+    need_squeeze = False
+    if img.ndim < 4:
+        img = img.unsqueeze(dim=0)
+        need_squeeze = True
+
+    out_dtype = img.dtype
+    need_cast = False
+    if img.dtype not in (torch.float32, torch.float64):
+        need_cast = True
+        img = img.to(torch.float32)
+
+    img = grid_sample(img, grid, mode=mode, padding_mode="zeros", align_corners=False)
+
+    if need_squeeze:
+        img = img.squeeze(dim=0)
+
+    if need_cast:
+        # it is better to round before cast
+        img = torch.round(img).to(out_dtype)
+
+    return img
+
+
+def _gen_affine_grid(
+        theta: Tensor, w: int, h: int, ow: int, oh: int,
+) -> Tensor:
+    # https://github.com/pytorch/pytorch/blob/74b65c32be68b15dc7c9e8bb62459efbfbde33d8/aten/src/ATen/native/
+    # AffineGridGenerator.cpp#L18
+    # Difference with AffineGridGenerator is that:
+    # 1) we normalize grid values after applying theta
+    # 2) we can normalize by other image size, such that it covers "extend" option like in PIL.Image.rotate
+
+    d = 0.5
+    base_grid = torch.empty(1, oh, ow, 3, dtype=theta.dtype, device=theta.device)
+    base_grid[..., 0].copy_(torch.linspace(-ow * 0.5 + d, ow * 0.5 + d - 1, steps=ow))
+    base_grid[..., 1].copy_(torch.linspace(-oh * 0.5 + d, oh * 0.5 + d - 1, steps=oh).unsqueeze_(-1))
+    base_grid[..., 2].fill_(1)
+
+    rescaled_theta = theta.transpose(1, 2) / torch.tensor([0.5 * w, 0.5 * h], dtype=theta.dtype, device=theta.device)
+    output_grid = base_grid.view(1, oh * ow, 3).bmm(rescaled_theta)
+    return output_grid.view(1, oh, ow, 2)
+
+
+def affine(
+        img: Tensor, matrix: List[float], resample: int = 0, fillcolor: Optional[int] = None
+) -> Tensor:
+    """Apply affine transformation on the Tensor image keeping image center invariant.
+
+    Args:
+        img (Tensor): image to be rotated.
+        matrix (list of floats): list of 6 float values representing inverse matrix for affine transformation.
+        resample (int, optional): An optional resampling filter. Default is nearest (=0). Other supported values:
+            bilinear(=2).
+        fillcolor (int, optional): this option is not supported for Tensor input. Fill value for the area outside the
+            transform in the output image is always 0.
+
+    Returns:
+        Tensor: Transformed image.
+    """
+    _interpolation_modes = {
+        0: "nearest",
+        2: "bilinear",
+    }
+
+    _assert_grid_transform_inputs(img, matrix, resample, fillcolor, _interpolation_modes)
+
+    theta = torch.tensor(matrix, dtype=torch.float, device=img.device).reshape(1, 2, 3)
+    shape = img.shape
+    # grid will be generated on the same device as theta and img
+    grid = _gen_affine_grid(theta, w=shape[-1], h=shape[-2], ow=shape[-1], oh=shape[-2])
+    mode = _interpolation_modes[resample]
+    return _apply_grid_transform(img, grid, mode)
+
+
+def _compute_output_size(matrix: List[float], w: int, h: int) -> Tuple[int, int]:
+
+    # Inspired of PIL implementation:
+    # https://github.com/python-pillow/Pillow/blob/11de3318867e4398057373ee9f12dcb33db7335c/src/PIL/Image.py#L2054
+
+    # pts are Top-Left, Top-Right, Bottom-Left, Bottom-Right points.
+    pts = torch.tensor([
+        [-0.5 * w, -0.5 * h, 1.0],
+        [-0.5 * w, 0.5 * h, 1.0],
+        [0.5 * w, 0.5 * h, 1.0],
+        [0.5 * w, -0.5 * h, 1.0],
+    ])
+    theta = torch.tensor(matrix, dtype=torch.float).reshape(1, 2, 3)
+    new_pts = pts.view(1, 4, 3).bmm(theta.transpose(1, 2)).view(4, 2)
+    min_vals, _ = new_pts.min(dim=0)
+    max_vals, _ = new_pts.max(dim=0)
+
+    # Truncate precision to 1e-4 to avoid ceil of Xe-15 to 1.0
+    tol = 1e-4
+    cmax = torch.ceil((max_vals / tol).trunc_() * tol)
+    cmin = torch.floor((min_vals / tol).trunc_() * tol)
+    size = cmax - cmin
+    return int(size[0]), int(size[1])
+
+
+def rotate(
+        img: Tensor, matrix: List[float], resample: int = 0, expand: bool = False, fill: Optional[int] = None
+) -> Tensor:
+    """Rotate the Tensor image by angle.
+
+    Args:
+        img (Tensor): image to be rotated.
+        matrix (list of floats): list of 6 float values representing inverse matrix for rotation transformation.
+            Translation part (``matrix[2]`` and ``matrix[5]``) should be in pixel coordinates.
+        resample (int, optional): An optional resampling filter. Default is nearest (=0). Other supported values:
+            bilinear(=2).
+        expand (bool, optional): Optional expansion flag.
+            If true, expands the output image to make it large enough to hold the entire rotated image.
+            If false or omitted, make the output image the same size as the input image.
+            Note that the expand flag assumes rotation around the center and no translation.
+        fill (n-tuple or int or float): this option is not supported for Tensor input.
+            Fill value for the area outside the transform in the output image is always 0.
+
+    Returns:
+        Tensor: Rotated image.
+
+    .. _filters: https://pillow.readthedocs.io/en/latest/handbook/concepts.html#filters
+
+    """
+    _interpolation_modes = {
+        0: "nearest",
+        2: "bilinear",
+    }
+
+    _assert_grid_transform_inputs(img, matrix, resample, fill, _interpolation_modes)
+    w, h = img.shape[-1], img.shape[-2]
+    ow, oh = _compute_output_size(matrix, w, h) if expand else (w, h)
+    theta = torch.tensor(matrix, dtype=torch.float, device=img.device).reshape(1, 2, 3)
+    # grid will be generated on the same device as theta and img
+    grid = _gen_affine_grid(theta, w=w, h=h, ow=ow, oh=oh)
+    mode = _interpolation_modes[resample]
+
+    return _apply_grid_transform(img, grid, mode)
+
+
+def _perspective_grid(coeffs: List[float], ow: int, oh: int, device: torch.device):
+    # https://github.com/python-pillow/Pillow/blob/4634eafe3c695a014267eefdce830b4a825beed7/
+    # src/libImaging/Geometry.c#L394
+
+    #
+    # x_out = (coeffs[0] * x + coeffs[1] * y + coeffs[2]) / (coeffs[6] * x + coeffs[7] * y + 1)
+    # y_out = (coeffs[3] * x + coeffs[4] * y + coeffs[5]) / (coeffs[6] * x + coeffs[7] * y + 1)
+    #
+
+    theta1 = torch.tensor([[
+        [coeffs[0], coeffs[1], coeffs[2]],
+        [coeffs[3], coeffs[4], coeffs[5]]
+    ]], dtype=torch.float, device=device)
+    theta2 = torch.tensor([[
+        [coeffs[6], coeffs[7], 1.0],
+        [coeffs[6], coeffs[7], 1.0]
+    ]], dtype=torch.float, device=device)
+
+    d = 0.5
+    base_grid = torch.empty(1, oh, ow, 3, dtype=torch.float, device=device)
+    base_grid[..., 0].copy_(torch.linspace(d, ow * 1.0 + d - 1.0, steps=ow))
+    base_grid[..., 1].copy_(torch.linspace(d, oh * 1.0 + d - 1.0, steps=oh).unsqueeze_(-1))
+    base_grid[..., 2].fill_(1)
+
+    rescaled_theta1 = theta1.transpose(1, 2) / torch.tensor([0.5 * ow, 0.5 * oh], dtype=torch.float, device=device)
+    output_grid1 = base_grid.view(1, oh * ow, 3).bmm(rescaled_theta1)
+    output_grid2 = base_grid.view(1, oh * ow, 3).bmm(theta2.transpose(1, 2))
+
+    output_grid = output_grid1 / output_grid2 - 1.0
+    return output_grid.view(1, oh, ow, 2)
+
+
+def perspective(
+        img: Tensor, perspective_coeffs: List[float], interpolation: int = 2, fill: Optional[int] = None
+) -> Tensor:
+    """Perform perspective transform of the given Tensor image.
+
+    Args:
+        img (Tensor): Image to be transformed.
+        perspective_coeffs (list of float): perspective transformation coefficients.
+        interpolation (int): Interpolation type. Default, ``PIL.Image.BILINEAR``.
+        fill (n-tuple or int or float): this option is not supported for Tensor input. Fill value for the area
+            outside the transform in the output image is always 0.
+
+    Returns:
+        Tensor: transformed image.
+    """
+    if not (isinstance(img, torch.Tensor) and _is_tensor_a_torch_image(img)):
+        raise TypeError('img should be Tensor Image. Got {}'.format(type(img)))
+
+    _interpolation_modes = {
+        0: "nearest",
+        2: "bilinear",
+    }
+
+    _assert_grid_transform_inputs(
+        img,
+        matrix=None,
+        resample=interpolation,
+        fillcolor=fill,
+        _interpolation_modes=_interpolation_modes,
+        coeffs=perspective_coeffs
+    )
+
+    ow, oh = img.shape[-1], img.shape[-2]
+    grid = _perspective_grid(perspective_coeffs, ow=ow, oh=oh, device=img.device)
+    mode = _interpolation_modes[interpolation]
+
+    return _apply_grid_transform(img, grid, mode)