[jacinto-ai/pytorch-jacinto-ai-devkit.git] / modules / pytorch_jacinto_ai / xvision / datasets / sbu.py
diff --git a/modules/pytorch_jacinto_ai/xvision/datasets/sbu.py b/modules/pytorch_jacinto_ai/xvision/datasets/sbu.py
index 8be27dbf409fd6fec2470289995428b88edc7d71..6c8ad15686b6b1928e4288ff597be496228dcbb8 100644 (file)
from PIL import Image
-from six.moves import zip
from .utils import download_url, check_integrity
+from typing import Any, Callable, Optional, Tuple
import os
from .vision import VisionDataset
filename = "SBUCaptionedPhotoDataset.tar.gz"
md5_checksum = '9aec147b3488753cf758b4d493422285'
- def __init__(self, root, transform=None, target_transform=None, download=True):
+ def __init__(
+ self,
+ root: str,
+ transform: Optional[Callable] = None,
+ target_transform: Optional[Callable] = None,
+ download: bool = True,
+ ) -> None:
super(SBU, self).__init__(root, transform=transform,
target_transform=target_transform)
self.photos.append(photo)
self.captions.append(caption)
- def __getitem__(self, index):
+ def __getitem__(self, index: int) -> Tuple[Any, Any]:
"""
Args:
index (int): Index
return img, target
- def __len__(self):
+ def __len__(self) -> int:
"""The number of photos in the dataset."""
return len(self.photos)
- def _check_integrity(self):
+ def _check_integrity(self) -> bool:
"""Check the md5 checksum of the downloaded tarball."""
root = self.root
fpath = os.path.join(root, self.filename)
return False
return True
- def download(self):
+ def download(self) -> None:
"""Download and extract the tarball, and download each individual photo."""
import tarfile