better implementation for epoch_size
authorManu Mathew <a0393608@ti.com>
Fri, 8 May 2020 15:11:27 +0000 (20:41 +0530)
committerManu Mathew <a0393608@ti.com>
Fri, 8 May 2020 15:17:20 +0000 (20:47 +0530)
modules/pytorch_jacinto_ai/engine/train_classification.py
modules/pytorch_jacinto_ai/engine/train_pixel2pixel.py
modules/pytorch_jacinto_ai/vision/models/__init__.py
modules/pytorch_jacinto_ai/vision/models/classification/__init__.py

index fafaa45ce6099329d1bcc2836370facee40a7708..07862e7204e64e72fc23a34497c93ffd1f4a08cb 100644 (file)
@@ -20,6 +20,9 @@ import torch.utils.data.distributed
 import sys
 import datetime
 
+import onnx
+from onnx import shape_inference
+
 from .. import xnn
 from .. import vision
 
@@ -193,6 +196,7 @@ def main(args):
     # reset character color, in case it is different
     print('{}'.format(Fore.RESET))
     print("=> args: ", args)
+    print('\n'.join("%s: %s" % item for item in sorted(vars(args).items())))
     print("=> resize resolution: {}".format(args.img_resize))
     print("=> crop resolution  : {}".format(args.img_crop))
     sys.stdout.flush()
@@ -341,9 +345,6 @@ def main(args):
 
     train_loader, val_loader = get_data_loaders(args)
 
-    # number of train iterations per epoch
-    args.iters = get_epoch_size(train_loader, args.epoch_size)
-
     args.cur_lr = adjust_learning_rate(args, optimizer, args.start_epoch)
 
     if args.evaluate_start or args.phase=='validation':
@@ -439,6 +440,11 @@ def write_onnx_model(args, model, save_path, name='checkpoint.onnx'):
     model.eval()
     torch.onnx.export(model, dummy_input, os.path.join(save_path,name), export_params=True, verbose=False,
                       do_constant_folding=True, opset_version=args.opset_version)
+    
+    #to see tensor shape in ONNX graph. Works only upto ver 8
+    if args.opset_version <= 8:
+        path = os.path.join(save_path,name)                       
+        onnx.save(onnx.shape_inference.infer_shapes(onnx.load(path)), path)
 
 
 def train(args, train_loader, model, criterion, optimizer, epoch):
@@ -454,19 +460,18 @@ def train(args, train_loader, model, criterion, optimizer, epoch):
     if args.freeze_bn:
         xnn.utils.freeze_bn(model)
     #
-    
-    progress_bar = progiter.ProgIter(np.arange(args.iters), chunksize=1)
+
+    num_iters = len(train_loader)
+    progress_bar = progiter.ProgIter(np.arange(num_iters), chunksize=1)
     args.cur_lr = adjust_learning_rate(args, optimizer, epoch)
 
     end = time.time()
-    train_iter = iter(train_loader)
     last_update_iter = -1
 
     progressbar_color = (Fore.YELLOW if (('calibration' in args.phase) or ('training' in args.phase and args.quantize)) else Fore.WHITE)
     print('{}'.format(progressbar_color), end='')
 
-    for iteration in range(args.iters):
-        (input, target) = next(train_iter)
+    for iteration, (input, target) in enumerate(train_loader):
         input = [inp.cuda() for inp in input] if xnn.utils.is_list(input) else input.cuda()
         input_size = input[0].size() if xnn.utils.is_list(input) else input.size()
         target = target.cuda(non_blocking=True)
@@ -512,7 +517,7 @@ def train(args, train_loader, model, criterion, optimizer, epoch):
         # measure elapsed time
         batch_time.update(time.time() - end)
         end = time.time()
-        final_iter = (iteration >= (args.iters-1))
+        final_iter = (iteration >= (num_iters-1))
 
         if ((iteration % args.print_freq) == 0) or final_iter:
             epoch_str = '{}/{}'.format(epoch + 1, args.epochs)
@@ -523,7 +528,8 @@ def train(args, train_loader, model, criterion, optimizer, epoch):
             progress_bar.set_postfix(Epoch='{}'.format(status_str))
             progress_bar.update(iteration-last_update_iter)
             last_update_iter = iteration
-
+        #
+    #
     progress_bar.close()
 
     # to print a new line - do not provide end=''
@@ -551,7 +557,8 @@ def validate(args, val_loader, model, criterion, epoch):
     # switch to evaluate mode
     model.eval()
 
-    progress_bar = progiter.ProgIter(np.arange(len(val_loader)), chunksize=1)
+    num_iters = len(val_loader)
+    progress_bar = progiter.ProgIter(np.arange(num_iters), chunksize=1)
     last_update_iter = -1
 
     # change color to green
@@ -591,7 +598,7 @@ def validate(args, val_loader, model, criterion, epoch):
             # measure elapsed time
             batch_time.update(time.time() - end)
             end = time.time()
-            final_iter = (iteration >= (len(val_loader)-1))
+            final_iter = (iteration >= (num_iters-1))
 
             if ((iteration % args.print_freq) == 0) or final_iter:
                 epoch_str = '{}/{}'.format(epoch+1,args.epochs)
@@ -603,6 +610,8 @@ def validate(args, val_loader, model, criterion, epoch):
                 progress_bar.set_postfix(Epoch='{}'.format(status_str))
                 progress_bar.update(iteration - last_update_iter)
                 last_update_iter = iteration
+            #
+        #
 
         progress_bar.close()
 
@@ -698,15 +707,27 @@ def accuracy(output, target, topk=(1,)):
         return res
 
 
-def get_epoch_size(loader, args_epoch_size):
-    if args_epoch_size == 0:
-        epoch_size = len(loader)
-    elif args_epoch_size < 1:
-        epoch_size = int(len(loader) * args_epoch_size)
+def get_dataset_sampler(dataset_object, epoch_size, balanced_sampler=False):
+    num_samples = len(dataset_object)
+    epoch_size = int(epoch_size * num_samples) if epoch_size < 1 else int(epoch_size)
+    print('=> creating a random sampler as epoch_size is specified')
+    if balanced_sampler:
+        # going through the dataset this way may take too much time
+        progress_bar = progiter.ProgIter(np.arange(num_samples), chunksize=1, \
+            desc='=> reading data to create a balanced data sampler : ')
+        sample_classes = [target for _, target in progress_bar(dataset_object)]
+        num_classes = max(sample_classes) + 1
+        sample_counts = np.zeros(num_classes, dtype=np.int32)
+        for target in sample_classes:
+            sample_counts[target] += 1
+        #
+        train_class_weights = [float(num_samples) / float(cnt) for cnt in sample_counts]
+        train_sample_weights = [train_class_weights[target] for target in sample_classes]
+        dataset_sampler = torch.utils.data.sampler.WeightedRandomSampler(train_sample_weights, epoch_size)
     else:
-        epoch_size = min(len(loader), int(args_epoch_size))
+        dataset_sampler = torch.utils.data.sampler.RandomSampler(data_source=dataset_object, replacement=True, num_samples=epoch_size)
     #
-    return epoch_size
+    return dataset_sampler
     
 
 def get_train_transform(args):
@@ -752,8 +773,13 @@ def get_data_loaders(args):
 
     train_dataset, val_dataset = vision.datasets.classification.__dict__[args.dataset_name](args.dataset_config, args.data_path, transforms=(train_transform,val_transform))
 
-    train_shuffle = (not args.distributed)
-    train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset) if args.distributed else None
+    if args.distributed:
+        train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset)
+    else:
+        train_sampler = get_dataset_sampler(train_dataset, args.epoch_size) if args.epoch_size != 0 else None
+    #
+
+    train_shuffle = (train_sampler is None)
     train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=args.batch_size, shuffle=train_shuffle, num_workers=args.workers,
         pin_memory=True, sampler=train_sampler)
 
index d9f41ed83b8ce0e01f12cc335d18d29ac8bdd7d8..bd91d70c739c1d9429150ee0e1fe61a4a3620b4e 100644 (file)
@@ -128,7 +128,7 @@ def get_config():
     args.count_flops = True                             # count flops and report
 
     args.shuffle = True                                 # shuffle or not
-    args.shuffle_val = False                            # shuffle val dataset or not
+    args.shuffle_val = True                             # shuffle val dataset or not
 
     args.transform_rotation = 0.                        # apply rotation augumentation. value is rotation in degrees. 0 indicates no rotation
     args.is_flow = None                                 # whether entries in images and targets lists are optical flow or not
@@ -278,15 +278,17 @@ def main(args):
     train_dataset, val_dataset = vision.datasets.pixel2pixel.__dict__[args.dataset_name](args.dataset_config, args.data_path, split=split_arg, transforms=transforms)
 
     #################################################
-    train_sampler = None
-    val_sampler = None
     print('=> {} samples found, {} train samples and {} test samples '.format(len(train_dataset)+len(val_dataset),
         len(train_dataset), len(val_dataset)))
+    train_sampler = get_dataset_sampler(train_dataset, args.epoch_size) if args.epoch_size != 0 else None
+    shuffle_train = args.shuffle and (train_sampler is None)
     train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=args.batch_size,
-        num_workers=args.workers, pin_memory=args.use_pinned_memory, sampler=train_sampler, shuffle=args.shuffle)
+        num_workers=args.workers, pin_memory=args.use_pinned_memory, sampler=train_sampler, shuffle=shuffle_train)
 
+    val_sampler = get_dataset_sampler(val_dataset, args.epoch_size_val) if args.epoch_size_val != 0 else None
+    shuffle_val = args.shuffle_val and (val_sampler is None)
     val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=args.batch_size,
-        num_workers=args.workers, pin_memory=args.use_pinned_memory, sampler=val_sampler, shuffle=args.shuffle_val)
+        num_workers=args.workers, pin_memory=args.use_pinned_memory, sampler=val_sampler, shuffle=shuffle_val)
 
     #################################################
     if (args.model_config.input_channels is None):
@@ -494,8 +496,7 @@ def main(args):
     #
 
     #################################################
-    epoch_size = get_epoch_size(args, train_loader, args.epoch_size)
-    max_iter = args.epochs * epoch_size
+    max_iter = args.epochs * len(train_loader)
     scheduler = xnn.optim.lr_scheduler.SchedulerWrapper(args.scheduler, optimizer, args.epochs, args.start_epoch, \
                                                             args.warmup_epochs, max_iter=max_iter, polystep_power=args.polystep_power,
                                                             milestones=args.milestones, multistep_gamma=args.multistep_gamma)
@@ -533,9 +534,11 @@ def main(args):
             validate(args, val_dataset, val_loader, model, args.start_epoch, val_writer)
 
     for epoch in range(args.start_epoch, args.epochs):
-        if train_sampler:
+        # epoch is needed to seed shuffling in DistributedSampler, every epoch.
+        # otherwise seed of 0 is used every epoch, which seems incorrect.
+        if train_sampler and isinstance(train_sampler, torch.utils.data.DistributedSampler):
             train_sampler.set_epoch(epoch)
-        if val_sampler:
+        if val_sampler and isinstance(val_sampler, torch.utils.data.DistributedSampler):
             val_sampler.set_epoch(epoch)
 
         # train for one epoch
@@ -597,7 +600,6 @@ def train(args, train_dataset, train_loader, model, optimizer, epoch, train_writ
     avg_loss = [xnn.utils.AverageMeter(print_avg=(not task_loss[0].info()['is_avg'])) for task_loss in args.loss_modules]
     avg_loss_orig = [xnn.utils.AverageMeter(print_avg=(not task_loss[0].info()['is_avg'])) for task_loss in args.loss_modules]
     avg_metric = [xnn.utils.AverageMeter(print_avg=(not task_metric[0].info()['is_avg'])) for task_metric in args.metric_modules]
-    epoch_size = get_epoch_size(args, train_loader, args.epoch_size)
 
     ##########################
     # switch to train mode
@@ -626,7 +628,8 @@ def train(args, train_dataset, train_loader, model, optimizer, epoch, train_writ
         for midx, metric_fn in enumerate(task_metrics):
             metric_fn.clear()
 
-    progress_bar = progiter.ProgIter(np.arange(epoch_size), chunksize=1)
+    num_iter = len(train_loader)
+    progress_bar = progiter.ProgIter(np.arange(num_iter), chunksize=1)
     metric_name = "Metric"
     metric_ctx = [None] * len(args.metric_modules)
     end_time = time.time()
@@ -638,7 +641,7 @@ def train(args, train_dataset, train_loader, model, optimizer, epoch, train_writ
     print('{}'.format(progressbar_color), end='')
 
     ##########################
-    for iter, (inputs, targets) in enumerate(train_loader):
+    for iter_id, (inputs, targets) in enumerate(train_loader):
         # measure data loading time
         data_time.update(time.time() - end_time)
 
@@ -683,13 +686,13 @@ def train(args, train_dataset, train_loader, model, optimizer, epoch, train_writ
 
         if 'training' in args.phase:
             # zero gradients so that we can accumulate gradients
-            if (iter % args.iter_size) == 0:
+            if (iter_id % args.iter_size) == 0:
                 optimizer.zero_grad()
 
             # accumulate gradients
             loss_total.backward()
             # optimization step
-            if ((iter+1) % args.iter_size) == 0:
+            if ((iter_id+1) % args.iter_size) == 0:
                 optimizer.step()
         #
 
@@ -708,9 +711,9 @@ def train(args, train_dataset, train_loader, model, optimizer, epoch, train_writ
 
         ##########################
         if args.tensorboard_enable:
-            write_output(args, 'Training_', epoch_size, iter, epoch, train_dataset, train_writer, input_list, task_outputs, target_list, metric_name, writer_idx)
+            write_output(args, 'Training_', num_iter, iter_id, epoch, train_dataset, train_writer, input_list, task_outputs, target_list, metric_name, writer_idx)
 
-        if ((iter % args.print_freq) == 0) or (iter == (epoch_size-1)):
+        if ((iter_id % args.print_freq) == 0) or (iter_id == (num_iter-1)):
             output_string = ''
             for task_idx, task_metrics in enumerate(args.metric_modules):
                 output_string += '[{}={}]'.format(metric_names[task_idx], str(avg_metric[task_idx]))
@@ -719,8 +722,8 @@ def train(args, train_dataset, train_loader, model, optimizer, epoch, train_writ
             progress_bar.set_description("{}=> {}  ".format(progressbar_color, args.phase))
             multi_task_factors_print = ['{:.3f}'.format(float(lmf)) for lmf in args.multi_task_factors] if args.multi_task_factors is not None else None
             progress_bar.set_postfix(Epoch=epoch_str, LR=lr, DataTime=str(data_time), LossMult=multi_task_factors_print, Loss=avg_loss, Output=output_string)
-            progress_bar.update(iter-last_update_iter)
-            last_update_iter = iter
+            progress_bar.update(iter_id-last_update_iter)
+            last_update_iter = iter_id
 
         args.n_iter += 1
         end_time = time.time()
@@ -729,14 +732,12 @@ def train(args, train_dataset, train_loader, model, optimizer, epoch, train_writ
         # add onnx graph to tensorboard
         # commenting out due to issues in transitioning to pytorch 0.4
         # (bilinear mode in upsampling causes hang or crash - may be due to align_borders change, nearest is fine)
-        #if epoch == 0 and iter == 0:
+        #if epoch == 0 and iter_id == 0:
         #    input_zero = torch.zeros(input_var.shape)
         #    train_writer.add_graph(model, input_zero)
         #This cache operation slows down tranining  
         #torch.cuda.empty_cache()
-
-        if iter >= epoch_size:
-            break
+    #
 
     if args.print_train_class_iou:
         print_class_iou(args=args, confusion_matrix=confusion_matrix, task_idx=task_idx)
@@ -774,7 +775,6 @@ def validate(args, val_dataset, val_loader, model, epoch, val_writer):
     data_time = xnn.utils.AverageMeter()
     # if the loss/ metric is already an average, no need to further average
     avg_metric = [xnn.utils.AverageMeter(print_avg=(not task_metric[0].info()['is_avg'])) for task_metric in args.metric_modules]
-    epoch_size = get_epoch_size(args, val_loader, args.epoch_size_val)
 
     ##########################
     # switch to evaluate mode
@@ -790,13 +790,15 @@ def validate(args, val_dataset, val_loader, model, epoch, val_writer):
     writer_idx = 0
     last_update_iter = -1
     metric_ctx = [None] * len(args.metric_modules)
-    progress_bar = progiter.ProgIter(np.arange(epoch_size), chunksize=1)
+
+    num_iter = len(val_loader)
+    progress_bar = progiter.ProgIter(np.arange(num_iter), chunksize=1)
 
     # change color to green
     print('{}'.format(Fore.GREEN), end='')
 
     ##########################
-    for iter, (inputs, targets) in enumerate(val_loader):
+    for iter_id, (inputs, targets) in enumerate(val_loader):
         data_time.update(time.time() - end_time)
         input_list = [[jj.cuda() for jj in img] if isinstance(img,(list,tuple)) else img.cuda() for img in inputs]
         target_list = [j.cuda(non_blocking=True) for j in targets]
@@ -806,7 +808,6 @@ def validate(args, val_dataset, val_loader, model, epoch, val_writer):
         # compute output
         task_outputs = model(input_list)
 
-
         task_outputs = task_outputs if isinstance(task_outputs, (list, tuple)) else [task_outputs]
         if args.upsample_mode is not None:
            task_outputs = upsample_tensors(task_outputs, target_sizes, args.upsample_mode)
@@ -825,9 +826,9 @@ def validate(args, val_dataset, val_loader, model, epoch, val_writer):
             avg_metric[task_idx].update(float(metric_list[task_idx].cpu()), batch_size_cur)
 
         if args.tensorboard_enable:
-            write_output(args, 'Validation_', epoch_size, iter, epoch, val_dataset, val_writer, input_list, task_outputs, target_list, metric_names, writer_idx)
+            write_output(args, 'Validation_', num_iter, iter_id, epoch, val_dataset, val_writer, input_list, task_outputs, target_list, metric_names, writer_idx)
 
-        if ((iter % args.print_freq) == 0) or (iter == (epoch_size-1)):
+        if ((iter_id % args.print_freq) == 0) or (iter_id == (num_iter-1)):
             output_string = ''
             for task_idx, task_metrics in enumerate(args.metric_modules):
                 output_string += '[{}={}]'.format(metric_names[task_idx], str(avg_metric[task_idx]))
@@ -835,18 +836,18 @@ def validate(args, val_dataset, val_loader, model, epoch, val_writer):
             epoch_str = '{}/{}'.format(epoch + 1, args.epochs)
             progress_bar.set_description("=> validation")
             progress_bar.set_postfix(Epoch=epoch_str, DataTime=data_time, Output="{}".format(output_string))
-            progress_bar.update(iter-last_update_iter)
-            last_update_iter = iter
+            progress_bar.update(iter_id-last_update_iter)
+            last_update_iter = iter_id
+        #
 
         end_time = time.time()
         writer_idx = (writer_idx + 1) % args.tensorboard_num_imgs
-
-        if iter >= epoch_size:
-            break
+    #
 
     if args.print_val_class_iou:
         print_class_iou(args = args, confusion_matrix = confusion_matrix, task_idx=task_idx)
-    
+    #
+
     #print_conf_matrix(conf_matrix=conf_matrix, en=False)
     progress_bar.close()
 
@@ -959,7 +960,7 @@ def write_onnx_model(args, model, save_path, name='checkpoint.onnx', save_traced
 
 
 ###################################################################
-def write_output(args, prefix, val_epoch_size, iter, epoch, dataset, output_writer, input_images, task_outputs, task_targets, metric_names, writer_idx):
+def write_output(args, prefix, val_epoch_size, iter_id, epoch, dataset, output_writer, input_images, task_outputs, task_targets, metric_names, writer_idx):
     write_freq = (args.tensorboard_num_imgs / float(val_epoch_size))
     write_prob = np.random.random()
     if (write_prob > write_freq):
@@ -1111,14 +1112,12 @@ def save_checkpoint(args, save_path, model, checkpoint_dict, is_best, filename='
     #
 
 
-def get_epoch_size(args, loader, args_epoch_size):
-    if args_epoch_size == 0:
-        epoch_size = len(loader)
-    elif args_epoch_size < 1:
-        epoch_size = int(len(loader) * args_epoch_size)
-    else:
-        epoch_size = min(len(loader), int(args_epoch_size))
-    return epoch_size
+def get_dataset_sampler(dataset_object, epoch_size):
+    print('=> creating a random sampler as epoch_size is specified')
+    num_samples = len(dataset_object)
+    epoch_size = int(epoch_size * num_samples) if epoch_size < 1 else int(epoch_size)
+    dataset_sampler = torch.utils.data.sampler.RandomSampler(data_source=dataset_object, replacement=True, num_samples=epoch_size)
+    return dataset_sampler
 
 
 def get_train_transform(args):
index cab21e92b395d4c6c25c28656d72ae30f532d0be..0977ce1d03e620f324c646d3df696de49b68e267 100644 (file)
@@ -22,7 +22,7 @@ from .multi_input_net import *
 try: from .mobilenetv2_ericsun_internal import *
 except: pass
 
-try: from .mobilenetv2_gws_internal import *
+try: from .mobilenetv2_internal import *
 except: pass
 
 try: from .mobilenetv2_shicai_internal import *
index e9e90b9b04cb3b3660ef89e6a92560c027d3b648..690def51a4cf023dd4e4b3476441a2daf2ed7c36 100644 (file)
@@ -2,7 +2,7 @@ from .. import mobilenetv2
 from .. import mobilenetv1
 from .. import resnet
 
-try: from .. import mobilenetv2_gws_internal
+try: from .. import mobilenetv2_internal
 except: pass
 
 try: from .. import mobilenetv2_ericsun_internal
@@ -28,7 +28,7 @@ __all__ = ['mobilenetv1_x1', 'mobilenetv2_tv_x1', 'mobilenetv2_x1', 'mobilenetv2
 
 
 #####################################################################
-def resnet50_x1(model_config, pretrained=None, width_mult=1.0):
+def resnet50_x1(model_config, pretrained=None):
     model_config = resnet.get_config().merge_from(model_config)
     model = resnet.resnet50_with_model_config(model_config)
 
@@ -43,7 +43,8 @@ def resnet50_x1(model_config, pretrained=None, width_mult=1.0):
 
 
 def resnet50_xp5(model_config, pretrained=None):
-    return resnet50_x1(model_config=model_config, pretrained=pretrained, width_mult=0.5)
+    model_config.width_mult = 0.5
+    return resnet50_x1(model_config=model_config, pretrained=pretrained)
 
 
 #####################################################################
@@ -86,8 +87,8 @@ def mobilenetv2_tv_x2_t2(model_config, pretrained=None):
 
 #####################################################################
 def mobilenetv2_tv_gws_x1(model_config, pretrained=None):
-    model_config = mobilenetv2_gws_internal.get_config().merge_from(model_config)
-    model = mobilenetv2_gws_internal.MobileNetV2TVGWS(model_config=model_config)
+    model_config = mobilenetv2_internal.get_config_mnetv2_gws().merge_from(model_config)
+    model = mobilenetv2_internal.MobileNetV2TVGWS(model_config=model_config)
     if pretrained:
         model = xnn.utils.load_weights(model, pretrained)
     return model