quantization_example - RandomSampler is used when epoch_size!=0. epoch_size=0.1 means...
authorManu Mathew <a0393608@ti.com>
Thu, 14 May 2020 07:29:35 +0000 (12:59 +0530)
committerManu Mathew <a0393608@ti.com>
Thu, 14 May 2020 07:57:15 +0000 (13:27 +0530)
examples/quantization_example.py
run_quantization_example.sh

index 3a9061b45f7cffe28cf77891b9be98023881cca3..8c7e22667ccd1a2ed1b305bcf47179fa16964137 100644 (file)
@@ -48,10 +48,10 @@ parser.add_argument('-j', '--workers', default=4, type=int, metavar='N',
                     help='number of data loading workers (default: 4)')
 parser.add_argument('--epochs', default=90, type=int, metavar='N',
                     help='number of total epochs to run')
-parser.add_argument('--epoch_size', default=0, type=int, metavar='N',
-                    help='number of iterations in one training epoch. 0 (default) means full training epoch')
-parser.add_argument('--epoch_size_val', default=0, type=int, metavar='N',
-                    help='number of iterations in one validation epoch. 0 (default) means full validation epoch')
+parser.add_argument('--epoch_size', default=0, type=float, metavar='N',
+                    help='fraction of training epoch to use. 0 (default) means full training epoch')
+parser.add_argument('--epoch_size_val', default=0, type=float, metavar='N',
+                    help='fraction of validation epoch to use. 0 (default) means full validation epoch')
 parser.add_argument('--start_epoch', default=0, type=int, metavar='N',
                     help='manual epoch number (useful on restarts)')
 parser.add_argument('-b', '--batch_size', default=256, type=int,
@@ -256,25 +256,28 @@ def main_worker(gpu, ngpus_per_node, args):
             normalize,
         ]))
 
+    val_dataset = datasets.ImageFolder(valdir, transforms.Compose([
+        transforms.Resize(256),
+        transforms.CenterCrop(224),
+        xvision.transforms.ToFloat(),  # converting to float avoids the division by 255 in ToTensor()
+        transforms.ToTensor(),
+        normalize,
+    ]))
+
     if args.distributed:
         train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset)
+        val_sampler = torch.utils.data.distributed.DistributedSampler(val_dataset)
     else:
-        train_sampler = None
+        train_sampler = get_dataset_sampler(train_dataset, args.epoch_size) if args.epoch_size != 0 else None
+        val_sampler = get_dataset_sampler(val_dataset, args.epoch_size_val) if args.epoch_size_val != 0 else None
 
     train_loader = torch.utils.data.DataLoader(
         train_dataset, batch_size=args.batch_size, shuffle=(train_sampler is None),
         num_workers=args.workers, pin_memory=True, sampler=train_sampler)
 
     val_loader = torch.utils.data.DataLoader(
-        datasets.ImageFolder(valdir, transforms.Compose([
-            transforms.Resize(256),
-            transforms.CenterCrop(224),
-            xvision.transforms.ToFloat(),   # converting to float avoids the division by 255 in ToTensor()
-            transforms.ToTensor(),
-            normalize,
-        ])),
-        batch_size=args.batch_size, shuffle=False,
-        num_workers=args.workers, pin_memory=True)
+        val_dataset, batch_size=args.batch_size, shuffle=(val_sampler is None),
+        num_workers=args.workers, pin_memory=True, sampler=val_sampler)
 
     validate(val_loader, model, criterion, args)
 
@@ -330,9 +333,6 @@ def train(train_loader, model, criterion, optimizer, epoch, args):
 
     end = time.time()
     for i, (images, target) in enumerate(train_loader):
-        # break the epoch at at the iteration epoch_size
-        if args.epoch_size != 0 and i >= args.epoch_size:
-            break
         # measure data loading time
         data_time.update(time.time() - end)
 
@@ -378,9 +378,6 @@ def validate(val_loader, model, criterion, args):
     with torch.no_grad():
         end = time.time()
         for i, (images, target) in enumerate(val_loader):
-            # break the epoch at at the iteration epoch_size_val
-            if args.epoch_size_val != 0 and i >= args.epoch_size_val:
-                break
             images = images.cuda(args.gpu, non_blocking=True)
             target = target.cuda(args.gpu, non_blocking=True)
 
@@ -427,7 +424,7 @@ def write_onnx_model(args, model, is_best, filename='checkpoint.onnx'):
     is_cuda = next(model.parameters()).is_cuda
     dummy_input = create_rand_inputs(is_cuda)
     torch.onnx.export(model, dummy_input, filename, export_params=True, verbose=False,
-                      do_constant_folding=True, opset_version=args.opset_vesion)
+                      do_constant_folding=True, opset_version=args.opset_version)
     if is_best:
         shutil.copyfile(filename, os.path.splitext(filename)[0]+'_best.onnx')
 
@@ -502,5 +499,13 @@ def accuracy(output, target, topk=(1,)):
         return res
 
 
+def get_dataset_sampler(dataset_object, epoch_size):
+    num_samples = len(dataset_object)
+    epoch_size = int(epoch_size * num_samples) if epoch_size < 1 else int(epoch_size)
+    print('=> creating a random sampler as epoch_size is specified')
+    dataset_sampler = torch.utils.data.sampler.RandomSampler(data_source=dataset_object, replacement=True, num_samples=epoch_size)
+    return dataset_sampler
+
+
 if __name__ == '__main__':
     main()
\ No newline at end of file
index 4ee691658fa99ca6a361b7ce36278fe2aba46895..519bb8e3f09fc8ef9261e9758a6d2045f2c7808a 100755 (executable)
@@ -32,7 +32,7 @@ declare -A model_pretrained=(
 lr=1e-5             # initial learning rate for quantization aware training - recommend to use 1e-5 (or at max 5e-5)
 batch_size=64       # use a relatively smaller batch size as quantization aware training does not use multi-gpu
 epochs=10           # numerb of epochs to train
-epoch_size=0.1     # artificially limit one training epoch to this many iterations - this argument is only used to limit the training time and may hurt acuracy - set to 0 to use the full training epoch
+epoch_size=0.1      # artificially limit one training epoch to this many iterations - this argument is only used to limit the training time and may hurt acuracy - set to 0 to use the full training epoch
 epoch_size_val=0    # validation epoch size - set to 0 for full validation epoch