[jacinto-ai/pytorch-jacinto-ai-devkit.git] / modules / pytorch_jacinto_ai / xvision / datasets / flickr.py
diff --git a/modules/pytorch_jacinto_ai/xvision/datasets/flickr.py b/modules/pytorch_jacinto_ai/xvision/datasets/flickr.py
index af8b1fee36b6ea3f893a41b417c7562d77a1f14e..a3b3e411b6e57e53a36b7077946e471f1f32dc4d 100644 (file)
from collections import defaultdict
from PIL import Image
-from six.moves import html_parser
+from html.parser import HTMLParser
+from typing import Any, Callable, Dict, List, Optional, Tuple
import glob
import os
from .vision import VisionDataset
-class Flickr8kParser(html_parser.HTMLParser):
+class Flickr8kParser(HTMLParser):
"""Parser for extracting captions from the Flickr8k dataset web page."""
- def __init__(self, root):
+ def __init__(self, root: str) -> None:
super(Flickr8kParser, self).__init__()
self.root = root
# Data structure to store captions
- self.annotations = {}
+ self.annotations: Dict[str, List[str]] = {}
# State variables
self.in_table = False
- self.current_tag = None
- self.current_img = None
+ self.current_tag: Optional[str] = None
+ self.current_img: Optional[str] = None
- def handle_starttag(self, tag, attrs):
+ def handle_starttag(self, tag: str, attrs: List[Tuple[str, Optional[str]]]) -> None:
self.current_tag = tag
if tag == 'table':
self.in_table = True
- def handle_endtag(self, tag):
+ def handle_endtag(self, tag: str) -> None:
self.current_tag = None
if tag == 'table':
self.in_table = False
- def handle_data(self, data):
+ def handle_data(self, data: str) -> None:
if self.in_table:
if data == 'Image Not Found':
self.current_img = None
class Flickr8k(VisionDataset):
- """`Flickr8k Entities <http://nlp.cs.illinois.edu/HockenmaierGroup/8k-pictures.html>`_ Dataset.
+ """`Flickr8k Entities <http://hockenmaier.cs.illinois.edu/8k-pictures.html>`_ Dataset.
Args:
root (string): Root directory where images are downloaded to.
target and transforms it.
"""
- def __init__(self, root, ann_file, transform=None, target_transform=None):
+ def __init__(
+ self,
+ root: str,
+ ann_file: str,
+ transform: Optional[Callable] = None,
+ target_transform: Optional[Callable] = None,
+ ) -> None:
super(Flickr8k, self).__init__(root, transform=transform,
target_transform=target_transform)
self.ann_file = os.path.expanduser(ann_file)
self.ids = list(sorted(self.annotations.keys()))
- def __getitem__(self, index):
+ def __getitem__(self, index: int) -> Tuple[Any, Any]:
"""
Args:
index (int): Index
return img, target
- def __len__(self):
+ def __len__(self) -> int:
return len(self.ids)
target and transforms it.
"""
- def __init__(self, root, ann_file, transform=None, target_transform=None):
+ def __init__(
+ self,
+ root: str,
+ ann_file: str,
+ transform: Optional[Callable] = None,
+ target_transform: Optional[Callable] = None,
+ ) -> None:
super(Flickr30k, self).__init__(root, transform=transform,
target_transform=target_transform)
self.ann_file = os.path.expanduser(ann_file)
self.ids = list(sorted(self.annotations.keys()))
- def __getitem__(self, index):
+ def __getitem__(self, index: int) -> Tuple[Any, Any]:
"""
Args:
index (int): Index
return img, target
- def __len__(self):
+ def __len__(self) -> int:
return len(self.ids)