epoch_size and shuffling arguments for validation
authorManu Mathew <a0393608@ti.com>
Wed, 13 May 2020 12:50:56 +0000 (18:20 +0530)
committerManu Mathew <a0393608@ti.com>
Thu, 14 May 2020 07:56:39 +0000 (13:26 +0530)
modules/pytorch_jacinto_ai/engine/train_classification.py
scripts/train_classification_main.py
scripts/train_segmentation_main.py

index a40907acd93ecab7ea94ead77f18108787490a49..a11798fd1192bbe22aeb6f121f48c5f24d773ab5 100644 (file)
@@ -52,6 +52,7 @@ def get_config():
     args.warmup_epochs = None                           # number of epochs to warm up by linearly increasing lr
 
     args.epoch_size = 0                                 # fraction of training epoch to use each time. 0 indicates full
+    args.epoch_size_val = 0                             # manual epoch size (will match dataset size if not specified)
     args.start_epoch = 0                                # manual epoch number to start
     args.stop_epoch = None                              # manual epoch number to stop
     args.batch_size = 256                               # mini_batch size (default: 256)
@@ -65,6 +66,9 @@ def get_config():
     args.weight_decay = 1e-4                            # weight decay (default: 1e-4)
     args.bias_decay = None                              # bias decay (default: 0.0)
 
+    args.shuffle = True                                 # shuffle or not
+    args.shuffle_val = True                             # shuffle val dataset or not
+
     args.rand_seed = 1                                  # random seed
     args.print_freq = 100                               # print frequency (default: 100)
     args.resume = None                                  # path to latest checkpoint (default: none)
@@ -775,16 +779,19 @@ def get_data_loaders(args):
 
     if args.distributed:
         train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset)
+        val_sampler = torch.utils.data.distributed.DistributedSampler(val_dataset)
     else:
         train_sampler = get_dataset_sampler(train_dataset, args.epoch_size) if args.epoch_size != 0 else None
+        val_sampler = get_dataset_sampler(val_dataset, args.epoch_size_val) if args.epoch_size_val != 0 else None
     #
 
-    train_shuffle = (train_sampler is None)
+    train_shuffle = args.shuffle and (train_sampler is None)
     train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=args.batch_size, shuffle=train_shuffle, num_workers=args.workers,
-        pin_memory=True, sampler=train_sampler)
+                                               pin_memory=True, sampler=train_sampler)
 
-    val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.workers,
-                                             pin_memory=True, drop_last=False)
+    val_shuffle = args.shuffle_val and (val_sampler is None)
+    val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=args.batch_size, shuffle=val_shuffle, num_workers=args.workers,
+                                             pin_memory=True, drop_last=False, sampler=val_sampler)
 
     return train_loader, val_loader
 
index a04ba2399495db473d6eef16d295a88d5bd4e6e4..94e7524e9631bb708f3400668021ff6868740865 100755 (executable)
@@ -20,7 +20,6 @@ parser.add_argument('--lr_calib', type=float, default=None, help='Learning rate
 parser.add_argument('--model_name', type=str, default=None, help='model name')
 parser.add_argument('--dataset_name', type=str, default=None, help='dataset name')
 parser.add_argument('--data_path', type=str, default=None, help='data path')
-parser.add_argument('--epoch_size', type=float, default=None, help='epoch size. using a fraction will reduce the data used for one epoch')
 parser.add_argument('--epochs', type=int, default=None, help='number of epochs')
 parser.add_argument('--warmup_epochs', type=int, default=None, help='number of epochs for the learning rate to increase and reach base value')
 parser.add_argument('--milestones', type=int, nargs='*', default=None, help='change lr at these milestones')
@@ -41,7 +40,18 @@ parser.add_argument('--bitwidth_weights', type=int, default=None, help='bitwidth
 parser.add_argument('--bitwidth_activations', type=int, default=None, help='bitwidth for activation quantization')
 #
 parser.add_argument('--freeze_bn', type=str2bool, default=None, help='freeze the bn stats or not')
-
+#
+parser.add_argument('--shuffle', type=str2bool, default=None, help='whether to shuffle the training set or not')
+parser.add_argument('--shuffle_val', type=str2bool, default=None, help='whether to shuffle the validation set or not')
+parser.add_argument('--epoch_size', type=float, default=None, help='epoch size. options are: 0, fraction or number. '
+                                                                   '0 will use th full epoch. '
+                                                                   'using a number will cause the epoch to have that many iterations'
+                                                                   'using a fraction will reduce the iterations used for one epoch to that fraction of the whole. ')
+parser.add_argument('--epoch_size_val', type=float, default=None, help='epoch size for validation. options are: 0, fraction or number. '
+                                                                   '0 will use th full epoch. '
+                                                                   'using a number will cause the epoch to have that many iterations. '
+                                                                   'using a fraction will reduce the iterations used for one epoch to that fraction of the whole. ')
+#
 cmds = parser.parse_args()
 
 ################################
index aca1965672f129b939aaac441812eef65e015c14..93e18fcddfddaf2fd568101f05652cc51736f452 100755 (executable)
@@ -19,7 +19,6 @@ parser.add_argument('--lr_calib', type=float, default=None, help='Learning rate
 parser.add_argument('--model_name', type=str, default=None, help='model name')
 parser.add_argument('--dataset_name', type=str, default=None, help='dataset name')
 parser.add_argument('--data_path', type=str, default=None, help='data path')
-parser.add_argument('--epoch_size', type=float, default=None, help='epoch size. using a fraction will reduce the data used for one epoch')
 parser.add_argument('--epochs', type=int, default=None, help='number of epochs')
 parser.add_argument('--warmup_epochs', type=int, default=None, help='number of epochs for the learning rate to increase and reach base value')
 parser.add_argument('--milestones', type=int, nargs='*', default=None, help='change lr at these milestones')
@@ -41,6 +40,18 @@ parser.add_argument('--bitwidth_weights', type=int, default=None, help='bitwidth
 parser.add_argument('--bitwidth_activations', type=int, default=None, help='bitwidth for activation quantization')
 #
 parser.add_argument('--freeze_bn', type=str2bool, default=None, help='freeze the bn stats or not')
+#
+parser.add_argument('--shuffle', type=str2bool, default=None, help='whether to shuffle the training set or not')
+parser.add_argument('--shuffle_val', type=str2bool, default=None, help='whether to shuffle the validation set or not')
+parser.add_argument('--epoch_size', type=float, default=None, help='epoch size. options are: 0, fraction or number. '
+                                                                   '0 will use th full epoch. '
+                                                                   'using a number will cause the epoch to have that many iterations'
+                                                                   'using a fraction will reduce the iterations used for one epoch to that fraction of the whole. ')
+parser.add_argument('--epoch_size_val', type=float, default=None, help='epoch size for validation. options are: 0, fraction or number. '
+                                                                   '0 will use th full epoch. '
+                                                                   'using a number will cause the epoch to have that many iterations. '
+                                                                   'using a fraction will reduce the iterations used for one epoch to that fraction of the whole. ')
+#
 cmds = parser.parse_args()
 
 ################################