quantization_example - RandomSampler is used when epoch_size!=0. epoch_size=0.1 means...
[jacinto-ai/pytorch-jacinto-ai-devkit.git] / examples / quantization_example.py
index 7d5e5488da26f9c9e6e43d429311ae56a1e38b88..8c7e22667ccd1a2ed1b305bcf47179fa16964137 100644 (file)
@@ -1,4 +1,9 @@
-# this code is modified from: https://github.com/pytorch/examples/blob/master/imagenet/main.py
+# ----------------------------------
+# Quantization Aware Training (QAT) Example
+# Texas Instruments (C) 2018-2020
+# All Rights Reserved
+# ----------------------------------
+# this original code is from: https://github.com/pytorch/examples/blob/master/imagenet/main.py
 # the changes required for quantizing the model are under the flag args.quantize
 import argparse
 import os
@@ -43,23 +48,27 @@ parser.add_argument('-j', '--workers', default=4, type=int, metavar='N',
                     help='number of data loading workers (default: 4)')
 parser.add_argument('--epochs', default=90, type=int, metavar='N',
                     help='number of total epochs to run')
-parser.add_argument('--epoch-size', default=0, type=int, metavar='N',
-                    help='number of iterations in one training epoch. 0 (default) means full training epoch')
-parser.add_argument('--start-epoch', default=0, type=int, metavar='N',
+parser.add_argument('--epoch_size', default=0, type=float, metavar='N',
+                    help='fraction of training epoch to use. 0 (default) means full training epoch')
+parser.add_argument('--epoch_size_val', default=0, type=float, metavar='N',
+                    help='fraction of validation epoch to use. 0 (default) means full validation epoch')
+parser.add_argument('--start_epoch', default=0, type=int, metavar='N',
                     help='manual epoch number (useful on restarts)')
-parser.add_argument('-b', '--batch-size', default=256, type=int,
+parser.add_argument('-b', '--batch_size', default=256, type=int,
                     metavar='N',
                     help='mini-batch size (default: 256), this is the total '
                          'batch size of all GPUs on the current node when '
                          'using Data Parallel or Distributed Data Parallel')
-parser.add_argument('--lr', '--learning-rate', default=0.1, type=float,
+parser.add_argument('--lr', '--learning_rate', default=0.1, type=float,
                     metavar='LR', help='initial learning rate', dest='lr')
+parser.add_argument('--lr_step_size', default=30, type=int,
+                    metavar='N', help='number of steps before learning rate is reduced')
 parser.add_argument('--momentum', default=0.9, type=float, metavar='M',
                     help='momentum')
-parser.add_argument('--wd', '--weight-decay', default=1e-4, type=float,
+parser.add_argument('--wd', '--weight_decay', default=1e-4, type=float,
                     metavar='W', help='weight decay (default: 1e-4)',
                     dest='weight_decay')
-parser.add_argument('-p', '--print-freq', default=100, type=int,
+parser.add_argument('-p', '--print_freq', default=100, type=int,
                     metavar='N', help='print frequency (default: 10)')
 parser.add_argument('--resume', default='', type=str, metavar='PATH',
                     help='path to latest checkpoint (default: none)')
@@ -67,31 +76,38 @@ parser.add_argument('-e', '--evaluate', dest='evaluate', action='store_true',
                     help='evaluate model on validation set')
 parser.add_argument('--pretrained', type=str, default=None,
                     help='use pre-trained model')
-parser.add_argument('--world-size', default=-1, type=int,
+parser.add_argument('--world_size', default=-1, type=int,
                     help='number of nodes for distributed training')
 parser.add_argument('--rank', default=-1, type=int,
                     help='node rank for distributed training')
-parser.add_argument('--dist-url', default='tcp://224.66.41.62:23456', type=str,
+parser.add_argument('--dist_url', default='tcp://224.66.41.62:23456', type=str,
                     help='url used to set up distributed training')
-parser.add_argument('--dist-backend', default='nccl', type=str,
+parser.add_argument('--dist_backend', default='nccl', type=str,
                     help='distributed backend')
 parser.add_argument('--seed', default=None, type=int,
                     help='seed for initializing training. ')
 parser.add_argument('--gpu', default=None, type=int,
                     help='GPU id to use.')
-parser.add_argument('--multiprocessing-distributed', action='store_true',
+parser.add_argument('--multiprocessing_distributed', action='store_true',
                     help='Use multi-processing distributed training to launch '
                          'N processes per node, which has N GPUs. This is the '
                          'fastest way to use PyTorch for either single node or '
                          'multi node data parallel training')
+parser.add_argument('--save_path', type=str, default='./data/checkpoints/quantization',
+                    help='path to save the logs and models')
 parser.add_argument('--quantize', action='store_true',
                     help='Enable Quantization')
+parser.add_argument('--opset_version', default=9, type=int,
+                    help='opset version for onnx export')
+
 best_acc1 = 0
 
 
 def main():
     args = parser.parse_args()
 
+    args.cur_lr = args.lr
+
     if args.seed is not None:
         random.seed(args.seed)
         torch.manual_seed(args.seed)
@@ -235,30 +251,33 @@ def main_worker(gpu, ngpus_per_node, args):
         transforms.Compose([
             transforms.RandomResizedCrop(224),
             transforms.RandomHorizontalFlip(),
-            xvision.transforms.ToFloat(),
+            xvision.transforms.ToFloat(),   # converting to float avoids the division by 255 in ToTensor()
             transforms.ToTensor(),
             normalize,
         ]))
 
+    val_dataset = datasets.ImageFolder(valdir, transforms.Compose([
+        transforms.Resize(256),
+        transforms.CenterCrop(224),
+        xvision.transforms.ToFloat(),  # converting to float avoids the division by 255 in ToTensor()
+        transforms.ToTensor(),
+        normalize,
+    ]))
+
     if args.distributed:
         train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset)
+        val_sampler = torch.utils.data.distributed.DistributedSampler(val_dataset)
     else:
-        train_sampler = None
+        train_sampler = get_dataset_sampler(train_dataset, args.epoch_size) if args.epoch_size != 0 else None
+        val_sampler = get_dataset_sampler(val_dataset, args.epoch_size_val) if args.epoch_size_val != 0 else None
 
     train_loader = torch.utils.data.DataLoader(
         train_dataset, batch_size=args.batch_size, shuffle=(train_sampler is None),
         num_workers=args.workers, pin_memory=True, sampler=train_sampler)
 
     val_loader = torch.utils.data.DataLoader(
-        datasets.ImageFolder(valdir, transforms.Compose([
-            transforms.Resize(256),
-            transforms.CenterCrop(224),
-            xvision.transforms.ToFloat(),
-            transforms.ToTensor(),
-            normalize,
-        ])),
-        batch_size=args.batch_size, shuffle=False,
-        num_workers=args.workers, pin_memory=True)
+        val_dataset, batch_size=args.batch_size, shuffle=(val_sampler is None),
+        num_workers=args.workers, pin_memory=True, sampler=val_sampler)
 
     validate(val_loader, model, criterion, args)
 
@@ -284,16 +303,18 @@ def main_worker(gpu, ngpus_per_node, args):
         model_orig = model_orig.module if args.quantize else model_orig
         if not args.multiprocessing_distributed or (args.multiprocessing_distributed
                 and args.rank % ngpus_per_node == 0):
-            out_basename = args.arch
-            out_basename += ('_quantized_checkpoint.pth.tar' if args.quantize else '_checkpoint.pth.tar')
-            save_filename = os.path.join('./data/checkpoints/quantization', out_basename)
-            save_checkpoint({
+            out_basename = args.arch + ('_checkpoint_quantized.pth' if args.quantize else '_checkpoint.pth')
+            save_filename = os.path.join(args.save_path, out_basename)
+            checkpoint_dict = {
                 'epoch': epoch + 1,
                 'arch': args.arch,
                 'state_dict': model_orig.state_dict(),
                 'best_acc1': best_acc1,
                 'optimizer' : optimizer.state_dict(),
-            }, is_best, filename=save_filename)
+            }
+            save_checkpoint(checkpoint_dict, is_best, filename=save_filename)
+            save_onnxname = os.path.splitext(save_filename)[0]+'.onnx'
+            write_onnx_model(args, model, is_best, filename=save_onnxname)
 
 
 def train(train_loader, model, criterion, optimizer, epoch, args):
@@ -312,9 +333,6 @@ def train(train_loader, model, criterion, optimizer, epoch, args):
 
     end = time.time()
     for i, (images, target) in enumerate(train_loader):
-        # break the epoch at at the iteration epoch_size
-        if args.epoch_size != 0 and i >= args.epoch_size:
-            break
         # measure data loading time
         data_time.update(time.time() - end)
 
@@ -341,7 +359,7 @@ def train(train_loader, model, criterion, optimizer, epoch, args):
         end = time.time()
 
         if i % args.print_freq == 0:
-            progress.display(i)
+            progress.display(i, args.cur_lr)
 
 
 def validate(val_loader, model, criterion, args):
@@ -378,7 +396,7 @@ def validate(val_loader, model, criterion, args):
             end = time.time()
 
             if i % args.print_freq == 0:
-                progress.display(i)
+                progress.display(i, args.cur_lr)
 
         # TODO: this should also be done with the ProgressMeter
         print(' * Acc@1 {top1.avg:.3f} Acc@5 {top5.avg:.3f}'
@@ -387,12 +405,28 @@ def validate(val_loader, model, criterion, args):
     return top1.avg
 
 
-def save_checkpoint(state, is_best, filename='checkpoint.pth.tar'):
+def save_checkpoint(state, is_best, filename='checkpoint.pth'):
     dirname = os.path.dirname(filename)
     xnn.utils.makedir_exist_ok(dirname)
     torch.save(state, filename)
     if is_best:
-        shutil.copyfile(filename, os.path.splitext(filename)[0]+'_best.pth.tar')
+        shutil.copyfile(filename, os.path.splitext(filename)[0]+'_best.pth')
+
+
+def create_rand_inputs(is_cuda):
+    dummy_input = torch.rand((1, 3, 224, 224))
+    dummy_input = dummy_input.cuda() if is_cuda else dummy_input
+    return dummy_input
+
+
+def write_onnx_model(args, model, is_best, filename='checkpoint.onnx'):
+    model.eval()
+    is_cuda = next(model.parameters()).is_cuda
+    dummy_input = create_rand_inputs(is_cuda)
+    torch.onnx.export(model, dummy_input, filename, export_params=True, verbose=False,
+                      do_constant_folding=True, opset_version=args.opset_version)
+    if is_best:
+        shutil.copyfile(filename, os.path.splitext(filename)[0]+'_best.onnx')
 
 
 class AverageMeter(object):
@@ -422,11 +456,12 @@ class AverageMeter(object):
 class ProgressMeter(object):
     def __init__(self, num_batches, meters, prefix=""):
         self.batch_fmtstr = self._get_batch_fmtstr(num_batches)
+        self.lr_fmtstr = self._get_lr_fmtstr()
         self.meters = meters
         self.prefix = prefix
 
-    def display(self, batch):
-        entries = [self.prefix + self.batch_fmtstr.format(batch)]
+    def display(self, batch, cur_lr):
+        entries = [self.prefix + self.batch_fmtstr.format(batch), self.lr_fmtstr.format(cur_lr)]
         entries += [str(meter) for meter in self.meters]
         print('\t'.join(entries))
 
@@ -435,10 +470,14 @@ class ProgressMeter(object):
         fmt = '{:' + str(num_digits) + 'd}'
         return '[' + fmt + '/' + fmt.format(num_batches) + ']'
 
+    def _get_lr_fmtstr(self):
+        fmt = 'LR {:g}'
+        return fmt
 
 def adjust_learning_rate(optimizer, epoch, args):
     """Sets the learning rate to the initial LR decayed by 10 every 30 epochs"""
-    lr = args.lr * (0.1 ** (epoch // 30))
+    lr = args.lr * (0.1 ** (epoch // args.lr_step_size))
+    args.cur_lr = lr
     for param_group in optimizer.param_groups:
         param_group['lr'] = lr
 
@@ -460,5 +499,13 @@ def accuracy(output, target, topk=(1,)):
         return res
 
 
+def get_dataset_sampler(dataset_object, epoch_size):
+    num_samples = len(dataset_object)
+    epoch_size = int(epoch_size * num_samples) if epoch_size < 1 else int(epoch_size)
+    print('=> creating a random sampler as epoch_size is specified')
+    dataset_sampler = torch.utils.data.sampler.RandomSampler(data_source=dataset_object, replacement=True, num_samples=epoch_size)
+    return dataset_sampler
+
+
 if __name__ == '__main__':
     main()
\ No newline at end of file