]> Gitweb @ Texas Instruments - Open Source Git Repositories - git.TI.com/gitweb - jacinto-ai/pytorch-jacinto-ai-devkit.git/blobdiff - modules/pytorch_jacinto_ai/xvision/datasets/sbd.py
added mobilenetv3 from torchvision and also mobilenetv3_lite models, updated docs
[jacinto-ai/pytorch-jacinto-ai-devkit.git] / modules / pytorch_jacinto_ai / xvision / datasets / sbd.py
index c4713f7257632b8c2a4488e18e4f5fb81e211094..1c3e221f4953258252911fd3d67bb0c47fea6d65 100644 (file)
@@ -1,6 +1,7 @@
 import os
 import shutil
 from .vision import VisionDataset
+from typing import Any, Callable, Optional, Tuple
 
 import numpy as np
 
@@ -49,12 +50,14 @@ class SBDataset(VisionDataset):
     voc_split_filename = "train_noval.txt"
     voc_split_md5 = "79bff800c5f0b1ec6b21080a3c066722"
 
-    def __init__(self,
-                 root,
-                 image_set='train',
-                 mode='boundaries',
-                 download=False,
-                 transforms=None):
+    def __init__(
+            self,
+            root: str,
+            image_set: str = "train",
+            mode: str = "boundaries",
+            download: bool = False,
+            transforms: Optional[Callable] = None,
+    ) -> None:
 
         try:
             from scipy.io import loadmat
@@ -88,8 +91,8 @@ class SBDataset(VisionDataset):
 
         split_f = os.path.join(sbd_root, image_set.rstrip('\n') + '.txt')
 
-        with open(os.path.join(split_f), "r") as f:
-            file_names = [x.strip() for x in f.readlines()]
+        with open(os.path.join(split_f), "r") as fh:
+            file_names = [x.strip() for x in fh.readlines()]
 
         self.images = [os.path.join(image_dir, x + ".jpg") for x in file_names]
         self.masks = [os.path.join(mask_dir, x + ".mat") for x in file_names]
@@ -98,16 +101,16 @@ class SBDataset(VisionDataset):
         self._get_target = self._get_segmentation_target \
             if self.mode == "segmentation" else self._get_boundaries_target
 
-    def _get_segmentation_target(self, filepath):
+    def _get_segmentation_target(self, filepath: str) -> Image.Image:
         mat = self._loadmat(filepath)
         return Image.fromarray(mat['GTcls'][0]['Segmentation'][0])
 
-    def _get_boundaries_target(self, filepath):
+    def _get_boundaries_target(self, filepath: str) -> np.ndarray:
         mat = self._loadmat(filepath)
         return np.concatenate([np.expand_dims(mat['GTcls'][0]['Boundaries'][0][i][0].toarray(), axis=0)
                                for i in range(self.num_classes)], axis=0)
 
-    def __getitem__(self, index):
+    def __getitem__(self, index: int) -> Tuple[Any, Any]:
         img = Image.open(self.images[index]).convert('RGB')
         target = self._get_target(self.masks[index])
 
@@ -116,9 +119,9 @@ class SBDataset(VisionDataset):
 
         return img, target
 
-    def __len__(self):
+    def __len__(self) -> int:
         return len(self.images)
 
-    def extra_repr(self):
+    def extra_repr(self) -> str:
         lines = ["Image set: {image_set}", "Mode: {mode}"]
         return '\n'.join(lines).format(**self.__dict__)