[jacinto-ai/pytorch-jacinto-ai-devkit.git] / modules / pytorch_jacinto_ai / vision / datasets / classification / __init__.py
1 import os
2 from .. import folder
3 from .. import cifar
4 from .. import imagenet
6 __all__ = ['image_folder_classification_train', 'image_folder_classification_validation', 'image_folder_classification',
7 'imagenet_classification_train', 'imagenet_classification_validation', 'imagenet_classification',
8 'cifar10_classification', 'cifar100_classification']
10 ########################################################################
11 def image_folder_classification_train(dataset_config, root, split=None, transforms=None):
12 split = 'train' if split is None else split
13 traindir = os.path.join(root, split)
14 assert os.path.exists(traindir), f'dataset training folder does not exist {traindir}'
15 train_transform = transforms[0] if isinstance(transforms,(list,tuple)) else transforms
16 train_dataset = folder.ImageFolder(traindir, train_transform)
17 return train_dataset
19 def image_folder_classification_validation(dataset_config, root, split=None, transforms=None):
20 split = 'val' if split is None else split
21 # validation folder can be either 'val' or 'validation'
22 if (split == 'val') and (not os.path.exists(os.path.join(root,split))):
23 split = 'validation'
24 #
25 valdir = os.path.join(root, split)
26 assert os.path.exists(valdir), f'dataset validation folder does not exist {valdir}'
27 val_transform = transforms[1] if isinstance(transforms,(list,tuple)) else transforms
28 val_dataset = folder.ImageFolder(valdir, val_transform)
29 return val_dataset
31 def image_folder_classification(dataset_config, root, split=None, transforms=None):
32 split = ('train', 'val') if split is None else split
33 train_transform, val_transform = transforms
34 train_dataset = image_folder_classification_train(dataset_config, root, split[0], train_transform)
35 val_dataset = image_folder_classification_validation(dataset_config, root, split[1], val_transform)
36 return train_dataset, val_dataset
38 ########################################################################
39 def imagenet_classification_train(dataset_config, root, split=None, transforms=None):
40 train_transform = transforms[0] if isinstance(transforms,(list,tuple)) else transforms
41 train_dataset = imagenet.ImageNet(root, train=True, transform=train_transform, target_transform=None, download=True)
42 return train_dataset
44 def imagenet_classification_validation(dataset_config, root, split=None, transforms=None):
45 val_transform = transforms[1] if isinstance(transforms,(list,tuple)) else transforms
46 val_dataset = imagenet.ImageNet(root, train=False, transform=val_transform, target_transform=None, download=True)
47 return val_dataset
49 def imagenet_classification(dataset_config, root, split=None, transforms=None):
50 train_dataset = imagenet_classification_train(dataset_config, root, split, transforms)
51 val_dataset = imagenet_classification_validation(dataset_config, root, split, transforms)
52 return train_dataset, val_dataset
55 ########################################################################
56 def cifar10_classification(dataset_config, root, split=None, transforms=None):
57 train_dataset = cifar.CIFAR10(root, train=True, transform=transforms[0], target_transform=None, download=True)
58 val_dataset = cifar.CIFAR10(root, train=False, transform=transforms[1], target_transform=None, download=True)
59 return train_dataset, val_dataset
61 def cifar100_classification(dataset_config, root, split=None, transforms=None):
62 train_dataset = cifar.CIFAR100(root, train=True, transform=transforms[0], target_transform=None, download=True)
63 val_dataset = cifar.CIFAR100(root, train=False, transform=transforms[1], target_transform=None, download=True)
64 return train_dataset, val_dataset