[jacinto-ai/pytorch-jacinto-ai-devkit.git] / modules / pytorch_jacinto_ai / xvision / transforms / image_transform_utils.py
diff --git a/modules/pytorch_jacinto_ai/xvision/transforms/image_transform_utils.py b/modules/pytorch_jacinto_ai/xvision/transforms/image_transform_utils.py
index c4a60c6a8a6748352b3256bbd3443bffafafae47..98d9e5040f15d42f5d104330e00316cb1370f1df 100644 (file)
import cv2
import torch
import types
+import PIL
class Compose(object):
""" Composes several co_transforms together.
img = cv2.warpAffine(img, rmat2x3, (w,h), flags=interpolation)
return img
+ @staticmethod
+ def reverse_channels(img):
+ if isinstance(img, np.ndarray):
+ return img[:,:,::-1]
+ elif isinstance(img, PIL.Image):
+ return PIL.Image.fromarray(np.array(img)[:,:,::-1])
+ else:
+ assert False, 'unrecognized image type'
+
@staticmethod
def array_to_tensor(array):
"""Converts a numpy.ndarray (H x W x C) to a torch.FloatTensor of shape (C x H x W)."""