]> Gitweb @ Texas Instruments - Open Source Git Repositories - git.TI.com/gitweb - jacinto-ai/pytorch-jacinto-ai-devkit.git/blobdiff - modules/pytorch_jacinto_ai/xvision/datasets/imagenet.py
docs - added deprecation notice
[jacinto-ai/pytorch-jacinto-ai-devkit.git] / modules / pytorch_jacinto_ai / xvision / datasets / imagenet.py
index 14a256c66ab3d4a78b981b737050a8ac84da672c..6dfc9bfebfd66d3f9cac016812a77269d9947191 100644 (file)
@@ -1,27 +1,21 @@
-from __future__ import print_function
+import warnings
+from contextlib import contextmanager
 import os
 import shutil
 import tempfile
+from typing import Any, Dict, List, Iterator, Optional, Tuple
 import torch
 from .folder import ImageFolder
-from .utils import check_integrity, download_and_extract_archive, extract_archive, \
-    verify_str_arg
-
-ARCHIVE_DICT = {
-    'train': {
-        'url': 'http://www.image-net.org/challenges/LSVRC/2012/nnoupb/ILSVRC2012_img_train.tar',
-        'md5': '1d675b47d978889d74fa0da5fadfb00e',
-    },
-    'val': {
-        'url': 'http://www.image-net.org/challenges/LSVRC/2012/nnoupb/ILSVRC2012_img_val.tar',
-        'md5': '29b22e2961454d5413ddabcf34fc5622',
-    },
-    'devkit': {
-        'url': 'http://www.image-net.org/challenges/LSVRC/2012/nnoupb/ILSVRC2012_devkit_t12.tar.gz',
-        'md5': 'fa75699e90414af021442c21a62c3abf',
-    }
+from .utils import check_integrity, extract_archive, verify_str_arg
+
+ARCHIVE_META = {
+    'train': ('ILSVRC2012_img_train.tar', '1d675b47d978889d74fa0da5fadfb00e'),
+    'val': ('ILSVRC2012_img_val.tar', '29b22e2961454d5413ddabcf34fc5622'),
+    'devkit': ('ILSVRC2012_devkit_t12.tar.gz', 'fa75699e90414af021442c21a62c3abf')
 }
 
+META_FILE = "meta.bin"
+
 
 class ImageNet(ImageFolder):
     """`ImageNet <http://image-net.org/>`_ 2012 Classification Dataset.
@@ -29,9 +23,6 @@ class ImageNet(ImageFolder):
     Args:
         root (string): Root directory of the ImageNet Dataset.
         split (string, optional): The dataset split, supports ``train``, or ``val``.
-        download (bool, optional): If true, downloads the dataset from the internet and
-            puts it in root directory. If dataset is already downloaded, it is not
-            downloaded again.
         transform (callable, optional): A function/transform that  takes in an PIL image
             and returns a transformed version. E.g, ``transforms.RandomCrop``
         target_transform (callable, optional): A function/transform that takes in the
@@ -47,13 +38,22 @@ class ImageNet(ImageFolder):
         targets (list): The class_index value for each image in the dataset
     """
 
-    def __init__(self, root, split='train', download=False, **kwargs):
+    def __init__(self, root: str, split: str = 'train', download: Optional[str] = None, **kwargs: Any) -> None:
+        if download is True:
+            msg = ("The dataset is no longer publicly accessible. You need to "
+                   "download the archives externally and place them in the root "
+                   "directory.")
+            raise RuntimeError(msg)
+        elif download is False:
+            msg = ("The use of the download flag is deprecated, since the dataset "
+                   "is no longer publicly accessible.")
+            warnings.warn(msg, RuntimeWarning)
+
         root = self.root = os.path.expanduser(root)
         self.split = verify_str_arg(split, "split", ("train", "val"))
 
-        if download:
-            self.download()
-        wnid_to_classes = self._load_meta_file()[0]
+        self.parse_archives()
+        wnid_to_classes = load_meta_file(self.root)[0]
 
         super(ImageNet, self).__init__(self.split_folder, **kwargs)
         self.root = root
@@ -65,107 +65,157 @@ class ImageNet(ImageFolder):
                              for idx, clss in enumerate(self.classes)
                              for cls in clss}
 
-    def download(self):
-        if not check_integrity(self.meta_file):
-            tmp_dir = tempfile.mkdtemp()
-
-            archive_dict = ARCHIVE_DICT['devkit']
-            download_and_extract_archive(archive_dict['url'], self.root,
-                                         extract_root=tmp_dir,
-                                         md5=archive_dict['md5'])
-            devkit_folder = _splitexts(os.path.basename(archive_dict['url']))[0]
-            meta = parse_devkit(os.path.join(tmp_dir, devkit_folder))
-            self._save_meta_file(*meta)
-
-            shutil.rmtree(tmp_dir)
+    def parse_archives(self) -> None:
+        if not check_integrity(os.path.join(self.root, META_FILE)):
+            parse_devkit_archive(self.root)
 
         if not os.path.isdir(self.split_folder):
-            archive_dict = ARCHIVE_DICT[self.split]
-            download_and_extract_archive(archive_dict['url'], self.root,
-                                         extract_root=self.split_folder,
-                                         md5=archive_dict['md5'])
-
             if self.split == 'train':
-                prepare_train_folder(self.split_folder)
+                parse_train_archive(self.root)
             elif self.split == 'val':
-                val_wnids = self._load_meta_file()[1]
-                prepare_val_folder(self.split_folder, val_wnids)
-        else:
-            msg = ("You set download=True, but a folder '{}' already exist in "
-                   "the root directory. If you want to re-download or re-extract the "
-                   "archive, delete the folder.")
-            print(msg.format(self.split))
+                parse_val_archive(self.root)
 
     @property
-    def meta_file(self):
-        return os.path.join(self.root, 'meta.bin')
+    def split_folder(self) -> str:
+        return os.path.join(self.root, self.split)
 
-    def _load_meta_file(self):
-        if check_integrity(self.meta_file):
-            return torch.load(self.meta_file)
-        else:
-            raise RuntimeError("Meta file not found or corrupted.",
-                               "You can use download=True to create it.")
+    def extra_repr(self) -> str:
+        return "Split: {split}".format(**self.__dict__)
 
-    def _save_meta_file(self, wnid_to_class, val_wnids):
-        torch.save((wnid_to_class, val_wnids), self.meta_file)
 
-    @property
-    def split_folder(self):
-        return os.path.join(self.root, self.split)
+def load_meta_file(root: str, file: Optional[str] = None) -> Tuple[Dict[str, str], List[str]]:
+    if file is None:
+        file = META_FILE
+    file = os.path.join(root, file)
 
-    def extra_repr(self):
-        return "Split: {split}".format(**self.__dict__)
+    if check_integrity(file):
+        return torch.load(file)
+    else:
+        msg = ("The meta file {} is not present in the root directory or is corrupted. "
+               "This file is automatically created by the ImageNet dataset.")
+        raise RuntimeError(msg.format(file, root))
 
 
-def parse_devkit(root):
-    idx_to_wnid, wnid_to_classes = parse_meta(root)
-    val_idcs = parse_val_groundtruth(root)
-    val_wnids = [idx_to_wnid[idx] for idx in val_idcs]
-    return wnid_to_classes, val_wnids
+def _verify_archive(root: str, file: str, md5: str) -> None:
+    if not check_integrity(os.path.join(root, file), md5):
+        msg = ("The archive {} is not present in the root directory or is corrupted. "
+               "You need to download it externally and place it in {}.")
+        raise RuntimeError(msg.format(file, root))
 
 
-def parse_meta(devkit_root, path='data', filename='meta.mat'):
+def parse_devkit_archive(root: str, file: Optional[str] = None) -> None:
+    """Parse the devkit archive of the ImageNet2012 classification dataset and save
+    the meta information in a binary file.
+
+    Args:
+        root (str): Root directory containing the devkit archive
+        file (str, optional): Name of devkit archive. Defaults to
+            'ILSVRC2012_devkit_t12.tar.gz'
+    """
     import scipy.io as sio
 
-    metafile = os.path.join(devkit_root, path, filename)
-    meta = sio.loadmat(metafile, squeeze_me=True)['synsets']
-    nums_children = list(zip(*meta))[4]
-    meta = [meta[idx] for idx, num_children in enumerate(nums_children)
-            if num_children == 0]
-    idcs, wnids, classes = list(zip(*meta))[:3]
-    classes = [tuple(clss.split(', ')) for clss in classes]
-    idx_to_wnid = {idx: wnid for idx, wnid in zip(idcs, wnids)}
-    wnid_to_classes = {wnid: clss for wnid, clss in zip(wnids, classes)}
-    return idx_to_wnid, wnid_to_classes
+    def parse_meta_mat(devkit_root: str) -> Tuple[Dict[int, str], Dict[str, str]]:
+        metafile = os.path.join(devkit_root, "data", "meta.mat")
+        meta = sio.loadmat(metafile, squeeze_me=True)['synsets']
+        nums_children = list(zip(*meta))[4]
+        meta = [meta[idx] for idx, num_children in enumerate(nums_children)
+                if num_children == 0]
+        idcs, wnids, classes = list(zip(*meta))[:3]
+        classes = [tuple(clss.split(', ')) for clss in classes]
+        idx_to_wnid = {idx: wnid for idx, wnid in zip(idcs, wnids)}
+        wnid_to_classes = {wnid: clss for wnid, clss in zip(wnids, classes)}
+        return idx_to_wnid, wnid_to_classes
+
+    def parse_val_groundtruth_txt(devkit_root: str) -> List[int]:
+        file = os.path.join(devkit_root, "data",
+                            "ILSVRC2012_validation_ground_truth.txt")
+        with open(file, 'r') as txtfh:
+            val_idcs = txtfh.readlines()
+        return [int(val_idx) for val_idx in val_idcs]
+
+    @contextmanager
+    def get_tmp_dir() -> Iterator[str]:
+        tmp_dir = tempfile.mkdtemp()
+        try:
+            yield tmp_dir
+        finally:
+            shutil.rmtree(tmp_dir)
 
+    archive_meta = ARCHIVE_META["devkit"]
+    if file is None:
+        file = archive_meta[0]
+    md5 = archive_meta[1]
 
-def parse_val_groundtruth(devkit_root, path='data',
-                          filename='ILSVRC2012_validation_ground_truth.txt'):
-    with open(os.path.join(devkit_root, path, filename), 'r') as txtfh:
-        val_idcs = txtfh.readlines()
-    return [int(val_idx) for val_idx in val_idcs]
+    _verify_archive(root, file, md5)
 
+    with get_tmp_dir() as tmp_dir:
+        extract_archive(os.path.join(root, file), tmp_dir)
 
-def prepare_train_folder(folder):
-    for archive in [os.path.join(folder, archive) for archive in os.listdir(folder)]:
+        devkit_root = os.path.join(tmp_dir, "ILSVRC2012_devkit_t12")
+        idx_to_wnid, wnid_to_classes = parse_meta_mat(devkit_root)
+        val_idcs = parse_val_groundtruth_txt(devkit_root)
+        val_wnids = [idx_to_wnid[idx] for idx in val_idcs]
+
+        torch.save((wnid_to_classes, val_wnids), os.path.join(root, META_FILE))
+
+
+def parse_train_archive(root: str, file: Optional[str] = None, folder: str = "train") -> None:
+    """Parse the train images archive of the ImageNet2012 classification dataset and
+    prepare it for usage with the ImageNet dataset.
+
+    Args:
+        root (str): Root directory containing the train images archive
+        file (str, optional): Name of train images archive. Defaults to
+            'ILSVRC2012_img_train.tar'
+        folder (str, optional): Optional name for train images folder. Defaults to
+            'train'
+    """
+    archive_meta = ARCHIVE_META["train"]
+    if file is None:
+        file = archive_meta[0]
+    md5 = archive_meta[1]
+
+    _verify_archive(root, file, md5)
+
+    train_root = os.path.join(root, folder)
+    extract_archive(os.path.join(root, file), train_root)
+
+    archives = [os.path.join(train_root, archive) for archive in os.listdir(train_root)]
+    for archive in archives:
         extract_archive(archive, os.path.splitext(archive)[0], remove_finished=True)
 
 
-def prepare_val_folder(folder, wnids):
-    img_files = sorted([os.path.join(folder, file) for file in os.listdir(folder)])
+def parse_val_archive(
+    root: str, file: Optional[str] = None, wnids: Optional[List[str]] = None, folder: str = "val"
+) -> None:
+    """Parse the validation images archive of the ImageNet2012 classification dataset
+    and prepare it for usage with the ImageNet dataset.
 
-    for wnid in set(wnids):
-        os.mkdir(os.path.join(folder, wnid))
+    Args:
+        root (str): Root directory containing the validation images archive
+        file (str, optional): Name of validation images archive. Defaults to
+            'ILSVRC2012_img_val.tar'
+        wnids (list, optional): List of WordNet IDs of the validation images. If None
+            is given, the IDs are loaded from the meta file in the root directory
+        folder (str, optional): Optional name for validation images folder. Defaults to
+            'val'
+    """
+    archive_meta = ARCHIVE_META["val"]
+    if file is None:
+        file = archive_meta[0]
+    md5 = archive_meta[1]
+    if wnids is None:
+        wnids = load_meta_file(root)[1]
 
-    for wnid, img_file in zip(wnids, img_files):
-        shutil.move(img_file, os.path.join(folder, wnid, os.path.basename(img_file)))
+    _verify_archive(root, file, md5)
 
+    val_root = os.path.join(root, folder)
+    extract_archive(os.path.join(root, file), val_root)
+
+    images = sorted([os.path.join(val_root, image) for image in os.listdir(val_root)])
+
+    for wnid in set(wnids):
+        os.mkdir(os.path.join(val_root, wnid))
 
-def _splitexts(root):
-    exts = []
-    ext = '.'
-    while ext:
-        root, ext = os.path.splitext(root)
-        exts.append(ext)
-    return root, ''.join(reversed(exts))
+    for wnid, img_file in zip(wnids, images):
+        shutil.move(img_file, os.path.join(val_root, wnid, os.path.basename(img_file)))