]> Gitweb @ Texas Instruments - Open Source Git Repositories - git.TI.com/gitweb - jacinto-ai/pytorch-jacinto-ai-devkit.git/blob - modules/pytorch_jacinto_ai/xvision/datasets/imagenet.py
docs - added deprecation notice
[jacinto-ai/pytorch-jacinto-ai-devkit.git] / modules / pytorch_jacinto_ai / xvision / datasets / imagenet.py
1 import warnings
2 from contextlib import contextmanager
3 import os
4 import shutil
5 import tempfile
6 from typing import Any, Dict, List, Iterator, Optional, Tuple
7 import torch
8 from .folder import ImageFolder
9 from .utils import check_integrity, extract_archive, verify_str_arg
11 ARCHIVE_DICT = {
12     'train': {
13         'url': 'http://www.image-net.org/challenges/LSVRC/2012/nnoupb/ILSVRC2012_img_train.tar',
14         'md5': '1d675b47d978889d74fa0da5fadfb00e',
15     },
16     'val': {
17         'url': 'http://www.image-net.org/challenges/LSVRC/2012/nnoupb/ILSVRC2012_img_val.tar',
18         'md5': '29b22e2961454d5413ddabcf34fc5622',
19     },
20     'devkit': {
21         'url': 'http://www.image-net.org/challenges/LSVRC/2012/nnoupb/ILSVRC2012_devkit_t12.tar.gz',
22         'md5': 'fa75699e90414af021442c21a62c3abf',
23     }
24 }
26 ARCHIVE_META = {
27     'train': ('ILSVRC2012_img_train.tar', '1d675b47d978889d74fa0da5fadfb00e'),
28     'val': ('ILSVRC2012_img_val.tar', '29b22e2961454d5413ddabcf34fc5622'),
29     'devkit': ('ILSVRC2012_devkit_t12.tar.gz', 'fa75699e90414af021442c21a62c3abf')
30 }
32 META_FILE = "meta.bin"
35 class ImageNet(ImageFolder):
36     """`ImageNet <http://image-net.org/>`_ 2012 Classification Dataset.
38     Args:
39         root (string): Root directory of the ImageNet Dataset.
40         split (string, optional): The dataset split, supports ``train``, or ``val``.
41         transform (callable, optional): A function/transform that  takes in an PIL image
42             and returns a transformed version. E.g, ``transforms.RandomCrop``
43         target_transform (callable, optional): A function/transform that takes in the
44             target and transforms it.
45         loader (callable, optional): A function to load an image given its path.
47      Attributes:
48         classes (list): List of the class name tuples.
49         class_to_idx (dict): Dict with items (class_name, class_index).
50         wnids (list): List of the WordNet IDs.
51         wnid_to_idx (dict): Dict with items (wordnet_id, class_index).
52         imgs (list): List of (image path, class_index) tuples
53         targets (list): The class_index value for each image in the dataset
54     """
56     def __init__(self, root: str, split: str = 'train', download: Optional[str] = None, **kwargs: Any) -> None:
57         if download is True:
58             msg = ("The dataset is no longer publicly accessible. You need to "
59                    "download the archives externally and place them in the root "
60                    "directory.")
61             raise RuntimeError(msg)
62         elif download is False:
63             msg = ("The use of the download flag is deprecated, since the dataset "
64                    "is no longer publicly accessible.")
65             warnings.warn(msg, RuntimeWarning)
67         root = self.root = os.path.expanduser(root)
68         self.split = verify_str_arg(split, "split", ("train", "val"))
70         self.parse_archives()
71         wnid_to_classes = load_meta_file(self.root)[0]
73         super(ImageNet, self).__init__(self.split_folder, **kwargs)
74         self.root = root
76         self.wnids = self.classes
77         self.wnid_to_idx = self.class_to_idx
78         self.classes = [wnid_to_classes[wnid] for wnid in self.wnids]
79         self.class_to_idx = {cls: idx
80                              for idx, clss in enumerate(self.classes)
81                              for cls in clss}
83     def parse_archives(self) -> None:
84         if not check_integrity(os.path.join(self.root, META_FILE)):
85             parse_devkit_archive(self.root)
87         if not os.path.isdir(self.split_folder):
88             if self.split == 'train':
89                 parse_train_archive(self.root)
90             elif self.split == 'val':
91                 parse_val_archive(self.root)
93     @property
94     def split_folder(self) -> str:
95         return os.path.join(self.root, self.split)
97     def extra_repr(self) -> str:
98         return "Split: {split}".format(**self.__dict__)
101 def load_meta_file(root: str, file: Optional[str] = None) -> Tuple[Dict[str, str], List[str]]:
102     if file is None:
103         file = META_FILE
104     file = os.path.join(root, file)
106     if check_integrity(file):
107         return torch.load(file)
108     else:
109         msg = ("The meta file {} is not present in the root directory or is corrupted. "
110                "This file is automatically created by the ImageNet dataset.")
111         raise RuntimeError(msg.format(file, root))
114 def _verify_archive(root: str, file: str, md5: str) -> None:
115     if not check_integrity(os.path.join(root, file), md5):
116         msg = ("The archive {} is not present in the root directory or is corrupted. "
117                "You need to download it externally and place it in {}.")
118         raise RuntimeError(msg.format(file, root))
121 def parse_devkit_archive(root: str, file: Optional[str] = None) -> None:
122     """Parse the devkit archive of the ImageNet2012 classification dataset and save
123     the meta information in a binary file.
125     Args:
126         root (str): Root directory containing the devkit archive
127         file (str, optional): Name of devkit archive. Defaults to
128             'ILSVRC2012_devkit_t12.tar.gz'
129     """
130     import scipy.io as sio
132     def parse_meta_mat(devkit_root: str) -> Tuple[Dict[int, str], Dict[str, str]]:
133         metafile = os.path.join(devkit_root, "data", "meta.mat")
134         meta = sio.loadmat(metafile, squeeze_me=True)['synsets']
135         nums_children = list(zip(*meta))[4]
136         meta = [meta[idx] for idx, num_children in enumerate(nums_children)
137                 if num_children == 0]
138         idcs, wnids, classes = list(zip(*meta))[:3]
139         classes = [tuple(clss.split(', ')) for clss in classes]
140         idx_to_wnid = {idx: wnid for idx, wnid in zip(idcs, wnids)}
141         wnid_to_classes = {wnid: clss for wnid, clss in zip(wnids, classes)}
142         return idx_to_wnid, wnid_to_classes
144     def parse_val_groundtruth_txt(devkit_root: str) -> List[int]:
145         file = os.path.join(devkit_root, "data",
146                             "ILSVRC2012_validation_ground_truth.txt")
147         with open(file, 'r') as txtfh:
148             val_idcs = txtfh.readlines()
149         return [int(val_idx) for val_idx in val_idcs]
151     @contextmanager
152     def get_tmp_dir() -> Iterator[str]:
153         tmp_dir = tempfile.mkdtemp()
154         try:
155             yield tmp_dir
156         finally:
157             shutil.rmtree(tmp_dir)
159     archive_meta = ARCHIVE_META["devkit"]
160     if file is None:
161         file = archive_meta[0]
162     md5 = archive_meta[1]
164     _verify_archive(root, file, md5)
166     with get_tmp_dir() as tmp_dir:
167         extract_archive(os.path.join(root, file), tmp_dir)
169         devkit_root = os.path.join(tmp_dir, "ILSVRC2012_devkit_t12")
170         idx_to_wnid, wnid_to_classes = parse_meta_mat(devkit_root)
171         val_idcs = parse_val_groundtruth_txt(devkit_root)
172         val_wnids = [idx_to_wnid[idx] for idx in val_idcs]
174         torch.save((wnid_to_classes, val_wnids), os.path.join(root, META_FILE))
177 def parse_train_archive(root: str, file: Optional[str] = None, folder: str = "train") -> None:
178     """Parse the train images archive of the ImageNet2012 classification dataset and
179     prepare it for usage with the ImageNet dataset.
181     Args:
182         root (str): Root directory containing the train images archive
183         file (str, optional): Name of train images archive. Defaults to
184             'ILSVRC2012_img_train.tar'
185         folder (str, optional): Optional name for train images folder. Defaults to
186             'train'
187     """
188     archive_meta = ARCHIVE_META["train"]
189     if file is None:
190         file = archive_meta[0]
191     md5 = archive_meta[1]
193     _verify_archive(root, file, md5)
195     train_root = os.path.join(root, folder)
196     extract_archive(os.path.join(root, file), train_root)
198     archives = [os.path.join(train_root, archive) for archive in os.listdir(train_root)]
199     for archive in archives:
200         extract_archive(archive, os.path.splitext(archive)[0], remove_finished=True)
203 def parse_val_archive(
204     root: str, file: Optional[str] = None, wnids: Optional[List[str]] = None, folder: str = "val"
205 ) -> None:
206     """Parse the validation images archive of the ImageNet2012 classification dataset
207     and prepare it for usage with the ImageNet dataset.
209     Args:
210         root (str): Root directory containing the validation images archive
211         file (str, optional): Name of validation images archive. Defaults to
212             'ILSVRC2012_img_val.tar'
213         wnids (list, optional): List of WordNet IDs of the validation images. If None
214             is given, the IDs are loaded from the meta file in the root directory
215         folder (str, optional): Optional name for validation images folder. Defaults to
216             'val'
217     """
218     archive_meta = ARCHIVE_META["val"]
219     if file is None:
220         file = archive_meta[0]
221     md5 = archive_meta[1]
222     if wnids is None:
223         wnids = load_meta_file(root)[1]
225     _verify_archive(root, file, md5)
227     val_root = os.path.join(root, folder)
228     extract_archive(os.path.join(root, file), val_root)
230     images = sorted([os.path.join(val_root, image) for image in os.listdir(val_root)])
232     for wnid in set(wnids):
233         os.mkdir(os.path.join(val_root, wnid))
235     for wnid, img_file in zip(wnids, images):
236         shutil.move(img_file, os.path.join(val_root, wnid, os.path.basename(img_file)))