[jacinto-ai/pytorch-jacinto-ai-devkit.git] / modules / pytorch_jacinto_ai / vision / datasets / mnist.py
1 from __future__ import print_function
2 from .vision import VisionDataset
3 import warnings
4 from PIL import Image
5 import os
6 import os.path
7 import numpy as np
8 import torch
9 import codecs
10 from .utils import download_url, download_and_extract_archive, extract_archive, \
11 makedir_exist_ok, verify_str_arg
14 class MNIST(VisionDataset):
15 """`MNIST <http://yann.lecun.com/exdb/mnist/>`_ Dataset.
17 Args:
18 root (string): Root directory of dataset where ``MNIST/processed/training.pt``
19 and ``MNIST/processed/test.pt`` exist.
20 train (bool, optional): If True, creates dataset from ``training.pt``,
21 otherwise from ``test.pt``.
22 download (bool, optional): If true, downloads the dataset from the internet and
23 puts it in root directory. If dataset is already downloaded, it is not
24 downloaded again.
25 transform (callable, optional): A function/transform that takes in an PIL image
26 and returns a transformed version. E.g, ``transforms.RandomCrop``
27 target_transform (callable, optional): A function/transform that takes in the
28 target and transforms it.
29 """
30 urls = [
31 'http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz',
32 'http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz',
33 'http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz',
34 'http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz',
35 ]
36 training_file = 'training.pt'
37 test_file = 'test.pt'
38 classes = ['0 - zero', '1 - one', '2 - two', '3 - three', '4 - four',
39 '5 - five', '6 - six', '7 - seven', '8 - eight', '9 - nine']
41 @property
42 def train_labels(self):
43 warnings.warn("train_labels has been renamed targets")
44 return self.targets
46 @property
47 def test_labels(self):
48 warnings.warn("test_labels has been renamed targets")
49 return self.targets
51 @property
52 def train_data(self):
53 warnings.warn("train_data has been renamed data")
54 return self.data
56 @property
57 def test_data(self):
58 warnings.warn("test_data has been renamed data")
59 return self.data
61 def __init__(self, root, train=True, transform=None, target_transform=None,
62 download=False):
63 super(MNIST, self).__init__(root, transform=transform,
64 target_transform=target_transform)
65 self.train = train # training set or test set
67 if download:
68 self.download()
70 if not self._check_exists():
71 raise RuntimeError('Dataset not found.' +
72 ' You can use download=True to download it')
74 if self.train:
75 data_file = self.training_file
76 else:
77 data_file = self.test_file
78 self.data, self.targets = torch.load(os.path.join(self.processed_folder, data_file))
80 def __getitem__(self, index):
81 """
82 Args:
83 index (int): Index
85 Returns:
86 tuple: (image, target) where target is index of the target class.
87 """
88 img, target = self.data[index], int(self.targets[index])
90 # doing this so that it is consistent with all other datasets
91 # to return a PIL Image
92 img = Image.fromarray(img.numpy(), mode='L')
94 if self.transform is not None:
95 img = self.transform(img)
97 if self.target_transform is not None:
98 target = self.target_transform(target)
100 return img, target
102 def __len__(self):
103 return len(self.data)
105 @property
106 def raw_folder(self):
107 return os.path.join(self.root, self.__class__.__name__, 'raw')
109 @property
110 def processed_folder(self):
111 return os.path.join(self.root, self.__class__.__name__, 'processed')
113 @property
114 def class_to_idx(self):
115 return {_class: i for i, _class in enumerate(self.classes)}
117 def _check_exists(self):
118 return (os.path.exists(os.path.join(self.processed_folder,
119 self.training_file)) and
120 os.path.exists(os.path.join(self.processed_folder,
121 self.test_file)))
123 def download(self):
124 """Download the MNIST data if it doesn't exist in processed_folder already."""
126 if self._check_exists():
127 return
129 makedir_exist_ok(self.raw_folder)
130 makedir_exist_ok(self.processed_folder)
132 # download files
133 for url in self.urls:
134 filename = url.rpartition('/')[2]
135 download_and_extract_archive(url, download_root=self.raw_folder, filename=filename)
137 # process and save as torch files
138 print('Processing...')
140 training_set = (
141 read_image_file(os.path.join(self.raw_folder, 'train-images-idx3-ubyte')),
142 read_label_file(os.path.join(self.raw_folder, 'train-labels-idx1-ubyte'))
143 )
144 test_set = (
145 read_image_file(os.path.join(self.raw_folder, 't10k-images-idx3-ubyte')),
146 read_label_file(os.path.join(self.raw_folder, 't10k-labels-idx1-ubyte'))
147 )
148 with open(os.path.join(self.processed_folder, self.training_file), 'wb') as f:
149 torch.save(training_set, f)
150 with open(os.path.join(self.processed_folder, self.test_file), 'wb') as f:
151 torch.save(test_set, f)
153 print('Done!')
155 def extra_repr(self):
156 return "Split: {}".format("Train" if self.train is True else "Test")
159 class FashionMNIST(MNIST):
160 """`Fashion-MNIST <https://github.com/zalandoresearch/fashion-mnist>`_ Dataset.
162 Args:
163 root (string): Root directory of dataset where ``Fashion-MNIST/processed/training.pt``
164 and ``Fashion-MNIST/processed/test.pt`` exist.
165 train (bool, optional): If True, creates dataset from ``training.pt``,
166 otherwise from ``test.pt``.
167 download (bool, optional): If true, downloads the dataset from the internet and
168 puts it in root directory. If dataset is already downloaded, it is not
169 downloaded again.
170 transform (callable, optional): A function/transform that takes in an PIL image
171 and returns a transformed version. E.g, ``transforms.RandomCrop``
172 target_transform (callable, optional): A function/transform that takes in the
173 target and transforms it.
174 """
175 urls = [
176 'http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz',
177 'http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz',
178 'http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz',
179 'http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz',
180 ]
181 classes = ['T-shirt/top', 'Trouser', 'Pullover', 'Dress', 'Coat', 'Sandal',
182 'Shirt', 'Sneaker', 'Bag', 'Ankle boot']
185 class KMNIST(MNIST):
186 """`Kuzushiji-MNIST <https://github.com/rois-codh/kmnist>`_ Dataset.
188 Args:
189 root (string): Root directory of dataset where ``KMNIST/processed/training.pt``
190 and ``KMNIST/processed/test.pt`` exist.
191 train (bool, optional): If True, creates dataset from ``training.pt``,
192 otherwise from ``test.pt``.
193 download (bool, optional): If true, downloads the dataset from the internet and
194 puts it in root directory. If dataset is already downloaded, it is not
195 downloaded again.
196 transform (callable, optional): A function/transform that takes in an PIL image
197 and returns a transformed version. E.g, ``transforms.RandomCrop``
198 target_transform (callable, optional): A function/transform that takes in the
199 target and transforms it.
200 """
201 urls = [
202 'http://codh.rois.ac.jp/kmnist/dataset/kmnist/train-images-idx3-ubyte.gz',
203 'http://codh.rois.ac.jp/kmnist/dataset/kmnist/train-labels-idx1-ubyte.gz',
204 'http://codh.rois.ac.jp/kmnist/dataset/kmnist/t10k-images-idx3-ubyte.gz',
205 'http://codh.rois.ac.jp/kmnist/dataset/kmnist/t10k-labels-idx1-ubyte.gz',
206 ]
207 classes = ['o', 'ki', 'su', 'tsu', 'na', 'ha', 'ma', 'ya', 're', 'wo']
210 class EMNIST(MNIST):
211 """`EMNIST <https://www.westernsydney.edu.au/bens/home/reproducible_research/emnist>`_ Dataset.
213 Args:
214 root (string): Root directory of dataset where ``EMNIST/processed/training.pt``
215 and ``EMNIST/processed/test.pt`` exist.
216 split (string): The dataset has 6 different splits: ``byclass``, ``bymerge``,
217 ``balanced``, ``letters``, ``digits`` and ``mnist``. This argument specifies
218 which one to use.
219 train (bool, optional): If True, creates dataset from ``training.pt``,
220 otherwise from ``test.pt``.
221 download (bool, optional): If true, downloads the dataset from the internet and
222 puts it in root directory. If dataset is already downloaded, it is not
223 downloaded again.
224 transform (callable, optional): A function/transform that takes in an PIL image
225 and returns a transformed version. E.g, ``transforms.RandomCrop``
226 target_transform (callable, optional): A function/transform that takes in the
227 target and transforms it.
228 """
229 # Updated URL from https://www.westernsydney.edu.au/bens/home/reproducible_research/emnist
230 url = 'https://cloudstor.aarnet.edu.au/plus/index.php/s/54h3OuGJhFLwAlQ/download'
231 splits = ('byclass', 'bymerge', 'balanced', 'letters', 'digits', 'mnist')
233 def __init__(self, root, split, **kwargs):
234 self.split = verify_str_arg(split, "split", self.splits)
235 self.training_file = self._training_file(split)
236 self.test_file = self._test_file(split)
237 super(EMNIST, self).__init__(root, **kwargs)
239 @staticmethod
240 def _training_file(split):
241 return 'training_{}.pt'.format(split)
243 @staticmethod
244 def _test_file(split):
245 return 'test_{}.pt'.format(split)
247 def download(self):
248 """Download the EMNIST data if it doesn't exist in processed_folder already."""
249 import shutil
251 if self._check_exists():
252 return
254 makedir_exist_ok(self.raw_folder)
255 makedir_exist_ok(self.processed_folder)
257 # download files
258 print('Downloading and extracting zip archive')
259 download_and_extract_archive(self.url, download_root=self.raw_folder, filename="emnist.zip",
260 remove_finished=True)
261 gzip_folder = os.path.join(self.raw_folder, 'gzip')
262 for gzip_file in os.listdir(gzip_folder):
263 if gzip_file.endswith('.gz'):
264 extract_archive(os.path.join(gzip_folder, gzip_file), gzip_folder)
266 # process and save as torch files
267 for split in self.splits:
268 print('Processing ' + split)
269 training_set = (
270 read_image_file(os.path.join(gzip_folder, 'emnist-{}-train-images-idx3-ubyte'.format(split))),
271 read_label_file(os.path.join(gzip_folder, 'emnist-{}-train-labels-idx1-ubyte'.format(split)))
272 )
273 test_set = (
274 read_image_file(os.path.join(gzip_folder, 'emnist-{}-test-images-idx3-ubyte'.format(split))),
275 read_label_file(os.path.join(gzip_folder, 'emnist-{}-test-labels-idx1-ubyte'.format(split)))
276 )
277 with open(os.path.join(self.processed_folder, self._training_file(split)), 'wb') as f:
278 torch.save(training_set, f)
279 with open(os.path.join(self.processed_folder, self._test_file(split)), 'wb') as f:
280 torch.save(test_set, f)
281 shutil.rmtree(gzip_folder)
283 print('Done!')
286 class QMNIST(MNIST):
287 """`QMNIST <https://github.com/facebookresearch/qmnist>`_ Dataset.
289 Args:
290 root (string): Root directory of dataset whose ``processed''
291 subdir contains torch binary files with the datasets.
292 what (string,optional): Can be 'train', 'test', 'test10k',
293 'test50k', or 'nist' for respectively the mnist compatible
294 training set, the 60k qmnist testing set, the 10k qmnist
295 examples that match the mnist testing set, the 50k
296 remaining qmnist testing examples, or all the nist
297 digits. The default is to select 'train' or 'test'
298 according to the compatibility argument 'train'.
299 compat (bool,optional): A boolean that says whether the target
300 for each example is class number (for compatibility with
301 the MNIST dataloader) or a torch vector containing the
302 full qmnist information. Default=True.
303 download (bool, optional): If true, downloads the dataset from
304 the internet and puts it in root directory. If dataset is
305 already downloaded, it is not downloaded again.
306 transform (callable, optional): A function/transform that
307 takes in an PIL image and returns a transformed
308 version. E.g, ``transforms.RandomCrop``
309 target_transform (callable, optional): A function/transform
310 that takes in the target and transforms it.
311 train (bool,optional,compatibility): When argument 'what' is
312 not specified, this boolean decides whether to load the
313 training set ot the testing set. Default: True.
315 """
317 subsets = {
318 'train': 'train',
319 'test': 'test', 'test10k': 'test', 'test50k': 'test',
320 'nist': 'nist'
321 }
322 urls = {
323 'train': ['https://raw.githubusercontent.com/facebookresearch/qmnist/master/qmnist-train-images-idx3-ubyte.gz',
324 'https://raw.githubusercontent.com/facebookresearch/qmnist/master/qmnist-train-labels-idx2-int.gz'],
325 'test': ['https://raw.githubusercontent.com/facebookresearch/qmnist/master/qmnist-test-images-idx3-ubyte.gz',
326 'https://raw.githubusercontent.com/facebookresearch/qmnist/master/qmnist-test-labels-idx2-int.gz'],
327 'nist': ['https://raw.githubusercontent.com/facebookresearch/qmnist/master/xnist-images-idx3-ubyte.xz',
328 'https://raw.githubusercontent.com/facebookresearch/qmnist/master/xnist-labels-idx2-int.xz']
329 }
330 classes = ['0 - zero', '1 - one', '2 - two', '3 - three', '4 - four',
331 '5 - five', '6 - six', '7 - seven', '8 - eight', '9 - nine']
333 def __init__(self, root, what=None, compat=True, train=True, **kwargs):
334 if what is None:
335 what = 'train' if train else 'test'
336 self.what = verify_str_arg(what, "what", tuple(self.subsets.keys()))
337 self.compat = compat
338 self.data_file = what + '.pt'
339 self.training_file = self.data_file
340 self.test_file = self.data_file
341 super(QMNIST, self).__init__(root, train, **kwargs)
343 def download(self):
344 """Download the QMNIST data if it doesn't exist in processed_folder already.
345 Note that we only download what has been asked for (argument 'what').
346 """
347 if self._check_exists():
348 return
349 makedir_exist_ok(self.raw_folder)
350 makedir_exist_ok(self.processed_folder)
351 urls = self.urls[self.subsets[self.what]]
352 files = []
354 # download data files if not already there
355 for url in urls:
356 filename = url.rpartition('/')[2]
357 file_path = os.path.join(self.raw_folder, filename)
358 if not os.path.isfile(file_path):
359 download_url(url, root=self.raw_folder, filename=filename, md5=None)
360 files.append(file_path)
362 # process and save as torch files
363 print('Processing...')
364 data = read_sn3_pascalvincent_tensor(files[0])
365 assert(data.dtype == torch.uint8)
366 assert(data.ndimension() == 3)
367 targets = read_sn3_pascalvincent_tensor(files[1]).long()
368 assert(targets.ndimension() == 2)
369 if self.what == 'test10k':
370 data = data[0:10000, :, :].clone()
371 targets = targets[0:10000, :].clone()
372 if self.what == 'test50k':
373 data = data[10000:, :, :].clone()
374 targets = targets[10000:, :].clone()
375 with open(os.path.join(self.processed_folder, self.data_file), 'wb') as f:
376 torch.save((data, targets), f)
378 def __getitem__(self, index):
379 # redefined to handle the compat flag
380 img, target = self.data[index], self.targets[index]
381 img = Image.fromarray(img.numpy(), mode='L')
382 if self.transform is not None:
383 img = self.transform(img)
384 if self.compat:
385 target = int(target[0])
386 if self.target_transform is not None:
387 target = self.target_transform(target)
388 return img, target
390 def extra_repr(self):
391 return "Split: {}".format(self.what)
394 def get_int(b):
395 return int(codecs.encode(b, 'hex'), 16)
398 def open_maybe_compressed_file(path):
399 """Return a file object that possibly decompresses 'path' on the fly.
400 Decompression occurs when argument `path` is a string and ends with '.gz' or '.xz'.
401 """
402 if not isinstance(path, torch._six.string_classes):
403 return path
404 if path.endswith('.gz'):
405 import gzip
406 return gzip.open(path, 'rb')
407 if path.endswith('.xz'):
408 import lzma
409 return lzma.open(path, 'rb')
410 return open(path, 'rb')
413 def read_sn3_pascalvincent_tensor(path, strict=True):
414 """Read a SN3 file in "Pascal Vincent" format (Lush file 'libidx/idx-io.lsh').
415 Argument may be a filename, compressed filename, or file object.
416 """
417 # typemap
418 if not hasattr(read_sn3_pascalvincent_tensor, 'typemap'):
419 read_sn3_pascalvincent_tensor.typemap = {
420 8: (torch.uint8, np.uint8, np.uint8),
421 9: (torch.int8, np.int8, np.int8),
422 11: (torch.int16, np.dtype('>i2'), 'i2'),
423 12: (torch.int32, np.dtype('>i4'), 'i4'),
424 13: (torch.float32, np.dtype('>f4'), 'f4'),
425 14: (torch.float64, np.dtype('>f8'), 'f8')}
426 # read
427 with open_maybe_compressed_file(path) as f:
428 data = f.read()
429 # parse
430 magic = get_int(data[0:4])
431 nd = magic % 256
432 ty = magic // 256
433 assert nd >= 1 and nd <= 3
434 assert ty >= 8 and ty <= 14
435 m = read_sn3_pascalvincent_tensor.typemap[ty]
436 s = [get_int(data[4 * (i + 1): 4 * (i + 2)]) for i in range(nd)]
437 parsed = np.frombuffer(data, dtype=m[1], offset=(4 * (nd + 1)))
438 assert parsed.shape[0] == np.prod(s) or not strict
439 return torch.from_numpy(parsed.astype(m[2], copy=False)).view(*s)
442 def read_label_file(path):
443 with open(path, 'rb') as f:
444 x = read_sn3_pascalvincent_tensor(f, strict=False)
445 assert(x.dtype == torch.uint8)
446 assert(x.ndimension() == 1)
447 return x.long()
450 def read_image_file(path):
451 with open(path, 'rb') as f:
452 x = read_sn3_pascalvincent_tensor(f, strict=False)
453 assert(x.dtype == torch.uint8)
454 assert(x.ndimension() == 3)
455 return x