[jacinto-ai/pytorch-jacinto-ai-devkit.git] / modules / pytorch_jacinto_ai / vision / datasets / coco.py
1 from .vision import VisionDataset
2 from PIL import Image
3 import os
4 import os.path
7 class CocoCaptions(VisionDataset):
8 """`MS Coco Captions <http://mscoco.org/dataset/#captions-challenge2015>`_ Dataset.
10 Args:
11 root (string): Root directory where images are downloaded to.
12 annFile (string): Path to json annotation file.
13 transform (callable, optional): A function/transform that takes in an PIL image
14 and returns a transformed version. E.g, ``transforms.ToTensor``
15 target_transform (callable, optional): A function/transform that takes in the
16 target and transforms it.
17 transforms (callable, optional): A function/transform that takes input sample and its target as entry
18 and returns a transformed version.
20 Example:
22 .. code:: python
24 import torchvision.datasets as dset
25 import torchvision.transforms as transforms
26 cap = dset.CocoCaptions(root = 'dir where images are',
27 annFile = 'json annotation file',
28 transform=transforms.ToTensor())
30 print('Number of samples: ', len(cap))
31 img, target = cap[3] # load 4th sample
33 print("Image Size: ", img.size())
34 print(target)
36 Output: ::
38 Number of samples: 82783
39 Image Size: (3L, 427L, 640L)
40 [u'A plane emitting smoke stream flying over a mountain.',
41 u'A plane darts across a bright blue sky behind a mountain covered in snow',
42 u'A plane leaves a contrail above the snowy mountain top.',
43 u'A mountain that has a plane flying overheard in the distance.',
44 u'A mountain view with a plume of smoke in the background']
46 """
48 def __init__(self, root, annFile, transform=None, target_transform=None, transforms=None):
49 super(CocoCaptions, self).__init__(root, transforms, transform, target_transform)
50 from pycocotools.coco import COCO
51 self.coco = COCO(annFile)
52 self.ids = list(sorted(self.coco.imgs.keys()))
54 def __getitem__(self, index):
55 """
56 Args:
57 index (int): Index
59 Returns:
60 tuple: Tuple (image, target). target is a list of captions for the image.
61 """
62 coco = self.coco
63 img_id = self.ids[index]
64 ann_ids = coco.getAnnIds(imgIds=img_id)
65 anns = coco.loadAnns(ann_ids)
66 target = [ann['caption'] for ann in anns]
68 path = coco.loadImgs(img_id)[0]['file_name']
70 img = Image.open(os.path.join(self.root, path)).convert('RGB')
72 if self.transforms is not None:
73 img, target = self.transforms(img, target)
75 return img, target
77 def __len__(self):
78 return len(self.ids)
81 class CocoDetection(VisionDataset):
82 """`MS Coco Detection <http://mscoco.org/dataset/#detections-challenge2016>`_ Dataset.
84 Args:
85 root (string): Root directory where images are downloaded to.
86 annFile (string): Path to json annotation file.
87 transform (callable, optional): A function/transform that takes in an PIL image
88 and returns a transformed version. E.g, ``transforms.ToTensor``
89 target_transform (callable, optional): A function/transform that takes in the
90 target and transforms it.
91 transforms (callable, optional): A function/transform that takes input sample and its target as entry
92 and returns a transformed version.
93 """
95 def __init__(self, root, annFile, transform=None, target_transform=None, transforms=None):
96 super(CocoDetection, self).__init__(root, transforms, transform, target_transform)
97 from pycocotools.coco import COCO
98 self.coco = COCO(annFile)
99 self.ids = list(sorted(self.coco.imgs.keys()))
101 def __getitem__(self, index):
102 """
103 Args:
104 index (int): Index
106 Returns:
107 tuple: Tuple (image, target). target is the object returned by ``coco.loadAnns``.
108 """
109 coco = self.coco
110 img_id = self.ids[index]
111 ann_ids = coco.getAnnIds(imgIds=img_id)
112 target = coco.loadAnns(ann_ids)
114 path = coco.loadImgs(img_id)[0]['file_name']
116 img = Image.open(os.path.join(self.root, path)).convert('RGB')
117 if self.transforms is not None:
118 img, target = self.transforms(img, target)
120 return img, target
122 def __len__(self):
123 return len(self.ids)