]> Gitweb @ Texas Instruments - Open Source Git Repositories - git.TI.com/gitweb - jacinto-ai/pytorch-jacinto-ai-devkit.git/blobdiff - modules/pytorch_jacinto_ai/xvision/datasets/flickr.py
docs - added deprecation notice
[jacinto-ai/pytorch-jacinto-ai-devkit.git] / modules / pytorch_jacinto_ai / xvision / datasets / flickr.py
index af8b1fee36b6ea3f893a41b417c7562d77a1f14e..a3b3e411b6e57e53a36b7077946e471f1f32dc4d 100644 (file)
@@ -1,41 +1,42 @@
 from collections import defaultdict
 from PIL import Image
-from six.moves import html_parser
+from html.parser import HTMLParser
+from typing import Any, Callable, Dict, List, Optional, Tuple
 
 import glob
 import os
 from .vision import VisionDataset
 
 
-class Flickr8kParser(html_parser.HTMLParser):
+class Flickr8kParser(HTMLParser):
     """Parser for extracting captions from the Flickr8k dataset web page."""
 
-    def __init__(self, root):
+    def __init__(self, root: str) -> None:
         super(Flickr8kParser, self).__init__()
 
         self.root = root
 
         # Data structure to store captions
-        self.annotations = {}
+        self.annotations: Dict[str, List[str]] = {}
 
         # State variables
         self.in_table = False
-        self.current_tag = None
-        self.current_img = None
+        self.current_tag: Optional[str] = None
+        self.current_img: Optional[str] = None
 
-    def handle_starttag(self, tag, attrs):
+    def handle_starttag(self, tag: str, attrs: List[Tuple[str, Optional[str]]]) -> None:
         self.current_tag = tag
 
         if tag == 'table':
             self.in_table = True
 
-    def handle_endtag(self, tag):
+    def handle_endtag(self, tag: str) -> None:
         self.current_tag = None
 
         if tag == 'table':
             self.in_table = False
 
-    def handle_data(self, data):
+    def handle_data(self, data: str) -> None:
         if self.in_table:
             if data == 'Image Not Found':
                 self.current_img = None
@@ -51,7 +52,7 @@ class Flickr8kParser(html_parser.HTMLParser):
 
 
 class Flickr8k(VisionDataset):
-    """`Flickr8k Entities <http://nlp.cs.illinois.edu/HockenmaierGroup/8k-pictures.html>`_ Dataset.
+    """`Flickr8k Entities <http://hockenmaier.cs.illinois.edu/8k-pictures.html>`_ Dataset.
 
     Args:
         root (string): Root directory where images are downloaded to.
@@ -62,7 +63,13 @@ class Flickr8k(VisionDataset):
             target and transforms it.
     """
 
-    def __init__(self, root, ann_file, transform=None, target_transform=None):
+    def __init__(
+            self,
+            root: str,
+            ann_file: str,
+            transform: Optional[Callable] = None,
+            target_transform: Optional[Callable] = None,
+    ) -> None:
         super(Flickr8k, self).__init__(root, transform=transform,
                                        target_transform=target_transform)
         self.ann_file = os.path.expanduser(ann_file)
@@ -75,7 +82,7 @@ class Flickr8k(VisionDataset):
 
         self.ids = list(sorted(self.annotations.keys()))
 
-    def __getitem__(self, index):
+    def __getitem__(self, index: int) -> Tuple[Any, Any]:
         """
         Args:
             index (int): Index
@@ -97,7 +104,7 @@ class Flickr8k(VisionDataset):
 
         return img, target
 
-    def __len__(self):
+    def __len__(self) -> int:
         return len(self.ids)
 
 
@@ -113,7 +120,13 @@ class Flickr30k(VisionDataset):
             target and transforms it.
     """
 
-    def __init__(self, root, ann_file, transform=None, target_transform=None):
+    def __init__(
+            self,
+            root: str,
+            ann_file: str,
+            transform: Optional[Callable] = None,
+            target_transform: Optional[Callable] = None,
+    ) -> None:
         super(Flickr30k, self).__init__(root, transform=transform,
                                         target_transform=target_transform)
         self.ann_file = os.path.expanduser(ann_file)
@@ -127,7 +140,7 @@ class Flickr30k(VisionDataset):
 
         self.ids = list(sorted(self.annotations.keys()))
 
-    def __getitem__(self, index):
+    def __getitem__(self, index: int) -> Tuple[Any, Any]:
         """
         Args:
             index (int): Index
@@ -150,5 +163,5 @@ class Flickr30k(VisionDataset):
 
         return img, target
 
-    def __len__(self):
+    def __len__(self) -> int:
         return len(self.ids)