[jacinto-ai/pytorch-jacinto-ai-devkit.git] / modules / pytorch_jacinto_ai / xvision / datasets / imagenet.py
diff --git a/modules/pytorch_jacinto_ai/xvision/datasets/imagenet.py b/modules/pytorch_jacinto_ai/xvision/datasets/imagenet.py
index 14a256c66ab3d4a78b981b737050a8ac84da672c..6dfc9bfebfd66d3f9cac016812a77269d9947191 100644 (file)
-from __future__ import print_function
+import warnings
+from contextlib import contextmanager
import os
import shutil
import tempfile
+from typing import Any, Dict, List, Iterator, Optional, Tuple
import torch
from .folder import ImageFolder
-from .utils import check_integrity, download_and_extract_archive, extract_archive, \
- verify_str_arg
-
-ARCHIVE_DICT = {
- 'train': {
- 'url': 'http://www.image-net.org/challenges/LSVRC/2012/nnoupb/ILSVRC2012_img_train.tar',
- 'md5': '1d675b47d978889d74fa0da5fadfb00e',
- },
- 'val': {
- 'url': 'http://www.image-net.org/challenges/LSVRC/2012/nnoupb/ILSVRC2012_img_val.tar',
- 'md5': '29b22e2961454d5413ddabcf34fc5622',
- },
- 'devkit': {
- 'url': 'http://www.image-net.org/challenges/LSVRC/2012/nnoupb/ILSVRC2012_devkit_t12.tar.gz',
- 'md5': 'fa75699e90414af021442c21a62c3abf',
- }
+from .utils import check_integrity, extract_archive, verify_str_arg
+
+ARCHIVE_META = {
+ 'train': ('ILSVRC2012_img_train.tar', '1d675b47d978889d74fa0da5fadfb00e'),
+ 'val': ('ILSVRC2012_img_val.tar', '29b22e2961454d5413ddabcf34fc5622'),
+ 'devkit': ('ILSVRC2012_devkit_t12.tar.gz', 'fa75699e90414af021442c21a62c3abf')
}
+META_FILE = "meta.bin"
+
class ImageNet(ImageFolder):
"""`ImageNet <http://image-net.org/>`_ 2012 Classification Dataset.
Args:
root (string): Root directory of the ImageNet Dataset.
split (string, optional): The dataset split, supports ``train``, or ``val``.
- download (bool, optional): If true, downloads the dataset from the internet and
- puts it in root directory. If dataset is already downloaded, it is not
- downloaded again.
transform (callable, optional): A function/transform that takes in an PIL image
and returns a transformed version. E.g, ``transforms.RandomCrop``
target_transform (callable, optional): A function/transform that takes in the
targets (list): The class_index value for each image in the dataset
"""
- def __init__(self, root, split='train', download=False, **kwargs):
+ def __init__(self, root: str, split: str = 'train', download: Optional[str] = None, **kwargs: Any) -> None:
+ if download is True:
+ msg = ("The dataset is no longer publicly accessible. You need to "
+ "download the archives externally and place them in the root "
+ "directory.")
+ raise RuntimeError(msg)
+ elif download is False:
+ msg = ("The use of the download flag is deprecated, since the dataset "
+ "is no longer publicly accessible.")
+ warnings.warn(msg, RuntimeWarning)
+
root = self.root = os.path.expanduser(root)
self.split = verify_str_arg(split, "split", ("train", "val"))
- if download:
- self.download()
- wnid_to_classes = self._load_meta_file()[0]
+ self.parse_archives()
+ wnid_to_classes = load_meta_file(self.root)[0]
super(ImageNet, self).__init__(self.split_folder, **kwargs)
self.root = root
for idx, clss in enumerate(self.classes)
for cls in clss}
- def download(self):
- if not check_integrity(self.meta_file):
- tmp_dir = tempfile.mkdtemp()
-
- archive_dict = ARCHIVE_DICT['devkit']
- download_and_extract_archive(archive_dict['url'], self.root,
- extract_root=tmp_dir,
- md5=archive_dict['md5'])
- devkit_folder = _splitexts(os.path.basename(archive_dict['url']))[0]
- meta = parse_devkit(os.path.join(tmp_dir, devkit_folder))
- self._save_meta_file(*meta)
-
- shutil.rmtree(tmp_dir)
+ def parse_archives(self) -> None:
+ if not check_integrity(os.path.join(self.root, META_FILE)):
+ parse_devkit_archive(self.root)
if not os.path.isdir(self.split_folder):
- archive_dict = ARCHIVE_DICT[self.split]
- download_and_extract_archive(archive_dict['url'], self.root,
- extract_root=self.split_folder,
- md5=archive_dict['md5'])
-
if self.split == 'train':
- prepare_train_folder(self.split_folder)
+ parse_train_archive(self.root)
elif self.split == 'val':
- val_wnids = self._load_meta_file()[1]
- prepare_val_folder(self.split_folder, val_wnids)
- else:
- msg = ("You set download=True, but a folder '{}' already exist in "
- "the root directory. If you want to re-download or re-extract the "
- "archive, delete the folder.")
- print(msg.format(self.split))
+ parse_val_archive(self.root)
@property
- def meta_file(self):
- return os.path.join(self.root, 'meta.bin')
+ def split_folder(self) -> str:
+ return os.path.join(self.root, self.split)
- def _load_meta_file(self):
- if check_integrity(self.meta_file):
- return torch.load(self.meta_file)
- else:
- raise RuntimeError("Meta file not found or corrupted.",
- "You can use download=True to create it.")
+ def extra_repr(self) -> str:
+ return "Split: {split}".format(**self.__dict__)
- def _save_meta_file(self, wnid_to_class, val_wnids):
- torch.save((wnid_to_class, val_wnids), self.meta_file)
- @property
- def split_folder(self):
- return os.path.join(self.root, self.split)
+def load_meta_file(root: str, file: Optional[str] = None) -> Tuple[Dict[str, str], List[str]]:
+ if file is None:
+ file = META_FILE
+ file = os.path.join(root, file)
- def extra_repr(self):
- return "Split: {split}".format(**self.__dict__)
+ if check_integrity(file):
+ return torch.load(file)
+ else:
+ msg = ("The meta file {} is not present in the root directory or is corrupted. "
+ "This file is automatically created by the ImageNet dataset.")
+ raise RuntimeError(msg.format(file, root))
-def parse_devkit(root):
- idx_to_wnid, wnid_to_classes = parse_meta(root)
- val_idcs = parse_val_groundtruth(root)
- val_wnids = [idx_to_wnid[idx] for idx in val_idcs]
- return wnid_to_classes, val_wnids
+def _verify_archive(root: str, file: str, md5: str) -> None:
+ if not check_integrity(os.path.join(root, file), md5):
+ msg = ("The archive {} is not present in the root directory or is corrupted. "
+ "You need to download it externally and place it in {}.")
+ raise RuntimeError(msg.format(file, root))
-def parse_meta(devkit_root, path='data', filename='meta.mat'):
+def parse_devkit_archive(root: str, file: Optional[str] = None) -> None:
+ """Parse the devkit archive of the ImageNet2012 classification dataset and save
+ the meta information in a binary file.
+
+ Args:
+ root (str): Root directory containing the devkit archive
+ file (str, optional): Name of devkit archive. Defaults to
+ 'ILSVRC2012_devkit_t12.tar.gz'
+ """
import scipy.io as sio
- metafile = os.path.join(devkit_root, path, filename)
- meta = sio.loadmat(metafile, squeeze_me=True)['synsets']
- nums_children = list(zip(*meta))[4]
- meta = [meta[idx] for idx, num_children in enumerate(nums_children)
- if num_children == 0]
- idcs, wnids, classes = list(zip(*meta))[:3]
- classes = [tuple(clss.split(', ')) for clss in classes]
- idx_to_wnid = {idx: wnid for idx, wnid in zip(idcs, wnids)}
- wnid_to_classes = {wnid: clss for wnid, clss in zip(wnids, classes)}
- return idx_to_wnid, wnid_to_classes
+ def parse_meta_mat(devkit_root: str) -> Tuple[Dict[int, str], Dict[str, str]]:
+ metafile = os.path.join(devkit_root, "data", "meta.mat")
+ meta = sio.loadmat(metafile, squeeze_me=True)['synsets']
+ nums_children = list(zip(*meta))[4]
+ meta = [meta[idx] for idx, num_children in enumerate(nums_children)
+ if num_children == 0]
+ idcs, wnids, classes = list(zip(*meta))[:3]
+ classes = [tuple(clss.split(', ')) for clss in classes]
+ idx_to_wnid = {idx: wnid for idx, wnid in zip(idcs, wnids)}
+ wnid_to_classes = {wnid: clss for wnid, clss in zip(wnids, classes)}
+ return idx_to_wnid, wnid_to_classes
+
+ def parse_val_groundtruth_txt(devkit_root: str) -> List[int]:
+ file = os.path.join(devkit_root, "data",
+ "ILSVRC2012_validation_ground_truth.txt")
+ with open(file, 'r') as txtfh:
+ val_idcs = txtfh.readlines()
+ return [int(val_idx) for val_idx in val_idcs]
+
+ @contextmanager
+ def get_tmp_dir() -> Iterator[str]:
+ tmp_dir = tempfile.mkdtemp()
+ try:
+ yield tmp_dir
+ finally:
+ shutil.rmtree(tmp_dir)
+ archive_meta = ARCHIVE_META["devkit"]
+ if file is None:
+ file = archive_meta[0]
+ md5 = archive_meta[1]
-def parse_val_groundtruth(devkit_root, path='data',
- filename='ILSVRC2012_validation_ground_truth.txt'):
- with open(os.path.join(devkit_root, path, filename), 'r') as txtfh:
- val_idcs = txtfh.readlines()
- return [int(val_idx) for val_idx in val_idcs]
+ _verify_archive(root, file, md5)
+ with get_tmp_dir() as tmp_dir:
+ extract_archive(os.path.join(root, file), tmp_dir)
-def prepare_train_folder(folder):
- for archive in [os.path.join(folder, archive) for archive in os.listdir(folder)]:
+ devkit_root = os.path.join(tmp_dir, "ILSVRC2012_devkit_t12")
+ idx_to_wnid, wnid_to_classes = parse_meta_mat(devkit_root)
+ val_idcs = parse_val_groundtruth_txt(devkit_root)
+ val_wnids = [idx_to_wnid[idx] for idx in val_idcs]
+
+ torch.save((wnid_to_classes, val_wnids), os.path.join(root, META_FILE))
+
+
+def parse_train_archive(root: str, file: Optional[str] = None, folder: str = "train") -> None:
+ """Parse the train images archive of the ImageNet2012 classification dataset and
+ prepare it for usage with the ImageNet dataset.
+
+ Args:
+ root (str): Root directory containing the train images archive
+ file (str, optional): Name of train images archive. Defaults to
+ 'ILSVRC2012_img_train.tar'
+ folder (str, optional): Optional name for train images folder. Defaults to
+ 'train'
+ """
+ archive_meta = ARCHIVE_META["train"]
+ if file is None:
+ file = archive_meta[0]
+ md5 = archive_meta[1]
+
+ _verify_archive(root, file, md5)
+
+ train_root = os.path.join(root, folder)
+ extract_archive(os.path.join(root, file), train_root)
+
+ archives = [os.path.join(train_root, archive) for archive in os.listdir(train_root)]
+ for archive in archives:
extract_archive(archive, os.path.splitext(archive)[0], remove_finished=True)
-def prepare_val_folder(folder, wnids):
- img_files = sorted([os.path.join(folder, file) for file in os.listdir(folder)])
+def parse_val_archive(
+ root: str, file: Optional[str] = None, wnids: Optional[List[str]] = None, folder: str = "val"
+) -> None:
+ """Parse the validation images archive of the ImageNet2012 classification dataset
+ and prepare it for usage with the ImageNet dataset.
- for wnid in set(wnids):
- os.mkdir(os.path.join(folder, wnid))
+ Args:
+ root (str): Root directory containing the validation images archive
+ file (str, optional): Name of validation images archive. Defaults to
+ 'ILSVRC2012_img_val.tar'
+ wnids (list, optional): List of WordNet IDs of the validation images. If None
+ is given, the IDs are loaded from the meta file in the root directory
+ folder (str, optional): Optional name for validation images folder. Defaults to
+ 'val'
+ """
+ archive_meta = ARCHIVE_META["val"]
+ if file is None:
+ file = archive_meta[0]
+ md5 = archive_meta[1]
+ if wnids is None:
+ wnids = load_meta_file(root)[1]
- for wnid, img_file in zip(wnids, img_files):
- shutil.move(img_file, os.path.join(folder, wnid, os.path.basename(img_file)))
+ _verify_archive(root, file, md5)
+ val_root = os.path.join(root, folder)
+ extract_archive(os.path.join(root, file), val_root)
+
+ images = sorted([os.path.join(val_root, image) for image in os.listdir(val_root)])
+
+ for wnid in set(wnids):
+ os.mkdir(os.path.join(val_root, wnid))
-def _splitexts(root):
- exts = []
- ext = '.'
- while ext:
- root, ext = os.path.splitext(root)
- exts.append(ext)
- return root, ''.join(reversed(exts))
+ for wnid, img_file in zip(wnids, images):
+ shutil.move(img_file, os.path.join(val_root, wnid, os.path.basename(img_file)))