release commit
authorManu Mathew <a0393608@ti.com>
Sat, 25 Jan 2020 04:40:28 +0000 (10:10 +0530)
committerManu Mathew <a0393608@ti.com>
Sat, 25 Jan 2020 04:40:28 +0000 (10:10 +0530)
modules/pytorch_jacinto_ai/engine/train_pixel2pixel.py

index 2eec922c978a3c207144224b3da0afb8658289e2..c9b0285ca3ca11d642d802a0079246d22c9e2023 100644 (file)
@@ -54,6 +54,8 @@ def get_config():
     args.model = None                                   # the model itself can be given from ouside
     args.model_name = 'deeplabv2lite_mobilenetv2'       # model architecture, overwritten if pretrained is specified
     args.dataset_name = 'cityscapes_segmentation'       # dataset type
+    args.transforms = None                              # the transforms itself can be given from outside
+
     args.data_path = './data/cityscapes'                # 'path to dataset'
     args.save_path = None                               # checkpoints save path
     args.phase = 'training'                             # training/calibration/validation
@@ -258,7 +260,8 @@ def main(args):
     #################################################
     train_writer = SummaryWriter(os.path.join(save_path,'train'))
     val_writer = SummaryWriter(os.path.join(save_path,'val'))
-    transforms = get_transforms(args)
+    transforms = get_transforms(args) if args.transforms is None else args.transforms
+    assert isinstance(transforms, (list,tuple)) and len(transforms) == 2, 'incorrect transforms were given'
 
     print("=> fetching images in '{}'".format(args.data_path))
     split_arg = args.split_file if args.split_file else (args.split_files if args.split_files else args.split_value)