[jacinto-ai/pytorch-jacinto-ai-devkit.git] / modules / pytorch_jacinto_ai / xvision / datasets / coco.py
1 from .vision import VisionDataset
2 from PIL import Image
3 import os
4 import os.path
5 from typing import Any, Callable, Optional, Tuple
8 class CocoCaptions(VisionDataset):
9 """`MS Coco Captions <http://mscoco.org/dataset/#captions-challenge2015>`_ Dataset.
11 Args:
12 root (string): Root directory where images are downloaded to.
13 annFile (string): Path to json annotation file.
14 transform (callable, optional): A function/transform that takes in an PIL image
15 and returns a transformed version. E.g, ``transforms.ToTensor``
16 target_transform (callable, optional): A function/transform that takes in the
17 target and transforms it.
18 transforms (callable, optional): A function/transform that takes input sample and its target as entry
19 and returns a transformed version.
21 Example:
23 .. code:: python
25 import torchvision.datasets as dset
26 import torchvision.transforms as transforms
27 cap = dset.CocoCaptions(root = 'dir where images are',
28 annFile = 'json annotation file',
29 transform=transforms.ToTensor())
31 print('Number of samples: ', len(cap))
32 img, target = cap[3] # load 4th sample
34 print("Image Size: ", img.size())
35 print(target)
37 Output: ::
39 Number of samples: 82783
40 Image Size: (3L, 427L, 640L)
41 [u'A plane emitting smoke stream flying over a mountain.',
42 u'A plane darts across a bright blue sky behind a mountain covered in snow',
43 u'A plane leaves a contrail above the snowy mountain top.',
44 u'A mountain that has a plane flying overheard in the distance.',
45 u'A mountain view with a plume of smoke in the background']
47 """
49 def __init__(
50 self,
51 root: str,
52 annFile: str,
53 transform: Optional[Callable] = None,
54 target_transform: Optional[Callable] = None,
55 transforms: Optional[Callable] = None,
56 ) -> None:
57 super(CocoCaptions, self).__init__(root, transforms, transform, target_transform)
58 from pycocotools.coco import COCO
59 self.coco = COCO(annFile)
60 self.ids = list(sorted(self.coco.imgs.keys()))
62 def __getitem__(self, index: int) -> Tuple[Any, Any]:
63 """
64 Args:
65 index (int): Index
67 Returns:
68 tuple: Tuple (image, target). target is a list of captions for the image.
69 """
70 coco = self.coco
71 img_id = self.ids[index]
72 ann_ids = coco.getAnnIds(imgIds=img_id)
73 anns = coco.loadAnns(ann_ids)
74 target = [ann['caption'] for ann in anns]
76 path = coco.loadImgs(img_id)[0]['file_name']
78 img = Image.open(os.path.join(self.root, path)).convert('RGB')
80 if self.transforms is not None:
81 img, target = self.transforms(img, target)
83 return img, target
85 def __len__(self) -> int:
86 return len(self.ids)
89 class CocoDetection(VisionDataset):
90 """`MS Coco Detection <http://mscoco.org/dataset/#detections-challenge2016>`_ Dataset.
92 Args:
93 root (string): Root directory where images are downloaded to.
94 annFile (string): Path to json annotation file.
95 transform (callable, optional): A function/transform that takes in an PIL image
96 and returns a transformed version. E.g, ``transforms.ToTensor``
97 target_transform (callable, optional): A function/transform that takes in the
98 target and transforms it.
99 transforms (callable, optional): A function/transform that takes input sample and its target as entry
100 and returns a transformed version.
101 """
103 def __init__(
104 self,
105 root: str,
106 annFile: str,
107 transform: Optional[Callable] = None,
108 target_transform: Optional[Callable] = None,
109 transforms: Optional[Callable] = None,
110 ) -> None:
111 super(CocoDetection, self).__init__(root, transforms, transform, target_transform)
112 from pycocotools.coco import COCO
113 self.coco = COCO(annFile)
114 self.ids = list(sorted(self.coco.imgs.keys()))
116 def __getitem__(self, index: int) -> Tuple[Any, Any]:
117 """
118 Args:
119 index (int): Index
121 Returns:
122 tuple: Tuple (image, target). target is the object returned by ``coco.loadAnns``.
123 """
124 coco = self.coco
125 img_id = self.ids[index]
126 ann_ids = coco.getAnnIds(imgIds=img_id)
127 target = coco.loadAnns(ann_ids)
129 path = coco.loadImgs(img_id)[0]['file_name']
131 img = Image.open(os.path.join(self.root, path)).convert('RGB')
132 if self.transforms is not None:
133 img, target = self.transforms(img, target)
135 return img, target
137 def __len__(self) -> int:
138 return len(self.ids)
142 import copy
143 import numpy as np
144 import random
145 import cv2
148 class COCOSegmentation():
149 '''
150 Modified from torchvision: https://github.com/pytorch/vision/references/segmentation/coco_utils.py
151 Reference: https://github.com/pytorch/vision/blob/master/docs/source/models.rst
152 '''
153 def __init__(self, root, split, shuffle=False, num_imgs=None, num_classes=None):
154 from pycocotools.coco import COCO
155 num_classes = 80 if num_classes is None else num_classes
156 if num_classes == 21:
157 self.categories = [0, 5, 2, 16, 9, 44, 6, 3, 17, 62, 21, 67, 18, 19, 4, 1, 64, 20, 63, 7, 72]
158 self.class_names = ['__background__', 'aeroplane', 'bicycle', 'bird', 'boat', 'bottle', 'bus',
159 'car', 'cat', 'chair', 'cow', 'diningtable', 'dog', 'horse', 'motorbike',
160 'person', 'pottedplant', 'sheep', 'sofa', 'train', 'tvmonitor']
161 else:
162 self.categories = range(num_classes)
163 self.class_names = None
164 #
166 dataset_folders = os.listdir(root)
167 assert 'annotations' in dataset_folders, 'invalid path to coco dataset annotations'
168 annotations_dir = os.path.join(root, 'annotations')
170 image_base_dir = 'images' if ('images' in dataset_folders) else ''
171 image_base_dir = os.path.join(root, image_base_dir)
172 image_split_dirs = os.listdir(image_base_dir)
173 image_dir = os.path.join(image_base_dir, split)
175 self.coco_dataset = COCO(os.path.join(annotations_dir, f'instances_{split}.json'))
177 self.cat_ids = self.coco_dataset.getCatIds()
178 img_ids = self.coco_dataset.getImgIds()
179 self.img_ids = self._remove_images_without_annotations(img_ids)
181 if shuffle:
182 random.seed(int(shuffle))
183 random.shuffle(self.img_ids)
184 #
186 if num_imgs is not None:
187 self.img_ids = self.img_ids[:num_imgs]
188 self.coco_dataset.imgs = {k:self.coco_dataset.imgs[k] for k in self.img_ids}
189 #
191 imgs = []
192 for img_id in self.img_ids:
193 img = self.coco_dataset.loadImgs([img_id])[0]
194 imgs.append(os.path.join(image_dir, img['file_name']))
195 #
196 self.imgs = imgs
197 self.num_imgs = len(self.imgs)
199 def __getitem__(self, idx, with_label=True):
200 if with_label:
201 image = Image.open(self.imgs[idx])
202 ann_ids = self.coco_dataset.getAnnIds(imgIds=self.img_ids[idx], iscrowd=None)
203 anno = self.coco_dataset.loadAnns(ann_ids)
204 image, anno = self._filter_and_remap_categories(image, anno)
205 image, target = self._convert_polys_to_mask(image, anno)
206 image = np.array(image)
207 if image.ndim==2 or image.shape[2] == 1:
208 image = cv2.cvtColor(image, cv2.COLOR_GRAY2RGB)
209 #
210 target = np.array(target)
211 return image, target
212 else:
213 return self.imgs[idx]
214 #
216 def __len__(self):
217 return self.num_imgs
219 def _remove_images_without_annotations(self, img_ids):
220 ids = []
221 for ds_idx, img_id in enumerate(img_ids):
222 ann_ids = self.coco_dataset.getAnnIds(imgIds=img_id, iscrowd=None)
223 anno = self.coco_dataset.loadAnns(ann_ids)
224 if self.categories:
225 anno = [obj for obj in anno if obj["category_id"] in self.categories]
226 if self._has_valid_annotation(anno):
227 ids.append(img_id)
228 #
229 #
230 return ids
232 def _has_valid_annotation(self, anno):
233 # if it's empty, there is no annotation
234 if len(anno) == 0:
235 return False
236 # if more than 1k pixels occupied in the image
237 return sum(obj["area"] for obj in anno) > 1000
239 def _filter_and_remap_categories(self, image, anno, remap=True):
240 anno = [obj for obj in anno if obj["category_id"] in self.categories]
241 if not remap:
242 return image, anno
243 #
244 anno = copy.deepcopy(anno)
245 for obj in anno:
246 obj["category_id"] = self.categories.index(obj["category_id"])
247 #
248 return image, anno
250 def _convert_polys_to_mask(self, image, anno):
251 w, h = image.size
252 segmentations = [obj["segmentation"] for obj in anno]
253 cats = [obj["category_id"] for obj in anno]
254 if segmentations:
255 masks = self._convert_poly_to_mask(segmentations, h, w)
256 cats = np.array(cats, dtype=masks.dtype)
257 cats = cats.reshape(-1, 1, 1)
258 # merge all instance masks into a single segmentation map
259 # with its corresponding categories
260 target = (masks * cats).max(axis=0)
261 # discard overlapping instances
262 target[masks.sum(0) > 1] = 255
263 else:
264 target = np.zeros((h, w), dtype=np.uint8)
265 #
266 return image, target
268 def _convert_poly_to_mask(self, segmentations, height, width):
269 from pycocotools import mask as coco_mask
270 masks = []
271 for polygons in segmentations:
272 rles = coco_mask.frPyObjects(polygons, height, width)
273 mask = coco_mask.decode(rles)
274 if len(mask.shape) < 3:
275 mask = mask[..., None]
276 mask = mask.any(axis=2)
277 mask = mask.astype(np.uint8)
278 masks.append(mask)
279 if masks:
280 masks = np.stack(masks, axis=0)
281 else:
282 masks = np.zeros((0, height, width), dtype=np.uint8)
283 return masks