]> Gitweb @ Texas Instruments - Open Source Git Repositories - git.TI.com/gitweb - jacinto-ai/pytorch-jacinto-ai-devkit.git/blobdiff - modules/pytorch_jacinto_ai/xvision/datasets/cifar.py
added mobilenetv3 from torchvision and also mobilenetv3_lite models, updated docs
[jacinto-ai/pytorch-jacinto-ai-devkit.git] / modules / pytorch_jacinto_ai / xvision / datasets / cifar.py
index 6230a64fa15ecd5f2b185bdfda2f03fc59cf6102..9d939326c76a627fb3fe116af09ab8270332cc63 100644 (file)
@@ -1,14 +1,9 @@
-from __future__ import print_function
 from PIL import Image
 import os
 import os.path
 import numpy as np
-import sys
-
-if sys.version_info[0] == 2:
-    import cPickle as pickle
-else:
-    import pickle
+import pickle
+from typing import Any, Callable, Optional, Tuple
 
 from .vision import VisionDataset
 from .utils import check_integrity, download_and_extract_archive
@@ -52,8 +47,14 @@ class CIFAR10(VisionDataset):
         'md5': '5ff9c542aee3614f3951f8cda6e48888',
     }
 
-    def __init__(self, root, train=True, transform=None, target_transform=None,
-                 download=False):
+    def __init__(
+            self,
+            root: str,
+            train: bool = True,
+            transform: Optional[Callable] = None,
+            target_transform: Optional[Callable] = None,
+            download: bool = False,
+    ) -> None:
 
         super(CIFAR10, self).__init__(root, transform=transform,
                                       target_transform=target_transform)
@@ -72,17 +73,14 @@ class CIFAR10(VisionDataset):
         else:
             downloaded_list = self.test_list
 
-        self.data = []
+        self.data: Any = []
         self.targets = []
 
         # now load the picked numpy arrays
         for file_name, checksum in downloaded_list:
             file_path = os.path.join(self.root, self.base_folder, file_name)
             with open(file_path, 'rb') as f:
-                if sys.version_info[0] == 2:
-                    entry = pickle.load(f)
-                else:
-                    entry = pickle.load(f, encoding='latin1')
+                entry = pickle.load(f, encoding='latin1')
                 self.data.append(entry['data'])
                 if 'labels' in entry:
                     self.targets.extend(entry['labels'])
@@ -94,20 +92,17 @@ class CIFAR10(VisionDataset):
 
         self._load_meta()
 
-    def _load_meta(self):
+    def _load_meta(self) -> None:
         path = os.path.join(self.root, self.base_folder, self.meta['filename'])
         if not check_integrity(path, self.meta['md5']):
             raise RuntimeError('Dataset metadata file not found or corrupted.' +
                                ' You can use download=True to download it')
         with open(path, 'rb') as infile:
-            if sys.version_info[0] == 2:
-                data = pickle.load(infile)
-            else:
-                data = pickle.load(infile, encoding='latin1')
+            data = pickle.load(infile, encoding='latin1')
             self.classes = data[self.meta['key']]
         self.class_to_idx = {_class: i for i, _class in enumerate(self.classes)}
 
-    def __getitem__(self, index):
+    def __getitem__(self, index: int) -> Tuple[Any, Any]:
         """
         Args:
             index (int): Index
@@ -129,10 +124,10 @@ class CIFAR10(VisionDataset):
 
         return img, target
 
-    def __len__(self):
+    def __len__(self) -> int:
         return len(self.data)
 
-    def _check_integrity(self):
+    def _check_integrity(self) -> bool:
         root = self.root
         for fentry in (self.train_list + self.test_list):
             filename, md5 = fentry[0], fentry[1]
@@ -141,13 +136,13 @@ class CIFAR10(VisionDataset):
                 return False
         return True
 
-    def download(self):
+    def download(self) -> None:
         if self._check_integrity():
             print('Files already downloaded and verified')
             return
         download_and_extract_archive(self.url, self.root, filename=self.filename, md5=self.tgz_md5)
 
-    def extra_repr(self):
+    def extra_repr(self) -> str:
         return "Split: {}".format("Train" if self.train is True else "Test")