]> Gitweb @ Texas Instruments - Open Source Git Repositories - git.TI.com/gitweb - jacinto-ai/pytorch-jacinto-ai-devkit.git/blobdiff - modules/pytorch_jacinto_ai/xvision/datasets/usps.py
docs - added deprecation notice
[jacinto-ai/pytorch-jacinto-ai-devkit.git] / modules / pytorch_jacinto_ai / xvision / datasets / usps.py
index 8a3cad0bd3d4fdd582f958c45f4fc30f04cf0964..c315b8d3111f522dc0ce1d5253afed01878b609d 100644 (file)
@@ -1,7 +1,7 @@
-from __future__ import print_function
 from PIL import Image
 import os
 import numpy as np
+from typing import Any, Callable, cast, Optional, Tuple
 
 from .utils import download_url
 from .vision import VisionDataset
@@ -37,8 +37,14 @@ class USPS(VisionDataset):
         ],
     }
 
-    def __init__(self, root, train=True, transform=None, target_transform=None,
-                 download=False):
+    def __init__(
+            self,
+            root: str,
+            train: bool = True,
+            transform: Optional[Callable] = None,
+            target_transform: Optional[Callable] = None,
+            download: bool = False,
+    ) -> None:
         super(USPS, self).__init__(root, transform=transform,
                                    target_transform=target_transform)
         split = 'train' if train else 'test'
@@ -50,16 +56,16 @@ class USPS(VisionDataset):
 
         import bz2
         with bz2.open(full_path) as fp:
-            raw_data = [l.decode().split() for l in fp.readlines()]
-            imgs = [[x.split(':')[-1] for x in data[1:]] for data in raw_data]
-            imgs = np.asarray(imgs, dtype=np.float32).reshape((-1, 16, 16))
-            imgs = ((imgs + 1) / 2 * 255).astype(dtype=np.uint8)
+            raw_data = [line.decode().split() for line in fp.readlines()]
+            tmp_list = [[x.split(':')[-1] for x in data[1:]] for data in raw_data]
+            imgs = np.asarray(tmp_list, dtype=np.float32).reshape((-1, 16, 16))
+            imgs = ((cast(np.ndarray, imgs) + 1) / 2 * 255).astype(dtype=np.uint8)
             targets = [int(d[0]) - 1 for d in raw_data]
 
         self.data = imgs
         self.targets = targets
 
-    def __getitem__(self, index):
+    def __getitem__(self, index: int) -> Tuple[Any, Any]:
         """
         Args:
             index (int): Index
@@ -81,5 +87,5 @@ class USPS(VisionDataset):
 
         return img, target
 
-    def __len__(self):
+    def __len__(self) -> int:
         return len(self.data)