quantization_example - RandomSampler is used when epoch_size!=0. epoch_size=0.1 means...
[jacinto-ai/pytorch-jacinto-ai-devkit.git] / examples / quantization_example.py
index 91dac7325bffe7b29844ef547c23af6b05a0eb17..8c7e22667ccd1a2ed1b305bcf47179fa16964137 100644 (file)
@@ -48,10 +48,10 @@ parser.add_argument('-j', '--workers', default=4, type=int, metavar='N',
                     help='number of data loading workers (default: 4)')
 parser.add_argument('--epochs', default=90, type=int, metavar='N',
                     help='number of total epochs to run')
-parser.add_argument('--epoch_size', default=0, type=int, metavar='N',
-                    help='number of iterations in one training epoch. 0 (default) means full training epoch')
-parser.add_argument('--epoch_size_val', default=0, type=int, metavar='N',
-                    help='number of iterations in one validation epoch. 0 (default) means full validation epoch')
+parser.add_argument('--epoch_size', default=0, type=float, metavar='N',
+                    help='fraction of training epoch to use. 0 (default) means full training epoch')
+parser.add_argument('--epoch_size_val', default=0, type=float, metavar='N',
+                    help='fraction of validation epoch to use. 0 (default) means full validation epoch')
 parser.add_argument('--start_epoch', default=0, type=int, metavar='N',
                     help='manual epoch number (useful on restarts)')
 parser.add_argument('-b', '--batch_size', default=256, type=int,
@@ -97,6 +97,9 @@ parser.add_argument('--save_path', type=str, default='./data/checkpoints/quantiz
                     help='path to save the logs and models')
 parser.add_argument('--quantize', action='store_true',
                     help='Enable Quantization')
+parser.add_argument('--opset_version', default=9, type=int,
+                    help='opset version for onnx export')
+
 best_acc1 = 0
 
 
@@ -253,25 +256,28 @@ def main_worker(gpu, ngpus_per_node, args):
             normalize,
         ]))
 
+    val_dataset = datasets.ImageFolder(valdir, transforms.Compose([
+        transforms.Resize(256),
+        transforms.CenterCrop(224),
+        xvision.transforms.ToFloat(),  # converting to float avoids the division by 255 in ToTensor()
+        transforms.ToTensor(),
+        normalize,
+    ]))
+
     if args.distributed:
         train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset)
+        val_sampler = torch.utils.data.distributed.DistributedSampler(val_dataset)
     else:
-        train_sampler = None
+        train_sampler = get_dataset_sampler(train_dataset, args.epoch_size) if args.epoch_size != 0 else None
+        val_sampler = get_dataset_sampler(val_dataset, args.epoch_size_val) if args.epoch_size_val != 0 else None
 
     train_loader = torch.utils.data.DataLoader(
         train_dataset, batch_size=args.batch_size, shuffle=(train_sampler is None),
         num_workers=args.workers, pin_memory=True, sampler=train_sampler)
 
     val_loader = torch.utils.data.DataLoader(
-        datasets.ImageFolder(valdir, transforms.Compose([
-            transforms.Resize(256),
-            transforms.CenterCrop(224),
-            xvision.transforms.ToFloat(),   # converting to float avoids the division by 255 in ToTensor()
-            transforms.ToTensor(),
-            normalize,
-        ])),
-        batch_size=args.batch_size, shuffle=False,
-        num_workers=args.workers, pin_memory=True)
+        val_dataset, batch_size=args.batch_size, shuffle=(val_sampler is None),
+        num_workers=args.workers, pin_memory=True, sampler=val_sampler)
 
     validate(val_loader, model, criterion, args)
 
@@ -308,7 +314,7 @@ def main_worker(gpu, ngpus_per_node, args):
             }
             save_checkpoint(checkpoint_dict, is_best, filename=save_filename)
             save_onnxname = os.path.splitext(save_filename)[0]+'.onnx'
-            write_onnx_model(model, is_best, filename=save_onnxname)
+            write_onnx_model(args, model, is_best, filename=save_onnxname)
 
 
 def train(train_loader, model, criterion, optimizer, epoch, args):
@@ -327,9 +333,6 @@ def train(train_loader, model, criterion, optimizer, epoch, args):
 
     end = time.time()
     for i, (images, target) in enumerate(train_loader):
-        # break the epoch at at the iteration epoch_size
-        if args.epoch_size != 0 and i >= args.epoch_size:
-            break
         # measure data loading time
         data_time.update(time.time() - end)
 
@@ -375,9 +378,6 @@ def validate(val_loader, model, criterion, args):
     with torch.no_grad():
         end = time.time()
         for i, (images, target) in enumerate(val_loader):
-            # break the epoch at at the iteration epoch_size_val
-            if args.epoch_size_val != 0 and i >= args.epoch_size_val:
-                break
             images = images.cuda(args.gpu, non_blocking=True)
             target = target.cuda(args.gpu, non_blocking=True)
 
@@ -419,11 +419,12 @@ def create_rand_inputs(is_cuda):
     return dummy_input
 
 
-def write_onnx_model(model, is_best, filename='checkpoint.onnx'):
+def write_onnx_model(args, model, is_best, filename='checkpoint.onnx'):
     model.eval()
     is_cuda = next(model.parameters()).is_cuda
     dummy_input = create_rand_inputs(is_cuda)
-    torch.onnx.export(model, dummy_input, filename, export_params=True, verbose=False)
+    torch.onnx.export(model, dummy_input, filename, export_params=True, verbose=False,
+                      do_constant_folding=True, opset_version=args.opset_version)
     if is_best:
         shutil.copyfile(filename, os.path.splitext(filename)[0]+'_best.onnx')
 
@@ -498,5 +499,13 @@ def accuracy(output, target, topk=(1,)):
         return res
 
 
+def get_dataset_sampler(dataset_object, epoch_size):
+    num_samples = len(dataset_object)
+    epoch_size = int(epoch_size * num_samples) if epoch_size < 1 else int(epoch_size)
+    print('=> creating a random sampler as epoch_size is specified')
+    dataset_sampler = torch.utils.data.sampler.RandomSampler(data_source=dataset_object, replacement=True, num_samples=epoch_size)
+    return dataset_sampler
+
+
 if __name__ == '__main__':
     main()
\ No newline at end of file