]> Gitweb @ Texas Instruments - Open Source Git Repositories - git.TI.com/gitweb - jacinto-ai/pytorch-jacinto-ai-devkit.git/blob - modules/pytorch_jacinto_ai/vision/utils.py
f07a3bb40165e2d664284325012bf1df75a214f3
[jacinto-ai/pytorch-jacinto-ai-devkit.git] / modules / pytorch_jacinto_ai / vision / utils.py
1 import torch
2 import math
3 irange = range
6 def make_grid(tensor, nrow=8, padding=2,
7               normalize=False, range=None, scale_each=False, pad_value=0):
8     """Make a grid of images.
10     Args:
11         tensor (Tensor or list): 4D mini-batch Tensor of shape (B x C x H x W)
12             or a list of images all of the same size.
13         nrow (int, optional): Number of images displayed in each row of the grid.
14             The final grid size is ``(B / nrow, nrow)``. Default: ``8``.
15         padding (int, optional): amount of padding. Default: ``2``.
16         normalize (bool, optional): If True, shift the image to the range (0, 1),
17             by the min and max values specified by :attr:`range`. Default: ``False``.
18         range (tuple, optional): tuple (min, max) where min and max are numbers,
19             then these numbers are used to normalize the image. By default, min and max
20             are computed from the tensor.
21         scale_each (bool, optional): If ``True``, scale each image in the batch of
22             images separately rather than the (min, max) over all images. Default: ``False``.
23         pad_value (float, optional): Value for the padded pixels. Default: ``0``.
25     Example:
26         See this notebook `here <https://gist.github.com/anonymous/bf16430f7750c023141c562f3e9f2a91>`_
28     """
29     if not (torch.is_tensor(tensor) or
30             (isinstance(tensor, list) and all(torch.is_tensor(t) for t in tensor))):
31         raise TypeError('tensor or list of tensors expected, got {}'.format(type(tensor)))
33     # if list of tensors, convert to a 4D mini-batch Tensor
34     if isinstance(tensor, list):
35         tensor = torch.stack(tensor, dim=0)
37     if tensor.dim() == 2:  # single image H x W
38         tensor = tensor.unsqueeze(0)
39     if tensor.dim() == 3:  # single image
40         if tensor.size(0) == 1:  # if single-channel, convert to 3-channel
41             tensor = torch.cat((tensor, tensor, tensor), 0)
42         tensor = tensor.unsqueeze(0)
44     if tensor.dim() == 4 and tensor.size(1) == 1:  # single-channel images
45         tensor = torch.cat((tensor, tensor, tensor), 1)
47     if normalize is True:
48         tensor = tensor.clone()  # avoid modifying tensor in-place
49         if range is not None:
50             assert isinstance(range, tuple), \
51                 "range has to be a tuple (min, max) if specified. min and max are numbers"
53         def norm_ip(img, min, max):
54             img.clamp_(min=min, max=max)
55             img.add_(-min).div_(max - min + 1e-5)
57         def norm_range(t, range):
58             if range is not None:
59                 norm_ip(t, range[0], range[1])
60             else:
61                 norm_ip(t, float(t.min()), float(t.max()))
63         if scale_each is True:
64             for t in tensor:  # loop over mini-batch dimension
65                 norm_range(t, range)
66         else:
67             norm_range(tensor, range)
69     if tensor.size(0) == 1:
70         return tensor.squeeze(0)
72     # make the mini-batch of images into a grid
73     nmaps = tensor.size(0)
74     xmaps = min(nrow, nmaps)
75     ymaps = int(math.ceil(float(nmaps) / xmaps))
76     height, width = int(tensor.size(2) + padding), int(tensor.size(3) + padding)
77     grid = tensor.new_full((3, height * ymaps + padding, width * xmaps + padding), pad_value)
78     k = 0
79     for y in irange(ymaps):
80         for x in irange(xmaps):
81             if k >= nmaps:
82                 break
83             grid.narrow(1, y * height + padding, height - padding)\
84                 .narrow(2, x * width + padding, width - padding)\
85                 .copy_(tensor[k])
86             k = k + 1
87     return grid
90 def save_image(tensor, filename, nrow=8, padding=2,
91                normalize=False, range=None, scale_each=False, pad_value=0):
92     """Save a given Tensor into an image file.
94     Args:
95         tensor (Tensor or list): Image to be saved. If given a mini-batch tensor,
96             saves the tensor as a grid of images by calling ``make_grid``.
97         **kwargs: Other arguments are documented in ``make_grid``.
98     """
99     from PIL import Image
100     grid = make_grid(tensor, nrow=nrow, padding=padding, pad_value=pad_value,
101                      normalize=normalize, range=range, scale_each=scale_each)
102     # Add 0.5 after unnormalizing to [0, 255] to round to nearest integer
103     ndarr = grid.mul(255).add_(0.5).clamp_(0, 255).permute(1, 2, 0).to('cpu', torch.uint8).numpy()
104     im = Image.fromarray(ndarr)
105     im.save(filename)