65e5012330cfd5c9d448fd8d24a057aa0e3505f1
[jacinto-ai/pytorch-jacinto-ai-devkit.git] / modules / pytorch_jacinto_ai / vision / transforms / image_transform_utils.py
2 import numpy as np
3 import cv2
4 import torch
5 import types
7 class Compose(object):
8     """ Composes several co_transforms together.
9     For example:
10     >>> co_transforms.Compose([
11     >>>     co_transforms.CenterCrop(10),
12     >>>     co_transforms.ToTensor(),
13     >>>  ])
14     """
16     def __init__(self, t):
17         self.co_transforms = t
19     def extend(self, t):
20         self.co_transforms.extend(t)
22     def insert(self, index, t):
23         self.co_transforms.insert(index, t)
25     def write_img(self, img=[], ch_num=-1, name='', en=False):
26         if en == False:
27             return
28         #name = './data/checkpoints/tiad_interest_pt_descriptor/debug/{:02d}.jpg'.format(aug_idx)
29         scale_range = 255.0 / np.max(img)
30         img = np.clip(img * scale_range, 0.0, 255.0)
31         img = np.asarray(img, 'uint8')
33         non_zero_el = cv2.countNonZero(img)
35         print("non zero element: {}".format(non_zero_el))
36         cv2.imwrite('{}_nz{}.jpg'.format(name, non_zero_el), img)
38     def __call__(self, input, target):
39         if self.co_transforms:
40             for aug_idx, t in enumerate(self.co_transforms):
41                 if t:
42                     input,target = t(input,target)
44         return input,target
48 class Bypass(object):
49     def __init__(self):
50         pass
52     def __call__(self, images, targets):
53         return images,targets
56 class Lambda(object):
57     """Applies a lambda as a transform"""
59     def __init__(self, lambd):
60         assert isinstance(lambd, types.LambdaType)
61         self.lambd = lambd
63     def __call__(self, input,target):
64         return self.lambd(input,target)
67 class ImageTransformUtils(object):
68     @staticmethod
69     def apply_to_list(func, inputs):
70         for img_idx in range(len(inputs)):
71             inputs[img_idx] = func(inputs[img_idx], img_idx)
73         return inputs
75     @staticmethod
76     def crop(img, r, c, h, w):
77         img = img[r:(r+h), c:(c+w),...] if (len(img.shape)>2) else img[r:(r+h), c:(c+w)]
78         return img
80     @staticmethod
81     def resize_fast(img, output_size_rc, interpolation=-1):
82         in_h, in_w = img.shape[:2]
83         out_h, out_w = output_size_rc
84         if interpolation<0:
85             interpolation = cv2.INTER_AREA if ((out_h<in_h) or (out_w<in_w)) else cv2.INTER_LINEAR
87         img = cv2.resize(img, (out_w,out_h), interpolation=interpolation) #opencv expects size in (w,h) format
88         img = img[...,np.newaxis] if len(img.shape) < 3 else img
89         return img
91     @staticmethod
92     def resize_img(img, size, interpolation=-1, is_flow=False):
93         #if (len(img.shape) == 3) and (img.shape[2] == 1 or img.shape[2] == 3):
94         #    return __class__.resize_fast(img, size, interpolation)
96         in_h, in_w = img.shape[:2]
97         out_h, out_w = size
98         if interpolation<0:
99             interpolation = cv2.INTER_AREA if ((out_h<in_h) or (out_w<in_w)) else cv2.INTER_LINEAR
101         # opencv handles planar, 1 or 3 channel images
102         img = img[...,np.newaxis] if len(img.shape) < 3 else img
103         num_chans = img.shape[2]
104         img = np.concatenate([img]+[img[...,0:1]]*(3-num_chans), axis=2) if num_chans<3 else img
105         img = cv2.resize(img, (out_w, out_h), interpolation=interpolation)
106         img = img[...,:num_chans]
108         if is_flow:
109             ratio_h = out_h / in_h
110             ratio_w = out_w / in_w
111             img = ImageTransformUtils.scale_flow(img, ratio_w, ratio_h)
113         return img
115     @staticmethod
116     def resize_and_crop(img, r, c, h, w, size, interpolation=-1, is_flow=False):
117         img = ImageTransformUtils.resize_img(img, size, interpolation, is_flow)
118         img = ImageTransformUtils.crop(img, r, c, h, w)
119         return img
121     @staticmethod
122     def crop_and_resize(img, r, c, h, w, size, interpolation=-1, is_flow=False):
123         img = ImageTransformUtils.crop(img, r, c, h, w)
124         img = ImageTransformUtils.resize_img(img, size, interpolation, is_flow)
125         return img
127     @staticmethod
128     def rotate_img(img, angle, interpolation=-1):
129         h, w = img.shape[:2]
130         rmat2x3 = cv2.getRotationMatrix2D(center=(w//2,h//2), angle=angle, scale=1.0)
131         interpolation = cv2.INTER_NEAREST if interpolation < 0 else interpolation
132         img = cv2.warpAffine(img, rmat2x3, (w,h), flags=interpolation)
133         return img
135     @staticmethod
136     def array_to_tensor(array):
137         """Converts a numpy.ndarray (H x W x C) to a torch.FloatTensor of shape (C x H x W)."""
138         assert(isinstance(array, np.ndarray))
139         # put it from HWC to CHW format
140         array = np.transpose(array, (2, 0, 1))
141         if len(array.shape) < 3:
142             array = array[np.newaxis, ...]
143         #
144         tensor = torch.from_numpy(array)
145         return tensor.float()
147     @staticmethod
148     def scale_flow(flow, ratio_x, ratio_y):
149         flow = flow.astype(np.float32)
150         flow[...,0] *= ratio_x
151         flow[...,1] *= ratio_y
152         return flow
154     @staticmethod
155     def scale_flows(inputs, ratio_x, ratio_y, is_flow):
156         for img_idx in range(len(inputs)):
157             if is_flow and is_flow[img_idx]:
158                 inputs[img_idx] = inputs[img_idx].astype(np.float32)
159                 inputs[img_idx][...,0] *= ratio_x
160                 inputs[img_idx][...,1] *= ratio_y
161         #
162         return inputs