[jacinto-ai/pytorch-jacinto-ai-devkit.git] / modules / pytorch_jacinto_ai / engine / train_classification.py
diff --git a/modules/pytorch_jacinto_ai/engine/train_classification.py b/modules/pytorch_jacinto_ai/engine/train_classification.py
index 06d418bb3460fb2c345a72f5b54c2ed105e6d9bd..bc208a4f6e4d0d88b69ed9cd589b898713d2d7bf 100644 (file)
args.model_config.num_tiles_x = int(1)
args.model_config.num_tiles_y = int(1)
args.model_config.en_make_divisible_by8 = True
-
args.model_config.input_channels = 3 # num input channels
+ args.input_channel_reverse = False # rgb to bgr
args.data_path = './data/datasets/ilsvrc' # path to dataset
args.model_name = 'mobilenetv2_tv_x1' # model architecture'
args.model = None #if mdoel is crated externaly
# load pretrained
if pretrained_data is not None and not is_onnx_model:
- xnn.utils.load_weights(get_model_orig(model), pretrained=pretrained_data, change_names_dict=change_names_dict)
+ model_orig = get_model_orig(model)
+ if hasattr(model_orig, 'load_weights'):
+ model_orig.load_weights(pretrained=pretrained_data, change_names_dict=change_names_dict)
+ else:
+ xnn.utils.load_weights(model_orig, pretrained=pretrained_data, change_names_dict=change_names_dict)
+ #
#
#################################################
normalize = xvision.transforms.NormalizeMeanScale(mean=args.image_mean, scale=args.image_scale) \
if (args.image_mean is not None and args.image_scale is not None) else None
multi_color_transform = xvision.transforms.MultiColor(args.multi_color_modes) if (args.multi_color_modes is not None) else None
+ reverse_channels = xvision.transforms.ReverseChannels() if args.input_channel_reverse else None
train_resize_crop_transform = xvision.transforms.RandomResizedCrop(size=args.img_crop, scale=args.rand_scale) \
if args.rand_scale else xvision.transforms.RandomCrop(size=args.img_crop)
- train_transform = xvision.transforms.Compose([train_resize_crop_transform,
+ train_transform = xvision.transforms.Compose([reverse_channels,
+ train_resize_crop_transform,
xvision.transforms.RandomHorizontalFlip(),
multi_color_transform,
xvision.transforms.ToFloat(),
normalize = xvision.transforms.NormalizeMeanScale(mean=args.image_mean, scale=args.image_scale) \
if (args.image_mean is not None and args.image_scale is not None) else None
multi_color_transform = xvision.transforms.MultiColor(args.multi_color_modes) if (args.multi_color_modes is not None) else None
+ reverse_channels = xvision.transforms.ReverseChannels() if args.input_channel_reverse else None
# pass tuple to Resize() to resize to exact size without respecting aspect ratio (typical caffe style)
val_resize_crop_transform = xvision.transforms.Resize(size=args.img_resize) if args.img_resize else xvision.transforms.Bypass()
- val_transform = xvision.transforms.Compose([val_resize_crop_transform,
+ val_transform = xvision.transforms.Compose([reverse_channels,
+ val_resize_crop_transform,
xvision.transforms.CenterCrop(size=args.img_crop),
multi_color_transform,
xvision.transforms.ToFloat(),