88ef557658e2a3270271efaa0a61d90c2995995d
[jacinto-ai/pytorch-jacinto-ai-devkit.git] / modules / pytorch_jacinto_ai / vision / datasets / folder.py
1 from .vision import VisionDataset
3 from PIL import Image
5 import os
6 import os.path
7 import sys
8 import warnings
10 warnings.filterwarnings('ignore', 'Corrupt EXIF data', UserWarning)
11 warnings.filterwarnings('ignore', 'Possibly corrupt EXIF data', UserWarning)
13 def has_file_allowed_extension(filename, extensions):
14 """Checks if a file is an allowed extension.
16 Args:
17 filename (string): path to a file
18 extensions (tuple of strings): extensions to consider (lowercase)
20 Returns:
21 bool: True if the filename ends with one of given extensions
22 """
23 return filename.lower().endswith(extensions)
26 def is_image_file(filename):
27 """Checks if a file is an allowed image extension.
29 Args:
30 filename (string): path to a file
32 Returns:
33 bool: True if the filename ends with a known image extension
34 """
35 return has_file_allowed_extension(filename, IMG_EXTENSIONS)
38 def make_dataset(dir, class_to_idx, extensions=None, is_valid_file=None):
39 images = []
40 dir = os.path.expanduser(dir)
41 if not ((extensions is None) ^ (is_valid_file is None)):
42 raise ValueError("Both extensions and is_valid_file cannot be None or not None at the same time")
43 if extensions is not None:
44 def is_valid_file(x):
45 return has_file_allowed_extension(x, extensions)
46 for target in sorted(class_to_idx.keys()):
47 d = os.path.join(dir, target)
48 if not os.path.isdir(d):
49 continue
50 for root, _, fnames in sorted(os.walk(d)):
51 for fname in sorted(fnames):
52 path = os.path.join(root, fname)
53 if is_valid_file(path):
54 item = (path, class_to_idx[target])
55 images.append(item)
57 return images
60 class DatasetFolder(VisionDataset):
61 """A generic data loader where the samples are arranged in this way: ::
63 root/class_x/xxx.ext
64 root/class_x/xxy.ext
65 root/class_x/xxz.ext
67 root/class_y/123.ext
68 root/class_y/nsdf3.ext
69 root/class_y/asd932_.ext
71 Args:
72 root (string): Root directory path.
73 loader (callable): A function to load a sample given its path.
74 extensions (tuple[string]): A list of allowed extensions.
75 both extensions and is_valid_file should not be passed.
76 transform (callable, optional): A function/transform that takes in
77 a sample and returns a transformed version.
78 E.g, ``transforms.RandomCrop`` for images.
79 target_transform (callable, optional): A function/transform that takes
80 in the target and transforms it.
81 is_valid_file (callable, optional): A function that takes path of an Image file
82 and check if the file is a valid_file (used to check of corrupt files)
83 both extensions and is_valid_file should not be passed.
85 Attributes:
86 classes (list): List of the class names.
87 class_to_idx (dict): Dict with items (class_name, class_index).
88 samples (list): List of (sample path, class_index) tuples
89 targets (list): The class_index value for each image in the dataset
90 """
92 def __init__(self, root, loader, extensions=None, transform=None,
93 target_transform=None, is_valid_file=None):
94 super(DatasetFolder, self).__init__(root, transform=transform,
95 target_transform=target_transform)
96 classes, class_to_idx = self._find_classes(self.root)
97 samples = make_dataset(self.root, class_to_idx, extensions, is_valid_file)
98 if len(samples) == 0:
99 raise (RuntimeError("Found 0 files in subfolders of: " + self.root + "\n"
100 "Supported extensions are: " + ",".join(extensions)))
102 self.loader = loader
103 self.extensions = extensions
105 self.classes = classes
106 self.class_to_idx = class_to_idx
107 self.samples = samples
108 self.targets = [s[1] for s in samples]
110 def _find_classes(self, dir):
111 """
112 Finds the class folders in a dataset.
114 Args:
115 dir (string): Root directory path.
117 Returns:
118 tuple: (classes, class_to_idx) where classes are relative to (dir), and class_to_idx is a dictionary.
120 Ensures:
121 No class is a subdirectory of another.
122 """
123 if sys.version_info >= (3, 5):
124 # Faster and available in Python 3.5 and above
125 classes = [d.name for d in os.scandir(dir) if d.is_dir()]
126 else:
127 classes = [d for d in os.listdir(dir) if os.path.isdir(os.path.join(dir, d))]
128 classes.sort()
129 class_to_idx = {classes[i]: i for i in range(len(classes))}
130 return classes, class_to_idx
132 def __getitem__(self, index):
133 """
134 Args:
135 index (int): Index
137 Returns:
138 tuple: (sample, target) where target is class_index of the target class.
139 """
140 path, target = self.samples[index]
141 sample = self.loader(path)
142 if self.transform is not None:
143 sample = self.transform(sample)
144 if self.target_transform is not None:
145 target = self.target_transform(target)
147 return sample, target
149 def __len__(self):
150 return len(self.samples)
153 IMG_EXTENSIONS = ('.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.pgm', '.tif', '.tiff', '.webp')
156 def pil_loader(path):
157 # open path as file to avoid ResourceWarning (https://github.com/python-pillow/Pillow/issues/835)
158 with open(path, 'rb') as f:
159 img = Image.open(f)
160 return img.convert('RGB')
163 def accimage_loader(path):
164 import accimage
165 try:
166 return accimage.Image(path)
167 except IOError:
168 # Potentially a decoding problem, fall back to PIL.Image
169 return pil_loader(path)
172 def default_loader(path):
173 from ...vision import get_image_backend
174 if get_image_backend() == 'accimage':
175 return accimage_loader(path)
176 else:
177 return pil_loader(path)
180 class ImageFolder(DatasetFolder):
181 """A generic data loader where the images are arranged in this way: ::
183 root/dog/xxx.png
184 root/dog/xxy.png
185 root/dog/xxz.png
187 root/cat/123.png
188 root/cat/nsdf3.png
189 root/cat/asd932_.png
191 Args:
192 root (string): Root directory path.
193 transform (callable, optional): A function/transform that takes in an PIL image
194 and returns a transformed version. E.g, ``transforms.RandomCrop``
195 target_transform (callable, optional): A function/transform that takes in the
196 target and transforms it.
197 loader (callable, optional): A function to load an image given its path.
198 is_valid_file (callable, optional): A function that takes path of an Image file
199 and check if the file is a valid_file (used to check of corrupt files)
201 Attributes:
202 classes (list): List of the class names.
203 class_to_idx (dict): Dict with items (class_name, class_index).
204 imgs (list): List of (image path, class_index) tuples
205 """
207 def __init__(self, root, transform=None, target_transform=None,
208 loader=default_loader, is_valid_file=None):
209 super(ImageFolder, self).__init__(root, loader, IMG_EXTENSIONS if is_valid_file is None else None,
210 transform=transform,
211 target_transform=target_transform,
212 is_valid_file=is_valid_file)
213 self.imgs = self.samples