]> Gitweb @ Texas Instruments - Open Source Git Repositories - git.TI.com/gitweb - jacinto-ai/pytorch-jacinto-ai-devkit.git/blob - modules/pytorch_jacinto_ai/xvision/datasets/fakedata.py
renamed pytorch_jacinto_ai.vision to pytorch_jacinto_ai.xvision
[jacinto-ai/pytorch-jacinto-ai-devkit.git] / modules / pytorch_jacinto_ai / xvision / datasets / fakedata.py
1 import torch
2 from .vision import VisionDataset
3 from .. import transforms
6 class FakeData(VisionDataset):
7     """A fake dataset that returns randomly generated images and returns them as PIL images
9     Args:
10         size (int, optional): Size of the dataset. Default: 1000 images
11         image_size(tuple, optional): Size if the returned images. Default: (3, 224, 224)
12         num_classes(int, optional): Number of classes in the datset. Default: 10
13         transform (callable, optional): A function/transform that  takes in an PIL image
14             and returns a transformed version. E.g, ``transforms.RandomCrop``
15         target_transform (callable, optional): A function/transform that takes in the
16             target and transforms it.
17         random_offset (int): Offsets the index-based random seed used to
18             generate each image. Default: 0
20     """
22     def __init__(self, size=1000, image_size=(3, 224, 224), num_classes=10,
23                  transform=None, target_transform=None, random_offset=0):
24         super(FakeData, self).__init__(None, transform=transform,
25                                        target_transform=target_transform)
26         self.size = size
27         self.num_classes = num_classes
28         self.image_size = image_size
29         self.random_offset = random_offset
31     def __getitem__(self, index):
32         """
33         Args:
34             index (int): Index
36         Returns:
37             tuple: (image, target) where target is class_index of the target class.
38         """
39         # create random image that is consistent with the index id
40         if index >= len(self):
41             raise IndexError("{} index out of range".format(self.__class__.__name__))
42         rng_state = torch.get_rng_state()
43         torch.manual_seed(index + self.random_offset)
44         img = torch.randn(*self.image_size)
45         target = torch.randint(0, self.num_classes, size=(1,), dtype=torch.long)[0]
46         torch.set_rng_state(rng_state)
48         # convert to PIL Image
49         img = transforms.ToPILImage()(img)
50         if self.transform is not None:
51             img = self.transform(img)
52         if self.target_transform is not None:
53             target = self.target_transform(target)
55         return img, target
57     def __len__(self):
58         return self.size