8c7e22667ccd1a2ed1b305bcf47179fa16964137
1 # ----------------------------------
2 # Quantization Aware Training (QAT) Example
3 # Texas Instruments (C) 2018-2020
4 # All Rights Reserved
5 # ----------------------------------
6 # this original code is from: https://github.com/pytorch/examples/blob/master/imagenet/main.py
7 # the changes required for quantizing the model are under the flag args.quantize
8 import argparse
9 import os
10 import random
11 import shutil
12 import time
13 import warnings
15 import torch
16 import torch.nn as nn
17 import torch.nn.parallel
18 import torch.backends.cudnn as cudnn
19 import torch.distributed as dist
20 import torch.optim
21 import torch.multiprocessing as mp
22 import torch.utils.data
23 import torch.utils.data.distributed
24 import torchvision.transforms as transforms
25 import torchvision.datasets as datasets
27 # some of the default torchvision models need some minor tweaks to be friendly for
28 # quantization aware training. so use models from pytorch_jacinto_ai.vision insead
29 #import torchvision.models as models
31 from pytorch_jacinto_ai import xnn
32 from pytorch_jacinto_ai import vision as xvision
33 from pytorch_jacinto_ai.vision import models as models
35 model_names = sorted(name for name in models.__dict__
36 if name.islower() and not name.startswith("__")
37 and callable(models.__dict__[name]))
39 parser = argparse.ArgumentParser(description='PyTorch ImageNet Training')
40 parser.add_argument('data', metavar='DIR',
41 help='path to dataset')
42 parser.add_argument('-a', '--arch', metavar='ARCH', default='resnet18',
43 choices=model_names,
44 help='model architecture: ' +
45 ' | '.join(model_names) +
46 ' (default: resnet18)')
47 parser.add_argument('-j', '--workers', default=4, type=int, metavar='N',
48 help='number of data loading workers (default: 4)')
49 parser.add_argument('--epochs', default=90, type=int, metavar='N',
50 help='number of total epochs to run')
51 parser.add_argument('--epoch_size', default=0, type=float, metavar='N',
52 help='fraction of training epoch to use. 0 (default) means full training epoch')
53 parser.add_argument('--epoch_size_val', default=0, type=float, metavar='N',
54 help='fraction of validation epoch to use. 0 (default) means full validation epoch')
55 parser.add_argument('--start_epoch', default=0, type=int, metavar='N',
56 help='manual epoch number (useful on restarts)')
57 parser.add_argument('-b', '--batch_size', default=256, type=int,
58 metavar='N',
59 help='mini-batch size (default: 256), this is the total '
60 'batch size of all GPUs on the current node when '
61 'using Data Parallel or Distributed Data Parallel')
62 parser.add_argument('--lr', '--learning_rate', default=0.1, type=float,
63 metavar='LR', help='initial learning rate', dest='lr')
64 parser.add_argument('--lr_step_size', default=30, type=int,
65 metavar='N', help='number of steps before learning rate is reduced')
66 parser.add_argument('--momentum', default=0.9, type=float, metavar='M',
67 help='momentum')
68 parser.add_argument('--wd', '--weight_decay', default=1e-4, type=float,
69 metavar='W', help='weight decay (default: 1e-4)',
70 dest='weight_decay')
71 parser.add_argument('-p', '--print_freq', default=100, type=int,
72 metavar='N', help='print frequency (default: 10)')
73 parser.add_argument('--resume', default='', type=str, metavar='PATH',
74 help='path to latest checkpoint (default: none)')
75 parser.add_argument('-e', '--evaluate', dest='evaluate', action='store_true',
76 help='evaluate model on validation set')
77 parser.add_argument('--pretrained', type=str, default=None,
78 help='use pre-trained model')
79 parser.add_argument('--world_size', default=-1, type=int,
80 help='number of nodes for distributed training')
81 parser.add_argument('--rank', default=-1, type=int,
82 help='node rank for distributed training')
83 parser.add_argument('--dist_url', default='tcp://224.66.41.62:23456', type=str,
84 help='url used to set up distributed training')
85 parser.add_argument('--dist_backend', default='nccl', type=str,
86 help='distributed backend')
87 parser.add_argument('--seed', default=None, type=int,
88 help='seed for initializing training. ')
89 parser.add_argument('--gpu', default=None, type=int,
90 help='GPU id to use.')
91 parser.add_argument('--multiprocessing_distributed', action='store_true',
92 help='Use multi-processing distributed training to launch '
93 'N processes per node, which has N GPUs. This is the '
94 'fastest way to use PyTorch for either single node or '
95 'multi node data parallel training')
96 parser.add_argument('--save_path', type=str, default='./data/checkpoints/quantization',
97 help='path to save the logs and models')
98 parser.add_argument('--quantize', action='store_true',
99 help='Enable Quantization')
100 parser.add_argument('--opset_version', default=9, type=int,
101 help='opset version for onnx export')
103 best_acc1 = 0
106 def main():
107 args = parser.parse_args()
109 args.cur_lr = args.lr
111 if args.seed is not None:
112 random.seed(args.seed)
113 torch.manual_seed(args.seed)
114 cudnn.deterministic = True
115 warnings.warn('You have chosen to seed training. '
116 'This will turn on the CUDNN deterministic setting, '
117 'which can slow down your training considerably! '
118 'You may see unexpected behavior when restarting '
119 'from checkpoints.')
121 if args.gpu is not None:
122 warnings.warn('You have chosen a specific GPU. This will completely '
123 'disable data parallelism.')
125 if args.dist_url == "env://" and args.world_size == -1:
126 args.world_size = int(os.environ["WORLD_SIZE"])
128 args.distributed = args.world_size > 1 or args.multiprocessing_distributed
130 ngpus_per_node = torch.cuda.device_count()
131 if args.multiprocessing_distributed:
132 # Since we have ngpus_per_node processes per node, the total world_size
133 # needs to be adjusted accordingly
134 args.world_size = ngpus_per_node * args.world_size
135 # Use torch.multiprocessing.spawn to launch distributed processes: the
136 # main_worker process function
137 mp.spawn(main_worker, nprocs=ngpus_per_node, args=(ngpus_per_node, args))
138 else:
139 # Simply call main_worker function
140 main_worker(args.gpu, ngpus_per_node, args)
143 def main_worker(gpu, ngpus_per_node, args):
144 global best_acc1
145 args.gpu = gpu
147 if args.gpu is not None:
148 print("Use GPU: {} for training".format(args.gpu))
150 if args.distributed:
151 if args.dist_url == "env://" and args.rank == -1:
152 args.rank = int(os.environ["RANK"])
153 if args.multiprocessing_distributed:
154 # For multiprocessing distributed training, rank needs to be the
155 # global rank among all the processes
156 args.rank = args.rank * ngpus_per_node + gpu
157 dist.init_process_group(backend=args.dist_backend, init_method=args.dist_url,
158 world_size=args.world_size, rank=args.rank)
159 # create model
160 print("=> creating model '{}'".format(args.arch))
161 model = models.__dict__[args.arch]()
163 if args.quantize:
164 # DistributedDataParallel / DataParallel are not supported with quantization
165 dummy_input = torch.rand((1, 3, 224, 224))
166 if args.evaluate:
167 # for validation accuracy check with quantization - can be used to estimate approximate accuracy achieved with quantization
168 model = xnn.quantize.QuantTestModule(model, dummy_input=dummy_input).cuda(args.gpu)
169 else:
170 # for quantization aware training
171 model = xnn.quantize.QuantTrainModule(model, dummy_input=dummy_input).cuda(args.gpu)
172 #
173 else:
174 if args.distributed:
175 # For multiprocessing distributed, DistributedDataParallel constructor
176 # should always set the single device scope, otherwise,
177 # DistributedDataParallel will use all available devices.
178 if args.gpu is not None:
179 torch.cuda.set_device(args.gpu)
180 model.cuda(args.gpu)
181 # When using a single GPU per process and per
182 # DistributedDataParallel, we need to divide the batch size
183 # ourselves based on the total number of GPUs we have
184 args.batch_size = int(args.batch_size / ngpus_per_node)
185 args.workers = int((args.workers + ngpus_per_node - 1) / ngpus_per_node)
186 model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu])
187 else:
188 model.cuda()
189 # DistributedDataParallel will divide and allocate batch_size to all
190 # available GPUs if device_ids are not set
191 model = torch.nn.parallel.DistributedDataParallel(model)
192 elif args.gpu is not None:
193 torch.cuda.set_device(args.gpu)
194 model = model.cuda(args.gpu)
195 else:
196 # DataParallel will divide and allocate batch_size to all available GPUs
197 if args.arch.startswith('alexnet') or args.arch.startswith('vgg'):
198 model.features = torch.nn.DataParallel(model.features)
199 model.cuda()
200 else:
201 model = torch.nn.DataParallel(model).cuda()
203 if args.pretrained is not None:
204 model_orig = model.module if isinstance(model, (torch.nn.parallel.DistributedDataParallel, torch.nn.parallel.DataParallel)) else model
205 model_orig = model_orig.module if args.quantize else model_orig
206 print("=> using pre-trained model for {} from {}".format(args.arch, args.pretrained))
207 if hasattr(model_orig, 'load_weights'):
208 model_orig.load_weights(args.pretrained, download_root='./data/downloads')
209 else:
210 xnn.utils.load_weights(model_orig, args.pretrained, download_root='./data/downloads')
211 #
213 # define loss function (criterion) and optimizer
214 criterion = nn.CrossEntropyLoss().cuda(args.gpu)
216 optimizer = torch.optim.SGD(model.parameters(), args.lr,
217 momentum=args.momentum,
218 weight_decay=args.weight_decay)
220 # optionally resume from a checkpoint
221 if args.resume:
222 if os.path.isfile(args.resume):
223 print("=> loading checkpoint '{}'".format(args.resume))
224 if args.gpu is None:
225 checkpoint = torch.load(args.resume)
226 else:
227 # Map model to be loaded to specified single gpu.
228 loc = 'cuda:{}'.format(args.gpu)
229 checkpoint = torch.load(args.resume, map_location=loc)
230 args.start_epoch = checkpoint['epoch']
231 best_acc1 = checkpoint['best_acc1']
232 if args.gpu is not None:
233 # best_acc1 may be from a checkpoint from a different GPU
234 best_acc1 = best_acc1.to(args.gpu)
235 model.load_state_dict(checkpoint['state_dict'])
236 optimizer.load_state_dict(checkpoint['optimizer'])
237 print("=> loaded checkpoint '{}' (epoch {})"
238 .format(args.resume, checkpoint['epoch']))
239 else:
240 print("=> no checkpoint found at '{}'".format(args.resume))
242 cudnn.benchmark = True
244 # Data loading code
245 traindir = os.path.join(args.data, 'train')
246 valdir = os.path.join(args.data, 'val')
247 normalize = xvision.transforms.NormalizeMeanScale(mean=[123.675, 116.28, 103.53], scale=[0.017125, 0.017507, 0.017429])
249 train_dataset = datasets.ImageFolder(
250 traindir,
251 transforms.Compose([
252 transforms.RandomResizedCrop(224),
253 transforms.RandomHorizontalFlip(),
254 xvision.transforms.ToFloat(), # converting to float avoids the division by 255 in ToTensor()
255 transforms.ToTensor(),
256 normalize,
257 ]))
259 val_dataset = datasets.ImageFolder(valdir, transforms.Compose([
260 transforms.Resize(256),
261 transforms.CenterCrop(224),
262 xvision.transforms.ToFloat(), # converting to float avoids the division by 255 in ToTensor()
263 transforms.ToTensor(),
264 normalize,
265 ]))
267 if args.distributed:
268 train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset)
269 val_sampler = torch.utils.data.distributed.DistributedSampler(val_dataset)
270 else:
271 train_sampler = get_dataset_sampler(train_dataset, args.epoch_size) if args.epoch_size != 0 else None
272 val_sampler = get_dataset_sampler(val_dataset, args.epoch_size_val) if args.epoch_size_val != 0 else None
274 train_loader = torch.utils.data.DataLoader(
275 train_dataset, batch_size=args.batch_size, shuffle=(train_sampler is None),
276 num_workers=args.workers, pin_memory=True, sampler=train_sampler)
278 val_loader = torch.utils.data.DataLoader(
279 val_dataset, batch_size=args.batch_size, shuffle=(val_sampler is None),
280 num_workers=args.workers, pin_memory=True, sampler=val_sampler)
282 validate(val_loader, model, criterion, args)
284 if args.evaluate:
285 return
287 for epoch in range(args.start_epoch, args.epochs):
288 if args.distributed:
289 train_sampler.set_epoch(epoch)
290 adjust_learning_rate(optimizer, epoch, args)
292 # train for one epoch
293 train(train_loader, model, criterion, optimizer, epoch, args)
295 # evaluate on validation set
296 acc1 = validate(val_loader, model, criterion, args)
298 # remember best acc@1 and save checkpoint
299 is_best = acc1 > best_acc1
300 best_acc1 = max(acc1, best_acc1)
302 model_orig = model.module if isinstance(model, (torch.nn.parallel.DistributedDataParallel, torch.nn.parallel.DataParallel)) else model
303 model_orig = model_orig.module if args.quantize else model_orig
304 if not args.multiprocessing_distributed or (args.multiprocessing_distributed
305 and args.rank % ngpus_per_node == 0):
306 out_basename = args.arch + ('_checkpoint_quantized.pth' if args.quantize else '_checkpoint.pth')
307 save_filename = os.path.join(args.save_path, out_basename)
308 checkpoint_dict = {
309 'epoch': epoch + 1,
310 'arch': args.arch,
311 'state_dict': model_orig.state_dict(),
312 'best_acc1': best_acc1,
313 'optimizer' : optimizer.state_dict(),
314 }
315 save_checkpoint(checkpoint_dict, is_best, filename=save_filename)
316 save_onnxname = os.path.splitext(save_filename)[0]+'.onnx'
317 write_onnx_model(args, model, is_best, filename=save_onnxname)
320 def train(train_loader, model, criterion, optimizer, epoch, args):
321 batch_time = AverageMeter('Time', ':6.3f')
322 data_time = AverageMeter('Data', ':6.3f')
323 losses = AverageMeter('Loss', ':.4e')
324 top1 = AverageMeter('Acc@1', ':6.2f')
325 top5 = AverageMeter('Acc@5', ':6.2f')
326 progress = ProgressMeter(
327 len(train_loader),
328 [batch_time, data_time, losses, top1, top5],
329 prefix="Epoch: [{}]".format(epoch))
331 # switch to train mode
332 model.train()
334 end = time.time()
335 for i, (images, target) in enumerate(train_loader):
336 # measure data loading time
337 data_time.update(time.time() - end)
339 images = images.cuda(args.gpu, non_blocking=True)
340 target = target.cuda(args.gpu, non_blocking=True)
342 # compute output
343 output = model(images)
344 loss = criterion(output, target)
346 # measure accuracy and record loss
347 acc1, acc5 = accuracy(output, target, topk=(1, 5))
348 losses.update(loss.item(), images.size(0))
349 top1.update(acc1[0], images.size(0))
350 top5.update(acc5[0], images.size(0))
352 # compute gradient and do SGD step
353 optimizer.zero_grad()
354 loss.backward()
355 optimizer.step()
357 # measure elapsed time
358 batch_time.update(time.time() - end)
359 end = time.time()
361 if i % args.print_freq == 0:
362 progress.display(i, args.cur_lr)
365 def validate(val_loader, model, criterion, args):
366 batch_time = AverageMeter('Time', ':6.3f')
367 losses = AverageMeter('Loss', ':.4e')
368 top1 = AverageMeter('Acc@1', ':6.2f')
369 top5 = AverageMeter('Acc@5', ':6.2f')
370 progress = ProgressMeter(
371 len(val_loader),
372 [batch_time, losses, top1, top5],
373 prefix='Test: ')
375 # switch to evaluate mode
376 model.eval()
378 with torch.no_grad():
379 end = time.time()
380 for i, (images, target) in enumerate(val_loader):
381 images = images.cuda(args.gpu, non_blocking=True)
382 target = target.cuda(args.gpu, non_blocking=True)
384 # compute output
385 output = model(images)
386 loss = criterion(output, target)
388 # measure accuracy and record loss
389 acc1, acc5 = accuracy(output, target, topk=(1, 5))
390 losses.update(loss.item(), images.size(0))
391 top1.update(acc1[0], images.size(0))
392 top5.update(acc5[0], images.size(0))
394 # measure elapsed time
395 batch_time.update(time.time() - end)
396 end = time.time()
398 if i % args.print_freq == 0:
399 progress.display(i, args.cur_lr)
401 # TODO: this should also be done with the ProgressMeter
402 print(' * Acc@1 {top1.avg:.3f} Acc@5 {top5.avg:.3f}'
403 .format(top1=top1, top5=top5))
405 return top1.avg
408 def save_checkpoint(state, is_best, filename='checkpoint.pth'):
409 dirname = os.path.dirname(filename)
410 xnn.utils.makedir_exist_ok(dirname)
411 torch.save(state, filename)
412 if is_best:
413 shutil.copyfile(filename, os.path.splitext(filename)[0]+'_best.pth')
416 def create_rand_inputs(is_cuda):
417 dummy_input = torch.rand((1, 3, 224, 224))
418 dummy_input = dummy_input.cuda() if is_cuda else dummy_input
419 return dummy_input
422 def write_onnx_model(args, model, is_best, filename='checkpoint.onnx'):
423 model.eval()
424 is_cuda = next(model.parameters()).is_cuda
425 dummy_input = create_rand_inputs(is_cuda)
426 torch.onnx.export(model, dummy_input, filename, export_params=True, verbose=False,
427 do_constant_folding=True, opset_version=args.opset_version)
428 if is_best:
429 shutil.copyfile(filename, os.path.splitext(filename)[0]+'_best.onnx')
432 class AverageMeter(object):
433 """Computes and stores the average and current value"""
434 def __init__(self, name, fmt=':f'):
435 self.name = name
436 self.fmt = fmt
437 self.reset()
439 def reset(self):
440 self.val = 0
441 self.avg = 0
442 self.sum = 0
443 self.count = 0
445 def update(self, val, n=1):
446 self.val = val
447 self.sum += val * n
448 self.count += n
449 self.avg = self.sum / self.count
451 def __str__(self):
452 fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})'
453 return fmtstr.format(**self.__dict__)
456 class ProgressMeter(object):
457 def __init__(self, num_batches, meters, prefix=""):
458 self.batch_fmtstr = self._get_batch_fmtstr(num_batches)
459 self.lr_fmtstr = self._get_lr_fmtstr()
460 self.meters = meters
461 self.prefix = prefix
463 def display(self, batch, cur_lr):
464 entries = [self.prefix + self.batch_fmtstr.format(batch), self.lr_fmtstr.format(cur_lr)]
465 entries += [str(meter) for meter in self.meters]
466 print('\t'.join(entries))
468 def _get_batch_fmtstr(self, num_batches):
469 num_digits = len(str(num_batches // 1))
470 fmt = '{:' + str(num_digits) + 'd}'
471 return '[' + fmt + '/' + fmt.format(num_batches) + ']'
473 def _get_lr_fmtstr(self):
474 fmt = 'LR {:g}'
475 return fmt
477 def adjust_learning_rate(optimizer, epoch, args):
478 """Sets the learning rate to the initial LR decayed by 10 every 30 epochs"""
479 lr = args.lr * (0.1 ** (epoch // args.lr_step_size))
480 args.cur_lr = lr
481 for param_group in optimizer.param_groups:
482 param_group['lr'] = lr
485 def accuracy(output, target, topk=(1,)):
486 """Computes the accuracy over the k top predictions for the specified values of k"""
487 with torch.no_grad():
488 maxk = max(topk)
489 batch_size = target.size(0)
491 _, pred = output.topk(maxk, 1, True, True)
492 pred = pred.t()
493 correct = pred.eq(target.view(1, -1).expand_as(pred))
495 res = []
496 for k in topk:
497 correct_k = correct[:k].view(-1).float().sum(0, keepdim=True)
498 res.append(correct_k.mul_(100.0 / batch_size))
499 return res
502 def get_dataset_sampler(dataset_object, epoch_size):
503 num_samples = len(dataset_object)
504 epoch_size = int(epoch_size * num_samples) if epoch_size < 1 else int(epoch_size)
505 print('=> creating a random sampler as epoch_size is specified')
506 dataset_sampler = torch.utils.data.sampler.RandomSampler(data_source=dataset_object, replacement=True, num_samples=epoch_size)
507 return dataset_sampler
510 if __name__ == '__main__':
511 main()