[jacinto-ai/pytorch-jacinto-ai-devkit.git] / modules / pytorch_jacinto_ai / xvision / datasets / fakedata.py
1 import torch
2 from .vision import VisionDataset
3 from .. import transforms
6 class FakeData(VisionDataset):
7 """A fake dataset that returns randomly generated images and returns them as PIL images
9 Args:
10 size (int, optional): Size of the dataset. Default: 1000 images
11 image_size(tuple, optional): Size if the returned images. Default: (3, 224, 224)
12 num_classes(int, optional): Number of classes in the datset. Default: 10
13 transform (callable, optional): A function/transform that takes in an PIL image
14 and returns a transformed version. E.g, ``transforms.RandomCrop``
15 target_transform (callable, optional): A function/transform that takes in the
16 target and transforms it.
17 random_offset (int): Offsets the index-based random seed used to
18 generate each image. Default: 0
20 """
22 def __init__(self, size=1000, image_size=(3, 224, 224), num_classes=10,
23 transform=None, target_transform=None, random_offset=0):
24 super(FakeData, self).__init__(None, transform=transform,
25 target_transform=target_transform)
26 self.size = size
27 self.num_classes = num_classes
28 self.image_size = image_size
29 self.random_offset = random_offset
31 def __getitem__(self, index):
32 """
33 Args:
34 index (int): Index
36 Returns:
37 tuple: (image, target) where target is class_index of the target class.
38 """
39 # create random image that is consistent with the index id
40 if index >= len(self):
41 raise IndexError("{} index out of range".format(self.__class__.__name__))
42 rng_state = torch.get_rng_state()
43 torch.manual_seed(index + self.random_offset)
44 img = torch.randn(*self.image_size)
45 target = torch.randint(0, self.num_classes, size=(1,), dtype=torch.long)[0]
46 torch.set_rng_state(rng_state)
48 # convert to PIL Image
49 img = transforms.ToPILImage()(img)
50 if self.transform is not None:
51 img = self.transform(img)
52 if self.target_transform is not None:
53 target = self.target_transform(target)
55 return img, target
57 def __len__(self):
58 return self.size