]> Gitweb @ Texas Instruments - Open Source Git Repositories - git.TI.com/gitweb - jacinto-ai/pytorch-jacinto-ai-devkit.git/blob - modules/pytorch_jacinto_ai/xvision/datasets/sbu.py
renamed pytorch_jacinto_ai.vision to pytorch_jacinto_ai.xvision
[jacinto-ai/pytorch-jacinto-ai-devkit.git] / modules / pytorch_jacinto_ai / xvision / datasets / sbu.py
1 from PIL import Image
2 from six.moves import zip
3 from .utils import download_url, check_integrity
5 import os
6 from .vision import VisionDataset
9 class SBU(VisionDataset):
10     """`SBU Captioned Photo <http://www.cs.virginia.edu/~vicente/sbucaptions/>`_ Dataset.
12     Args:
13         root (string): Root directory of dataset where tarball
14             ``SBUCaptionedPhotoDataset.tar.gz`` exists.
15         transform (callable, optional): A function/transform that takes in a PIL image
16             and returns a transformed version. E.g, ``transforms.RandomCrop``
17         target_transform (callable, optional): A function/transform that takes in the
18             target and transforms it.
19         download (bool, optional): If True, downloads the dataset from the internet and
20             puts it in root directory. If dataset is already downloaded, it is not
21             downloaded again.
22     """
23     url = "http://www.cs.virginia.edu/~vicente/sbucaptions/SBUCaptionedPhotoDataset.tar.gz"
24     filename = "SBUCaptionedPhotoDataset.tar.gz"
25     md5_checksum = '9aec147b3488753cf758b4d493422285'
27     def __init__(self, root, transform=None, target_transform=None, download=True):
28         super(SBU, self).__init__(root, transform=transform,
29                                   target_transform=target_transform)
31         if download:
32             self.download()
34         if not self._check_integrity():
35             raise RuntimeError('Dataset not found or corrupted.' +
36                                ' You can use download=True to download it')
38         # Read the caption for each photo
39         self.photos = []
40         self.captions = []
42         file1 = os.path.join(self.root, 'dataset', 'SBU_captioned_photo_dataset_urls.txt')
43         file2 = os.path.join(self.root, 'dataset', 'SBU_captioned_photo_dataset_captions.txt')
45         for line1, line2 in zip(open(file1), open(file2)):
46             url = line1.rstrip()
47             photo = os.path.basename(url)
48             filename = os.path.join(self.root, 'dataset', photo)
49             if os.path.exists(filename):
50                 caption = line2.rstrip()
51                 self.photos.append(photo)
52                 self.captions.append(caption)
54     def __getitem__(self, index):
55         """
56         Args:
57             index (int): Index
59         Returns:
60             tuple: (image, target) where target is a caption for the photo.
61         """
62         filename = os.path.join(self.root, 'dataset', self.photos[index])
63         img = Image.open(filename).convert('RGB')
64         if self.transform is not None:
65             img = self.transform(img)
67         target = self.captions[index]
68         if self.target_transform is not None:
69             target = self.target_transform(target)
71         return img, target
73     def __len__(self):
74         """The number of photos in the dataset."""
75         return len(self.photos)
77     def _check_integrity(self):
78         """Check the md5 checksum of the downloaded tarball."""
79         root = self.root
80         fpath = os.path.join(root, self.filename)
81         if not check_integrity(fpath, self.md5_checksum):
82             return False
83         return True
85     def download(self):
86         """Download and extract the tarball, and download each individual photo."""
87         import tarfile
89         if self._check_integrity():
90             print('Files already downloaded and verified')
91             return
93         download_url(self.url, self.root, self.filename, self.md5_checksum)
95         # Extract file
96         with tarfile.open(os.path.join(self.root, self.filename), 'r:gz') as tar:
97             tar.extractall(path=self.root)
99         # Download individual photos
100         with open(os.path.join(self.root, 'dataset', 'SBU_captioned_photo_dataset_urls.txt')) as fh:
101             for line in fh:
102                 url = line.rstrip()
103                 try:
104                     download_url(url, os.path.join(self.root, 'dataset'))
105                 except OSError:
106                     # The images point to public images on Flickr.
107                     # Note: Images might be removed by users at anytime.
108                     pass