[jacinto-ai/pytorch-jacinto-ai-devkit.git] / modules / pytorch_jacinto_ai / xvision / datasets / cifar.py
diff --git a/modules/pytorch_jacinto_ai/xvision/datasets/cifar.py b/modules/pytorch_jacinto_ai/xvision/datasets/cifar.py
index 6230a64fa15ecd5f2b185bdfda2f03fc59cf6102..9d939326c76a627fb3fe116af09ab8270332cc63 100644 (file)
-from __future__ import print_function
from PIL import Image
import os
import os.path
import numpy as np
-import sys
-
-if sys.version_info[0] == 2:
- import cPickle as pickle
-else:
- import pickle
+import pickle
+from typing import Any, Callable, Optional, Tuple
from .vision import VisionDataset
from .utils import check_integrity, download_and_extract_archive
'md5': '5ff9c542aee3614f3951f8cda6e48888',
}
- def __init__(self, root, train=True, transform=None, target_transform=None,
- download=False):
+ def __init__(
+ self,
+ root: str,
+ train: bool = True,
+ transform: Optional[Callable] = None,
+ target_transform: Optional[Callable] = None,
+ download: bool = False,
+ ) -> None:
super(CIFAR10, self).__init__(root, transform=transform,
target_transform=target_transform)
else:
downloaded_list = self.test_list
- self.data = []
+ self.data: Any = []
self.targets = []
# now load the picked numpy arrays
for file_name, checksum in downloaded_list:
file_path = os.path.join(self.root, self.base_folder, file_name)
with open(file_path, 'rb') as f:
- if sys.version_info[0] == 2:
- entry = pickle.load(f)
- else:
- entry = pickle.load(f, encoding='latin1')
+ entry = pickle.load(f, encoding='latin1')
self.data.append(entry['data'])
if 'labels' in entry:
self.targets.extend(entry['labels'])
self._load_meta()
- def _load_meta(self):
+ def _load_meta(self) -> None:
path = os.path.join(self.root, self.base_folder, self.meta['filename'])
if not check_integrity(path, self.meta['md5']):
raise RuntimeError('Dataset metadata file not found or corrupted.' +
' You can use download=True to download it')
with open(path, 'rb') as infile:
- if sys.version_info[0] == 2:
- data = pickle.load(infile)
- else:
- data = pickle.load(infile, encoding='latin1')
+ data = pickle.load(infile, encoding='latin1')
self.classes = data[self.meta['key']]
self.class_to_idx = {_class: i for i, _class in enumerate(self.classes)}
- def __getitem__(self, index):
+ def __getitem__(self, index: int) -> Tuple[Any, Any]:
"""
Args:
index (int): Index
return img, target
- def __len__(self):
+ def __len__(self) -> int:
return len(self.data)
- def _check_integrity(self):
+ def _check_integrity(self) -> bool:
root = self.root
for fentry in (self.train_list + self.test_list):
filename, md5 = fentry[0], fentry[1]
return False
return True
- def download(self):
+ def download(self) -> None:
if self._check_integrity():
print('Files already downloaded and verified')
return
download_and_extract_archive(self.url, self.root, filename=self.filename, md5=self.tgz_md5)
- def extra_repr(self):
+ def extra_repr(self) -> str:
return "Split: {}".format("Train" if self.train is True else "Test")