8c7e22667ccd1a2ed1b305bcf47179fa16964137
[jacinto-ai/pytorch-jacinto-ai-devkit.git] / examples / quantization_example.py
1 # ----------------------------------
2 # Quantization Aware Training (QAT) Example
3 # Texas Instruments (C) 2018-2020
4 # All Rights Reserved
5 # ----------------------------------
6 # this original code is from: https://github.com/pytorch/examples/blob/master/imagenet/main.py
7 # the changes required for quantizing the model are under the flag args.quantize
8 import argparse
9 import os
10 import random
11 import shutil
12 import time
13 import warnings
15 import torch
16 import torch.nn as nn
17 import torch.nn.parallel
18 import torch.backends.cudnn as cudnn
19 import torch.distributed as dist
20 import torch.optim
21 import torch.multiprocessing as mp
22 import torch.utils.data
23 import torch.utils.data.distributed
24 import torchvision.transforms as transforms
25 import torchvision.datasets as datasets
27 # some of the default torchvision models need some minor tweaks to be friendly for
28 # quantization aware training. so use models from pytorch_jacinto_ai.vision insead
29 #import torchvision.models as models
31 from pytorch_jacinto_ai import xnn
32 from pytorch_jacinto_ai import vision as xvision
33 from pytorch_jacinto_ai.vision import models as models
35 model_names = sorted(name for name in models.__dict__
36     if name.islower() and not name.startswith("__")
37     and callable(models.__dict__[name]))
39 parser = argparse.ArgumentParser(description='PyTorch ImageNet Training')
40 parser.add_argument('data', metavar='DIR',
41                     help='path to dataset')
42 parser.add_argument('-a', '--arch', metavar='ARCH', default='resnet18',
43                     choices=model_names,
44                     help='model architecture: ' +
45                         ' | '.join(model_names) +
46                         ' (default: resnet18)')
47 parser.add_argument('-j', '--workers', default=4, type=int, metavar='N',
48                     help='number of data loading workers (default: 4)')
49 parser.add_argument('--epochs', default=90, type=int, metavar='N',
50                     help='number of total epochs to run')
51 parser.add_argument('--epoch_size', default=0, type=float, metavar='N',
52                     help='fraction of training epoch to use. 0 (default) means full training epoch')
53 parser.add_argument('--epoch_size_val', default=0, type=float, metavar='N',
54                     help='fraction of validation epoch to use. 0 (default) means full validation epoch')
55 parser.add_argument('--start_epoch', default=0, type=int, metavar='N',
56                     help='manual epoch number (useful on restarts)')
57 parser.add_argument('-b', '--batch_size', default=256, type=int,
58                     metavar='N',
59                     help='mini-batch size (default: 256), this is the total '
60                          'batch size of all GPUs on the current node when '
61                          'using Data Parallel or Distributed Data Parallel')
62 parser.add_argument('--lr', '--learning_rate', default=0.1, type=float,
63                     metavar='LR', help='initial learning rate', dest='lr')
64 parser.add_argument('--lr_step_size', default=30, type=int,
65                     metavar='N', help='number of steps before learning rate is reduced')
66 parser.add_argument('--momentum', default=0.9, type=float, metavar='M',
67                     help='momentum')
68 parser.add_argument('--wd', '--weight_decay', default=1e-4, type=float,
69                     metavar='W', help='weight decay (default: 1e-4)',
70                     dest='weight_decay')
71 parser.add_argument('-p', '--print_freq', default=100, type=int,
72                     metavar='N', help='print frequency (default: 10)')
73 parser.add_argument('--resume', default='', type=str, metavar='PATH',
74                     help='path to latest checkpoint (default: none)')
75 parser.add_argument('-e', '--evaluate', dest='evaluate', action='store_true',
76                     help='evaluate model on validation set')
77 parser.add_argument('--pretrained', type=str, default=None,
78                     help='use pre-trained model')
79 parser.add_argument('--world_size', default=-1, type=int,
80                     help='number of nodes for distributed training')
81 parser.add_argument('--rank', default=-1, type=int,
82                     help='node rank for distributed training')
83 parser.add_argument('--dist_url', default='tcp://224.66.41.62:23456', type=str,
84                     help='url used to set up distributed training')
85 parser.add_argument('--dist_backend', default='nccl', type=str,
86                     help='distributed backend')
87 parser.add_argument('--seed', default=None, type=int,
88                     help='seed for initializing training. ')
89 parser.add_argument('--gpu', default=None, type=int,
90                     help='GPU id to use.')
91 parser.add_argument('--multiprocessing_distributed', action='store_true',
92                     help='Use multi-processing distributed training to launch '
93                          'N processes per node, which has N GPUs. This is the '
94                          'fastest way to use PyTorch for either single node or '
95                          'multi node data parallel training')
96 parser.add_argument('--save_path', type=str, default='./data/checkpoints/quantization',
97                     help='path to save the logs and models')
98 parser.add_argument('--quantize', action='store_true',
99                     help='Enable Quantization')
100 parser.add_argument('--opset_version', default=9, type=int,
101                     help='opset version for onnx export')
103 best_acc1 = 0
106 def main():
107     args = parser.parse_args()
109     args.cur_lr = args.lr
111     if args.seed is not None:
112         random.seed(args.seed)
113         torch.manual_seed(args.seed)
114         cudnn.deterministic = True
115         warnings.warn('You have chosen to seed training. '
116                       'This will turn on the CUDNN deterministic setting, '
117                       'which can slow down your training considerably! '
118                       'You may see unexpected behavior when restarting '
119                       'from checkpoints.')
121     if args.gpu is not None:
122         warnings.warn('You have chosen a specific GPU. This will completely '
123                       'disable data parallelism.')
125     if args.dist_url == "env://" and args.world_size == -1:
126         args.world_size = int(os.environ["WORLD_SIZE"])
128     args.distributed = args.world_size > 1 or args.multiprocessing_distributed
130     ngpus_per_node = torch.cuda.device_count()
131     if args.multiprocessing_distributed:
132         # Since we have ngpus_per_node processes per node, the total world_size
133         # needs to be adjusted accordingly
134         args.world_size = ngpus_per_node * args.world_size
135         # Use torch.multiprocessing.spawn to launch distributed processes: the
136         # main_worker process function
137         mp.spawn(main_worker, nprocs=ngpus_per_node, args=(ngpus_per_node, args))
138     else:
139         # Simply call main_worker function
140         main_worker(args.gpu, ngpus_per_node, args)
143 def main_worker(gpu, ngpus_per_node, args):
144     global best_acc1
145     args.gpu = gpu
147     if args.gpu is not None:
148         print("Use GPU: {} for training".format(args.gpu))
150     if args.distributed:
151         if args.dist_url == "env://" and args.rank == -1:
152             args.rank = int(os.environ["RANK"])
153         if args.multiprocessing_distributed:
154             # For multiprocessing distributed training, rank needs to be the
155             # global rank among all the processes
156             args.rank = args.rank * ngpus_per_node + gpu
157         dist.init_process_group(backend=args.dist_backend, init_method=args.dist_url,
158                                 world_size=args.world_size, rank=args.rank)
159     # create model
160     print("=> creating model '{}'".format(args.arch))
161     model = models.__dict__[args.arch]()
163     if args.quantize:
164         # DistributedDataParallel / DataParallel are not supported with quantization
165         dummy_input = torch.rand((1, 3, 224, 224))
166         if args.evaluate:
167             # for validation accuracy check with quantization - can be used to estimate approximate accuracy achieved with quantization
168             model = xnn.quantize.QuantTestModule(model, dummy_input=dummy_input).cuda(args.gpu)
169         else:
170             # for quantization aware training
171             model = xnn.quantize.QuantTrainModule(model, dummy_input=dummy_input).cuda(args.gpu)
172         #
173     else:
174         if args.distributed:
175             # For multiprocessing distributed, DistributedDataParallel constructor
176             # should always set the single device scope, otherwise,
177             # DistributedDataParallel will use all available devices.
178             if args.gpu is not None:
179                 torch.cuda.set_device(args.gpu)
180                 model.cuda(args.gpu)
181                 # When using a single GPU per process and per
182                 # DistributedDataParallel, we need to divide the batch size
183                 # ourselves based on the total number of GPUs we have
184                 args.batch_size = int(args.batch_size / ngpus_per_node)
185                 args.workers = int((args.workers + ngpus_per_node - 1) / ngpus_per_node)
186                 model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu])
187             else:
188                 model.cuda()
189                 # DistributedDataParallel will divide and allocate batch_size to all
190                 # available GPUs if device_ids are not set
191                 model = torch.nn.parallel.DistributedDataParallel(model)
192         elif args.gpu is not None:
193             torch.cuda.set_device(args.gpu)
194             model = model.cuda(args.gpu)
195         else:
196             # DataParallel will divide and allocate batch_size to all available GPUs
197             if args.arch.startswith('alexnet') or args.arch.startswith('vgg'):
198                 model.features = torch.nn.DataParallel(model.features)
199                 model.cuda()
200             else:
201                 model = torch.nn.DataParallel(model).cuda()
203     if args.pretrained is not None:
204         model_orig = model.module if isinstance(model, (torch.nn.parallel.DistributedDataParallel, torch.nn.parallel.DataParallel)) else model
205         model_orig = model_orig.module if args.quantize else model_orig
206         print("=> using pre-trained model for {} from {}".format(args.arch, args.pretrained))
207         if hasattr(model_orig, 'load_weights'):
208             model_orig.load_weights(args.pretrained, download_root='./data/downloads')
209         else:
210             xnn.utils.load_weights(model_orig, args.pretrained, download_root='./data/downloads')
211         #
213     # define loss function (criterion) and optimizer
214     criterion = nn.CrossEntropyLoss().cuda(args.gpu)
216     optimizer = torch.optim.SGD(model.parameters(), args.lr,
217                                 momentum=args.momentum,
218                                 weight_decay=args.weight_decay)
220     # optionally resume from a checkpoint
221     if args.resume:
222         if os.path.isfile(args.resume):
223             print("=> loading checkpoint '{}'".format(args.resume))
224             if args.gpu is None:
225                 checkpoint = torch.load(args.resume)
226             else:
227                 # Map model to be loaded to specified single gpu.
228                 loc = 'cuda:{}'.format(args.gpu)
229                 checkpoint = torch.load(args.resume, map_location=loc)
230             args.start_epoch = checkpoint['epoch']
231             best_acc1 = checkpoint['best_acc1']
232             if args.gpu is not None:
233                 # best_acc1 may be from a checkpoint from a different GPU
234                 best_acc1 = best_acc1.to(args.gpu)
235             model.load_state_dict(checkpoint['state_dict'])
236             optimizer.load_state_dict(checkpoint['optimizer'])
237             print("=> loaded checkpoint '{}' (epoch {})"
238                   .format(args.resume, checkpoint['epoch']))
239         else:
240             print("=> no checkpoint found at '{}'".format(args.resume))
242     cudnn.benchmark = True
244     # Data loading code
245     traindir = os.path.join(args.data, 'train')
246     valdir = os.path.join(args.data, 'val')
247     normalize = xvision.transforms.NormalizeMeanScale(mean=[123.675, 116.28, 103.53], scale=[0.017125, 0.017507, 0.017429])
249     train_dataset = datasets.ImageFolder(
250         traindir,
251         transforms.Compose([
252             transforms.RandomResizedCrop(224),
253             transforms.RandomHorizontalFlip(),
254             xvision.transforms.ToFloat(),   # converting to float avoids the division by 255 in ToTensor()
255             transforms.ToTensor(),
256             normalize,
257         ]))
259     val_dataset = datasets.ImageFolder(valdir, transforms.Compose([
260         transforms.Resize(256),
261         transforms.CenterCrop(224),
262         xvision.transforms.ToFloat(),  # converting to float avoids the division by 255 in ToTensor()
263         transforms.ToTensor(),
264         normalize,
265     ]))
267     if args.distributed:
268         train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset)
269         val_sampler = torch.utils.data.distributed.DistributedSampler(val_dataset)
270     else:
271         train_sampler = get_dataset_sampler(train_dataset, args.epoch_size) if args.epoch_size != 0 else None
272         val_sampler = get_dataset_sampler(val_dataset, args.epoch_size_val) if args.epoch_size_val != 0 else None
274     train_loader = torch.utils.data.DataLoader(
275         train_dataset, batch_size=args.batch_size, shuffle=(train_sampler is None),
276         num_workers=args.workers, pin_memory=True, sampler=train_sampler)
278     val_loader = torch.utils.data.DataLoader(
279         val_dataset, batch_size=args.batch_size, shuffle=(val_sampler is None),
280         num_workers=args.workers, pin_memory=True, sampler=val_sampler)
282     validate(val_loader, model, criterion, args)
284     if args.evaluate:
285         return
287     for epoch in range(args.start_epoch, args.epochs):
288         if args.distributed:
289             train_sampler.set_epoch(epoch)
290         adjust_learning_rate(optimizer, epoch, args)
292         # train for one epoch
293         train(train_loader, model, criterion, optimizer, epoch, args)
295         # evaluate on validation set
296         acc1 = validate(val_loader, model, criterion, args)
298         # remember best acc@1 and save checkpoint
299         is_best = acc1 > best_acc1
300         best_acc1 = max(acc1, best_acc1)
302         model_orig = model.module if isinstance(model, (torch.nn.parallel.DistributedDataParallel, torch.nn.parallel.DataParallel)) else model
303         model_orig = model_orig.module if args.quantize else model_orig
304         if not args.multiprocessing_distributed or (args.multiprocessing_distributed
305                 and args.rank % ngpus_per_node == 0):
306             out_basename = args.arch + ('_checkpoint_quantized.pth' if args.quantize else '_checkpoint.pth')
307             save_filename = os.path.join(args.save_path, out_basename)
308             checkpoint_dict = {
309                 'epoch': epoch + 1,
310                 'arch': args.arch,
311                 'state_dict': model_orig.state_dict(),
312                 'best_acc1': best_acc1,
313                 'optimizer' : optimizer.state_dict(),
314             }
315             save_checkpoint(checkpoint_dict, is_best, filename=save_filename)
316             save_onnxname = os.path.splitext(save_filename)[0]+'.onnx'
317             write_onnx_model(args, model, is_best, filename=save_onnxname)
320 def train(train_loader, model, criterion, optimizer, epoch, args):
321     batch_time = AverageMeter('Time', ':6.3f')
322     data_time = AverageMeter('Data', ':6.3f')
323     losses = AverageMeter('Loss', ':.4e')
324     top1 = AverageMeter('Acc@1', ':6.2f')
325     top5 = AverageMeter('Acc@5', ':6.2f')
326     progress = ProgressMeter(
327         len(train_loader),
328         [batch_time, data_time, losses, top1, top5],
329         prefix="Epoch: [{}]".format(epoch))
331     # switch to train mode
332     model.train()
334     end = time.time()
335     for i, (images, target) in enumerate(train_loader):
336         # measure data loading time
337         data_time.update(time.time() - end)
339         images = images.cuda(args.gpu, non_blocking=True)
340         target = target.cuda(args.gpu, non_blocking=True)
342         # compute output
343         output = model(images)
344         loss = criterion(output, target)
346         # measure accuracy and record loss
347         acc1, acc5 = accuracy(output, target, topk=(1, 5))
348         losses.update(loss.item(), images.size(0))
349         top1.update(acc1[0], images.size(0))
350         top5.update(acc5[0], images.size(0))
352         # compute gradient and do SGD step
353         optimizer.zero_grad()
354         loss.backward()
355         optimizer.step()
357         # measure elapsed time
358         batch_time.update(time.time() - end)
359         end = time.time()
361         if i % args.print_freq == 0:
362             progress.display(i, args.cur_lr)
365 def validate(val_loader, model, criterion, args):
366     batch_time = AverageMeter('Time', ':6.3f')
367     losses = AverageMeter('Loss', ':.4e')
368     top1 = AverageMeter('Acc@1', ':6.2f')
369     top5 = AverageMeter('Acc@5', ':6.2f')
370     progress = ProgressMeter(
371         len(val_loader),
372         [batch_time, losses, top1, top5],
373         prefix='Test: ')
375     # switch to evaluate mode
376     model.eval()
378     with torch.no_grad():
379         end = time.time()
380         for i, (images, target) in enumerate(val_loader):
381             images = images.cuda(args.gpu, non_blocking=True)
382             target = target.cuda(args.gpu, non_blocking=True)
384             # compute output
385             output = model(images)
386             loss = criterion(output, target)
388             # measure accuracy and record loss
389             acc1, acc5 = accuracy(output, target, topk=(1, 5))
390             losses.update(loss.item(), images.size(0))
391             top1.update(acc1[0], images.size(0))
392             top5.update(acc5[0], images.size(0))
394             # measure elapsed time
395             batch_time.update(time.time() - end)
396             end = time.time()
398             if i % args.print_freq == 0:
399                 progress.display(i, args.cur_lr)
401         # TODO: this should also be done with the ProgressMeter
402         print(' * Acc@1 {top1.avg:.3f} Acc@5 {top5.avg:.3f}'
403               .format(top1=top1, top5=top5))
405     return top1.avg
408 def save_checkpoint(state, is_best, filename='checkpoint.pth'):
409     dirname = os.path.dirname(filename)
410     xnn.utils.makedir_exist_ok(dirname)
411     torch.save(state, filename)
412     if is_best:
413         shutil.copyfile(filename, os.path.splitext(filename)[0]+'_best.pth')
416 def create_rand_inputs(is_cuda):
417     dummy_input = torch.rand((1, 3, 224, 224))
418     dummy_input = dummy_input.cuda() if is_cuda else dummy_input
419     return dummy_input
422 def write_onnx_model(args, model, is_best, filename='checkpoint.onnx'):
423     model.eval()
424     is_cuda = next(model.parameters()).is_cuda
425     dummy_input = create_rand_inputs(is_cuda)
426     torch.onnx.export(model, dummy_input, filename, export_params=True, verbose=False,
427                       do_constant_folding=True, opset_version=args.opset_version)
428     if is_best:
429         shutil.copyfile(filename, os.path.splitext(filename)[0]+'_best.onnx')
432 class AverageMeter(object):
433     """Computes and stores the average and current value"""
434     def __init__(self, name, fmt=':f'):
435         self.name = name
436         self.fmt = fmt
437         self.reset()
439     def reset(self):
440         self.val = 0
441         self.avg = 0
442         self.sum = 0
443         self.count = 0
445     def update(self, val, n=1):
446         self.val = val
447         self.sum += val * n
448         self.count += n
449         self.avg = self.sum / self.count
451     def __str__(self):
452         fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})'
453         return fmtstr.format(**self.__dict__)
456 class ProgressMeter(object):
457     def __init__(self, num_batches, meters, prefix=""):
458         self.batch_fmtstr = self._get_batch_fmtstr(num_batches)
459         self.lr_fmtstr = self._get_lr_fmtstr()
460         self.meters = meters
461         self.prefix = prefix
463     def display(self, batch, cur_lr):
464         entries = [self.prefix + self.batch_fmtstr.format(batch), self.lr_fmtstr.format(cur_lr)]
465         entries += [str(meter) for meter in self.meters]
466         print('\t'.join(entries))
468     def _get_batch_fmtstr(self, num_batches):
469         num_digits = len(str(num_batches // 1))
470         fmt = '{:' + str(num_digits) + 'd}'
471         return '[' + fmt + '/' + fmt.format(num_batches) + ']'
473     def _get_lr_fmtstr(self):
474         fmt = 'LR {:g}'
475         return fmt
477 def adjust_learning_rate(optimizer, epoch, args):
478     """Sets the learning rate to the initial LR decayed by 10 every 30 epochs"""
479     lr = args.lr * (0.1 ** (epoch // args.lr_step_size))
480     args.cur_lr = lr
481     for param_group in optimizer.param_groups:
482         param_group['lr'] = lr
485 def accuracy(output, target, topk=(1,)):
486     """Computes the accuracy over the k top predictions for the specified values of k"""
487     with torch.no_grad():
488         maxk = max(topk)
489         batch_size = target.size(0)
491         _, pred = output.topk(maxk, 1, True, True)
492         pred = pred.t()
493         correct = pred.eq(target.view(1, -1).expand_as(pred))
495         res = []
496         for k in topk:
497             correct_k = correct[:k].view(-1).float().sum(0, keepdim=True)
498             res.append(correct_k.mul_(100.0 / batch_size))
499         return res
502 def get_dataset_sampler(dataset_object, epoch_size):
503     num_samples = len(dataset_object)
504     epoch_size = int(epoch_size * num_samples) if epoch_size < 1 else int(epoch_size)
505     print('=> creating a random sampler as epoch_size is specified')
506     dataset_sampler = torch.utils.data.sampler.RandomSampler(data_source=dataset_object, replacement=True, num_samples=epoch_size)
507     return dataset_sampler
510 if __name__ == '__main__':
511     main()