534c1110678945e487ba29f903e2a5dd4d4e454f
[jacinto-ai/pytorch-jacinto-ai-devkit.git] / modules / pytorch_jacinto_ai / engine / test_classification.py
1 import os
2 import sys
3 import shutil
4 import time
5 import datetime
7 import random
8 import numpy as np
9 from colorama import Fore
10 import random
11 import progiter
12 import warnings
14 import torch
15 import torch.nn.parallel
16 import torch.backends.cudnn as cudnn
17 import torch.optim
18 import torch.utils.data
19 import torch.utils.data.distributed
21 from .. import xnn
22 from .. import vision
25 # ################################################
26 def get_config():
27 args = xnn.utils.ConfigNode()
28 args.model_config = xnn.utils.ConfigNode()
29 args.dataset_config = xnn.utils.ConfigNode()
31 args.model_name = 'mobilenet_v2_classification' # model architecture'
32 args.dataset_name = 'imagenet_classification' # image folder classification
34 args.data_path = './data/datasets/ilsvrc' # path to dataset
35 args.save_path = None # checkpoints save path
36 args.pretrained = './data/modelzoo/pretrained/pytorch/imagenet_classification/ericsun99/MobileNet-V2-Pytorch/mobilenetv2_Top1_71.806_Top2_90.410.pth.tar' # path to pre_trained model
38 args.workers = 8 # number of data loading workers (default: 4)
39 args.batch_size = 256 # mini_batch size (default: 256)
40 args.print_freq = 100 # print frequency (default: 100)
42 args.img_resize = 256 # image resize
43 args.img_crop = 224 # image crop
45 args.image_mean = (123.675, 116.28, 103.53) # image mean for input image normalization')
46 args.image_scale = (0.017125, 0.017507, 0.017429) # image scaling/mult for input iamge normalization')
48 args.logger = None # logger stream to output into
50 args.data_augument = 'inception' # data augumentation method, choices=['inception','resize','adaptive_resize']
51 args.dataset_format = 'folder' # dataset format, choices=['folder','lmdb']
52 args.count_flops = True # count flops and report
54 args.lr_calib = 0.1 # lr for bias calibration
56 args.rand_seed = 1 # random seed
57 args.generate_onnx = False # apply quantized inference or not
58 args.print_model = False # print the model to text
59 args.run_soon = True # Set to false if only cfs files/onnx modelsneeded but no training
60 args.parallel_model = True # parallel or not
61 args.shuffle = True # shuffle or not
62 args.epoch_size = 0 # epoch size
63 args.rand_seed = 1 # random seed
64 args.date = None # date to add to save path. if this is None, current date will be added.
65 args.write_layer_ip_op = False
67 args.quantize = False # apply quantized inference or not
68 #args.model_surgery = None # replace activations with PAct2 activation module. Helpful in quantized training.
69 args.bitwidth_weights = 8 # bitwidth for weights
70 args.bitwidth_activations = 8 # bitwidth for activations
71 args.histogram_range = True # histogram range for calibration
72 args.per_channel_q = False # apply separate quantizion factor for each channel in depthwise or not
73 args.bias_calibration = False # apply bias correction during quantized inference calibration
74 return args
77 def main(args):
78 assert not hasattr(args, 'model_surgery'), 'the argument model_surgery is deprecated, it is not needed now - remove it'
80 if (args.phase == 'validation' and args.bias_calibration):
81 args.bias_calibration = False
82 warnings.warn('switching off bias calibration in validation')
83 #
85 #################################################
86 # onnx generation is filing for post quantized module
87 args.generate_onnx = False if (args.quantize) else args.generate_onnx
89 if args.save_path is None:
90 save_path = get_save_path(args)
91 else:
92 save_path = args.save_path
93 #
95 if not os.path.exists(save_path):
96 os.makedirs(save_path)
97 #
99 #################################################
100 if args.logger is None:
101 log_file = os.path.splitext(os.path.basename(__file__))[0] + '.log'
102 args.logger = xnn.utils.TeeLogger(filename=os.path.join(save_path,log_file))
104 #################################################
105 # global settings. rand seeds for repeatability
106 random.seed(args.rand_seed)
107 np.random.seed(args.rand_seed)
108 torch.manual_seed(args.rand_seed)
109 torch.backends.cudnn.deterministic = True
110 torch.backends.cudnn.benchmark = True
112 ################################
113 # print everything for log
114 # reset character color, in case it is different
115 print('{}'.format(Fore.RESET))
116 print("=> args: ", args)
117 print("=> resize resolution: {}".format(args.img_resize))
118 print("=> crop resolution : {}".format(args.img_crop))
119 sys.stdout.flush()
121 #################################################
122 pretrained_data = None
123 model_surgery_quantize = False
124 if args.pretrained and args.pretrained != "None":
125 if args.pretrained.startswith('http://') or args.pretrained.startswith('https://'):
126 pretrained_file = vision.datasets.utils.download_url(args.pretrained, './data/downloads')
127 else:
128 pretrained_file = args.pretrained
129 #
130 print(f'=> using pre-trained weights from: {args.pretrained}')
131 pretrained_data = torch.load(pretrained_file)
132 model_surgery_quantize = pretrained_data['quantize'] if 'quantize' in pretrained_data else False
133 #
135 ################################
136 # create model
137 print("=> creating model '{}'".format(args.model_name))
138 model = vision.models.classification.__dict__[args.model_name](args.model_config)
140 # check if we got the model as well as parameters to change the names in pretrained
141 model, change_names_dict = model if isinstance(model, (list,tuple)) else (model,None)
143 #################################################
144 if args.quantize:
145 # dummy input is used by quantized models to analyze graph
146 is_cuda = next(model.parameters()).is_cuda
147 dummy_input = create_rand_inputs(args, is_cuda=is_cuda)
148 #
149 if 'training' in args.phase:
150 model = xnn.quantize.QuantTrainModule(model, per_channel_q=args.per_channel_q,
151 bitwidth_weights=args.bitwidth_weights, bitwidth_activations=args.bitwidth_activations,
152 dummy_input=dummy_input)
153 elif 'calibration' in args.phase:
154 model = xnn.quantize.QuantCalibrateModule(model, per_channel_q=args.per_channel_q,
155 bitwidth_weights=args.bitwidth_weights, bitwidth_activations=args.bitwidth_activations,
156 bias_calibration=args.bias_calibration, lr_calib=args.lr_calib,
157 dummy_input=dummy_input)
158 elif 'validation' in args.phase:
159 # Note: bias_calibration is not enabled in test
160 model = xnn.quantize.QuantTestModule(model, per_channel_q=args.per_channel_q,
161 bitwidth_weights=args.bitwidth_weights, bitwidth_activations=args.bitwidth_activations,
162 histogram_range=args.histogram_range, model_surgery_quantize=model_surgery_quantize,
163 dummy_input=dummy_input)
164 else:
165 assert False, f'invalid phase {args.phase}'
166 #
168 # load pretrained
169 xnn.utils.load_weights(get_model_orig(model), pretrained=pretrained_data, change_names_dict=change_names_dict)
171 #################################################
172 if args.count_flops:
173 count_flops(args, model)
175 #################################################
176 if args.generate_onnx:
177 write_onnx_model(args, get_model_orig(model), save_path)
178 #
180 #################################################
181 if args.print_model:
182 print(model)
183 else:
184 args.logger.debug(str(model))
186 #################################################
187 if (not args.run_soon):
188 print("Training not needed for now")
189 exit()
191 #################################################
192 # multi gpu mode is not yet supported with quantization in evaluate
193 if args.parallel_model and ('training' in args.phase):
194 model = torch.nn.DataParallel(model)
196 #################################################
197 model = model.cuda()
199 #################################################
200 if args.write_layer_ip_op:
201 # for dumping module outputs
202 for name, module in model.named_modules():
203 module.name = name
204 print(name)
205 #if 'module.encoder.features.0.' in name:
206 module.register_forward_hook(write_tensor_hook_function)
207 print('{:7} {:33} {:12} {:8} {:6} {:30} : {:17} : {:4} : {:11} : {:7} : {:7}'.format("type", "name", "layer", "min", "max", "tensor_shape", "dtype", "scale", "dtype", "min", "max"))
210 #################################################
211 # define loss function (criterion) and optimizer
212 criterion = torch.nn.CrossEntropyLoss().cuda()
214 val_loader = get_data_loaders(args)
215 validate(args, val_loader, model, criterion)
219 def validate(args, val_loader, model, criterion):
220 # switch to evaluate mode
221 model.eval()
223 # change color to green
224 print('{}'.format(Fore.GREEN), end='')
226 with torch.no_grad():
227 batch_time = AverageMeter()
228 losses = AverageMeter()
229 top1 = AverageMeter()
230 top5 = AverageMeter()
231 use_progressbar = True
232 epoch_size = get_epoch_size(args, val_loader, args.epoch_size)
234 if use_progressbar:
235 progress_bar = progiter.ProgIter(np.arange(epoch_size), chunksize=1)
236 last_update_iter = -1
238 end = time.time()
239 for iteration, (input, target) in enumerate(val_loader):
240 target = target.cuda(non_blocking=True)
241 input = torch.cat([j.cuda() for j in input], dim=1) if (type(input) in (list,tuple)) else input.cuda()
243 # compute output
244 output = model(input)
245 if type(output) in (list, tuple):
246 output = output[0]
247 #
249 loss = criterion(output, target)
251 # measure accuracy and record loss
252 prec1, prec5 = accuracy(output, target, topk=(1, 5))
253 losses.update(loss.item(), input.size(0))
254 top1.update(prec1[0], input.size(0))
255 top5.update(prec5[0], input.size(0))
257 # measure elapsed time
258 batch_time.update(time.time() - end)
259 end = time.time()
260 final_iter = (iteration >= (epoch_size-1))
262 if ((iteration % args.print_freq) == 0) or final_iter:
263 status_str = 'Time {batch_time.val:.2f}({batch_time.avg:.2f}) LR {cur_lr:.4f} ' \
264 'Loss {loss.val:.2f}({loss.avg:.2f}) Prec@1 {top1.val:.2f}({top1.avg:.2f}) Prec@5 {top5.val:.2f}({top5.avg:.2f})' \
265 .format(batch_time=batch_time, cur_lr=0.0, loss=losses, top1=top1, top5=top5)
266 #
267 prefix = '**' if final_iter else '=>'
268 if use_progressbar:
269 progress_bar.set_description('{} validation'.format(prefix))
270 progress_bar.set_postfix(Epoch='{}'.format(status_str))
271 progress_bar.update(iteration - last_update_iter)
272 last_update_iter = iteration
273 else:
274 iter_str = '{:6}/{:6} : '.format(iteration+1, len(val_loader))
275 status_str = prefix + ' ' + iter_str + status_str
276 if final_iter:
277 xnn.utils.print_color(status_str, color=Fore.GREEN)
278 else:
279 xnn.utils.print_color(status_str)
281 if final_iter:
282 break
284 if use_progressbar:
285 progress_bar.close()
287 # to print a new line - do not provide end=''
288 print('{}'.format(Fore.RESET), end='')
290 return top1.avg
293 #######################################################################
294 def get_save_path(args, phase=None):
295 date = args.date if args.date else datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
296 save_path_base = os.path.join('./data/checkpoints', args.dataset_name, date + '_' + args.dataset_name + '_' + args.model_name)
297 save_path = save_path_base + '_resize{}_crop{}'.format(args.img_resize, args.img_crop)
298 phase = phase if (phase is not None) else args.phase
299 save_path = os.path.join(save_path, phase)
300 return save_path
303 def get_model_orig(model):
304 is_parallel_model = isinstance(model, (torch.nn.DataParallel, torch.nn.parallel.DistributedDataParallel))
305 model_orig = (model.module if is_parallel_model else model)
306 model_orig = (model_orig.module if isinstance(model_orig, (xnn.quantize.QuantBaseModule)) else model_orig)
307 return model_orig
310 def create_rand_inputs(args, is_cuda=True):
311 x = torch.rand((1, args.model_config.input_channels, args.img_crop, args.img_crop))
312 x = x.cuda() if is_cuda else x
313 return x
316 def count_flops(args, model):
317 is_cuda = next(model.parameters()).is_cuda
318 input_list = create_rand_inputs(args, is_cuda)
319 model.eval()
320 flops = xnn.utils.forward_count_flops(model, input_list)
321 gflops = flops/1e9
322 print('=> Resize = {}, Crop = {}, GFLOPs = {}, GMACs = {}'.format(args.img_resize, args.img_crop, gflops, gflops/2))
325 def write_onnx_model(args, model, save_path, name='checkpoint.onnx'):
326 is_cuda = next(model.parameters()).is_cuda
327 dummy_input = create_rand_inputs(args, is_cuda)
328 #
329 model.eval()
330 torch.onnx.export(model, dummy_input, os.path.join(save_path,name), export_params=True, verbose=False)
332 def save_checkpoint(state, is_best, filename='checkpoint.pth.tar'):
333 torch.save(state, filename)
334 if is_best:
335 shutil.copyfile(filename, 'model_best.pth.tar')
338 class AverageMeter(object):
339 """Computes and stores the average and current value"""
340 def __init__(self):
341 self.reset()
343 def reset(self):
344 self.val = 0
345 self.avg = 0
346 self.sum = 0
347 self.count = 0
349 def update(self, val, n=1):
350 self.val = val
351 self.sum += val * n
352 self.count += n
353 self.avg = self.sum / self.count
356 def accuracy(output, target, topk=(1,)):
357 """Computes the precision@k for the specified values of k"""
358 with torch.no_grad():
359 maxk = max(topk)
360 batch_size = target.size(0)
362 _, pred = output.topk(maxk, 1, True, True)
363 pred = pred.t()
364 correct = pred.eq(target.view(1, -1).expand_as(pred))
366 res = []
367 for k in topk:
368 correct_k = correct[:k].view(-1).float().sum(0, keepdim=True)
369 res.append(correct_k.mul_(100.0 / batch_size))
370 return res
373 def get_epoch_size(args, loader, args_epoch_size):
374 if args_epoch_size == 0:
375 epoch_size = len(loader)
376 elif args_epoch_size < 1:
377 epoch_size = int(len(loader) * args_epoch_size)
378 else:
379 epoch_size = min(len(loader), int(args_epoch_size))
380 return epoch_size
383 def get_data_loaders(args):
384 # Data loading code
385 normalize = vision.transforms.NormalizeMeanScale(mean=args.image_mean, scale=args.image_scale) \
386 if (args.image_mean is not None and args.image_scale is not None) else None
388 # pass tuple to Resize() to resize to exact size without respecting aspect ratio (typical caffe style)
389 val_transform = vision.transforms.Compose([vision.transforms.Resize(size=args.img_resize),
390 vision.transforms.CenterCrop(size=args.img_crop),
391 vision.transforms.ToFloat(),
392 vision.transforms.ToTensor(),
393 normalize])
395 train_dataset, val_dataset = vision.datasets.classification.__dict__[args.dataset_name](args.dataset_config, args.data_path, transforms=(None,val_transform))
397 val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=args.batch_size, shuffle=args.shuffle, num_workers=args.workers,
398 pin_memory=True, drop_last=False)
400 return val_loader
403 #################################################
404 def shape_as_string(shape=[]):
405 shape_str = ''
406 for dim in shape:
407 shape_str += '_' + str(dim)
408 return shape_str
411 def write_tensor_int(m=[], tensor=[], suffix='op', bitwidth=8, power2_scaling=True, file_format='bin',
412 rnd_type='rnd_sym'):
413 mn = tensor.min()
414 mx = tensor.max()
416 print(
417 '{:6}, {:32}, {:10}, {:7.2f}, {:7.2f}'.format(suffix, m.name, m.__class__.__name__, tensor.min(), tensor.max()),
418 end=" ")
420 [tensor_scale, clamp_limits] = xnn.quantize.compute_tensor_scale(tensor, mn, mx, bitwidth, power2_scaling)
421 print("{:30} : {:15} : {:8.2f}".format(str(tensor.shape), str(tensor.dtype), tensor_scale), end=" ")
423 print_weight_bias = False
424 if rnd_type == 'rnd_sym':
425 # use best rounding for offline quantities
426 if suffix == 'weight' and print_weight_bias:
427 no_idx = 0
428 torch.set_printoptions(precision=32)
429 print("tensor_scale: ", tensor_scale)
430 print(tensor[no_idx])
431 if tensor.dtype != torch.int64:
432 tensor = xnn.quantize.symmetric_round_tensor(tensor * tensor_scale)
433 if suffix == 'weight' and print_weight_bias:
434 print(tensor[no_idx])
435 else:
436 # for activation use HW friendly rounding
437 if tensor.dtype != torch.int64:
438 tensor = xnn.quantize.upward_round_tensor(tensor * tensor_scale)
439 tensor = tensor.clamp(clamp_limits[0], clamp_limits[1]).float()
441 if bitwidth == 8:
442 data_type = np.int8
443 elif bitwidth == 16:
444 data_type = np.int16
445 elif bitwidth == 32:
446 data_type = np.int32
447 else:
448 exit("Bit width other 8,16,32 not supported for writing layer level op")
450 tensor = tensor.cpu().numpy().astype(data_type)
452 print("{:7} : {:7d} : {:7d}".format(str(tensor.dtype), tensor.min(), tensor.max()))
454 tensor_dir = './data/checkpoints/debug/test_vecs/' + '{}_{}_{}_scale_{:010.4f}'.format(m.name,
455 m.__class__.__name__,
456 suffix, tensor_scale)
458 if not os.path.exists(tensor_dir):
459 os.makedirs(tensor_dir)
461 if file_format == 'bin':
462 tensor_name = tensor_dir + "/{}_shape{}.bin".format(m.name, shape_as_string(shape=tensor.shape))
463 tensor.tofile(tensor_name)
464 elif file_format == 'npy':
465 tensor_name = tensor_dir + "/{}_shape{}.npy".format(m.name, shape_as_string(shape=tensor.shape))
466 np.save(tensor_name, tensor)
467 else:
468 warnings.warn('unknown file_format for write_tensor - no file written')
469 #
471 # utils_hist.comp_hist_tensor3d(x=tensor, name=m.name, en=True, dir=m.name, log=True, ch_dim=0)
474 def write_tensor_float(m=[], tensor=[], suffix='op'):
475 mn = tensor.min()
476 mx = tensor.max()
478 print(
479 '{:6}, {:32}, {:10}, {:7.2f}, {:7.2f}'.format(suffix, m.name, m.__class__.__name__, tensor.min(), tensor.max()))
480 root = os.getcwd()
481 tensor_dir = root + '/checkpoints/debug/test_vecs/' + '{}_{}_{}'.format(m.name, m.__class__.__name__, suffix)
483 if not os.path.exists(tensor_dir):
484 os.makedirs(tensor_dir)
486 tensor_name = tensor_dir + "/{}_shape{}.npy".format(m.name, shape_as_string(shape=tensor.shape))
487 np.save(tensor_name, tensor.data)
490 def write_tensor(data_type='int', m=[], tensor=[], suffix='op', bitwidth=8, power2_scaling=True, file_format='bin',
491 rnd_type='rnd_sym'):
492 if data_type == 'int':
493 write_tensor_int(m=m, tensor=tensor, suffix=suffix, rnd_type=rnd_type, file_format=file_format)
494 elif data_type == 'float':
495 write_tensor_float(m=m, tensor=tensor, suffix=suffix)
498 enable_hook_function = True
499 def write_tensor_hook_function(m, inp, out, file_format='bin'):
500 if not enable_hook_function:
501 return
503 #Output
504 if isinstance(out, (torch.Tensor)):
505 write_tensor(m=m, tensor=out, suffix='op', rnd_type ='rnd_up', file_format=file_format)
507 #Input(s)
508 if type(inp) is tuple:
509 #if there are more than 1 inputs
510 for index, sub_ip in enumerate(inp[0]):
511 if isinstance(sub_ip, (torch.Tensor)):
512 write_tensor(m=m, tensor=sub_ip, suffix='ip_{}'.format(index), rnd_type ='rnd_up', file_format=file_format)
513 elif isinstance(inp, (torch.Tensor)):
514 write_tensor(m=m, tensor=inp, suffix='ip', rnd_type ='rnd_up', file_format=file_format)
516 #weights
517 if hasattr(m, 'weight'):
518 if isinstance(m.weight,torch.Tensor):
519 write_tensor(m=m, tensor=m.weight, suffix='weight', rnd_type ='rnd_sym', file_format=file_format)
521 #bias
522 if hasattr(m, 'bias'):
523 if m.bias is not None:
524 write_tensor(m=m, tensor=m.bias, suffix='bias', rnd_type ='rnd_sym', file_format=file_format)
527 if __name__ == '__main__':
528 main()