[jacinto-ai/pytorch-jacinto-ai-devkit.git] / modules / pytorch_jacinto_ai / xvision / datasets / sbu.py
1 from PIL import Image
2 from six.moves import zip
3 from .utils import download_url, check_integrity
5 import os
6 from .vision import VisionDataset
9 class SBU(VisionDataset):
10 """`SBU Captioned Photo <http://www.cs.virginia.edu/~vicente/sbucaptions/>`_ Dataset.
12 Args:
13 root (string): Root directory of dataset where tarball
14 ``SBUCaptionedPhotoDataset.tar.gz`` exists.
15 transform (callable, optional): A function/transform that takes in a PIL image
16 and returns a transformed version. E.g, ``transforms.RandomCrop``
17 target_transform (callable, optional): A function/transform that takes in the
18 target and transforms it.
19 download (bool, optional): If True, downloads the dataset from the internet and
20 puts it in root directory. If dataset is already downloaded, it is not
21 downloaded again.
22 """
23 url = "http://www.cs.virginia.edu/~vicente/sbucaptions/SBUCaptionedPhotoDataset.tar.gz"
24 filename = "SBUCaptionedPhotoDataset.tar.gz"
25 md5_checksum = '9aec147b3488753cf758b4d493422285'
27 def __init__(self, root, transform=None, target_transform=None, download=True):
28 super(SBU, self).__init__(root, transform=transform,
29 target_transform=target_transform)
31 if download:
32 self.download()
34 if not self._check_integrity():
35 raise RuntimeError('Dataset not found or corrupted.' +
36 ' You can use download=True to download it')
38 # Read the caption for each photo
39 self.photos = []
40 self.captions = []
42 file1 = os.path.join(self.root, 'dataset', 'SBU_captioned_photo_dataset_urls.txt')
43 file2 = os.path.join(self.root, 'dataset', 'SBU_captioned_photo_dataset_captions.txt')
45 for line1, line2 in zip(open(file1), open(file2)):
46 url = line1.rstrip()
47 photo = os.path.basename(url)
48 filename = os.path.join(self.root, 'dataset', photo)
49 if os.path.exists(filename):
50 caption = line2.rstrip()
51 self.photos.append(photo)
52 self.captions.append(caption)
54 def __getitem__(self, index):
55 """
56 Args:
57 index (int): Index
59 Returns:
60 tuple: (image, target) where target is a caption for the photo.
61 """
62 filename = os.path.join(self.root, 'dataset', self.photos[index])
63 img = Image.open(filename).convert('RGB')
64 if self.transform is not None:
65 img = self.transform(img)
67 target = self.captions[index]
68 if self.target_transform is not None:
69 target = self.target_transform(target)
71 return img, target
73 def __len__(self):
74 """The number of photos in the dataset."""
75 return len(self.photos)
77 def _check_integrity(self):
78 """Check the md5 checksum of the downloaded tarball."""
79 root = self.root
80 fpath = os.path.join(root, self.filename)
81 if not check_integrity(fpath, self.md5_checksum):
82 return False
83 return True
85 def download(self):
86 """Download and extract the tarball, and download each individual photo."""
87 import tarfile
89 if self._check_integrity():
90 print('Files already downloaded and verified')
91 return
93 download_url(self.url, self.root, self.filename, self.md5_checksum)
95 # Extract file
96 with tarfile.open(os.path.join(self.root, self.filename), 'r:gz') as tar:
97 tar.extractall(path=self.root)
99 # Download individual photos
100 with open(os.path.join(self.root, 'dataset', 'SBU_captioned_photo_dataset_urls.txt')) as fh:
101 for line in fh:
102 url = line.rstrip()
103 try:
104 download_url(url, os.path.join(self.root, 'dataset'))
105 except OSError:
106 # The images point to public images on Flickr.
107 # Note: Images might be removed by users at anytime.
108 pass