[jacinto-ai/pytorch-jacinto-ai-devkit.git] / modules / pytorch_jacinto_ai / xvision / datasets / usps.py
diff --git a/modules/pytorch_jacinto_ai/xvision/datasets/usps.py b/modules/pytorch_jacinto_ai/xvision/datasets/usps.py
index 8a3cad0bd3d4fdd582f958c45f4fc30f04cf0964..c315b8d3111f522dc0ce1d5253afed01878b609d 100644 (file)
-from __future__ import print_function
from PIL import Image
import os
import numpy as np
+from typing import Any, Callable, cast, Optional, Tuple
from .utils import download_url
from .vision import VisionDataset
],
}
- def __init__(self, root, train=True, transform=None, target_transform=None,
- download=False):
+ def __init__(
+ self,
+ root: str,
+ train: bool = True,
+ transform: Optional[Callable] = None,
+ target_transform: Optional[Callable] = None,
+ download: bool = False,
+ ) -> None:
super(USPS, self).__init__(root, transform=transform,
target_transform=target_transform)
split = 'train' if train else 'test'
import bz2
with bz2.open(full_path) as fp:
- raw_data = [l.decode().split() for l in fp.readlines()]
- imgs = [[x.split(':')[-1] for x in data[1:]] for data in raw_data]
- imgs = np.asarray(imgs, dtype=np.float32).reshape((-1, 16, 16))
- imgs = ((imgs + 1) / 2 * 255).astype(dtype=np.uint8)
+ raw_data = [line.decode().split() for line in fp.readlines()]
+ tmp_list = [[x.split(':')[-1] for x in data[1:]] for data in raw_data]
+ imgs = np.asarray(tmp_list, dtype=np.float32).reshape((-1, 16, 16))
+ imgs = ((cast(np.ndarray, imgs) + 1) / 2 * 255).astype(dtype=np.uint8)
targets = [int(d[0]) - 1 for d in raw_data]
self.data = imgs
self.targets = targets
- def __getitem__(self, index):
+ def __getitem__(self, index: int) -> Tuple[Any, Any]:
"""
Args:
index (int): Index
return img, target
- def __len__(self):
+ def __len__(self) -> int:
return len(self.data)