release commit
[jacinto-ai/pytorch-jacinto-ai-devkit.git] / examples / quantization_example.py
1 # ----------------------------------
2 # Quantization Aware Training (QAT) Example
3 # Texas Instruments (C) 2018-2020
4 # All Rights Reserved
5 # ----------------------------------
6 # this original code is from: https://github.com/pytorch/examples/blob/master/imagenet/main.py
7 # the changes required for quantizing the model are under the flag args.quantize
8 import argparse
9 import os
10 import random
11 import shutil
12 import time
13 import warnings
15 import torch
16 import torch.nn as nn
17 import torch.nn.parallel
18 import torch.backends.cudnn as cudnn
19 import torch.distributed as dist
20 import torch.optim
21 import torch.multiprocessing as mp
22 import torch.utils.data
23 import torch.utils.data.distributed
24 import torchvision.transforms as transforms
25 import torchvision.datasets as datasets
27 # some of the default torchvision models need some minor tweaks to be friendly for
28 # quantization aware training. so use models from pytorch_jacinto_ai.vision insead
29 #import torchvision.models as models
31 from pytorch_jacinto_ai import xnn
32 from pytorch_jacinto_ai import vision as xvision
33 from pytorch_jacinto_ai.vision import models as models
35 model_names = sorted(name for name in models.__dict__
36     if name.islower() and not name.startswith("__")
37     and callable(models.__dict__[name]))
39 parser = argparse.ArgumentParser(description='PyTorch ImageNet Training')
40 parser.add_argument('data', metavar='DIR',
41                     help='path to dataset')
42 parser.add_argument('-a', '--arch', metavar='ARCH', default='resnet18',
43                     choices=model_names,
44                     help='model architecture: ' +
45                         ' | '.join(model_names) +
46                         ' (default: resnet18)')
47 parser.add_argument('-j', '--workers', default=4, type=int, metavar='N',
48                     help='number of data loading workers (default: 4)')
49 parser.add_argument('--epochs', default=90, type=int, metavar='N',
50                     help='number of total epochs to run')
51 parser.add_argument('--epoch_size', default=0, type=int, metavar='N',
52                     help='number of iterations in one training epoch. 0 (default) means full training epoch')
53 parser.add_argument('--epoch_size_val', default=0, type=int, metavar='N',
54                     help='number of iterations in one validation epoch. 0 (default) means full validation epoch')
55 parser.add_argument('--start_epoch', default=0, type=int, metavar='N',
56                     help='manual epoch number (useful on restarts)')
57 parser.add_argument('-b', '--batch_size', default=256, type=int,
58                     metavar='N',
59                     help='mini-batch size (default: 256), this is the total '
60                          'batch size of all GPUs on the current node when '
61                          'using Data Parallel or Distributed Data Parallel')
62 parser.add_argument('--lr', '--learning_rate', default=0.1, type=float,
63                     metavar='LR', help='initial learning rate', dest='lr')
64 parser.add_argument('--lr_step_size', default=30, type=int,
65                     metavar='N', help='number of steps before learning rate is reduced')
66 parser.add_argument('--momentum', default=0.9, type=float, metavar='M',
67                     help='momentum')
68 parser.add_argument('--wd', '--weight_decay', default=1e-4, type=float,
69                     metavar='W', help='weight decay (default: 1e-4)',
70                     dest='weight_decay')
71 parser.add_argument('-p', '--print_freq', default=100, type=int,
72                     metavar='N', help='print frequency (default: 10)')
73 parser.add_argument('--resume', default='', type=str, metavar='PATH',
74                     help='path to latest checkpoint (default: none)')
75 parser.add_argument('-e', '--evaluate', dest='evaluate', action='store_true',
76                     help='evaluate model on validation set')
77 parser.add_argument('--pretrained', type=str, default=None,
78                     help='use pre-trained model')
79 parser.add_argument('--world_size', default=-1, type=int,
80                     help='number of nodes for distributed training')
81 parser.add_argument('--rank', default=-1, type=int,
82                     help='node rank for distributed training')
83 parser.add_argument('--dist_url', default='tcp://224.66.41.62:23456', type=str,
84                     help='url used to set up distributed training')
85 parser.add_argument('--dist_backend', default='nccl', type=str,
86                     help='distributed backend')
87 parser.add_argument('--seed', default=None, type=int,
88                     help='seed for initializing training. ')
89 parser.add_argument('--gpu', default=None, type=int,
90                     help='GPU id to use.')
91 parser.add_argument('--multiprocessing_distributed', action='store_true',
92                     help='Use multi-processing distributed training to launch '
93                          'N processes per node, which has N GPUs. This is the '
94                          'fastest way to use PyTorch for either single node or '
95                          'multi node data parallel training')
96 parser.add_argument('--save_path', type=str, default='./data/checkpoints/quantization',
97                     help='path to save the logs and models')
98 parser.add_argument('--quantize', action='store_true',
99                     help='Enable Quantization')
100 best_acc1 = 0
103 def main():
104     args = parser.parse_args()
106     args.cur_lr = args.lr
108     if args.seed is not None:
109         random.seed(args.seed)
110         torch.manual_seed(args.seed)
111         cudnn.deterministic = True
112         warnings.warn('You have chosen to seed training. '
113                       'This will turn on the CUDNN deterministic setting, '
114                       'which can slow down your training considerably! '
115                       'You may see unexpected behavior when restarting '
116                       'from checkpoints.')
118     if args.gpu is not None:
119         warnings.warn('You have chosen a specific GPU. This will completely '
120                       'disable data parallelism.')
122     if args.dist_url == "env://" and args.world_size == -1:
123         args.world_size = int(os.environ["WORLD_SIZE"])
125     args.distributed = args.world_size > 1 or args.multiprocessing_distributed
127     ngpus_per_node = torch.cuda.device_count()
128     if args.multiprocessing_distributed:
129         # Since we have ngpus_per_node processes per node, the total world_size
130         # needs to be adjusted accordingly
131         args.world_size = ngpus_per_node * args.world_size
132         # Use torch.multiprocessing.spawn to launch distributed processes: the
133         # main_worker process function
134         mp.spawn(main_worker, nprocs=ngpus_per_node, args=(ngpus_per_node, args))
135     else:
136         # Simply call main_worker function
137         main_worker(args.gpu, ngpus_per_node, args)
140 def main_worker(gpu, ngpus_per_node, args):
141     global best_acc1
142     args.gpu = gpu
144     if args.gpu is not None:
145         print("Use GPU: {} for training".format(args.gpu))
147     if args.distributed:
148         if args.dist_url == "env://" and args.rank == -1:
149             args.rank = int(os.environ["RANK"])
150         if args.multiprocessing_distributed:
151             # For multiprocessing distributed training, rank needs to be the
152             # global rank among all the processes
153             args.rank = args.rank * ngpus_per_node + gpu
154         dist.init_process_group(backend=args.dist_backend, init_method=args.dist_url,
155                                 world_size=args.world_size, rank=args.rank)
156     # create model
157     print("=> creating model '{}'".format(args.arch))
158     model = models.__dict__[args.arch]()
160     if args.quantize:
161         # DistributedDataParallel / DataParallel are not supported with quantization
162         dummy_input = torch.rand((1, 3, 224, 224))
163         if args.evaluate:
164             # for validation accuracy check with quantization - can be used to estimate approximate accuracy achieved with quantization
165             model = xnn.quantize.QuantTestModule(model, dummy_input=dummy_input).cuda(args.gpu)
166         else:
167             # for quantization aware training
168             model = xnn.quantize.QuantTrainModule(model, dummy_input=dummy_input).cuda(args.gpu)
169         #
170     else:
171         if args.distributed:
172             # For multiprocessing distributed, DistributedDataParallel constructor
173             # should always set the single device scope, otherwise,
174             # DistributedDataParallel will use all available devices.
175             if args.gpu is not None:
176                 torch.cuda.set_device(args.gpu)
177                 model.cuda(args.gpu)
178                 # When using a single GPU per process and per
179                 # DistributedDataParallel, we need to divide the batch size
180                 # ourselves based on the total number of GPUs we have
181                 args.batch_size = int(args.batch_size / ngpus_per_node)
182                 args.workers = int((args.workers + ngpus_per_node - 1) / ngpus_per_node)
183                 model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu])
184             else:
185                 model.cuda()
186                 # DistributedDataParallel will divide and allocate batch_size to all
187                 # available GPUs if device_ids are not set
188                 model = torch.nn.parallel.DistributedDataParallel(model)
189         elif args.gpu is not None:
190             torch.cuda.set_device(args.gpu)
191             model = model.cuda(args.gpu)
192         else:
193             # DataParallel will divide and allocate batch_size to all available GPUs
194             if args.arch.startswith('alexnet') or args.arch.startswith('vgg'):
195                 model.features = torch.nn.DataParallel(model.features)
196                 model.cuda()
197             else:
198                 model = torch.nn.DataParallel(model).cuda()
200     if args.pretrained is not None:
201         model_orig = model.module if isinstance(model, (torch.nn.parallel.DistributedDataParallel, torch.nn.parallel.DataParallel)) else model
202         model_orig = model_orig.module if args.quantize else model_orig
203         print("=> using pre-trained model for {} from {}".format(args.arch, args.pretrained))
204         if hasattr(model_orig, 'load_weights'):
205             model_orig.load_weights(args.pretrained, download_root='./data/downloads')
206         else:
207             xnn.utils.load_weights(model_orig, args.pretrained, download_root='./data/downloads')
208         #
210     # define loss function (criterion) and optimizer
211     criterion = nn.CrossEntropyLoss().cuda(args.gpu)
213     optimizer = torch.optim.SGD(model.parameters(), args.lr,
214                                 momentum=args.momentum,
215                                 weight_decay=args.weight_decay)
217     # optionally resume from a checkpoint
218     if args.resume:
219         if os.path.isfile(args.resume):
220             print("=> loading checkpoint '{}'".format(args.resume))
221             if args.gpu is None:
222                 checkpoint = torch.load(args.resume)
223             else:
224                 # Map model to be loaded to specified single gpu.
225                 loc = 'cuda:{}'.format(args.gpu)
226                 checkpoint = torch.load(args.resume, map_location=loc)
227             args.start_epoch = checkpoint['epoch']
228             best_acc1 = checkpoint['best_acc1']
229             if args.gpu is not None:
230                 # best_acc1 may be from a checkpoint from a different GPU
231                 best_acc1 = best_acc1.to(args.gpu)
232             model.load_state_dict(checkpoint['state_dict'])
233             optimizer.load_state_dict(checkpoint['optimizer'])
234             print("=> loaded checkpoint '{}' (epoch {})"
235                   .format(args.resume, checkpoint['epoch']))
236         else:
237             print("=> no checkpoint found at '{}'".format(args.resume))
239     cudnn.benchmark = True
241     # Data loading code
242     traindir = os.path.join(args.data, 'train')
243     valdir = os.path.join(args.data, 'val')
244     normalize = xvision.transforms.NormalizeMeanScale(mean=[123.675, 116.28, 103.53], scale=[0.017125, 0.017507, 0.017429])
246     train_dataset = datasets.ImageFolder(
247         traindir,
248         transforms.Compose([
249             transforms.RandomResizedCrop(224),
250             transforms.RandomHorizontalFlip(),
251             xvision.transforms.ToFloat(),   # converting to float avoids the division by 255 in ToTensor()
252             transforms.ToTensor(),
253             normalize,
254         ]))
256     if args.distributed:
257         train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset)
258     else:
259         train_sampler = None
261     train_loader = torch.utils.data.DataLoader(
262         train_dataset, batch_size=args.batch_size, shuffle=(train_sampler is None),
263         num_workers=args.workers, pin_memory=True, sampler=train_sampler)
265     val_loader = torch.utils.data.DataLoader(
266         datasets.ImageFolder(valdir, transforms.Compose([
267             transforms.Resize(256),
268             transforms.CenterCrop(224),
269             xvision.transforms.ToFloat(),   # converting to float avoids the division by 255 in ToTensor()
270             transforms.ToTensor(),
271             normalize,
272         ])),
273         batch_size=args.batch_size, shuffle=False,
274         num_workers=args.workers, pin_memory=True)
276     validate(val_loader, model, criterion, args)
278     if args.evaluate:
279         return
281     for epoch in range(args.start_epoch, args.epochs):
282         if args.distributed:
283             train_sampler.set_epoch(epoch)
284         adjust_learning_rate(optimizer, epoch, args)
286         # train for one epoch
287         train(train_loader, model, criterion, optimizer, epoch, args)
289         # evaluate on validation set
290         acc1 = validate(val_loader, model, criterion, args)
292         # remember best acc@1 and save checkpoint
293         is_best = acc1 > best_acc1
294         best_acc1 = max(acc1, best_acc1)
296         model_orig = model.module if isinstance(model, (torch.nn.parallel.DistributedDataParallel, torch.nn.parallel.DataParallel)) else model
297         model_orig = model_orig.module if args.quantize else model_orig
298         if not args.multiprocessing_distributed or (args.multiprocessing_distributed
299                 and args.rank % ngpus_per_node == 0):
300             out_basename = args.arch + ('_checkpoint_quantized.pth' if args.quantize else '_checkpoint.pth')
301             save_filename = os.path.join(args.save_path, out_basename)
302             checkpoint_dict = {
303                 'epoch': epoch + 1,
304                 'arch': args.arch,
305                 'state_dict': model_orig.state_dict(),
306                 'best_acc1': best_acc1,
307                 'optimizer' : optimizer.state_dict(),
308             }
309             save_checkpoint(checkpoint_dict, is_best, filename=save_filename)
310             save_onnxname = os.path.splitext(save_filename)[0]+'.onnx'
311             write_onnx_model(model, is_best, filename=save_onnxname)
314 def train(train_loader, model, criterion, optimizer, epoch, args):
315     batch_time = AverageMeter('Time', ':6.3f')
316     data_time = AverageMeter('Data', ':6.3f')
317     losses = AverageMeter('Loss', ':.4e')
318     top1 = AverageMeter('Acc@1', ':6.2f')
319     top5 = AverageMeter('Acc@5', ':6.2f')
320     progress = ProgressMeter(
321         len(train_loader),
322         [batch_time, data_time, losses, top1, top5],
323         prefix="Epoch: [{}]".format(epoch))
325     # switch to train mode
326     model.train()
328     end = time.time()
329     for i, (images, target) in enumerate(train_loader):
330         # break the epoch at at the iteration epoch_size
331         if args.epoch_size != 0 and i >= args.epoch_size:
332             break
333         # measure data loading time
334         data_time.update(time.time() - end)
336         images = images.cuda(args.gpu, non_blocking=True)
337         target = target.cuda(args.gpu, non_blocking=True)
339         # compute output
340         output = model(images)
341         loss = criterion(output, target)
343         # measure accuracy and record loss
344         acc1, acc5 = accuracy(output, target, topk=(1, 5))
345         losses.update(loss.item(), images.size(0))
346         top1.update(acc1[0], images.size(0))
347         top5.update(acc5[0], images.size(0))
349         # compute gradient and do SGD step
350         optimizer.zero_grad()
351         loss.backward()
352         optimizer.step()
354         # measure elapsed time
355         batch_time.update(time.time() - end)
356         end = time.time()
358         if i % args.print_freq == 0:
359             progress.display(i, args.cur_lr)
362 def validate(val_loader, model, criterion, args):
363     batch_time = AverageMeter('Time', ':6.3f')
364     losses = AverageMeter('Loss', ':.4e')
365     top1 = AverageMeter('Acc@1', ':6.2f')
366     top5 = AverageMeter('Acc@5', ':6.2f')
367     progress = ProgressMeter(
368         len(val_loader),
369         [batch_time, losses, top1, top5],
370         prefix='Test: ')
372     # switch to evaluate mode
373     model.eval()
375     with torch.no_grad():
376         end = time.time()
377         for i, (images, target) in enumerate(val_loader):
378             # break the epoch at at the iteration epoch_size_val
379             if args.epoch_size_val != 0 and i >= args.epoch_size_val:
380                 break
381             images = images.cuda(args.gpu, non_blocking=True)
382             target = target.cuda(args.gpu, non_blocking=True)
384             # compute output
385             output = model(images)
386             loss = criterion(output, target)
388             # measure accuracy and record loss
389             acc1, acc5 = accuracy(output, target, topk=(1, 5))
390             losses.update(loss.item(), images.size(0))
391             top1.update(acc1[0], images.size(0))
392             top5.update(acc5[0], images.size(0))
394             # measure elapsed time
395             batch_time.update(time.time() - end)
396             end = time.time()
398             if i % args.print_freq == 0:
399                 progress.display(i, args.cur_lr)
401         # TODO: this should also be done with the ProgressMeter
402         print(' * Acc@1 {top1.avg:.3f} Acc@5 {top5.avg:.3f}'
403               .format(top1=top1, top5=top5))
405     return top1.avg
408 def save_checkpoint(state, is_best, filename='checkpoint.pth'):
409     dirname = os.path.dirname(filename)
410     xnn.utils.makedir_exist_ok(dirname)
411     torch.save(state, filename)
412     if is_best:
413         shutil.copyfile(filename, os.path.splitext(filename)[0]+'_best.pth')
416 def create_rand_inputs(is_cuda):
417     dummy_input = torch.rand((1, 3, 224, 224))
418     dummy_input = dummy_input.cuda() if is_cuda else dummy_input
419     return dummy_input
422 def write_onnx_model(model, is_best, filename='checkpoint.onnx'):
423     model.eval()
424     is_cuda = next(model.parameters()).is_cuda
425     dummy_input = create_rand_inputs(is_cuda)
426     torch.onnx.export(model, dummy_input, filename, export_params=True, verbose=False)
427     if is_best:
428         shutil.copyfile(filename, os.path.splitext(filename)[0]+'_best.onnx')
431 class AverageMeter(object):
432     """Computes and stores the average and current value"""
433     def __init__(self, name, fmt=':f'):
434         self.name = name
435         self.fmt = fmt
436         self.reset()
438     def reset(self):
439         self.val = 0
440         self.avg = 0
441         self.sum = 0
442         self.count = 0
444     def update(self, val, n=1):
445         self.val = val
446         self.sum += val * n
447         self.count += n
448         self.avg = self.sum / self.count
450     def __str__(self):
451         fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})'
452         return fmtstr.format(**self.__dict__)
455 class ProgressMeter(object):
456     def __init__(self, num_batches, meters, prefix=""):
457         self.batch_fmtstr = self._get_batch_fmtstr(num_batches)
458         self.lr_fmtstr = self._get_lr_fmtstr()
459         self.meters = meters
460         self.prefix = prefix
462     def display(self, batch, cur_lr):
463         entries = [self.prefix + self.batch_fmtstr.format(batch), self.lr_fmtstr.format(cur_lr)]
464         entries += [str(meter) for meter in self.meters]
465         print('\t'.join(entries))
467     def _get_batch_fmtstr(self, num_batches):
468         num_digits = len(str(num_batches // 1))
469         fmt = '{:' + str(num_digits) + 'd}'
470         return '[' + fmt + '/' + fmt.format(num_batches) + ']'
472     def _get_lr_fmtstr(self):
473         fmt = 'LR {:g}'
474         return fmt
476 def adjust_learning_rate(optimizer, epoch, args):
477     """Sets the learning rate to the initial LR decayed by 10 every 30 epochs"""
478     lr = args.lr * (0.1 ** (epoch // args.lr_step_size))
479     args.cur_lr = lr
480     for param_group in optimizer.param_groups:
481         param_group['lr'] = lr
484 def accuracy(output, target, topk=(1,)):
485     """Computes the accuracy over the k top predictions for the specified values of k"""
486     with torch.no_grad():
487         maxk = max(topk)
488         batch_size = target.size(0)
490         _, pred = output.topk(maxk, 1, True, True)
491         pred = pred.t()
492         correct = pred.eq(target.view(1, -1).expand_as(pred))
494         res = []
495         for k in topk:
496             correct_k = correct[:k].view(-1).float().sum(0, keepdim=True)
497             res.append(correct_k.mul_(100.0 / batch_size))
498         return res
501 if __name__ == '__main__':
502     main()