[jacinto-ai/pytorch-jacinto-ai-devkit.git] / modules / pytorch_jacinto_ai / xvision / datasets / omniglot.py
1 from __future__ import print_function
2 from PIL import Image
3 from os.path import join
4 import os
5 from .vision import VisionDataset
6 from .utils import download_and_extract_archive, check_integrity, list_dir, list_files
9 class Omniglot(VisionDataset):
10 """`Omniglot <https://github.com/brendenlake/omniglot>`_ Dataset.
11 Args:
12 root (string): Root directory of dataset where directory
13 ``omniglot-py`` exists.
14 background (bool, optional): If True, creates dataset from the "background" set, otherwise
15 creates from the "evaluation" set. This terminology is defined by the authors.
16 transform (callable, optional): A function/transform that takes in an PIL image
17 and returns a transformed version. E.g, ``transforms.RandomCrop``
18 target_transform (callable, optional): A function/transform that takes in the
19 target and transforms it.
20 download (bool, optional): If true, downloads the dataset zip files from the internet and
21 puts it in root directory. If the zip files are already downloaded, they are not
22 downloaded again.
23 """
24 folder = 'omniglot-py'
25 download_url_prefix = 'https://github.com/brendenlake/omniglot/raw/master/python'
26 zips_md5 = {
27 'images_background': '68d2efa1b9178cc56df9314c21c6e718',
28 'images_evaluation': '6b91aef0f799c5bb55b94e3f2daec811'
29 }
31 def __init__(self, root, background=True, transform=None, target_transform=None,
32 download=False):
33 super(Omniglot, self).__init__(join(root, self.folder), transform=transform,
34 target_transform=target_transform)
35 self.background = background
37 if download:
38 self.download()
40 if not self._check_integrity():
41 raise RuntimeError('Dataset not found or corrupted.' +
42 ' You can use download=True to download it')
44 self.target_folder = join(self.root, self._get_target_folder())
45 self._alphabets = list_dir(self.target_folder)
46 self._characters = sum([[join(a, c) for c in list_dir(join(self.target_folder, a))]
47 for a in self._alphabets], [])
48 self._character_images = [[(image, idx) for image in list_files(join(self.target_folder, character), '.png')]
49 for idx, character in enumerate(self._characters)]
50 self._flat_character_images = sum(self._character_images, [])
52 def __len__(self):
53 return len(self._flat_character_images)
55 def __getitem__(self, index):
56 """
57 Args:
58 index (int): Index
60 Returns:
61 tuple: (image, target) where target is index of the target character class.
62 """
63 image_name, character_class = self._flat_character_images[index]
64 image_path = join(self.target_folder, self._characters[character_class], image_name)
65 image = Image.open(image_path, mode='r').convert('L')
67 if self.transform:
68 image = self.transform(image)
70 if self.target_transform:
71 character_class = self.target_transform(character_class)
73 return image, character_class
75 def _check_integrity(self):
76 zip_filename = self._get_target_folder()
77 if not check_integrity(join(self.root, zip_filename + '.zip'), self.zips_md5[zip_filename]):
78 return False
79 return True
81 def download(self):
82 if self._check_integrity():
83 print('Files already downloaded and verified')
84 return
86 filename = self._get_target_folder()
87 zip_filename = filename + '.zip'
88 url = self.download_url_prefix + '/' + zip_filename
89 download_and_extract_archive(url, self.root, filename=zip_filename, md5=self.zips_md5[filename])
91 def _get_target_folder(self):
92 return 'images_background' if self.background else 'images_evaluation'