index 3a9061b45f7cffe28cf77891b9be98023881cca3..8c7e22667ccd1a2ed1b305bcf47179fa16964137 100644 (file)
help='number of data loading workers (default: 4)')
parser.add_argument('--epochs', default=90, type=int, metavar='N',
help='number of total epochs to run')
-parser.add_argument('--epoch_size', default=0, type=int, metavar='N',
- help='number of iterations in one training epoch. 0 (default) means full training epoch')
-parser.add_argument('--epoch_size_val', default=0, type=int, metavar='N',
- help='number of iterations in one validation epoch. 0 (default) means full validation epoch')
+parser.add_argument('--epoch_size', default=0, type=float, metavar='N',
+ help='fraction of training epoch to use. 0 (default) means full training epoch')
+parser.add_argument('--epoch_size_val', default=0, type=float, metavar='N',
+ help='fraction of validation epoch to use. 0 (default) means full validation epoch')
parser.add_argument('--start_epoch', default=0, type=int, metavar='N',
help='manual epoch number (useful on restarts)')
parser.add_argument('-b', '--batch_size', default=256, type=int,
normalize,
]))
+ val_dataset = datasets.ImageFolder(valdir, transforms.Compose([
+ transforms.Resize(256),
+ transforms.CenterCrop(224),
+ xvision.transforms.ToFloat(), # converting to float avoids the division by 255 in ToTensor()
+ transforms.ToTensor(),
+ normalize,
+ ]))
+
if args.distributed:
train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset)
+ val_sampler = torch.utils.data.distributed.DistributedSampler(val_dataset)
else:
- train_sampler = None
+ train_sampler = get_dataset_sampler(train_dataset, args.epoch_size) if args.epoch_size != 0 else None
+ val_sampler = get_dataset_sampler(val_dataset, args.epoch_size_val) if args.epoch_size_val != 0 else None
train_loader = torch.utils.data.DataLoader(
train_dataset, batch_size=args.batch_size, shuffle=(train_sampler is None),
num_workers=args.workers, pin_memory=True, sampler=train_sampler)
val_loader = torch.utils.data.DataLoader(
- datasets.ImageFolder(valdir, transforms.Compose([
- transforms.Resize(256),
- transforms.CenterCrop(224),
- xvision.transforms.ToFloat(), # converting to float avoids the division by 255 in ToTensor()
- transforms.ToTensor(),
- normalize,
- ])),
- batch_size=args.batch_size, shuffle=False,
- num_workers=args.workers, pin_memory=True)
+ val_dataset, batch_size=args.batch_size, shuffle=(val_sampler is None),
+ num_workers=args.workers, pin_memory=True, sampler=val_sampler)
validate(val_loader, model, criterion, args)
end = time.time()
for i, (images, target) in enumerate(train_loader):
- # break the epoch at at the iteration epoch_size
- if args.epoch_size != 0 and i >= args.epoch_size:
- break
# measure data loading time
data_time.update(time.time() - end)
with torch.no_grad():
end = time.time()
for i, (images, target) in enumerate(val_loader):
- # break the epoch at at the iteration epoch_size_val
- if args.epoch_size_val != 0 and i >= args.epoch_size_val:
- break
images = images.cuda(args.gpu, non_blocking=True)
target = target.cuda(args.gpu, non_blocking=True)
is_cuda = next(model.parameters()).is_cuda
dummy_input = create_rand_inputs(is_cuda)
torch.onnx.export(model, dummy_input, filename, export_params=True, verbose=False,
- do_constant_folding=True, opset_version=args.opset_vesion)
+ do_constant_folding=True, opset_version=args.opset_version)
if is_best:
shutil.copyfile(filename, os.path.splitext(filename)[0]+'_best.onnx')
return res
+def get_dataset_sampler(dataset_object, epoch_size):
+ num_samples = len(dataset_object)
+ epoch_size = int(epoch_size * num_samples) if epoch_size < 1 else int(epoch_size)
+ print('=> creating a random sampler as epoch_size is specified')
+ dataset_sampler = torch.utils.data.sampler.RandomSampler(data_source=dataset_object, replacement=True, num_samples=epoch_size)
+ return dataset_sampler
+
+
if __name__ == '__main__':
main()
\ No newline at end of file