]> Gitweb @ Texas Instruments - Open Source Git Repositories - git.TI.com/gitweb - jacinto-ai/pytorch-jacinto-ai-devkit.git/blob - modules/pytorch_jacinto_ai/vision/datasets/vision.py
7ee5a84dfccbc59590972872bfdcb71da6c461d7
[jacinto-ai/pytorch-jacinto-ai-devkit.git] / modules / pytorch_jacinto_ai / vision / datasets / vision.py
1 import os
2 import torch
3 import torch.utils.data as data
6 class VisionDataset(data.Dataset):
7     _repr_indent = 4
9     def __init__(self, root, transforms=None, transform=None, target_transform=None):
10         if isinstance(root, torch._six.string_classes):
11             root = os.path.expanduser(root)
12         self.root = root
14         has_transforms = transforms is not None
15         has_separate_transform = transform is not None or target_transform is not None
16         if has_transforms and has_separate_transform:
17             raise ValueError("Only transforms or transform/target_transform can "
18                              "be passed as argument")
20         # for backwards-compatibility
21         self.transform = transform
22         self.target_transform = target_transform
24         if has_separate_transform:
25             transforms = StandardTransform(transform, target_transform)
26         self.transforms = transforms
28     def __getitem__(self, index):
29         raise NotImplementedError
31     def __len__(self):
32         raise NotImplementedError
34     def __repr__(self):
35         head = "Dataset " + self.__class__.__name__
36         body = ["Number of datapoints: {}".format(self.__len__())]
37         if self.root is not None:
38             body.append("Root location: {}".format(self.root))
39         body += self.extra_repr().splitlines()
40         if hasattr(self, "transforms") and self.transforms is not None:
41             body += [repr(self.transforms)]
42         lines = [head] + [" " * self._repr_indent + line for line in body]
43         return '\n'.join(lines)
45     def _format_transform_repr(self, transform, head):
46         lines = transform.__repr__().splitlines()
47         return (["{}{}".format(head, lines[0])] +
48                 ["{}{}".format(" " * len(head), line) for line in lines[1:]])
50     def extra_repr(self):
51         return ""
54 class StandardTransform(object):
55     def __init__(self, transform=None, target_transform=None):
56         self.transform = transform
57         self.target_transform = target_transform
59     def __call__(self, input, target):
60         if self.transform is not None:
61             input = self.transform(input)
62         if self.target_transform is not None:
63             target = self.target_transform(target)
64         return input, target
66     def _format_transform_repr(self, transform, head):
67         lines = transform.__repr__().splitlines()
68         return (["{}{}".format(head, lines[0])] +
69                 ["{}{}".format(" " * len(head), line) for line in lines[1:]])
71     def __repr__(self):
72         body = [self.__class__.__name__]
73         if self.transform is not None:
74             body += self._format_transform_repr(self.transform,
75                                                 "Transform: ")
76         if self.target_transform is not None:
77             body += self._format_transform_repr(self.target_transform,
78                                                 "Target transform: ")
80         return '\n'.join(body)