14a256c66ab3d4a78b981b737050a8ac84da672c
[jacinto-ai/pytorch-jacinto-ai-devkit.git] / modules / pytorch_jacinto_ai / vision / datasets / imagenet.py
1 from __future__ import print_function
2 import os
3 import shutil
4 import tempfile
5 import torch
6 from .folder import ImageFolder
7 from .utils import check_integrity, download_and_extract_archive, extract_archive, \
8 verify_str_arg
10 ARCHIVE_DICT = {
11 'train': {
12 'url': 'http://www.image-net.org/challenges/LSVRC/2012/nnoupb/ILSVRC2012_img_train.tar',
13 'md5': '1d675b47d978889d74fa0da5fadfb00e',
14 },
15 'val': {
16 'url': 'http://www.image-net.org/challenges/LSVRC/2012/nnoupb/ILSVRC2012_img_val.tar',
17 'md5': '29b22e2961454d5413ddabcf34fc5622',
18 },
19 'devkit': {
20 'url': 'http://www.image-net.org/challenges/LSVRC/2012/nnoupb/ILSVRC2012_devkit_t12.tar.gz',
21 'md5': 'fa75699e90414af021442c21a62c3abf',
22 }
23 }
26 class ImageNet(ImageFolder):
27 """`ImageNet <http://image-net.org/>`_ 2012 Classification Dataset.
29 Args:
30 root (string): Root directory of the ImageNet Dataset.
31 split (string, optional): The dataset split, supports ``train``, or ``val``.
32 download (bool, optional): If true, downloads the dataset from the internet and
33 puts it in root directory. If dataset is already downloaded, it is not
34 downloaded again.
35 transform (callable, optional): A function/transform that takes in an PIL image
36 and returns a transformed version. E.g, ``transforms.RandomCrop``
37 target_transform (callable, optional): A function/transform that takes in the
38 target and transforms it.
39 loader (callable, optional): A function to load an image given its path.
41 Attributes:
42 classes (list): List of the class name tuples.
43 class_to_idx (dict): Dict with items (class_name, class_index).
44 wnids (list): List of the WordNet IDs.
45 wnid_to_idx (dict): Dict with items (wordnet_id, class_index).
46 imgs (list): List of (image path, class_index) tuples
47 targets (list): The class_index value for each image in the dataset
48 """
50 def __init__(self, root, split='train', download=False, **kwargs):
51 root = self.root = os.path.expanduser(root)
52 self.split = verify_str_arg(split, "split", ("train", "val"))
54 if download:
55 self.download()
56 wnid_to_classes = self._load_meta_file()[0]
58 super(ImageNet, self).__init__(self.split_folder, **kwargs)
59 self.root = root
61 self.wnids = self.classes
62 self.wnid_to_idx = self.class_to_idx
63 self.classes = [wnid_to_classes[wnid] for wnid in self.wnids]
64 self.class_to_idx = {cls: idx
65 for idx, clss in enumerate(self.classes)
66 for cls in clss}
68 def download(self):
69 if not check_integrity(self.meta_file):
70 tmp_dir = tempfile.mkdtemp()
72 archive_dict = ARCHIVE_DICT['devkit']
73 download_and_extract_archive(archive_dict['url'], self.root,
74 extract_root=tmp_dir,
75 md5=archive_dict['md5'])
76 devkit_folder = _splitexts(os.path.basename(archive_dict['url']))[0]
77 meta = parse_devkit(os.path.join(tmp_dir, devkit_folder))
78 self._save_meta_file(*meta)
80 shutil.rmtree(tmp_dir)
82 if not os.path.isdir(self.split_folder):
83 archive_dict = ARCHIVE_DICT[self.split]
84 download_and_extract_archive(archive_dict['url'], self.root,
85 extract_root=self.split_folder,
86 md5=archive_dict['md5'])
88 if self.split == 'train':
89 prepare_train_folder(self.split_folder)
90 elif self.split == 'val':
91 val_wnids = self._load_meta_file()[1]
92 prepare_val_folder(self.split_folder, val_wnids)
93 else:
94 msg = ("You set download=True, but a folder '{}' already exist in "
95 "the root directory. If you want to re-download or re-extract the "
96 "archive, delete the folder.")
97 print(msg.format(self.split))
99 @property
100 def meta_file(self):
101 return os.path.join(self.root, 'meta.bin')
103 def _load_meta_file(self):
104 if check_integrity(self.meta_file):
105 return torch.load(self.meta_file)
106 else:
107 raise RuntimeError("Meta file not found or corrupted.",
108 "You can use download=True to create it.")
110 def _save_meta_file(self, wnid_to_class, val_wnids):
111 torch.save((wnid_to_class, val_wnids), self.meta_file)
113 @property
114 def split_folder(self):
115 return os.path.join(self.root, self.split)
117 def extra_repr(self):
118 return "Split: {split}".format(**self.__dict__)
121 def parse_devkit(root):
122 idx_to_wnid, wnid_to_classes = parse_meta(root)
123 val_idcs = parse_val_groundtruth(root)
124 val_wnids = [idx_to_wnid[idx] for idx in val_idcs]
125 return wnid_to_classes, val_wnids
128 def parse_meta(devkit_root, path='data', filename='meta.mat'):
129 import scipy.io as sio
131 metafile = os.path.join(devkit_root, path, filename)
132 meta = sio.loadmat(metafile, squeeze_me=True)['synsets']
133 nums_children = list(zip(*meta))[4]
134 meta = [meta[idx] for idx, num_children in enumerate(nums_children)
135 if num_children == 0]
136 idcs, wnids, classes = list(zip(*meta))[:3]
137 classes = [tuple(clss.split(', ')) for clss in classes]
138 idx_to_wnid = {idx: wnid for idx, wnid in zip(idcs, wnids)}
139 wnid_to_classes = {wnid: clss for wnid, clss in zip(wnids, classes)}
140 return idx_to_wnid, wnid_to_classes
143 def parse_val_groundtruth(devkit_root, path='data',
144 filename='ILSVRC2012_validation_ground_truth.txt'):
145 with open(os.path.join(devkit_root, path, filename), 'r') as txtfh:
146 val_idcs = txtfh.readlines()
147 return [int(val_idx) for val_idx in val_idcs]
150 def prepare_train_folder(folder):
151 for archive in [os.path.join(folder, archive) for archive in os.listdir(folder)]:
152 extract_archive(archive, os.path.splitext(archive)[0], remove_finished=True)
155 def prepare_val_folder(folder, wnids):
156 img_files = sorted([os.path.join(folder, file) for file in os.listdir(folder)])
158 for wnid in set(wnids):
159 os.mkdir(os.path.join(folder, wnid))
161 for wnid, img_file in zip(wnids, img_files):
162 shutil.move(img_file, os.path.join(folder, wnid, os.path.basename(img_file)))
165 def _splitexts(root):
166 exts = []
167 ext = '.'
168 while ext:
169 root, ext = os.path.splitext(root)
170 exts.append(ext)
171 return root, ''.join(reversed(exts))