]> Gitweb @ Texas Instruments - Open Source Git Repositories - git.TI.com/gitweb - jacinto-ai/pytorch-jacinto-ai-devkit.git/blob - modules/pytorch_jacinto_ai/xvision/datasets/voc.py
updated python package requirements (don't need tensorflow for tensorboard). not...
[jacinto-ai/pytorch-jacinto-ai-devkit.git] / modules / pytorch_jacinto_ai / xvision / datasets / voc.py
1 '''
2 Dataset loader for the PASCAL VOC dataset:
4 - Everingham, M., Van Gool, L., Williams, C. K. I., Winn, J. and Zisserman, A.,
5 The PASCAL Visual Object Classes (VOC) Challenge, International Journal of Computer Vision,
6 88(2), 303-338, 2010
8 - Everingham, M. and Van~Gool, L. and Williams, C. K. I. and Winn, J. and Zisserman, A.",
9 The PASCAL Visual Object Classes Challenge 2012, VOC2012",
10 "http://www.pascal-network.org/challenges/VOC/voc2012/workshop/index.html"
11 '''
13 import os
14 import sys
15 import tarfile
16 import numpy as np
17 import collections
18 from collections import namedtuple
19 from .vision import VisionDataset
21 if sys.version_info[0] == 2:
22     import xml.etree.cElementTree as ET
23 else:
24     import xml.etree.ElementTree as ET
26 from PIL import Image
27 from .utils import download_url, check_integrity, verify_str_arg
29 DATASET_YEAR_DICT = {
30     '2012': {
31         'url': 'http://host.robots.ox.ac.uk/pascal/VOC/voc2012/VOCtrainval_11-May-2012.tar',
32         'filename': 'VOCtrainval_11-May-2012.tar',
33         'md5': '6cd6e144f989b92b3379bac3b3de84fd',
34         'base_dir': 'VOCdevkit/VOC2012'
35     },
36     '2011': {
37         'url': 'http://host.robots.ox.ac.uk/pascal/VOC/voc2011/VOCtrainval_25-May-2011.tar',
38         'filename': 'VOCtrainval_25-May-2011.tar',
39         'md5': '6c3384ef61512963050cb5d687e5bf1e',
40         'base_dir': 'TrainVal/VOCdevkit/VOC2011'
41     },
42     '2010': {
43         'url': 'http://host.robots.ox.ac.uk/pascal/VOC/voc2010/VOCtrainval_03-May-2010.tar',
44         'filename': 'VOCtrainval_03-May-2010.tar',
45         'md5': 'da459979d0c395079b5c75ee67908abb',
46         'base_dir': 'VOCdevkit/VOC2010'
47     },
48     '2009': {
49         'url': 'http://host.robots.ox.ac.uk/pascal/VOC/voc2009/VOCtrainval_11-May-2009.tar',
50         'filename': 'VOCtrainval_11-May-2009.tar',
51         'md5': '59065e4b188729180974ef6572f6a212',
52         'base_dir': 'VOCdevkit/VOC2009'
53     },
54     '2008': {
55         'url': 'http://host.robots.ox.ac.uk/pascal/VOC/voc2008/VOCtrainval_14-Jul-2008.tar',
56         'filename': 'VOCtrainval_11-May-2012.tar',
57         'md5': '2629fa636546599198acfcfbfcf1904a',
58         'base_dir': 'VOCdevkit/VOC2008'
59     },
60     '2007': {
61         'url': 'http://host.robots.ox.ac.uk/pascal/VOC/voc2007/VOCtrainval_06-Nov-2007.tar',
62         'filename': 'VOCtrainval_06-Nov-2007.tar',
63         'md5': 'c52e279531787c972589f7e41ab4ae64',
64         'base_dir': 'VOCdevkit/VOC2007'
65     }
66 }
69 class VOCSegmentation(VisionDataset):
71     # Based on Voc2012 devkit
72     VOCSegmentationClass = namedtuple('VOCSegmentationClass', ['name', 'id', 'train_id', 'category', 'category_id',
73                                                      'has_instances', 'ignore_in_eval', 'color'])
74     classes = [
75         VOCSegmentationClass('aeroplane', 0, 0, 'object', 0, True, False, (0, 60, 100)),
76         VOCSegmentationClass('bicycle', 1, 1, 'object', 0, True, False, (119, 11, 32)),
77         VOCSegmentationClass('bird', 2, 2, 'object', 0, True, False, (0, 0, 230)),
78         VOCSegmentationClass('boat', 3, 3, 'object', 0, True, False, (0, 80, 100)),
79         VOCSegmentationClass('bottle', 4, 4, 'object', 0, True, False, (0, 0, 110)),
80         VOCSegmentationClass('bus', 5, 5, 'object', 0, True, False, (111, 74, 0)),
81         VOCSegmentationClass('car', 6, 6, 'object', 0, True, False, (0, 0, 142)),
82         VOCSegmentationClass('cat', 7, 7, 'object', 0, True, False, (128, 64, 128)),
83         VOCSegmentationClass('chair', 8, 8, 'object', 0, True, False, (244, 35, 232)),
84         VOCSegmentationClass('cow', 9, 9, 'object', 0, True, False, (250, 170, 160)),
85         VOCSegmentationClass('diningtable', 10, 10, 'object', 0, True, False, (230, 150, 140)),
86         VOCSegmentationClass('dog', 11, 11, 'object', 0, True, False, (70, 70, 70)),
87         VOCSegmentationClass('horse', 12, 12, 'object', 0, True, False, (102, 102, 156)),
88         VOCSegmentationClass('motorbike', 13, 13, 'object', 0, True, False, (190, 153, 153)),
89         VOCSegmentationClass('person', 14, 14, 'object', 0, True, False, (220, 20, 60)),
90         VOCSegmentationClass('pottedplant', 15, 15, 'object', 0, False, True, (150, 100, 100)),
91         VOCSegmentationClass('sheep', 16, 16, 'object', 0, True, False, (150, 120, 90)),
92         VOCSegmentationClass('sofa', 17, 17, 'object', 0, True, False, (153, 153, 153)),
93         VOCSegmentationClass('train', 18, 18, 'object', 0, True, False, (153, 153, 153)),
94         VOCSegmentationClass('tvmonitor', 19, 19, 'object', 0, True, False, (250, 170, 30)),
95     ]
97     """`Pascal VOC <http://host.robots.ox.ac.uk/pascal/VOC/>`_ Segmentation Dataset.
99     Args:
100         root (string): Root directory of the VOC Dataset.
101         year (string, optional): The dataset year, supports years 2007 to 2012.
102         image_set (string, optional): Select the image_set to use, ``train``, ``trainval`` or ``val``
103         download (bool, optional): If true, downloads the dataset from the internet and
104             puts it in root directory. If dataset is already downloaded, it is not
105             downloaded again.
106         transform (callable, optional): A function/transform that  takes in an PIL image
107             and returns a transformed version. E.g, ``transforms.RandomCrop``
108         target_transform (callable, optional): A function/transform that takes in the
109             target and transforms it.
110         transforms (callable, optional): A function/transform that takes input sample and its target as entry
111             and returns a transformed version.
112     
113     Note: Made minor modifications to load additional annotations described in:
114             @InProceedings{BharathICCV2011,
115             author = "Bharath Hariharan and Pablo Arbelaez and Lubomir Bourdev and Subhransu Maji and Jitendra Malik",
116             title = "Semantic Contours from Inverse Detectors",
117             booktitle = "International Conference on Computer Vision (ICCV)",
118             year = "2011",
119             }
120         And also described in: http://home.bharathh.info/pubs/codes/SBD/download.html 
121         Those annotations were converted to VOC format by: https://github.com/DrSleep/tensorflow-deeplab-resnet
122         Further description is in: https://www.sun11.me/blog/2018/how-to-use-10582-trainaug-images-on-DeeplabV3-code/
123     """
125     def __init__(self,
126                  root,
127                  year='2012',
128                  image_set='train',
129                  download=False,
130                  transform=None,
131                  target_transform=None,
132                  transforms=None):
133         super(VOCSegmentation, self).__init__(root, transforms, transform, target_transform)
134         self.year = year
135         self.url = DATASET_YEAR_DICT[year]['url']
136         self.filename = DATASET_YEAR_DICT[year]['filename']
137         self.md5 = DATASET_YEAR_DICT[year]['md5']
138         self.image_set = verify_str_arg(image_set, "image_set",
139                                         ("train", "trainval", "val", "trainaug", "trainaug_noval"))
140         base_dir = DATASET_YEAR_DICT[year]['base_dir']
141         voc_root = os.path.join(self.root, base_dir)
142         image_dir = os.path.join(voc_root, 'JPEGImages')
143         seg_name = 'SegmentationClassAug' if 'trainaug' in image_set else 'SegmentationClass'
144         mask_dir = os.path.join(voc_root, seg_name)
145         self.ignore_index = 255
147         if download:
148             download_extract(self.url, self.root, self.filename, self.md5)
150         if not os.path.isdir(voc_root):
151             raise RuntimeError('Dataset not found or corrupted.' +
152                                ' You can use download=True to download it')
154         splits_dir = os.path.join(voc_root, 'ImageSets/Segmentation')
156         split_f = os.path.join(splits_dir, image_set.rstrip('\n') + '.txt')
158         with open(os.path.join(split_f), "r") as f:
159             file_names = [x.strip() for x in f.readlines()]
161         self.images = [os.path.join(image_dir, x + ".jpg") for x in file_names]
162         self.masks = [os.path.join(mask_dir, x + ".png") for x in file_names]
163         assert (len(self.images) == len(self.masks))
165     def __getitem__(self, index):
166         """
167         Args:
168             index (int): Index
170         Returns:
171             tuple: (image, target) where target is the image segmentation.
172         """
173         img = Image.open(self.images[index]).convert('RGB')
174         target = Image.open(self.masks[index])
176         # ignore potential invalid pixels - this is strictly not required
177         # but there are few pixels falling outside the valid range.
178         target = np.array(target)
179         target[target >= self.num_classes()] = self.ignore_index
180         target = Image.fromarray(target)
182         if self.transforms is not None:
183             img, target = self.transforms(img, target)
185         return img, target
187     def __len__(self):
188         return len(self.images)
190     def num_classes(self):
191         return [20]
193     @classmethod
194     def decode_segmap(cls, lbl, year='2012'):
195         if year != '2012':
196             return None
197         #
198         r = lbl.copy()
199         g = lbl.copy()
200         b = lbl.copy()
201         for l in range(0, len(cls.classes)):
202             r[lbl == l] = cls.classes[l].color[0]
203             g[lbl == l] = cls.classes[l].color[1]
204             b[lbl == l] = cls.classes[l].color[2]
205         #
206         rgb = np.zeros((lbl.shape[0], lbl.shape[1], 3))
207         rgb[:, :, 0] = r / 255.0
208         rgb[:, :, 1] = g / 255.0
209         rgb[:, :, 2] = b / 255.0
210         return rgb
215 class VOCDetection(VisionDataset):
216     """`Pascal VOC <http://host.robots.ox.ac.uk/pascal/VOC/>`_ Detection Dataset.
218     Args:
219         root (string): Root directory of the VOC Dataset.
220         year (string, optional): The dataset year, supports years 2007 to 2012.
221         image_set (string, optional): Select the image_set to use, ``train``, ``trainval`` or ``val``
222         download (bool, optional): If true, downloads the dataset from the internet and
223             puts it in root directory. If dataset is already downloaded, it is not
224             downloaded again.
225             (default: alphabetic indexing of VOC's 20 classes).
226         transform (callable, optional): A function/transform that  takes in an PIL image
227             and returns a transformed version. E.g, ``transforms.RandomCrop``
228         target_transform (callable, required): A function/transform that takes in the
229             target and transforms it.
230         transforms (callable, optional): A function/transform that takes input sample and its target as entry
231             and returns a transformed version.
232     """
234     def __init__(self,
235                  root,
236                  year='2012',
237                  image_set='train',
238                  download=False,
239                  transform=None,
240                  target_transform=None,
241                  transforms=None):
242         super(VOCDetection, self).__init__(root, transforms, transform, target_transform)
243         self.year = year
244         self.url = DATASET_YEAR_DICT[year]['url']
245         self.filename = DATASET_YEAR_DICT[year]['filename']
246         self.md5 = DATASET_YEAR_DICT[year]['md5']
247         self.image_set = verify_str_arg(image_set, "image_set",
248                                         ("train", "trainval", "val"))
250         base_dir = DATASET_YEAR_DICT[year]['base_dir']
251         voc_root = os.path.join(self.root, base_dir)
252         image_dir = os.path.join(voc_root, 'JPEGImages')
253         annotation_dir = os.path.join(voc_root, 'Annotations')
255         if download:
256             download_extract(self.url, self.root, self.filename, self.md5)
258         if not os.path.isdir(voc_root):
259             raise RuntimeError('Dataset not found or corrupted.' +
260                                ' You can use download=True to download it')
262         splits_dir = os.path.join(voc_root, 'ImageSets/Main')
264         split_f = os.path.join(splits_dir, image_set.rstrip('\n') + '.txt')
266         with open(os.path.join(split_f), "r") as f:
267             file_names = [x.strip() for x in f.readlines()]
269         self.images = [os.path.join(image_dir, x + ".jpg") for x in file_names]
270         self.annotations = [os.path.join(annotation_dir, x + ".xml") for x in file_names]
271         assert (len(self.images) == len(self.annotations))
273     def __getitem__(self, index):
274         """
275         Args:
276             index (int): Index
278         Returns:
279             tuple: (image, target) where target is a dictionary of the XML tree.
280         """
281         img = Image.open(self.images[index]).convert('RGB')
282         target = self.parse_voc_xml(
283             ET.parse(self.annotations[index]).getroot())
285         if self.transforms is not None:
286             img, target = self.transforms(img, target)
288         return img, target
290     def __len__(self):
291         return len(self.images)
293     def parse_voc_xml(self, node):
294         voc_dict = {}
295         children = list(node)
296         if children:
297             def_dic = collections.defaultdict(list)
298             for dc in map(self.parse_voc_xml, children):
299                 for ind, v in dc.items():
300                     def_dic[ind].append(v)
301             voc_dict = {
302                 node.tag:
303                     {ind: v[0] if len(v) == 1 else v
304                      for ind, v in def_dic.items()}
305             }
306         if node.text:
307             text = node.text.strip()
308             if not children:
309                 voc_dict[node.tag] = text
310         return voc_dict
313 def download_extract(url, root, filename, md5):
314     download_url(url, root, filename, md5)
315     with tarfile.open(os.path.join(root, filename), "r") as tar:
316         tar.extractall(path=root)