]> Gitweb @ Texas Instruments - Open Source Git Repositories - git.TI.com/gitweb - jacinto-ai/pytorch-jacinto-ai-devkit.git/blob - modules/pytorch_jacinto_ai/vision/datasets/folder.py
release commit
[jacinto-ai/pytorch-jacinto-ai-devkit.git] / modules / pytorch_jacinto_ai / vision / datasets / folder.py
1 from .vision import VisionDataset
3 from PIL import Image
5 import os
6 import os.path
7 import sys
8 import warnings
10 warnings.filterwarnings('ignore', 'Corrupt EXIF data', UserWarning)
11 warnings.filterwarnings('ignore', 'Possibly corrupt EXIF data', UserWarning)
13 def has_file_allowed_extension(filename, extensions):
14     """Checks if a file is an allowed extension.
16     Args:
17         filename (string): path to a file
18         extensions (tuple of strings): extensions to consider (lowercase)
20     Returns:
21         bool: True if the filename ends with one of given extensions
22     """
23     return filename.lower().endswith(extensions)
26 def is_image_file(filename):
27     """Checks if a file is an allowed image extension.
29     Args:
30         filename (string): path to a file
32     Returns:
33         bool: True if the filename ends with a known image extension
34     """
35     return has_file_allowed_extension(filename, IMG_EXTENSIONS)
38 def make_dataset(dir, class_to_idx, extensions=None, is_valid_file=None):
39     images = []
40     dir = os.path.expanduser(dir)
41     if not ((extensions is None) ^ (is_valid_file is None)):
42         raise ValueError("Both extensions and is_valid_file cannot be None or not None at the same time")
43     if extensions is not None:
44         def is_valid_file(x):
45             return has_file_allowed_extension(x, extensions)
46     for target in sorted(class_to_idx.keys()):
47         d = os.path.join(dir, target)
48         if not os.path.isdir(d):
49             continue
50         for root, _, fnames in sorted(os.walk(d)):
51             for fname in sorted(fnames):
52                 path = os.path.join(root, fname)
53                 if is_valid_file(path):
54                     item = (path, class_to_idx[target])
55                     images.append(item)
57     return images
60 class DatasetFolder(VisionDataset):
61     """A generic data loader where the samples are arranged in this way: ::
63         root/class_x/xxx.ext
64         root/class_x/xxy.ext
65         root/class_x/xxz.ext
67         root/class_y/123.ext
68         root/class_y/nsdf3.ext
69         root/class_y/asd932_.ext
71     Args:
72         root (string): Root directory path.
73         loader (callable): A function to load a sample given its path.
74         extensions (tuple[string]): A list of allowed extensions.
75             both extensions and is_valid_file should not be passed.
76         transform (callable, optional): A function/transform that takes in
77             a sample and returns a transformed version.
78             E.g, ``transforms.RandomCrop`` for images.
79         target_transform (callable, optional): A function/transform that takes
80             in the target and transforms it.
81         is_valid_file (callable, optional): A function that takes path of an Image file
82             and check if the file is a valid_file (used to check of corrupt files)
83             both extensions and is_valid_file should not be passed.
85      Attributes:
86         classes (list): List of the class names.
87         class_to_idx (dict): Dict with items (class_name, class_index).
88         samples (list): List of (sample path, class_index) tuples
89         targets (list): The class_index value for each image in the dataset
90     """
92     def __init__(self, root, loader, extensions=None, transform=None,
93                  target_transform=None, is_valid_file=None):
94         super(DatasetFolder, self).__init__(root, transform=transform,
95                                             target_transform=target_transform)
96         classes, class_to_idx = self._find_classes(self.root)
97         samples = make_dataset(self.root, class_to_idx, extensions, is_valid_file)
98         if len(samples) == 0:
99             raise (RuntimeError("Found 0 files in subfolders of: " + self.root + "\n"
100                                 "Supported extensions are: " + ",".join(extensions)))
102         self.loader = loader
103         self.extensions = extensions
105         self.classes = classes
106         self.class_to_idx = class_to_idx
107         self.samples = samples
108         self.targets = [s[1] for s in samples]
110     def _find_classes(self, dir):
111         """
112         Finds the class folders in a dataset.
114         Args:
115             dir (string): Root directory path.
117         Returns:
118             tuple: (classes, class_to_idx) where classes are relative to (dir), and class_to_idx is a dictionary.
120         Ensures:
121             No class is a subdirectory of another.
122         """
123         if sys.version_info >= (3, 5):
124             # Faster and available in Python 3.5 and above
125             classes = [d.name for d in os.scandir(dir) if d.is_dir()]
126         else:
127             classes = [d for d in os.listdir(dir) if os.path.isdir(os.path.join(dir, d))]
128         classes.sort()
129         class_to_idx = {classes[i]: i for i in range(len(classes))}
130         return classes, class_to_idx
132     def __getitem__(self, index):
133         """
134         Args:
135             index (int): Index
137         Returns:
138             tuple: (sample, target) where target is class_index of the target class.
139         """
140         path, target = self.samples[index]
141         sample = self.loader(path)
142         if self.transform is not None:
143             sample = self.transform(sample)
144         if self.target_transform is not None:
145             target = self.target_transform(target)
147         return sample, target
149     def __len__(self):
150         return len(self.samples)
153 IMG_EXTENSIONS = ('.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.pgm', '.tif', '.tiff', '.webp')
156 def pil_loader(path):
157     # open path as file to avoid ResourceWarning (https://github.com/python-pillow/Pillow/issues/835)
158     with open(path, 'rb') as f:
159         img = Image.open(f)
160         return img.convert('RGB')
163 def accimage_loader(path):
164     import accimage
165     try:
166         return accimage.Image(path)
167     except IOError:
168         # Potentially a decoding problem, fall back to PIL.Image
169         return pil_loader(path)
172 def default_loader(path):
173     from ...vision import get_image_backend
174     if get_image_backend() == 'accimage':
175         return accimage_loader(path)
176     else:
177         return pil_loader(path)
180 class ImageFolder(DatasetFolder):
181     """A generic data loader where the images are arranged in this way: ::
183         root/dog/xxx.png
184         root/dog/xxy.png
185         root/dog/xxz.png
187         root/cat/123.png
188         root/cat/nsdf3.png
189         root/cat/asd932_.png
191     Args:
192         root (string): Root directory path.
193         transform (callable, optional): A function/transform that  takes in an PIL image
194             and returns a transformed version. E.g, ``transforms.RandomCrop``
195         target_transform (callable, optional): A function/transform that takes in the
196             target and transforms it.
197         loader (callable, optional): A function to load an image given its path.
198         is_valid_file (callable, optional): A function that takes path of an Image file
199             and check if the file is a valid_file (used to check of corrupt files)
201      Attributes:
202         classes (list): List of the class names.
203         class_to_idx (dict): Dict with items (class_name, class_index).
204         imgs (list): List of (image path, class_index) tuples
205     """
207     def __init__(self, root, transform=None, target_transform=None,
208                  loader=default_loader, is_valid_file=None):
209         super(ImageFolder, self).__init__(root, loader, IMG_EXTENSIONS if is_valid_file is None else None,
210                                           transform=transform,
211                                           target_transform=target_transform,
212                                           is_valid_file=is_valid_file)
213         self.imgs = self.samples