7d5e5488da26f9c9e6e43d429311ae56a1e38b88
[jacinto-ai/pytorch-jacinto-ai-devkit.git] / examples / quantization_example.py
1 # this code is modified from: https://github.com/pytorch/examples/blob/master/imagenet/main.py
2 # the changes required for quantizing the model are under the flag args.quantize
3 import argparse
4 import os
5 import random
6 import shutil
7 import time
8 import warnings
10 import torch
11 import torch.nn as nn
12 import torch.nn.parallel
13 import torch.backends.cudnn as cudnn
14 import torch.distributed as dist
15 import torch.optim
16 import torch.multiprocessing as mp
17 import torch.utils.data
18 import torch.utils.data.distributed
19 import torchvision.transforms as transforms
20 import torchvision.datasets as datasets
22 # some of the default torchvision models need some minor tweaks to be friendly for
23 # quantization aware training. so use models from pytorch_jacinto_ai.vision insead
24 #import torchvision.models as models
26 from pytorch_jacinto_ai import xnn
27 from pytorch_jacinto_ai import vision as xvision
28 from pytorch_jacinto_ai.vision import models as models
30 model_names = sorted(name for name in models.__dict__
31     if name.islower() and not name.startswith("__")
32     and callable(models.__dict__[name]))
34 parser = argparse.ArgumentParser(description='PyTorch ImageNet Training')
35 parser.add_argument('data', metavar='DIR',
36                     help='path to dataset')
37 parser.add_argument('-a', '--arch', metavar='ARCH', default='resnet18',
38                     choices=model_names,
39                     help='model architecture: ' +
40                         ' | '.join(model_names) +
41                         ' (default: resnet18)')
42 parser.add_argument('-j', '--workers', default=4, type=int, metavar='N',
43                     help='number of data loading workers (default: 4)')
44 parser.add_argument('--epochs', default=90, type=int, metavar='N',
45                     help='number of total epochs to run')
46 parser.add_argument('--epoch-size', default=0, type=int, metavar='N',
47                     help='number of iterations in one training epoch. 0 (default) means full training epoch')
48 parser.add_argument('--start-epoch', default=0, type=int, metavar='N',
49                     help='manual epoch number (useful on restarts)')
50 parser.add_argument('-b', '--batch-size', default=256, type=int,
51                     metavar='N',
52                     help='mini-batch size (default: 256), this is the total '
53                          'batch size of all GPUs on the current node when '
54                          'using Data Parallel or Distributed Data Parallel')
55 parser.add_argument('--lr', '--learning-rate', default=0.1, type=float,
56                     metavar='LR', help='initial learning rate', dest='lr')
57 parser.add_argument('--momentum', default=0.9, type=float, metavar='M',
58                     help='momentum')
59 parser.add_argument('--wd', '--weight-decay', default=1e-4, type=float,
60                     metavar='W', help='weight decay (default: 1e-4)',
61                     dest='weight_decay')
62 parser.add_argument('-p', '--print-freq', default=100, type=int,
63                     metavar='N', help='print frequency (default: 10)')
64 parser.add_argument('--resume', default='', type=str, metavar='PATH',
65                     help='path to latest checkpoint (default: none)')
66 parser.add_argument('-e', '--evaluate', dest='evaluate', action='store_true',
67                     help='evaluate model on validation set')
68 parser.add_argument('--pretrained', type=str, default=None,
69                     help='use pre-trained model')
70 parser.add_argument('--world-size', default=-1, type=int,
71                     help='number of nodes for distributed training')
72 parser.add_argument('--rank', default=-1, type=int,
73                     help='node rank for distributed training')
74 parser.add_argument('--dist-url', default='tcp://224.66.41.62:23456', type=str,
75                     help='url used to set up distributed training')
76 parser.add_argument('--dist-backend', default='nccl', type=str,
77                     help='distributed backend')
78 parser.add_argument('--seed', default=None, type=int,
79                     help='seed for initializing training. ')
80 parser.add_argument('--gpu', default=None, type=int,
81                     help='GPU id to use.')
82 parser.add_argument('--multiprocessing-distributed', action='store_true',
83                     help='Use multi-processing distributed training to launch '
84                          'N processes per node, which has N GPUs. This is the '
85                          'fastest way to use PyTorch for either single node or '
86                          'multi node data parallel training')
87 parser.add_argument('--quantize', action='store_true',
88                     help='Enable Quantization')
89 best_acc1 = 0
92 def main():
93     args = parser.parse_args()
95     if args.seed is not None:
96         random.seed(args.seed)
97         torch.manual_seed(args.seed)
98         cudnn.deterministic = True
99         warnings.warn('You have chosen to seed training. '
100                       'This will turn on the CUDNN deterministic setting, '
101                       'which can slow down your training considerably! '
102                       'You may see unexpected behavior when restarting '
103                       'from checkpoints.')
105     if args.gpu is not None:
106         warnings.warn('You have chosen a specific GPU. This will completely '
107                       'disable data parallelism.')
109     if args.dist_url == "env://" and args.world_size == -1:
110         args.world_size = int(os.environ["WORLD_SIZE"])
112     args.distributed = args.world_size > 1 or args.multiprocessing_distributed
114     ngpus_per_node = torch.cuda.device_count()
115     if args.multiprocessing_distributed:
116         # Since we have ngpus_per_node processes per node, the total world_size
117         # needs to be adjusted accordingly
118         args.world_size = ngpus_per_node * args.world_size
119         # Use torch.multiprocessing.spawn to launch distributed processes: the
120         # main_worker process function
121         mp.spawn(main_worker, nprocs=ngpus_per_node, args=(ngpus_per_node, args))
122     else:
123         # Simply call main_worker function
124         main_worker(args.gpu, ngpus_per_node, args)
127 def main_worker(gpu, ngpus_per_node, args):
128     global best_acc1
129     args.gpu = gpu
131     if args.gpu is not None:
132         print("Use GPU: {} for training".format(args.gpu))
134     if args.distributed:
135         if args.dist_url == "env://" and args.rank == -1:
136             args.rank = int(os.environ["RANK"])
137         if args.multiprocessing_distributed:
138             # For multiprocessing distributed training, rank needs to be the
139             # global rank among all the processes
140             args.rank = args.rank * ngpus_per_node + gpu
141         dist.init_process_group(backend=args.dist_backend, init_method=args.dist_url,
142                                 world_size=args.world_size, rank=args.rank)
143     # create model
144     print("=> creating model '{}'".format(args.arch))
145     model = models.__dict__[args.arch]()
147     if args.quantize:
148         # DistributedDataParallel / DataParallel are not supported with quantization
149         dummy_input = torch.rand((1, 3, 224, 224))
150         if args.evaluate:
151             # for validation accuracy check with quantization - can be used to estimate approximate accuracy achieved with quantization
152             model = xnn.quantize.QuantTestModule(model, dummy_input=dummy_input).cuda(args.gpu)
153         else:
154             # for quantization aware training
155             model = xnn.quantize.QuantTrainModule(model, dummy_input=dummy_input).cuda(args.gpu)
156         #
157     else:
158         if args.distributed:
159             # For multiprocessing distributed, DistributedDataParallel constructor
160             # should always set the single device scope, otherwise,
161             # DistributedDataParallel will use all available devices.
162             if args.gpu is not None:
163                 torch.cuda.set_device(args.gpu)
164                 model.cuda(args.gpu)
165                 # When using a single GPU per process and per
166                 # DistributedDataParallel, we need to divide the batch size
167                 # ourselves based on the total number of GPUs we have
168                 args.batch_size = int(args.batch_size / ngpus_per_node)
169                 args.workers = int((args.workers + ngpus_per_node - 1) / ngpus_per_node)
170                 model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu])
171             else:
172                 model.cuda()
173                 # DistributedDataParallel will divide and allocate batch_size to all
174                 # available GPUs if device_ids are not set
175                 model = torch.nn.parallel.DistributedDataParallel(model)
176         elif args.gpu is not None:
177             torch.cuda.set_device(args.gpu)
178             model = model.cuda(args.gpu)
179         else:
180             # DataParallel will divide and allocate batch_size to all available GPUs
181             if args.arch.startswith('alexnet') or args.arch.startswith('vgg'):
182                 model.features = torch.nn.DataParallel(model.features)
183                 model.cuda()
184             else:
185                 model = torch.nn.DataParallel(model).cuda()
187     if args.pretrained is not None:
188         model_orig = model.module if isinstance(model, (torch.nn.parallel.DistributedDataParallel, torch.nn.parallel.DataParallel)) else model
189         model_orig = model_orig.module if args.quantize else model_orig
190         print("=> using pre-trained model for {} from {}".format(args.arch, args.pretrained))
191         if hasattr(model_orig, 'load_weights'):
192             model_orig.load_weights(args.pretrained, download_root='./data/downloads')
193         else:
194             xnn.utils.load_weights(model_orig, args.pretrained, download_root='./data/downloads')
195         #
197     # define loss function (criterion) and optimizer
198     criterion = nn.CrossEntropyLoss().cuda(args.gpu)
200     optimizer = torch.optim.SGD(model.parameters(), args.lr,
201                                 momentum=args.momentum,
202                                 weight_decay=args.weight_decay)
204     # optionally resume from a checkpoint
205     if args.resume:
206         if os.path.isfile(args.resume):
207             print("=> loading checkpoint '{}'".format(args.resume))
208             if args.gpu is None:
209                 checkpoint = torch.load(args.resume)
210             else:
211                 # Map model to be loaded to specified single gpu.
212                 loc = 'cuda:{}'.format(args.gpu)
213                 checkpoint = torch.load(args.resume, map_location=loc)
214             args.start_epoch = checkpoint['epoch']
215             best_acc1 = checkpoint['best_acc1']
216             if args.gpu is not None:
217                 # best_acc1 may be from a checkpoint from a different GPU
218                 best_acc1 = best_acc1.to(args.gpu)
219             model.load_state_dict(checkpoint['state_dict'])
220             optimizer.load_state_dict(checkpoint['optimizer'])
221             print("=> loaded checkpoint '{}' (epoch {})"
222                   .format(args.resume, checkpoint['epoch']))
223         else:
224             print("=> no checkpoint found at '{}'".format(args.resume))
226     cudnn.benchmark = True
228     # Data loading code
229     traindir = os.path.join(args.data, 'train')
230     valdir = os.path.join(args.data, 'val')
231     normalize = xvision.transforms.NormalizeMeanScale(mean=[123.675, 116.28, 103.53], scale=[0.017125, 0.017507, 0.017429])
233     train_dataset = datasets.ImageFolder(
234         traindir,
235         transforms.Compose([
236             transforms.RandomResizedCrop(224),
237             transforms.RandomHorizontalFlip(),
238             xvision.transforms.ToFloat(),
239             transforms.ToTensor(),
240             normalize,
241         ]))
243     if args.distributed:
244         train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset)
245     else:
246         train_sampler = None
248     train_loader = torch.utils.data.DataLoader(
249         train_dataset, batch_size=args.batch_size, shuffle=(train_sampler is None),
250         num_workers=args.workers, pin_memory=True, sampler=train_sampler)
252     val_loader = torch.utils.data.DataLoader(
253         datasets.ImageFolder(valdir, transforms.Compose([
254             transforms.Resize(256),
255             transforms.CenterCrop(224),
256             xvision.transforms.ToFloat(),
257             transforms.ToTensor(),
258             normalize,
259         ])),
260         batch_size=args.batch_size, shuffle=False,
261         num_workers=args.workers, pin_memory=True)
263     validate(val_loader, model, criterion, args)
265     if args.evaluate:
266         return
268     for epoch in range(args.start_epoch, args.epochs):
269         if args.distributed:
270             train_sampler.set_epoch(epoch)
271         adjust_learning_rate(optimizer, epoch, args)
273         # train for one epoch
274         train(train_loader, model, criterion, optimizer, epoch, args)
276         # evaluate on validation set
277         acc1 = validate(val_loader, model, criterion, args)
279         # remember best acc@1 and save checkpoint
280         is_best = acc1 > best_acc1
281         best_acc1 = max(acc1, best_acc1)
283         model_orig = model.module if isinstance(model, (torch.nn.parallel.DistributedDataParallel, torch.nn.parallel.DataParallel)) else model
284         model_orig = model_orig.module if args.quantize else model_orig
285         if not args.multiprocessing_distributed or (args.multiprocessing_distributed
286                 and args.rank % ngpus_per_node == 0):
287             out_basename = args.arch
288             out_basename += ('_quantized_checkpoint.pth.tar' if args.quantize else '_checkpoint.pth.tar')
289             save_filename = os.path.join('./data/checkpoints/quantization', out_basename)
290             save_checkpoint({
291                 'epoch': epoch + 1,
292                 'arch': args.arch,
293                 'state_dict': model_orig.state_dict(),
294                 'best_acc1': best_acc1,
295                 'optimizer' : optimizer.state_dict(),
296             }, is_best, filename=save_filename)
299 def train(train_loader, model, criterion, optimizer, epoch, args):
300     batch_time = AverageMeter('Time', ':6.3f')
301     data_time = AverageMeter('Data', ':6.3f')
302     losses = AverageMeter('Loss', ':.4e')
303     top1 = AverageMeter('Acc@1', ':6.2f')
304     top5 = AverageMeter('Acc@5', ':6.2f')
305     progress = ProgressMeter(
306         len(train_loader),
307         [batch_time, data_time, losses, top1, top5],
308         prefix="Epoch: [{}]".format(epoch))
310     # switch to train mode
311     model.train()
313     end = time.time()
314     for i, (images, target) in enumerate(train_loader):
315         # break the epoch at at the iteration epoch_size
316         if args.epoch_size != 0 and i >= args.epoch_size:
317             break
318         # measure data loading time
319         data_time.update(time.time() - end)
321         images = images.cuda(args.gpu, non_blocking=True)
322         target = target.cuda(args.gpu, non_blocking=True)
324         # compute output
325         output = model(images)
326         loss = criterion(output, target)
328         # measure accuracy and record loss
329         acc1, acc5 = accuracy(output, target, topk=(1, 5))
330         losses.update(loss.item(), images.size(0))
331         top1.update(acc1[0], images.size(0))
332         top5.update(acc5[0], images.size(0))
334         # compute gradient and do SGD step
335         optimizer.zero_grad()
336         loss.backward()
337         optimizer.step()
339         # measure elapsed time
340         batch_time.update(time.time() - end)
341         end = time.time()
343         if i % args.print_freq == 0:
344             progress.display(i)
347 def validate(val_loader, model, criterion, args):
348     batch_time = AverageMeter('Time', ':6.3f')
349     losses = AverageMeter('Loss', ':.4e')
350     top1 = AverageMeter('Acc@1', ':6.2f')
351     top5 = AverageMeter('Acc@5', ':6.2f')
352     progress = ProgressMeter(
353         len(val_loader),
354         [batch_time, losses, top1, top5],
355         prefix='Test: ')
357     # switch to evaluate mode
358     model.eval()
360     with torch.no_grad():
361         end = time.time()
362         for i, (images, target) in enumerate(val_loader):
363             images = images.cuda(args.gpu, non_blocking=True)
364             target = target.cuda(args.gpu, non_blocking=True)
366             # compute output
367             output = model(images)
368             loss = criterion(output, target)
370             # measure accuracy and record loss
371             acc1, acc5 = accuracy(output, target, topk=(1, 5))
372             losses.update(loss.item(), images.size(0))
373             top1.update(acc1[0], images.size(0))
374             top5.update(acc5[0], images.size(0))
376             # measure elapsed time
377             batch_time.update(time.time() - end)
378             end = time.time()
380             if i % args.print_freq == 0:
381                 progress.display(i)
383         # TODO: this should also be done with the ProgressMeter
384         print(' * Acc@1 {top1.avg:.3f} Acc@5 {top5.avg:.3f}'
385               .format(top1=top1, top5=top5))
387     return top1.avg
390 def save_checkpoint(state, is_best, filename='checkpoint.pth.tar'):
391     dirname = os.path.dirname(filename)
392     xnn.utils.makedir_exist_ok(dirname)
393     torch.save(state, filename)
394     if is_best:
395         shutil.copyfile(filename, os.path.splitext(filename)[0]+'_best.pth.tar')
398 class AverageMeter(object):
399     """Computes and stores the average and current value"""
400     def __init__(self, name, fmt=':f'):
401         self.name = name
402         self.fmt = fmt
403         self.reset()
405     def reset(self):
406         self.val = 0
407         self.avg = 0
408         self.sum = 0
409         self.count = 0
411     def update(self, val, n=1):
412         self.val = val
413         self.sum += val * n
414         self.count += n
415         self.avg = self.sum / self.count
417     def __str__(self):
418         fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})'
419         return fmtstr.format(**self.__dict__)
422 class ProgressMeter(object):
423     def __init__(self, num_batches, meters, prefix=""):
424         self.batch_fmtstr = self._get_batch_fmtstr(num_batches)
425         self.meters = meters
426         self.prefix = prefix
428     def display(self, batch):
429         entries = [self.prefix + self.batch_fmtstr.format(batch)]
430         entries += [str(meter) for meter in self.meters]
431         print('\t'.join(entries))
433     def _get_batch_fmtstr(self, num_batches):
434         num_digits = len(str(num_batches // 1))
435         fmt = '{:' + str(num_digits) + 'd}'
436         return '[' + fmt + '/' + fmt.format(num_batches) + ']'
439 def adjust_learning_rate(optimizer, epoch, args):
440     """Sets the learning rate to the initial LR decayed by 10 every 30 epochs"""
441     lr = args.lr * (0.1 ** (epoch // 30))
442     for param_group in optimizer.param_groups:
443         param_group['lr'] = lr
446 def accuracy(output, target, topk=(1,)):
447     """Computes the accuracy over the k top predictions for the specified values of k"""
448     with torch.no_grad():
449         maxk = max(topk)
450         batch_size = target.size(0)
452         _, pred = output.topk(maxk, 1, True, True)
453         pred = pred.t()
454         correct = pred.eq(target.view(1, -1).expand_as(pred))
456         res = []
457         for k in topk:
458             correct_k = correct[:k].view(-1).float().sum(0, keepdim=True)
459             res.append(correct_k.mul_(100.0 / batch_size))
460         return res
463 if __name__ == '__main__':
464     main()