]> Gitweb @ Texas Instruments - Open Source Git Repositories - git.TI.com/gitweb - jacinto-ai/pytorch-jacinto-ai-devkit.git/blobdiff - modules/pytorch_jacinto_ai/engine/train_pixel2pixel.py
cleanedup STE for QAT. Added RegNetX models
[jacinto-ai/pytorch-jacinto-ai-devkit.git] / modules / pytorch_jacinto_ai / engine / train_pixel2pixel.py
index d87a850117a12ae66fd829e2752a2d2fd95b1945..bd031238d63a9297672ed19c622517d9a05d6314 100644 (file)
@@ -57,6 +57,7 @@ def get_config():
     args.model_name = 'deeplabv2lite_mobilenetv2'       # model architecture, overwritten if pretrained is specified
     args.dataset_name = 'cityscapes_segmentation'       # dataset type
     args.transforms = None                              # the transforms itself can be given from outside
+    args.input_channel_reverse = False                  # reverse input channels, for example RGB to BGR
 
     args.data_path = './data/cityscapes'                # 'path to dataset'
     args.save_path = None                               # checkpoints save path
@@ -376,9 +377,15 @@ def main(args):
 
     # load pretrained model
     if pretrained_data is not None and not is_onnx_model:
+        model_orig = get_model_orig(model)
         for (p_data,p_file) in zip(pretrained_data, pretrained_files):
             print("=> using pretrained weights from: {}".format(p_file))
-            xnn.utils.load_weights(get_model_orig(model), pretrained=p_data, change_names_dict=change_names_dict)
+            if hasattr(model_orig, 'load_weights'):
+                model_orig.load_weights(pretrained=p_data, change_names_dict=change_names_dict)
+            else:
+                xnn.utils.load_weights(get_model_orig(model), pretrained=p_data, change_names_dict=change_names_dict)
+            #
+        #
     #
 
     #################################################
@@ -1129,11 +1136,13 @@ def get_train_transform(args):
     image_scale = np.array(args.image_scale, dtype=np.float32)
     image_prenorm = xvision.transforms.image_transforms.NormalizeMeanScale(mean=image_mean, scale=image_scale) if args.image_prenorm else None
     image_postnorm = xvision.transforms.image_transforms.NormalizeMeanScale(mean=image_mean, scale=image_scale) if (not image_prenorm) else None
+    reverse_channels = xvision.transforms.image_transforms.ReverseImageChannels() if args.input_channel_reverse else None
 
     # crop size used only for training
     image_train_output_scaling = xvision.transforms.image_transforms.Scale(args.rand_resize, target_size=args.rand_output_size, is_flow=args.is_flow) \
         if (args.rand_output_size is not None and args.rand_output_size != args.rand_resize) else None
     train_transform = xvision.transforms.image_transforms.Compose([
+        reverse_channels,
         image_prenorm,
         xvision.transforms.image_transforms.AlignImages(),
         xvision.transforms.image_transforms.MaskTarget(args.target_mask, 0),
@@ -1156,9 +1165,11 @@ def get_validation_transform(args):
     image_scale = np.array(args.image_scale, dtype=np.float32)
     image_prenorm = xvision.transforms.image_transforms.NormalizeMeanScale(mean=image_mean, scale=image_scale) if args.image_prenorm else None
     image_postnorm = xvision.transforms.image_transforms.NormalizeMeanScale(mean=image_mean, scale=image_scale) if (not image_prenorm) else None
+    reverse_channels = xvision.transforms.image_transforms.ReverseImageChannels() if args.input_channel_reverse else None
 
     # prediction is resized to output_size before evaluation.
     val_transform = xvision.transforms.image_transforms.Compose([
+        reverse_channels,
         image_prenorm,
         xvision.transforms.image_transforms.AlignImages(),
         xvision.transforms.image_transforms.MaskTarget(args.target_mask, 0),