[jacinto-ai/pytorch-jacinto-ai-devkit.git] / modules / pytorch_jacinto_ai / vision / transforms / functional.py
1 from __future__ import division
2 import torch
3 import sys
4 import math
5 from PIL import Image, ImageOps, ImageEnhance, PILLOW_VERSION
6 try:
7 import accimage
8 except ImportError:
9 accimage = None
10 import numpy as np
11 import numbers
12 import collections
13 import warnings
15 if sys.version_info < (3, 3):
16 Sequence = collections.Sequence
17 Iterable = collections.Iterable
18 else:
19 Sequence = collections.abc.Sequence
20 Iterable = collections.abc.Iterable
23 def _is_pil_image(img):
24 if accimage is not None:
25 return isinstance(img, (Image.Image, accimage.Image))
26 else:
27 return isinstance(img, Image.Image)
30 def _is_tensor_image(img):
31 return torch.is_tensor(img) and img.ndimension() == 3
34 def _is_numpy(img):
35 return isinstance(img, np.ndarray)
38 def _is_numpy_image(img):
39 return isinstance(img, np.ndarray) and (img.ndim in {2, 3})
42 def to_tensor(pic):
43 """Convert a ``PIL Image`` or ``numpy.ndarray`` to tensor.
45 See ``ToTensor`` for more details.
47 Args:
48 pic (PIL Image or numpy.ndarray): Image to be converted to tensor.
50 Returns:
51 Tensor: Converted image.
52 """
53 if not(_is_pil_image(pic) or _is_numpy(pic)):
54 raise TypeError('pic should be PIL Image or ndarray. Got {}'.format(type(pic)))
56 if _is_numpy(pic) and not _is_numpy_image(pic):
57 raise ValueError('pic should be 2/3 dimensional. Got {} dimensions.'.format(pic.ndim))
59 if isinstance(pic, np.ndarray):
60 # handle numpy array
61 if pic.ndim == 2:
62 pic = pic[:, :, None]
64 img = torch.from_numpy(pic.transpose((2, 0, 1)))
65 # backward compatibility
66 if isinstance(img, torch.ByteTensor):
67 return img.float().div(255)
68 else:
69 return img
71 if accimage is not None and isinstance(pic, accimage.Image):
72 nppic = np.zeros([pic.channels, pic.height, pic.width], dtype=np.float32)
73 pic.copyto(nppic)
74 return torch.from_numpy(nppic)
76 # handle PIL Image
77 if pic.mode == 'I':
78 img = torch.from_numpy(np.array(pic, np.int32, copy=False))
79 elif pic.mode == 'I;16':
80 img = torch.from_numpy(np.array(pic, np.int16, copy=False))
81 elif pic.mode == 'F':
82 img = torch.from_numpy(np.array(pic, np.float32, copy=False))
83 elif pic.mode == '1':
84 img = 255 * torch.from_numpy(np.array(pic, np.uint8, copy=False))
85 else:
86 img = torch.ByteTensor(torch.ByteStorage.from_buffer(pic.tobytes()))
87 # PIL image mode: L, LA, P, I, F, RGB, YCbCr, RGBA, CMYK
88 if pic.mode == 'YCbCr':
89 nchannel = 3
90 elif pic.mode == 'I;16':
91 nchannel = 1
92 else:
93 nchannel = len(pic.mode)
94 img = img.view(pic.size[1], pic.size[0], nchannel)
95 # put it from HWC to CHW format
96 # yikes, this transpose takes 80% of the loading time/CPU
97 img = img.transpose(0, 1).transpose(0, 2).contiguous()
98 if isinstance(img, torch.ByteTensor):
99 return img.float().div(255)
100 else:
101 return img
104 def to_pil_image(pic, mode=None):
105 """Convert a tensor or an ndarray to PIL Image.
107 See :class:`~pytorch_jacinto_ai.vision.transforms.ToPILImage` for more details.
109 Args:
110 pic (Tensor or numpy.ndarray): Image to be converted to PIL Image.
111 mode (`PIL.Image mode`_): color space and pixel depth of input data (optional).
113 .. _PIL.Image mode: https://pillow.readthedocs.io/en/latest/handbook/concepts.html#concept-modes
115 Returns:
116 PIL Image: Image converted to PIL Image.
117 """
118 if not(isinstance(pic, torch.Tensor) or isinstance(pic, np.ndarray)):
119 raise TypeError('pic should be Tensor or ndarray. Got {}.'.format(type(pic)))
121 elif isinstance(pic, torch.Tensor):
122 if pic.ndimension() not in {2, 3}:
123 raise ValueError('pic should be 2/3 dimensional. Got {} dimensions.'.format(pic.ndimension()))
125 elif pic.ndimension() == 2:
126 # if 2D image, add channel dimension (CHW)
127 pic = pic.unsqueeze(0)
129 elif isinstance(pic, np.ndarray):
130 if pic.ndim not in {2, 3}:
131 raise ValueError('pic should be 2/3 dimensional. Got {} dimensions.'.format(pic.ndim))
133 elif pic.ndim == 2:
134 # if 2D image, add channel dimension (HWC)
135 pic = np.expand_dims(pic, 2)
137 npimg = pic
138 if isinstance(pic, torch.FloatTensor) and mode != 'F':
139 pic = pic.mul(255).byte()
140 if isinstance(pic, torch.Tensor):
141 npimg = np.transpose(pic.numpy(), (1, 2, 0))
143 if not isinstance(npimg, np.ndarray):
144 raise TypeError('Input pic must be a torch.Tensor or NumPy ndarray, ' +
145 'not {}'.format(type(npimg)))
147 if npimg.shape[2] == 1:
148 expected_mode = None
149 npimg = npimg[:, :, 0]
150 if npimg.dtype == np.uint8:
151 expected_mode = 'L'
152 elif npimg.dtype == np.int16:
153 expected_mode = 'I;16'
154 elif npimg.dtype == np.int32:
155 expected_mode = 'I'
156 elif npimg.dtype == np.float32:
157 expected_mode = 'F'
158 if mode is not None and mode != expected_mode:
159 raise ValueError("Incorrect mode ({}) supplied for input type {}. Should be {}"
160 .format(mode, np.dtype, expected_mode))
161 mode = expected_mode
163 elif npimg.shape[2] == 2:
164 permitted_2_channel_modes = ['LA']
165 if mode is not None and mode not in permitted_2_channel_modes:
166 raise ValueError("Only modes {} are supported for 2D inputs".format(permitted_2_channel_modes))
168 if mode is None and npimg.dtype == np.uint8:
169 mode = 'LA'
171 elif npimg.shape[2] == 4:
172 permitted_4_channel_modes = ['RGBA', 'CMYK', 'RGBX']
173 if mode is not None and mode not in permitted_4_channel_modes:
174 raise ValueError("Only modes {} are supported for 4D inputs".format(permitted_4_channel_modes))
176 if mode is None and npimg.dtype == np.uint8:
177 mode = 'RGBA'
178 else:
179 permitted_3_channel_modes = ['RGB', 'YCbCr', 'HSV']
180 if mode is not None and mode not in permitted_3_channel_modes:
181 raise ValueError("Only modes {} are supported for 3D inputs".format(permitted_3_channel_modes))
182 if mode is None and npimg.dtype == np.uint8:
183 mode = 'RGB'
185 if mode is None:
186 raise TypeError('Input type {} is not supported'.format(npimg.dtype))
188 return Image.fromarray(npimg, mode=mode)
191 def normalize(tensor, mean, std, inplace=False):
192 """Normalize a tensor image with mean and standard deviation.
194 .. note::
195 This transform acts out of place by default, i.e., it does not mutates the input tensor.
197 See :class:`~pytorch_jacinto_ai.vision.transforms.Normalize` for more details.
199 Args:
200 tensor (Tensor): Tensor image of size (C, H, W) to be normalized.
201 mean (sequence): Sequence of means for each channel.
202 std (sequence): Sequence of standard deviations for each channel.
203 inplace(bool,optional): Bool to make this operation inplace.
205 Returns:
206 Tensor: Normalized Tensor image.
207 """
208 if not _is_tensor_image(tensor):
209 raise TypeError('tensor is not a torch image.')
211 if not inplace:
212 tensor = tensor.clone()
214 dtype = tensor.dtype
215 mean = torch.as_tensor(mean, dtype=dtype, device=tensor.device)
216 std = torch.as_tensor(std, dtype=dtype, device=tensor.device)
217 tensor.sub_(mean[:, None, None]).div_(std[:, None, None])
218 return tensor
221 def resize(img, size, interpolation=Image.BILINEAR):
222 r"""Resize the input PIL Image to the given size.
224 Args:
225 img (PIL Image): Image to be resized.
226 size (sequence or int): Desired output size. If size is a sequence like
227 (h, w), the output size will be matched to this. If size is an int,
228 the smaller edge of the image will be matched to this number maintaing
229 the aspect ratio. i.e, if height > width, then image will be rescaled to
230 :math:`\left(\text{size} \times \frac{\text{height}}{\text{width}}, \text{size}\right)`
231 interpolation (int, optional): Desired interpolation. Default is
232 ``PIL.Image.BILINEAR``
234 Returns:
235 PIL Image: Resized image.
236 """
237 if not _is_pil_image(img):
238 raise TypeError('img should be PIL Image. Got {}'.format(type(img)))
239 if not (isinstance(size, int) or (isinstance(size, Iterable) and len(size) == 2)):
240 raise TypeError('Got inappropriate size arg: {}'.format(size))
242 if isinstance(size, int):
243 w, h = img.size
244 if (w <= h and w == size) or (h <= w and h == size):
245 return img
246 if w < h:
247 ow = size
248 oh = int(size * h / w)
249 return img.resize((ow, oh), interpolation)
250 else:
251 oh = size
252 ow = int(size * w / h)
253 return img.resize((ow, oh), interpolation)
254 else:
255 return img.resize(size[::-1], interpolation)
258 def scale(*args, **kwargs):
259 warnings.warn("The use of the transforms.Scale transform is deprecated, " +
260 "please use transforms.Resize instead.")
261 return resize(*args, **kwargs)
264 def pad(img, padding, fill=0, padding_mode='constant'):
265 r"""Pad the given PIL Image on all sides with specified padding mode and fill value.
267 Args:
268 img (PIL Image): Image to be padded.
269 padding (int or tuple): Padding on each border. If a single int is provided this
270 is used to pad all borders. If tuple of length 2 is provided this is the padding
271 on left/right and top/bottom respectively. If a tuple of length 4 is provided
272 this is the padding for the left, top, right and bottom borders
273 respectively.
274 fill: Pixel fill value for constant fill. Default is 0. If a tuple of
275 length 3, it is used to fill R, G, B channels respectively.
276 This value is only used when the padding_mode is constant
277 padding_mode: Type of padding. Should be: constant, edge, reflect or symmetric. Default is constant.
279 - constant: pads with a constant value, this value is specified with fill
281 - edge: pads with the last value on the edge of the image
283 - reflect: pads with reflection of image (without repeating the last value on the edge)
285 padding [1, 2, 3, 4] with 2 elements on both sides in reflect mode
286 will result in [3, 2, 1, 2, 3, 4, 3, 2]
288 - symmetric: pads with reflection of image (repeating the last value on the edge)
290 padding [1, 2, 3, 4] with 2 elements on both sides in symmetric mode
291 will result in [2, 1, 1, 2, 3, 4, 4, 3]
293 Returns:
294 PIL Image: Padded image.
295 """
296 if not _is_pil_image(img):
297 raise TypeError('img should be PIL Image. Got {}'.format(type(img)))
299 if not isinstance(padding, (numbers.Number, tuple)):
300 raise TypeError('Got inappropriate padding arg')
301 if not isinstance(fill, (numbers.Number, str, tuple)):
302 raise TypeError('Got inappropriate fill arg')
303 if not isinstance(padding_mode, str):
304 raise TypeError('Got inappropriate padding_mode arg')
306 if isinstance(padding, Sequence) and len(padding) not in [2, 4]:
307 raise ValueError("Padding must be an int or a 2, or 4 element tuple, not a " +
308 "{} element tuple".format(len(padding)))
310 assert padding_mode in ['constant', 'edge', 'reflect', 'symmetric'], \
311 'Padding mode should be either constant, edge, reflect or symmetric'
313 if padding_mode == 'constant':
314 if img.mode == 'P':
315 palette = img.getpalette()
316 image = ImageOps.expand(img, border=padding, fill=fill)
317 image.putpalette(palette)
318 return image
320 return ImageOps.expand(img, border=padding, fill=fill)
321 else:
322 if isinstance(padding, int):
323 pad_left = pad_right = pad_top = pad_bottom = padding
324 if isinstance(padding, Sequence) and len(padding) == 2:
325 pad_left = pad_right = padding[0]
326 pad_top = pad_bottom = padding[1]
327 if isinstance(padding, Sequence) and len(padding) == 4:
328 pad_left = padding[0]
329 pad_top = padding[1]
330 pad_right = padding[2]
331 pad_bottom = padding[3]
333 if img.mode == 'P':
334 palette = img.getpalette()
335 img = np.asarray(img)
336 img = np.pad(img, ((pad_top, pad_bottom), (pad_left, pad_right)), padding_mode)
337 img = Image.fromarray(img)
338 img.putpalette(palette)
339 return img
341 img = np.asarray(img)
342 # RGB image
343 if len(img.shape) == 3:
344 img = np.pad(img, ((pad_top, pad_bottom), (pad_left, pad_right), (0, 0)), padding_mode)
345 # Grayscale image
346 if len(img.shape) == 2:
347 img = np.pad(img, ((pad_top, pad_bottom), (pad_left, pad_right)), padding_mode)
349 return Image.fromarray(img)
352 def crop(img, i, j, h, w):
353 """Crop the given PIL Image.
355 Args:
356 img (PIL Image): Image to be cropped.
357 i (int): i in (i,j) i.e coordinates of the upper left corner.
358 j (int): j in (i,j) i.e coordinates of the upper left corner.
359 h (int): Height of the cropped image.
360 w (int): Width of the cropped image.
362 Returns:
363 PIL Image: Cropped image.
364 """
365 if not _is_pil_image(img):
366 raise TypeError('img should be PIL Image. Got {}'.format(type(img)))
368 return img.crop((j, i, j + w, i + h))
371 def center_crop(img, output_size):
372 if isinstance(output_size, numbers.Number):
373 output_size = (int(output_size), int(output_size))
374 w, h = img.size
375 th, tw = output_size
376 i = int(round((h - th) / 2.))
377 j = int(round((w - tw) / 2.))
378 return crop(img, i, j, th, tw)
381 def resized_crop(img, i, j, h, w, size, interpolation=Image.BILINEAR):
382 """Crop the given PIL Image and resize it to desired size.
384 Notably used in :class:`~pytorch_jacinto_ai.vision.transforms.RandomResizedCrop`.
386 Args:
387 img (PIL Image): Image to be cropped.
388 i (int): i in (i,j) i.e coordinates of the upper left corner
389 j (int): j in (i,j) i.e coordinates of the upper left corner
390 h (int): Height of the cropped image.
391 w (int): Width of the cropped image.
392 size (sequence or int): Desired output size. Same semantics as ``resize``.
393 interpolation (int, optional): Desired interpolation. Default is
394 ``PIL.Image.BILINEAR``.
395 Returns:
396 PIL Image: Cropped image.
397 """
398 assert _is_pil_image(img), 'img should be PIL Image'
399 img = crop(img, i, j, h, w)
400 img = resize(img, size, interpolation)
401 return img
404 def hflip(img):
405 """Horizontally flip the given PIL Image.
407 Args:
408 img (PIL Image): Image to be flipped.
410 Returns:
411 PIL Image: Horizontall flipped image.
412 """
413 if not _is_pil_image(img):
414 raise TypeError('img should be PIL Image. Got {}'.format(type(img)))
416 return img.transpose(Image.FLIP_LEFT_RIGHT)
419 def _get_perspective_coeffs(startpoints, endpoints):
420 """Helper function to get the coefficients (a, b, c, d, e, f, g, h) for the perspective transforms.
422 In Perspective Transform each pixel (x, y) in the orignal image gets transformed as,
423 (x, y) -> ( (ax + by + c) / (gx + hy + 1), (dx + ey + f) / (gx + hy + 1) )
425 Args:
426 List containing [top-left, top-right, bottom-right, bottom-left] of the orignal image,
427 List containing [top-left, top-right, bottom-right, bottom-left] of the transformed
428 image
429 Returns:
430 octuple (a, b, c, d, e, f, g, h) for transforming each pixel.
431 """
432 matrix = []
434 for p1, p2 in zip(endpoints, startpoints):
435 matrix.append([p1[0], p1[1], 1, 0, 0, 0, -p2[0] * p1[0], -p2[0] * p1[1]])
436 matrix.append([0, 0, 0, p1[0], p1[1], 1, -p2[1] * p1[0], -p2[1] * p1[1]])
438 A = torch.tensor(matrix, dtype=torch.float)
439 B = torch.tensor(startpoints, dtype=torch.float).view(8)
440 res = torch.gels(B, A)[0]
441 return res.squeeze_(1).tolist()
444 def perspective(img, startpoints, endpoints, interpolation=Image.BICUBIC):
445 """Perform perspective transform of the given PIL Image.
447 Args:
448 img (PIL Image): Image to be transformed.
449 startpoints: List containing [top-left, top-right, bottom-right, bottom-left] of the orignal image
450 endpoints: List containing [top-left, top-right, bottom-right, bottom-left] of the transformed image
451 interpolation: Default- Image.BICUBIC
452 Returns:
453 PIL Image: Perspectively transformed Image.
454 """
455 if not _is_pil_image(img):
456 raise TypeError('img should be PIL Image. Got {}'.format(type(img)))
458 coeffs = _get_perspective_coeffs(startpoints, endpoints)
459 return img.transform(img.size, Image.PERSPECTIVE, coeffs, interpolation)
462 def vflip(img):
463 """Vertically flip the given PIL Image.
465 Args:
466 img (PIL Image): Image to be flipped.
468 Returns:
469 PIL Image: Vertically flipped image.
470 """
471 if not _is_pil_image(img):
472 raise TypeError('img should be PIL Image. Got {}'.format(type(img)))
474 return img.transpose(Image.FLIP_TOP_BOTTOM)
477 def five_crop(img, size):
478 """Crop the given PIL Image into four corners and the central crop.
480 .. Note::
481 This transform returns a tuple of images and there may be a
482 mismatch in the number of inputs and targets your ``Dataset`` returns.
484 Args:
485 size (sequence or int): Desired output size of the crop. If size is an
486 int instead of sequence like (h, w), a square crop (size, size) is
487 made.
489 Returns:
490 tuple: tuple (tl, tr, bl, br, center)
491 Corresponding top left, top right, bottom left, bottom right and center crop.
492 """
493 if isinstance(size, numbers.Number):
494 size = (int(size), int(size))
495 else:
496 assert len(size) == 2, "Please provide only two dimensions (h, w) for size."
498 w, h = img.size
499 crop_h, crop_w = size
500 if crop_w > w or crop_h > h:
501 raise ValueError("Requested crop size {} is bigger than input size {}".format(size,
502 (h, w)))
503 tl = img.crop((0, 0, crop_w, crop_h))
504 tr = img.crop((w - crop_w, 0, w, crop_h))
505 bl = img.crop((0, h - crop_h, crop_w, h))
506 br = img.crop((w - crop_w, h - crop_h, w, h))
507 center = center_crop(img, (crop_h, crop_w))
508 return (tl, tr, bl, br, center)
511 def ten_crop(img, size, vertical_flip=False):
512 r"""Crop the given PIL Image into four corners and the central crop plus the
513 flipped version of these (horizontal flipping is used by default).
515 .. Note::
516 This transform returns a tuple of images and there may be a
517 mismatch in the number of inputs and targets your ``Dataset`` returns.
519 Args:
520 size (sequence or int): Desired output size of the crop. If size is an
521 int instead of sequence like (h, w), a square crop (size, size) is
522 made.
523 vertical_flip (bool): Use vertical flipping instead of horizontal
525 Returns:
526 tuple: tuple (tl, tr, bl, br, center, tl_flip, tr_flip, bl_flip, br_flip, center_flip)
527 Corresponding top left, top right, bottom left, bottom right and center crop
528 and same for the flipped image.
529 """
530 if isinstance(size, numbers.Number):
531 size = (int(size), int(size))
532 else:
533 assert len(size) == 2, "Please provide only two dimensions (h, w) for size."
535 first_five = five_crop(img, size)
537 if vertical_flip:
538 img = vflip(img)
539 else:
540 img = hflip(img)
542 second_five = five_crop(img, size)
543 return first_five + second_five
546 def adjust_brightness(img, brightness_factor):
547 """Adjust brightness of an Image.
549 Args:
550 img (PIL Image): PIL Image to be adjusted.
551 brightness_factor (float): How much to adjust the brightness. Can be
552 any non negative number. 0 gives a black image, 1 gives the
553 original image while 2 increases the brightness by a factor of 2.
555 Returns:
556 PIL Image: Brightness adjusted image.
557 """
558 if not _is_pil_image(img):
559 raise TypeError('img should be PIL Image. Got {}'.format(type(img)))
561 enhancer = ImageEnhance.Brightness(img)
562 img = enhancer.enhance(brightness_factor)
563 return img
566 def adjust_contrast(img, contrast_factor):
567 """Adjust contrast of an Image.
569 Args:
570 img (PIL Image): PIL Image to be adjusted.
571 contrast_factor (float): How much to adjust the contrast. Can be any
572 non negative number. 0 gives a solid gray image, 1 gives the
573 original image while 2 increases the contrast by a factor of 2.
575 Returns:
576 PIL Image: Contrast adjusted image.
577 """
578 if not _is_pil_image(img):
579 raise TypeError('img should be PIL Image. Got {}'.format(type(img)))
581 enhancer = ImageEnhance.Contrast(img)
582 img = enhancer.enhance(contrast_factor)
583 return img
586 def adjust_saturation(img, saturation_factor):
587 """Adjust color saturation of an image.
589 Args:
590 img (PIL Image): PIL Image to be adjusted.
591 saturation_factor (float): How much to adjust the saturation. 0 will
592 give a black and white image, 1 will give the original image while
593 2 will enhance the saturation by a factor of 2.
595 Returns:
596 PIL Image: Saturation adjusted image.
597 """
598 if not _is_pil_image(img):
599 raise TypeError('img should be PIL Image. Got {}'.format(type(img)))
601 enhancer = ImageEnhance.Color(img)
602 img = enhancer.enhance(saturation_factor)
603 return img
606 def adjust_hue(img, hue_factor):
607 """Adjust hue of an image.
609 The image hue is adjusted by converting the image to HSV and
610 cyclically shifting the intensities in the hue channel (H).
611 The image is then converted back to original image mode.
613 `hue_factor` is the amount of shift in H channel and must be in the
614 interval `[-0.5, 0.5]`.
616 See `Hue`_ for more details.
618 .. _Hue: https://en.wikipedia.org/wiki/Hue
620 Args:
621 img (PIL Image): PIL Image to be adjusted.
622 hue_factor (float): How much to shift the hue channel. Should be in
623 [-0.5, 0.5]. 0.5 and -0.5 give complete reversal of hue channel in
624 HSV space in positive and negative direction respectively.
625 0 means no shift. Therefore, both -0.5 and 0.5 will give an image
626 with complementary colors while 0 gives the original image.
628 Returns:
629 PIL Image: Hue adjusted image.
630 """
631 if not(-0.5 <= hue_factor <= 0.5):
632 raise ValueError('hue_factor is not in [-0.5, 0.5].'.format(hue_factor))
634 if not _is_pil_image(img):
635 raise TypeError('img should be PIL Image. Got {}'.format(type(img)))
637 input_mode = img.mode
638 if input_mode in {'L', '1', 'I', 'F'}:
639 return img
641 h, s, v = img.convert('HSV').split()
643 np_h = np.array(h, dtype=np.uint8)
644 # uint8 addition take cares of rotation across boundaries
645 with np.errstate(over='ignore'):
646 np_h += np.uint8(hue_factor * 255)
647 h = Image.fromarray(np_h, 'L')
649 img = Image.merge('HSV', (h, s, v)).convert(input_mode)
650 return img
653 def adjust_gamma(img, gamma, gain=1):
654 r"""Perform gamma correction on an image.
656 Also known as Power Law Transform. Intensities in RGB mode are adjusted
657 based on the following equation:
659 .. math::
660 I_{\text{out}} = 255 \times \text{gain} \times \left(\frac{I_{\text{in}}}{255}\right)^{\gamma}
662 See `Gamma Correction`_ for more details.
664 .. _Gamma Correction: https://en.wikipedia.org/wiki/Gamma_correction
666 Args:
667 img (PIL Image): PIL Image to be adjusted.
668 gamma (float): Non negative real number, same as :math:`\gamma` in the equation.
669 gamma larger than 1 make the shadows darker,
670 while gamma smaller than 1 make dark regions lighter.
671 gain (float): The constant multiplier.
672 """
673 if not _is_pil_image(img):
674 raise TypeError('img should be PIL Image. Got {}'.format(type(img)))
676 if gamma < 0:
677 raise ValueError('Gamma should be a non-negative real number')
679 input_mode = img.mode
680 img = img.convert('RGB')
682 gamma_map = [255 * gain * pow(ele / 255., gamma) for ele in range(256)] * 3
683 img = img.point(gamma_map) # use PIL's point-function to accelerate this part
685 img = img.convert(input_mode)
686 return img
689 def rotate(img, angle, resample=False, expand=False, center=None):
690 """Rotate the image by angle.
693 Args:
694 img (PIL Image): PIL Image to be rotated.
695 angle (float or int): In degrees degrees counter clockwise order.
696 resample (``PIL.Image.NEAREST`` or ``PIL.Image.BILINEAR`` or ``PIL.Image.BICUBIC``, optional):
697 An optional resampling filter. See `filters`_ for more information.
698 If omitted, or if the image has mode "1" or "P", it is set to ``PIL.Image.NEAREST``.
699 expand (bool, optional): Optional expansion flag.
700 If true, expands the output image to make it large enough to hold the entire rotated image.
701 If false or omitted, make the output image the same size as the input image.
702 Note that the expand flag assumes rotation around the center and no translation.
703 center (2-tuple, optional): Optional center of rotation.
704 Origin is the upper left corner.
705 Default is the center of the image.
707 .. _filters: https://pillow.readthedocs.io/en/latest/handbook/concepts.html#filters
709 """
711 if not _is_pil_image(img):
712 raise TypeError('img should be PIL Image. Got {}'.format(type(img)))
714 return img.rotate(angle, resample, expand, center)
717 def _get_inverse_affine_matrix(center, angle, translate, scale, shear):
718 # Helper method to compute inverse matrix for affine transformation
720 # As it is explained in PIL.Image.rotate
721 # We need compute INVERSE of affine transformation matrix: M = T * C * RSS * C^-1
722 # where T is translation matrix: [1, 0, tx | 0, 1, ty | 0, 0, 1]
723 # C is translation matrix to keep center: [1, 0, cx | 0, 1, cy | 0, 0, 1]
724 # RSS is rotation with scale and shear matrix
725 # RSS(a, scale, shear) = [ cos(a + shear_y)*scale -sin(a + shear_x)*scale 0]
726 # [ sin(a + shear_y)*scale cos(a + shear_x)*scale 0]
727 # [ 0 0 1]
728 # Thus, the inverse is M^-1 = C * RSS^-1 * C^-1 * T^-1
730 angle = math.radians(angle)
731 if isinstance(shear, (tuple, list)) and len(shear) == 2:
732 shear = [math.radians(s) for s in shear]
733 elif isinstance(shear, numbers.Number):
734 shear = math.radians(shear)
735 shear = [shear, 0]
736 else:
737 raise ValueError(
738 "Shear should be a single value or a tuple/list containing " +
739 "two values. Got {}".format(shear))
740 scale = 1.0 / scale
742 # Inverted rotation matrix with scale and shear
743 d = math.cos(angle + shear[0]) * math.cos(angle + shear[1]) + \
744 math.sin(angle + shear[0]) * math.sin(angle + shear[1])
745 matrix = [
746 math.cos(angle + shear[0]), math.sin(angle + shear[0]), 0,
747 -math.sin(angle + shear[1]), math.cos(angle + shear[1]), 0
748 ]
749 matrix = [scale / d * m for m in matrix]
751 # Apply inverse of translation and of center translation: RSS^-1 * C^-1 * T^-1
752 matrix[2] += matrix[0] * (-center[0] - translate[0]) + matrix[1] * (-center[1] - translate[1])
753 matrix[5] += matrix[3] * (-center[0] - translate[0]) + matrix[4] * (-center[1] - translate[1])
755 # Apply center translation: C * RSS^-1 * C^-1 * T^-1
756 matrix[2] += center[0]
757 matrix[5] += center[1]
758 return matrix
761 def affine(img, angle, translate, scale, shear, resample=0, fillcolor=None):
762 """Apply affine transformation on the image keeping image center invariant
764 Args:
765 img (PIL Image): PIL Image to be rotated.
766 angle (float or int): rotation angle in degrees between -180 and 180, clockwise direction.
767 translate (list or tuple of integers): horizontal and vertical translations (post-rotation translation)
768 scale (float): overall scale
769 shear (float or tuple or list): shear angle value in degrees between -180 to 180, clockwise direction.
770 If a tuple of list is specified, the first value corresponds to a shear parallel to the x axis, while
771 the second value corresponds to a shear parallel to the y axis.
772 resample (``PIL.Image.NEAREST`` or ``PIL.Image.BILINEAR`` or ``PIL.Image.BICUBIC``, optional):
773 An optional resampling filter.
774 See `filters`_ for more information.
775 If omitted, or if the image has mode "1" or "P", it is set to ``PIL.Image.NEAREST``.
776 fillcolor (int): Optional fill color for the area outside the transform in the output image. (Pillow>=5.0.0)
777 """
778 if not _is_pil_image(img):
779 raise TypeError('img should be PIL Image. Got {}'.format(type(img)))
781 assert isinstance(translate, (tuple, list)) and len(translate) == 2, \
782 "Argument translate should be a list or tuple of length 2"
784 assert scale > 0.0, "Argument scale should be positive"
786 output_size = img.size
787 center = (img.size[0] * 0.5 + 0.5, img.size[1] * 0.5 + 0.5)
788 matrix = _get_inverse_affine_matrix(center, angle, translate, scale, shear)
789 kwargs = {"fillcolor": fillcolor} if PILLOW_VERSION[0] >= '5' else {}
790 return img.transform(output_size, Image.AFFINE, matrix, resample, **kwargs)
793 def to_grayscale(img, num_output_channels=1):
794 """Convert image to grayscale version of image.
796 Args:
797 img (PIL Image): Image to be converted to grayscale.
799 Returns:
800 PIL Image: Grayscale version of the image.
801 if num_output_channels = 1 : returned image is single channel
803 if num_output_channels = 3 : returned image is 3 channel with r = g = b
804 """
805 if not _is_pil_image(img):
806 raise TypeError('img should be PIL Image. Got {}'.format(type(img)))
808 if num_output_channels == 1:
809 img = img.convert('L')
810 elif num_output_channels == 3:
811 img = img.convert('L')
812 np_img = np.array(img, dtype=np.uint8)
813 np_img = np.dstack([np_img, np_img, np_img])
814 img = Image.fromarray(np_img, 'RGB')
815 else:
816 raise ValueError('num_output_channels should be either 1 or 3')
818 return img
821 def erase(img, i, j, h, w, v, inplace=False):
822 """ Erase the input Tensor Image with given value.
824 Args:
825 img (Tensor Image): Tensor image of size (C, H, W) to be erased
826 i (int): i in (i,j) i.e coordinates of the upper left corner.
827 j (int): j in (i,j) i.e coordinates of the upper left corner.
828 h (int): Height of the erased region.
829 w (int): Width of the erased region.
830 v: Erasing value.
831 inplace(bool, optional): For in-place operations. By default is set False.
833 Returns:
834 Tensor Image: Erased image.
835 """
836 if not isinstance(img, torch.Tensor):
837 raise TypeError('img should be Tensor Image. Got {}'.format(type(img)))
839 if not inplace:
840 img = img.clone()
842 img[:, i:i + h, j:j + w] = v
843 return img