[jacinto-ai/pytorch-jacinto-ai-devkit.git] / modules / pytorch_jacinto_ai / xvision / datasets / celeba.py
1 from functools import partial
2 import torch
3 import os
4 import PIL
5 from .vision import VisionDataset
6 from .utils import download_file_from_google_drive, check_integrity, verify_str_arg
9 class CelebA(VisionDataset):
10 """`Large-scale CelebFaces Attributes (CelebA) Dataset <http://mmlab.ie.cuhk.edu.hk/projects/CelebA.html>`_ Dataset.
12 Args:
13 root (string): Root directory where images are downloaded to.
14 split (string): One of {'train', 'valid', 'test', 'all'}.
15 Accordingly dataset is selected.
16 target_type (string or list, optional): Type of target to use, ``attr``, ``identity``, ``bbox``,
17 or ``landmarks``. Can also be a list to output a tuple with all specified target types.
18 The targets represent:
19 ``attr`` (np.array shape=(40,) dtype=int): binary (0, 1) labels for attributes
20 ``identity`` (int): label for each person (data points with the same identity are the same person)
21 ``bbox`` (np.array shape=(4,) dtype=int): bounding box (x, y, width, height)
22 ``landmarks`` (np.array shape=(10,) dtype=int): landmark points (lefteye_x, lefteye_y, righteye_x,
23 righteye_y, nose_x, nose_y, leftmouth_x, leftmouth_y, rightmouth_x, rightmouth_y)
24 Defaults to ``attr``.
25 transform (callable, optional): A function/transform that takes in an PIL image
26 and returns a transformed version. E.g, ``transforms.ToTensor``
27 target_transform (callable, optional): A function/transform that takes in the
28 target and transforms it.
29 download (bool, optional): If true, downloads the dataset from the internet and
30 puts it in root directory. If dataset is already downloaded, it is not
31 downloaded again.
32 """
34 base_folder = "celeba"
35 # There currently does not appear to be a easy way to extract 7z in python (without introducing additional
36 # dependencies). The "in-the-wild" (not aligned+cropped) images are only in 7z, so they are not available
37 # right now.
38 file_list = [
39 # File ID MD5 Hash Filename
40 ("0B7EVK8r0v71pZjFTYXZWM3FlRnM", "00d2c5bc6d35e252742224ab0c1e8fcb", "img_align_celeba.zip"),
41 # ("0B7EVK8r0v71pbWNEUjJKdDQ3dGc", "b6cd7e93bc7a96c2dc33f819aa3ac651", "img_align_celeba_png.7z"),
42 # ("0B7EVK8r0v71peklHb0pGdDl6R28", "b6cd7e93bc7a96c2dc33f819aa3ac651", "img_celeba.7z"),
43 ("0B7EVK8r0v71pblRyaVFSWGxPY0U", "75e246fa4810816ffd6ee81facbd244c", "list_attr_celeba.txt"),
44 ("1_ee_0u7vcNLOfNLegJRHmolfH5ICW-XS", "32bd1bd63d3c78cd57e08160ec5ed1e2", "identity_CelebA.txt"),
45 ("0B7EVK8r0v71pbThiMVRxWXZ4dU0", "00566efa6fedff7a56946cd1c10f1c16", "list_bbox_celeba.txt"),
46 ("0B7EVK8r0v71pd0FJY3Blby1HUTQ", "cc24ecafdb5b50baae59b03474781f8c", "list_landmarks_align_celeba.txt"),
47 # ("0B7EVK8r0v71pTzJIdlJWdHczRlU", "063ee6ddb681f96bc9ca28c6febb9d1a", "list_landmarks_celeba.txt"),
48 ("0B7EVK8r0v71pY0NSMzRuSXJEVkk", "d32c9cbf5e040fd4025c592c306e6668", "list_eval_partition.txt"),
49 ]
51 def __init__(self, root, split="train", target_type="attr", transform=None,
52 target_transform=None, download=False):
53 import pandas
54 super(CelebA, self).__init__(root, transform=transform,
55 target_transform=target_transform)
56 self.split = split
57 if isinstance(target_type, list):
58 self.target_type = target_type
59 else:
60 self.target_type = [target_type]
62 if download:
63 self.download()
65 if not self._check_integrity():
66 raise RuntimeError('Dataset not found or corrupted.' +
67 ' You can use download=True to download it')
69 split_map = {
70 "train": 0,
71 "valid": 1,
72 "test": 2,
73 "all": None,
74 }
75 split = split_map[verify_str_arg(split.lower(), "split",
76 ("train", "valid", "test", "all"))]
78 fn = partial(os.path.join, self.root, self.base_folder)
79 splits = pandas.read_csv(fn("list_eval_partition.txt"), delim_whitespace=True, header=None, index_col=0)
80 identity = pandas.read_csv(fn("identity_CelebA.txt"), delim_whitespace=True, header=None, index_col=0)
81 bbox = pandas.read_csv(fn("list_bbox_celeba.txt"), delim_whitespace=True, header=1, index_col=0)
82 landmarks_align = pandas.read_csv(fn("list_landmarks_align_celeba.txt"), delim_whitespace=True, header=1)
83 attr = pandas.read_csv(fn("list_attr_celeba.txt"), delim_whitespace=True, header=1)
85 mask = slice(None) if split is None else (splits[1] == split)
87 self.filename = splits[mask].index.values
88 self.identity = torch.as_tensor(identity[mask].values)
89 self.bbox = torch.as_tensor(bbox[mask].values)
90 self.landmarks_align = torch.as_tensor(landmarks_align[mask].values)
91 self.attr = torch.as_tensor(attr[mask].values)
92 self.attr = (self.attr + 1) // 2 # map from {-1, 1} to {0, 1}
93 self.attr_names = list(attr.columns)
95 def _check_integrity(self):
96 for (_, md5, filename) in self.file_list:
97 fpath = os.path.join(self.root, self.base_folder, filename)
98 _, ext = os.path.splitext(filename)
99 # Allow original archive to be deleted (zip and 7z)
100 # Only need the extracted images
101 if ext not in [".zip", ".7z"] and not check_integrity(fpath, md5):
102 return False
104 # Should check a hash of the images
105 return os.path.isdir(os.path.join(self.root, self.base_folder, "img_align_celeba"))
107 def download(self):
108 import zipfile
110 if self._check_integrity():
111 print('Files already downloaded and verified')
112 return
114 for (file_id, md5, filename) in self.file_list:
115 download_file_from_google_drive(file_id, os.path.join(self.root, self.base_folder), filename, md5)
117 with zipfile.ZipFile(os.path.join(self.root, self.base_folder, "img_align_celeba.zip"), "r") as f:
118 f.extractall(os.path.join(self.root, self.base_folder))
120 def __getitem__(self, index):
121 X = PIL.Image.open(os.path.join(self.root, self.base_folder, "img_align_celeba", self.filename[index]))
123 target = []
124 for t in self.target_type:
125 if t == "attr":
126 target.append(self.attr[index, :])
127 elif t == "identity":
128 target.append(self.identity[index, 0])
129 elif t == "bbox":
130 target.append(self.bbox[index, :])
131 elif t == "landmarks":
132 target.append(self.landmarks_align[index, :])
133 else:
134 # TODO: refactor with utils.verify_str_arg
135 raise ValueError("Target type \"{}\" is not recognized.".format(t))
136 target = tuple(target) if len(target) > 1 else target[0]
138 if self.transform is not None:
139 X = self.transform(X)
141 if self.target_transform is not None:
142 target = self.target_transform(target)
144 return X, target
146 def __len__(self):
147 return len(self.attr)
149 def extra_repr(self):
150 lines = ["Target type: {target_type}", "Split: {split}"]
151 return '\n'.join(lines).format(**self.__dict__)