[jacinto-ai/pytorch-jacinto-ai-devkit.git] / modules / pytorch_jacinto_ai / xvision / transforms / image_transforms.py
diff --git a/modules/pytorch_jacinto_ai/xvision/transforms/image_transforms.py b/modules/pytorch_jacinto_ai/xvision/transforms/image_transforms.py
index fbf9ac53908070fa0b25f62c25e2a645ab8105a5..df29b433cb630081b4f3b94e7c8b734ab3775bf3 100644 (file)
return images, targets
+class ReverseImageChannels(object):
+ """Reverse the channels fo the tensor. eg. RGB to BGR
+ """
+ def __call__(self, images, targets):
+ def func(imgs, img_idx):
+ imgs = [ImageTransformUtils.reverse_channels(img_plane) for img_plane in imgs] \
+ if isinstance(imgs, list) else ImageTransformUtils.reverse_channels(imgs)
+ return imgs
+
+ images = ImageTransformUtils.apply_to_list(func, images)
+ # do not apply to targets
+ return images, targets
+
+
class ConvertToTensor(object):
"""Converts a numpy.ndarray (H x W x C) to a torch.FloatTensor of shape (C x H x W)."""
def __call__(self, images, targets):