1f28e09786719985debe90f87d3c09ae1f7af830
[jacinto-ai/pytorch-jacinto-ai-devkit.git] / modules / pytorch_jacinto_ai / vision / datasets / voc.py
1 '''
2 Dataset loader for the PASCAL VOC dataset:
4 - Everingham, M., Van Gool, L., Williams, C. K. I., Winn, J. and Zisserman, A.,
5 The PASCAL Visual Object Classes (VOC) Challenge, International Journal of Computer Vision,
6 88(2), 303-338, 2010
8 - Everingham, M. and Van~Gool, L. and Williams, C. K. I. and Winn, J. and Zisserman, A.",
9 The PASCAL Visual Object Classes Challenge 2012, VOC2012",
10 "http://www.pascal-network.org/challenges/VOC/voc2012/workshop/index.html"
11 '''
13 import os
14 import sys
15 import tarfile
16 import numpy as np
17 import collections
18 from collections import namedtuple
19 from .vision import VisionDataset
21 if sys.version_info[0] == 2:
22 import xml.etree.cElementTree as ET
23 else:
24 import xml.etree.ElementTree as ET
26 from PIL import Image
27 from .utils import download_url, check_integrity, verify_str_arg
29 DATASET_YEAR_DICT = {
30 '2012': {
31 'url': 'http://host.robots.ox.ac.uk/pascal/VOC/voc2012/VOCtrainval_11-May-2012.tar',
32 'filename': 'VOCtrainval_11-May-2012.tar',
33 'md5': '6cd6e144f989b92b3379bac3b3de84fd',
34 'base_dir': 'VOCdevkit/VOC2012'
35 },
36 '2011': {
37 'url': 'http://host.robots.ox.ac.uk/pascal/VOC/voc2011/VOCtrainval_25-May-2011.tar',
38 'filename': 'VOCtrainval_25-May-2011.tar',
39 'md5': '6c3384ef61512963050cb5d687e5bf1e',
40 'base_dir': 'TrainVal/VOCdevkit/VOC2011'
41 },
42 '2010': {
43 'url': 'http://host.robots.ox.ac.uk/pascal/VOC/voc2010/VOCtrainval_03-May-2010.tar',
44 'filename': 'VOCtrainval_03-May-2010.tar',
45 'md5': 'da459979d0c395079b5c75ee67908abb',
46 'base_dir': 'VOCdevkit/VOC2010'
47 },
48 '2009': {
49 'url': 'http://host.robots.ox.ac.uk/pascal/VOC/voc2009/VOCtrainval_11-May-2009.tar',
50 'filename': 'VOCtrainval_11-May-2009.tar',
51 'md5': '59065e4b188729180974ef6572f6a212',
52 'base_dir': 'VOCdevkit/VOC2009'
53 },
54 '2008': {
55 'url': 'http://host.robots.ox.ac.uk/pascal/VOC/voc2008/VOCtrainval_14-Jul-2008.tar',
56 'filename': 'VOCtrainval_11-May-2012.tar',
57 'md5': '2629fa636546599198acfcfbfcf1904a',
58 'base_dir': 'VOCdevkit/VOC2008'
59 },
60 '2007': {
61 'url': 'http://host.robots.ox.ac.uk/pascal/VOC/voc2007/VOCtrainval_06-Nov-2007.tar',
62 'filename': 'VOCtrainval_06-Nov-2007.tar',
63 'md5': 'c52e279531787c972589f7e41ab4ae64',
64 'base_dir': 'VOCdevkit/VOC2007'
65 }
66 }
69 class VOCSegmentation(VisionDataset):
71 # Based on Voc2012 devkit
72 VOCSegmentationClass = namedtuple('VOCSegmentationClass', ['name', 'id', 'train_id', 'category', 'category_id',
73 'has_instances', 'ignore_in_eval', 'color'])
74 classes = [
75 VOCSegmentationClass('aeroplane', 0, 0, 'object', 0, True, False, (0, 60, 100)),
76 VOCSegmentationClass('bicycle', 1, 1, 'object', 0, True, False, (119, 11, 32)),
77 VOCSegmentationClass('bird', 2, 2, 'object', 0, True, False, (0, 0, 230)),
78 VOCSegmentationClass('boat', 3, 3, 'object', 0, True, False, (0, 80, 100)),
79 VOCSegmentationClass('bottle', 4, 4, 'object', 0, True, False, (0, 0, 110)),
80 VOCSegmentationClass('bus', 5, 5, 'object', 0, True, False, (111, 74, 0)),
81 VOCSegmentationClass('car', 6, 6, 'object', 0, True, False, (0, 0, 142)),
82 VOCSegmentationClass('cat', 7, 7, 'object', 0, True, False, (128, 64, 128)),
83 VOCSegmentationClass('chair', 8, 8, 'object', 0, True, False, (244, 35, 232)),
84 VOCSegmentationClass('cow', 9, 9, 'object', 0, True, False, (250, 170, 160)),
85 VOCSegmentationClass('diningtable', 10, 10, 'object', 0, True, False, (230, 150, 140)),
86 VOCSegmentationClass('dog', 11, 11, 'object', 0, True, False, (70, 70, 70)),
87 VOCSegmentationClass('horse', 12, 12, 'object', 0, True, False, (102, 102, 156)),
88 VOCSegmentationClass('motorbike', 13, 13, 'object', 0, True, False, (190, 153, 153)),
89 VOCSegmentationClass('person', 14, 14, 'object', 0, True, False, (220, 20, 60)),
90 VOCSegmentationClass('pottedplant', 15, 15, 'object', 0, False, True, (150, 100, 100)),
91 VOCSegmentationClass('sheep', 16, 16, 'object', 0, True, False, (150, 120, 90)),
92 VOCSegmentationClass('sofa', 17, 17, 'object', 0, True, False, (153, 153, 153)),
93 VOCSegmentationClass('train', 18, 18, 'object', 0, True, False, (153, 153, 153)),
94 VOCSegmentationClass('tvmonitor', 19, 19, 'object', 0, True, False, (250, 170, 30)),
95 ]
97 """`Pascal VOC <http://host.robots.ox.ac.uk/pascal/VOC/>`_ Segmentation Dataset.
99 Args:
100 root (string): Root directory of the VOC Dataset.
101 year (string, optional): The dataset year, supports years 2007 to 2012.
102 image_set (string, optional): Select the image_set to use, ``train``, ``trainval`` or ``val``
103 download (bool, optional): If true, downloads the dataset from the internet and
104 puts it in root directory. If dataset is already downloaded, it is not
105 downloaded again.
106 transform (callable, optional): A function/transform that takes in an PIL image
107 and returns a transformed version. E.g, ``transforms.RandomCrop``
108 target_transform (callable, optional): A function/transform that takes in the
109 target and transforms it.
110 transforms (callable, optional): A function/transform that takes input sample and its target as entry
111 and returns a transformed version.
113 Note: Made minor modifications to load additional annotations described in:
114 @InProceedings{BharathICCV2011,
115 author = "Bharath Hariharan and Pablo Arbelaez and Lubomir Bourdev and Subhransu Maji and Jitendra Malik",
116 title = "Semantic Contours from Inverse Detectors",
117 booktitle = "International Conference on Computer Vision (ICCV)",
118 year = "2011",
119 }
120 And also described in: http://home.bharathh.info/pubs/codes/SBD/download.html
121 Those annotations were converted to VOC format by: https://github.com/DrSleep/tensorflow-deeplab-resnet
122 Further description is in: https://www.sun11.me/blog/2018/how-to-use-10582-trainaug-images-on-DeeplabV3-code/
123 """
125 def __init__(self,
126 root,
127 year='2012',
128 image_set='train',
129 download=False,
130 transform=None,
131 target_transform=None,
132 transforms=None):
133 super(VOCSegmentation, self).__init__(root, transforms, transform, target_transform)
134 self.year = year
135 self.url = DATASET_YEAR_DICT[year]['url']
136 self.filename = DATASET_YEAR_DICT[year]['filename']
137 self.md5 = DATASET_YEAR_DICT[year]['md5']
138 self.image_set = verify_str_arg(image_set, "image_set",
139 ("train", "trainval", "val", "trainaug", "trainaug_noval"))
140 base_dir = DATASET_YEAR_DICT[year]['base_dir']
141 voc_root = os.path.join(self.root, base_dir)
142 image_dir = os.path.join(voc_root, 'JPEGImages')
143 seg_name = 'SegmentationClassAug' if 'trainaug' in image_set else 'SegmentationClass'
144 mask_dir = os.path.join(voc_root, seg_name)
145 self.ignore_index = 255
147 if download:
148 download_extract(self.url, self.root, self.filename, self.md5)
150 if not os.path.isdir(voc_root):
151 raise RuntimeError('Dataset not found or corrupted.' +
152 ' You can use download=True to download it')
154 splits_dir = os.path.join(voc_root, 'ImageSets/Segmentation')
156 split_f = os.path.join(splits_dir, image_set.rstrip('\n') + '.txt')
158 with open(os.path.join(split_f), "r") as f:
159 file_names = [x.strip() for x in f.readlines()]
161 self.images = [os.path.join(image_dir, x + ".jpg") for x in file_names]
162 self.masks = [os.path.join(mask_dir, x + ".png") for x in file_names]
163 assert (len(self.images) == len(self.masks))
165 def __getitem__(self, index):
166 """
167 Args:
168 index (int): Index
170 Returns:
171 tuple: (image, target) where target is the image segmentation.
172 """
173 img = Image.open(self.images[index]).convert('RGB')
174 target = Image.open(self.masks[index])
176 # ignore potential invalid pixels - this is strictly not required
177 # but there are few pixels falling outside the valid range.
178 target = np.array(target)
179 target[target >= self.num_classes()] = self.ignore_index
180 target = Image.fromarray(target)
182 if self.transforms is not None:
183 img, target = self.transforms(img, target)
185 return img, target
187 def __len__(self):
188 return len(self.images)
190 def num_classes(self):
191 return [20]
193 @classmethod
194 def decode_segmap(cls, lbl, year='2012'):
195 if year != '2012':
196 return None
197 #
198 r = lbl.copy()
199 g = lbl.copy()
200 b = lbl.copy()
201 for l in range(0, len(cls.classes)):
202 r[lbl == l] = cls.classes[l].color[0]
203 g[lbl == l] = cls.classes[l].color[1]
204 b[lbl == l] = cls.classes[l].color[2]
205 #
206 rgb = np.zeros((lbl.shape[0], lbl.shape[1], 3))
207 rgb[:, :, 0] = r / 255.0
208 rgb[:, :, 1] = g / 255.0
209 rgb[:, :, 2] = b / 255.0
210 return rgb
215 class VOCDetection(VisionDataset):
216 """`Pascal VOC <http://host.robots.ox.ac.uk/pascal/VOC/>`_ Detection Dataset.
218 Args:
219 root (string): Root directory of the VOC Dataset.
220 year (string, optional): The dataset year, supports years 2007 to 2012.
221 image_set (string, optional): Select the image_set to use, ``train``, ``trainval`` or ``val``
222 download (bool, optional): If true, downloads the dataset from the internet and
223 puts it in root directory. If dataset is already downloaded, it is not
224 downloaded again.
225 (default: alphabetic indexing of VOC's 20 classes).
226 transform (callable, optional): A function/transform that takes in an PIL image
227 and returns a transformed version. E.g, ``transforms.RandomCrop``
228 target_transform (callable, required): A function/transform that takes in the
229 target and transforms it.
230 transforms (callable, optional): A function/transform that takes input sample and its target as entry
231 and returns a transformed version.
232 """
234 def __init__(self,
235 root,
236 year='2012',
237 image_set='train',
238 download=False,
239 transform=None,
240 target_transform=None,
241 transforms=None):
242 super(VOCDetection, self).__init__(root, transforms, transform, target_transform)
243 self.year = year
244 self.url = DATASET_YEAR_DICT[year]['url']
245 self.filename = DATASET_YEAR_DICT[year]['filename']
246 self.md5 = DATASET_YEAR_DICT[year]['md5']
247 self.image_set = verify_str_arg(image_set, "image_set",
248 ("train", "trainval", "val"))
250 base_dir = DATASET_YEAR_DICT[year]['base_dir']
251 voc_root = os.path.join(self.root, base_dir)
252 image_dir = os.path.join(voc_root, 'JPEGImages')
253 annotation_dir = os.path.join(voc_root, 'Annotations')
255 if download:
256 download_extract(self.url, self.root, self.filename, self.md5)
258 if not os.path.isdir(voc_root):
259 raise RuntimeError('Dataset not found or corrupted.' +
260 ' You can use download=True to download it')
262 splits_dir = os.path.join(voc_root, 'ImageSets/Main')
264 split_f = os.path.join(splits_dir, image_set.rstrip('\n') + '.txt')
266 with open(os.path.join(split_f), "r") as f:
267 file_names = [x.strip() for x in f.readlines()]
269 self.images = [os.path.join(image_dir, x + ".jpg") for x in file_names]
270 self.annotations = [os.path.join(annotation_dir, x + ".xml") for x in file_names]
271 assert (len(self.images) == len(self.annotations))
273 def __getitem__(self, index):
274 """
275 Args:
276 index (int): Index
278 Returns:
279 tuple: (image, target) where target is a dictionary of the XML tree.
280 """
281 img = Image.open(self.images[index]).convert('RGB')
282 target = self.parse_voc_xml(
283 ET.parse(self.annotations[index]).getroot())
285 if self.transforms is not None:
286 img, target = self.transforms(img, target)
288 return img, target
290 def __len__(self):
291 return len(self.images)
293 def parse_voc_xml(self, node):
294 voc_dict = {}
295 children = list(node)
296 if children:
297 def_dic = collections.defaultdict(list)
298 for dc in map(self.parse_voc_xml, children):
299 for ind, v in dc.items():
300 def_dic[ind].append(v)
301 voc_dict = {
302 node.tag:
303 {ind: v[0] if len(v) == 1 else v
304 for ind, v in def_dic.items()}
305 }
306 if node.text:
307 text = node.text.strip()
308 if not children:
309 voc_dict[node.tag] = text
310 return voc_dict
313 def download_extract(url, root, filename, md5):
314 download_url(url, root, filename, md5)
315 with tarfile.open(os.path.join(root, filename), "r") as tar:
316 tar.extractall(path=root)