3a9061b45f7cffe28cf77891b9be98023881cca3
[jacinto-ai/pytorch-jacinto-ai-devkit.git] / examples / quantization_example.py
1 # ----------------------------------
2 # Quantization Aware Training (QAT) Example
3 # Texas Instruments (C) 2018-2020
4 # All Rights Reserved
5 # ----------------------------------
6 # this original code is from: https://github.com/pytorch/examples/blob/master/imagenet/main.py
7 # the changes required for quantizing the model are under the flag args.quantize
8 import argparse
9 import os
10 import random
11 import shutil
12 import time
13 import warnings
15 import torch
16 import torch.nn as nn
17 import torch.nn.parallel
18 import torch.backends.cudnn as cudnn
19 import torch.distributed as dist
20 import torch.optim
21 import torch.multiprocessing as mp
22 import torch.utils.data
23 import torch.utils.data.distributed
24 import torchvision.transforms as transforms
25 import torchvision.datasets as datasets
27 # some of the default torchvision models need some minor tweaks to be friendly for
28 # quantization aware training. so use models from pytorch_jacinto_ai.vision insead
29 #import torchvision.models as models
31 from pytorch_jacinto_ai import xnn
32 from pytorch_jacinto_ai import vision as xvision
33 from pytorch_jacinto_ai.vision import models as models
35 model_names = sorted(name for name in models.__dict__
36     if name.islower() and not name.startswith("__")
37     and callable(models.__dict__[name]))
39 parser = argparse.ArgumentParser(description='PyTorch ImageNet Training')
40 parser.add_argument('data', metavar='DIR',
41                     help='path to dataset')
42 parser.add_argument('-a', '--arch', metavar='ARCH', default='resnet18',
43                     choices=model_names,
44                     help='model architecture: ' +
45                         ' | '.join(model_names) +
46                         ' (default: resnet18)')
47 parser.add_argument('-j', '--workers', default=4, type=int, metavar='N',
48                     help='number of data loading workers (default: 4)')
49 parser.add_argument('--epochs', default=90, type=int, metavar='N',
50                     help='number of total epochs to run')
51 parser.add_argument('--epoch_size', default=0, type=int, metavar='N',
52                     help='number of iterations in one training epoch. 0 (default) means full training epoch')
53 parser.add_argument('--epoch_size_val', default=0, type=int, metavar='N',
54                     help='number of iterations in one validation epoch. 0 (default) means full validation epoch')
55 parser.add_argument('--start_epoch', default=0, type=int, metavar='N',
56                     help='manual epoch number (useful on restarts)')
57 parser.add_argument('-b', '--batch_size', default=256, type=int,
58                     metavar='N',
59                     help='mini-batch size (default: 256), this is the total '
60                          'batch size of all GPUs on the current node when '
61                          'using Data Parallel or Distributed Data Parallel')
62 parser.add_argument('--lr', '--learning_rate', default=0.1, type=float,
63                     metavar='LR', help='initial learning rate', dest='lr')
64 parser.add_argument('--lr_step_size', default=30, type=int,
65                     metavar='N', help='number of steps before learning rate is reduced')
66 parser.add_argument('--momentum', default=0.9, type=float, metavar='M',
67                     help='momentum')
68 parser.add_argument('--wd', '--weight_decay', default=1e-4, type=float,
69                     metavar='W', help='weight decay (default: 1e-4)',
70                     dest='weight_decay')
71 parser.add_argument('-p', '--print_freq', default=100, type=int,
72                     metavar='N', help='print frequency (default: 10)')
73 parser.add_argument('--resume', default='', type=str, metavar='PATH',
74                     help='path to latest checkpoint (default: none)')
75 parser.add_argument('-e', '--evaluate', dest='evaluate', action='store_true',
76                     help='evaluate model on validation set')
77 parser.add_argument('--pretrained', type=str, default=None,
78                     help='use pre-trained model')
79 parser.add_argument('--world_size', default=-1, type=int,
80                     help='number of nodes for distributed training')
81 parser.add_argument('--rank', default=-1, type=int,
82                     help='node rank for distributed training')
83 parser.add_argument('--dist_url', default='tcp://224.66.41.62:23456', type=str,
84                     help='url used to set up distributed training')
85 parser.add_argument('--dist_backend', default='nccl', type=str,
86                     help='distributed backend')
87 parser.add_argument('--seed', default=None, type=int,
88                     help='seed for initializing training. ')
89 parser.add_argument('--gpu', default=None, type=int,
90                     help='GPU id to use.')
91 parser.add_argument('--multiprocessing_distributed', action='store_true',
92                     help='Use multi-processing distributed training to launch '
93                          'N processes per node, which has N GPUs. This is the '
94                          'fastest way to use PyTorch for either single node or '
95                          'multi node data parallel training')
96 parser.add_argument('--save_path', type=str, default='./data/checkpoints/quantization',
97                     help='path to save the logs and models')
98 parser.add_argument('--quantize', action='store_true',
99                     help='Enable Quantization')
100 parser.add_argument('--opset_version', default=9, type=int,
101                     help='opset version for onnx export')
103 best_acc1 = 0
106 def main():
107     args = parser.parse_args()
109     args.cur_lr = args.lr
111     if args.seed is not None:
112         random.seed(args.seed)
113         torch.manual_seed(args.seed)
114         cudnn.deterministic = True
115         warnings.warn('You have chosen to seed training. '
116                       'This will turn on the CUDNN deterministic setting, '
117                       'which can slow down your training considerably! '
118                       'You may see unexpected behavior when restarting '
119                       'from checkpoints.')
121     if args.gpu is not None:
122         warnings.warn('You have chosen a specific GPU. This will completely '
123                       'disable data parallelism.')
125     if args.dist_url == "env://" and args.world_size == -1:
126         args.world_size = int(os.environ["WORLD_SIZE"])
128     args.distributed = args.world_size > 1 or args.multiprocessing_distributed
130     ngpus_per_node = torch.cuda.device_count()
131     if args.multiprocessing_distributed:
132         # Since we have ngpus_per_node processes per node, the total world_size
133         # needs to be adjusted accordingly
134         args.world_size = ngpus_per_node * args.world_size
135         # Use torch.multiprocessing.spawn to launch distributed processes: the
136         # main_worker process function
137         mp.spawn(main_worker, nprocs=ngpus_per_node, args=(ngpus_per_node, args))
138     else:
139         # Simply call main_worker function
140         main_worker(args.gpu, ngpus_per_node, args)
143 def main_worker(gpu, ngpus_per_node, args):
144     global best_acc1
145     args.gpu = gpu
147     if args.gpu is not None:
148         print("Use GPU: {} for training".format(args.gpu))
150     if args.distributed:
151         if args.dist_url == "env://" and args.rank == -1:
152             args.rank = int(os.environ["RANK"])
153         if args.multiprocessing_distributed:
154             # For multiprocessing distributed training, rank needs to be the
155             # global rank among all the processes
156             args.rank = args.rank * ngpus_per_node + gpu
157         dist.init_process_group(backend=args.dist_backend, init_method=args.dist_url,
158                                 world_size=args.world_size, rank=args.rank)
159     # create model
160     print("=> creating model '{}'".format(args.arch))
161     model = models.__dict__[args.arch]()
163     if args.quantize:
164         # DistributedDataParallel / DataParallel are not supported with quantization
165         dummy_input = torch.rand((1, 3, 224, 224))
166         if args.evaluate:
167             # for validation accuracy check with quantization - can be used to estimate approximate accuracy achieved with quantization
168             model = xnn.quantize.QuantTestModule(model, dummy_input=dummy_input).cuda(args.gpu)
169         else:
170             # for quantization aware training
171             model = xnn.quantize.QuantTrainModule(model, dummy_input=dummy_input).cuda(args.gpu)
172         #
173     else:
174         if args.distributed:
175             # For multiprocessing distributed, DistributedDataParallel constructor
176             # should always set the single device scope, otherwise,
177             # DistributedDataParallel will use all available devices.
178             if args.gpu is not None:
179                 torch.cuda.set_device(args.gpu)
180                 model.cuda(args.gpu)
181                 # When using a single GPU per process and per
182                 # DistributedDataParallel, we need to divide the batch size
183                 # ourselves based on the total number of GPUs we have
184                 args.batch_size = int(args.batch_size / ngpus_per_node)
185                 args.workers = int((args.workers + ngpus_per_node - 1) / ngpus_per_node)
186                 model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu])
187             else:
188                 model.cuda()
189                 # DistributedDataParallel will divide and allocate batch_size to all
190                 # available GPUs if device_ids are not set
191                 model = torch.nn.parallel.DistributedDataParallel(model)
192         elif args.gpu is not None:
193             torch.cuda.set_device(args.gpu)
194             model = model.cuda(args.gpu)
195         else:
196             # DataParallel will divide and allocate batch_size to all available GPUs
197             if args.arch.startswith('alexnet') or args.arch.startswith('vgg'):
198                 model.features = torch.nn.DataParallel(model.features)
199                 model.cuda()
200             else:
201                 model = torch.nn.DataParallel(model).cuda()
203     if args.pretrained is not None:
204         model_orig = model.module if isinstance(model, (torch.nn.parallel.DistributedDataParallel, torch.nn.parallel.DataParallel)) else model
205         model_orig = model_orig.module if args.quantize else model_orig
206         print("=> using pre-trained model for {} from {}".format(args.arch, args.pretrained))
207         if hasattr(model_orig, 'load_weights'):
208             model_orig.load_weights(args.pretrained, download_root='./data/downloads')
209         else:
210             xnn.utils.load_weights(model_orig, args.pretrained, download_root='./data/downloads')
211         #
213     # define loss function (criterion) and optimizer
214     criterion = nn.CrossEntropyLoss().cuda(args.gpu)
216     optimizer = torch.optim.SGD(model.parameters(), args.lr,
217                                 momentum=args.momentum,
218                                 weight_decay=args.weight_decay)
220     # optionally resume from a checkpoint
221     if args.resume:
222         if os.path.isfile(args.resume):
223             print("=> loading checkpoint '{}'".format(args.resume))
224             if args.gpu is None:
225                 checkpoint = torch.load(args.resume)
226             else:
227                 # Map model to be loaded to specified single gpu.
228                 loc = 'cuda:{}'.format(args.gpu)
229                 checkpoint = torch.load(args.resume, map_location=loc)
230             args.start_epoch = checkpoint['epoch']
231             best_acc1 = checkpoint['best_acc1']
232             if args.gpu is not None:
233                 # best_acc1 may be from a checkpoint from a different GPU
234                 best_acc1 = best_acc1.to(args.gpu)
235             model.load_state_dict(checkpoint['state_dict'])
236             optimizer.load_state_dict(checkpoint['optimizer'])
237             print("=> loaded checkpoint '{}' (epoch {})"
238                   .format(args.resume, checkpoint['epoch']))
239         else:
240             print("=> no checkpoint found at '{}'".format(args.resume))
242     cudnn.benchmark = True
244     # Data loading code
245     traindir = os.path.join(args.data, 'train')
246     valdir = os.path.join(args.data, 'val')
247     normalize = xvision.transforms.NormalizeMeanScale(mean=[123.675, 116.28, 103.53], scale=[0.017125, 0.017507, 0.017429])
249     train_dataset = datasets.ImageFolder(
250         traindir,
251         transforms.Compose([
252             transforms.RandomResizedCrop(224),
253             transforms.RandomHorizontalFlip(),
254             xvision.transforms.ToFloat(),   # converting to float avoids the division by 255 in ToTensor()
255             transforms.ToTensor(),
256             normalize,
257         ]))
259     if args.distributed:
260         train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset)
261     else:
262         train_sampler = None
264     train_loader = torch.utils.data.DataLoader(
265         train_dataset, batch_size=args.batch_size, shuffle=(train_sampler is None),
266         num_workers=args.workers, pin_memory=True, sampler=train_sampler)
268     val_loader = torch.utils.data.DataLoader(
269         datasets.ImageFolder(valdir, transforms.Compose([
270             transforms.Resize(256),
271             transforms.CenterCrop(224),
272             xvision.transforms.ToFloat(),   # converting to float avoids the division by 255 in ToTensor()
273             transforms.ToTensor(),
274             normalize,
275         ])),
276         batch_size=args.batch_size, shuffle=False,
277         num_workers=args.workers, pin_memory=True)
279     validate(val_loader, model, criterion, args)
281     if args.evaluate:
282         return
284     for epoch in range(args.start_epoch, args.epochs):
285         if args.distributed:
286             train_sampler.set_epoch(epoch)
287         adjust_learning_rate(optimizer, epoch, args)
289         # train for one epoch
290         train(train_loader, model, criterion, optimizer, epoch, args)
292         # evaluate on validation set
293         acc1 = validate(val_loader, model, criterion, args)
295         # remember best acc@1 and save checkpoint
296         is_best = acc1 > best_acc1
297         best_acc1 = max(acc1, best_acc1)
299         model_orig = model.module if isinstance(model, (torch.nn.parallel.DistributedDataParallel, torch.nn.parallel.DataParallel)) else model
300         model_orig = model_orig.module if args.quantize else model_orig
301         if not args.multiprocessing_distributed or (args.multiprocessing_distributed
302                 and args.rank % ngpus_per_node == 0):
303             out_basename = args.arch + ('_checkpoint_quantized.pth' if args.quantize else '_checkpoint.pth')
304             save_filename = os.path.join(args.save_path, out_basename)
305             checkpoint_dict = {
306                 'epoch': epoch + 1,
307                 'arch': args.arch,
308                 'state_dict': model_orig.state_dict(),
309                 'best_acc1': best_acc1,
310                 'optimizer' : optimizer.state_dict(),
311             }
312             save_checkpoint(checkpoint_dict, is_best, filename=save_filename)
313             save_onnxname = os.path.splitext(save_filename)[0]+'.onnx'
314             write_onnx_model(args, model, is_best, filename=save_onnxname)
317 def train(train_loader, model, criterion, optimizer, epoch, args):
318     batch_time = AverageMeter('Time', ':6.3f')
319     data_time = AverageMeter('Data', ':6.3f')
320     losses = AverageMeter('Loss', ':.4e')
321     top1 = AverageMeter('Acc@1', ':6.2f')
322     top5 = AverageMeter('Acc@5', ':6.2f')
323     progress = ProgressMeter(
324         len(train_loader),
325         [batch_time, data_time, losses, top1, top5],
326         prefix="Epoch: [{}]".format(epoch))
328     # switch to train mode
329     model.train()
331     end = time.time()
332     for i, (images, target) in enumerate(train_loader):
333         # break the epoch at at the iteration epoch_size
334         if args.epoch_size != 0 and i >= args.epoch_size:
335             break
336         # measure data loading time
337         data_time.update(time.time() - end)
339         images = images.cuda(args.gpu, non_blocking=True)
340         target = target.cuda(args.gpu, non_blocking=True)
342         # compute output
343         output = model(images)
344         loss = criterion(output, target)
346         # measure accuracy and record loss
347         acc1, acc5 = accuracy(output, target, topk=(1, 5))
348         losses.update(loss.item(), images.size(0))
349         top1.update(acc1[0], images.size(0))
350         top5.update(acc5[0], images.size(0))
352         # compute gradient and do SGD step
353         optimizer.zero_grad()
354         loss.backward()
355         optimizer.step()
357         # measure elapsed time
358         batch_time.update(time.time() - end)
359         end = time.time()
361         if i % args.print_freq == 0:
362             progress.display(i, args.cur_lr)
365 def validate(val_loader, model, criterion, args):
366     batch_time = AverageMeter('Time', ':6.3f')
367     losses = AverageMeter('Loss', ':.4e')
368     top1 = AverageMeter('Acc@1', ':6.2f')
369     top5 = AverageMeter('Acc@5', ':6.2f')
370     progress = ProgressMeter(
371         len(val_loader),
372         [batch_time, losses, top1, top5],
373         prefix='Test: ')
375     # switch to evaluate mode
376     model.eval()
378     with torch.no_grad():
379         end = time.time()
380         for i, (images, target) in enumerate(val_loader):
381             # break the epoch at at the iteration epoch_size_val
382             if args.epoch_size_val != 0 and i >= args.epoch_size_val:
383                 break
384             images = images.cuda(args.gpu, non_blocking=True)
385             target = target.cuda(args.gpu, non_blocking=True)
387             # compute output
388             output = model(images)
389             loss = criterion(output, target)
391             # measure accuracy and record loss
392             acc1, acc5 = accuracy(output, target, topk=(1, 5))
393             losses.update(loss.item(), images.size(0))
394             top1.update(acc1[0], images.size(0))
395             top5.update(acc5[0], images.size(0))
397             # measure elapsed time
398             batch_time.update(time.time() - end)
399             end = time.time()
401             if i % args.print_freq == 0:
402                 progress.display(i, args.cur_lr)
404         # TODO: this should also be done with the ProgressMeter
405         print(' * Acc@1 {top1.avg:.3f} Acc@5 {top5.avg:.3f}'
406               .format(top1=top1, top5=top5))
408     return top1.avg
411 def save_checkpoint(state, is_best, filename='checkpoint.pth'):
412     dirname = os.path.dirname(filename)
413     xnn.utils.makedir_exist_ok(dirname)
414     torch.save(state, filename)
415     if is_best:
416         shutil.copyfile(filename, os.path.splitext(filename)[0]+'_best.pth')
419 def create_rand_inputs(is_cuda):
420     dummy_input = torch.rand((1, 3, 224, 224))
421     dummy_input = dummy_input.cuda() if is_cuda else dummy_input
422     return dummy_input
425 def write_onnx_model(args, model, is_best, filename='checkpoint.onnx'):
426     model.eval()
427     is_cuda = next(model.parameters()).is_cuda
428     dummy_input = create_rand_inputs(is_cuda)
429     torch.onnx.export(model, dummy_input, filename, export_params=True, verbose=False,
430                       do_constant_folding=True, opset_version=args.opset_vesion)
431     if is_best:
432         shutil.copyfile(filename, os.path.splitext(filename)[0]+'_best.onnx')
435 class AverageMeter(object):
436     """Computes and stores the average and current value"""
437     def __init__(self, name, fmt=':f'):
438         self.name = name
439         self.fmt = fmt
440         self.reset()
442     def reset(self):
443         self.val = 0
444         self.avg = 0
445         self.sum = 0
446         self.count = 0
448     def update(self, val, n=1):
449         self.val = val
450         self.sum += val * n
451         self.count += n
452         self.avg = self.sum / self.count
454     def __str__(self):
455         fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})'
456         return fmtstr.format(**self.__dict__)
459 class ProgressMeter(object):
460     def __init__(self, num_batches, meters, prefix=""):
461         self.batch_fmtstr = self._get_batch_fmtstr(num_batches)
462         self.lr_fmtstr = self._get_lr_fmtstr()
463         self.meters = meters
464         self.prefix = prefix
466     def display(self, batch, cur_lr):
467         entries = [self.prefix + self.batch_fmtstr.format(batch), self.lr_fmtstr.format(cur_lr)]
468         entries += [str(meter) for meter in self.meters]
469         print('\t'.join(entries))
471     def _get_batch_fmtstr(self, num_batches):
472         num_digits = len(str(num_batches // 1))
473         fmt = '{:' + str(num_digits) + 'd}'
474         return '[' + fmt + '/' + fmt.format(num_batches) + ']'
476     def _get_lr_fmtstr(self):
477         fmt = 'LR {:g}'
478         return fmt
480 def adjust_learning_rate(optimizer, epoch, args):
481     """Sets the learning rate to the initial LR decayed by 10 every 30 epochs"""
482     lr = args.lr * (0.1 ** (epoch // args.lr_step_size))
483     args.cur_lr = lr
484     for param_group in optimizer.param_groups:
485         param_group['lr'] = lr
488 def accuracy(output, target, topk=(1,)):
489     """Computes the accuracy over the k top predictions for the specified values of k"""
490     with torch.no_grad():
491         maxk = max(topk)
492         batch_size = target.size(0)
494         _, pred = output.topk(maxk, 1, True, True)
495         pred = pred.t()
496         correct = pred.eq(target.view(1, -1).expand_as(pred))
498         res = []
499         for k in topk:
500             correct_k = correct[:k].view(-1).float().sum(0, keepdim=True)
501             res.append(correct_k.mul_(100.0 / batch_size))
502         return res
505 if __name__ == '__main__':
506     main()