[jacinto-ai/pytorch-jacinto-ai-devkit.git] / modules / pytorch_jacinto_ai / xvision / datasets / sbd.py
diff --git a/modules/pytorch_jacinto_ai/xvision/datasets/sbd.py b/modules/pytorch_jacinto_ai/xvision/datasets/sbd.py
index c4713f7257632b8c2a4488e18e4f5fb81e211094..1c3e221f4953258252911fd3d67bb0c47fea6d65 100644 (file)
import os
import shutil
from .vision import VisionDataset
+from typing import Any, Callable, Optional, Tuple
import numpy as np
voc_split_filename = "train_noval.txt"
voc_split_md5 = "79bff800c5f0b1ec6b21080a3c066722"
- def __init__(self,
- root,
- image_set='train',
- mode='boundaries',
- download=False,
- transforms=None):
+ def __init__(
+ self,
+ root: str,
+ image_set: str = "train",
+ mode: str = "boundaries",
+ download: bool = False,
+ transforms: Optional[Callable] = None,
+ ) -> None:
try:
from scipy.io import loadmat
split_f = os.path.join(sbd_root, image_set.rstrip('\n') + '.txt')
- with open(os.path.join(split_f), "r") as f:
- file_names = [x.strip() for x in f.readlines()]
+ with open(os.path.join(split_f), "r") as fh:
+ file_names = [x.strip() for x in fh.readlines()]
self.images = [os.path.join(image_dir, x + ".jpg") for x in file_names]
self.masks = [os.path.join(mask_dir, x + ".mat") for x in file_names]
self._get_target = self._get_segmentation_target \
if self.mode == "segmentation" else self._get_boundaries_target
- def _get_segmentation_target(self, filepath):
+ def _get_segmentation_target(self, filepath: str) -> Image.Image:
mat = self._loadmat(filepath)
return Image.fromarray(mat['GTcls'][0]['Segmentation'][0])
- def _get_boundaries_target(self, filepath):
+ def _get_boundaries_target(self, filepath: str) -> np.ndarray:
mat = self._loadmat(filepath)
return np.concatenate([np.expand_dims(mat['GTcls'][0]['Boundaries'][0][i][0].toarray(), axis=0)
for i in range(self.num_classes)], axis=0)
- def __getitem__(self, index):
+ def __getitem__(self, index: int) -> Tuple[Any, Any]:
img = Image.open(self.images[index]).convert('RGB')
target = self._get_target(self.masks[index])
return img, target
- def __len__(self):
+ def __len__(self) -> int:
return len(self.images)
- def extra_repr(self):
+ def extra_repr(self) -> str:
lines = ["Image set: {image_set}", "Mode: {mode}"]
return '\n'.join(lines).format(**self.__dict__)