]> Gitweb @ Texas Instruments - Open Source Git Repositories - git.TI.com/gitweb - jacinto-ai/pytorch-jacinto-ai-devkit.git/blob - modules/pytorch_jacinto_ai/vision/models/classification/__init__.py
0950ce29d1727e06060713d2d4bb23b79428e62a
[jacinto-ai/pytorch-jacinto-ai-devkit.git] / modules / pytorch_jacinto_ai / vision / models / classification / __init__.py
1 from .. import mobilenetv2
2 from .. import mobilenetv1
3 from .. import resnet
5 try: from .. import mobilenetv2_gws_internal
6 except: pass
8 try: from .. import mobilenetv2_ericsun_internal
9 except: pass
11 try: from .. import mobilenetv2_shicai_internal
12 except: pass
14 try: from .. import flownetbase_internal
15 except: pass
17 from .... import xnn
19 __all__ = ['mobilenetv1_x1', 'mobilenetv2_tv_x1', 'mobilenetv2_x1', 'mobilenetv2_tv_x2_t2',
20            'resnet50_x1', 'resnet50_xp5',
21            # experimental
22            'mobilenetv2_ericsun_x1', 'mobilenetv2_shicai_x1',
23            'mobilenetv2_tv_gws_x1', 'flownetslite_base_x1']
26 #####################################################################
27 def resnet50_x1(model_config, pretrained=None, width_mult=1.0):
28     model_config = resnet.get_config().merge_from(model_config)
29     model = resnet.resnet50_with_model_config(model_config)
31     # the pretrained model provided by torchvision and what is defined here differs slightly
32     # note: that this change_names_dict  will take effect only if the direct load fails
33     change_names_dict = {'^conv1.': 'features.conv1.', '^bn1.': 'features.bn1.',
34                          '^relu.': 'features.relu.', '^maxpool.': 'features.maxpool.',
35                          '^layer': 'features.layer', '^fc.': 'classifier.'}
36     if pretrained:
37         model = xnn.utils.load_weights(model, pretrained, change_names_dict=change_names_dict)
38     return model, change_names_dict
41 def resnet50_xp5(model_config, pretrained=None):
42     return resnet50_x1(model_config=model_config, pretrained=pretrained, width_mult=0.5)
45 #####################################################################
46 def mobilenetv1_x1(model_config, pretrained=None):
47     model_config = mobilenetv1.get_config().merge_from(model_config)
48     model = mobilenetv1.MobileNetV1(model_config=model_config)
49     if pretrained:
50         model = xnn.utils.load_weights(model, pretrained)
51     return model
54 #####################################################################
55 def mobilenetv2_tv_x1(model_config, pretrained=None):
56     model_config = mobilenetv2.get_config().merge_from(model_config)
57     model = mobilenetv2.MobileNetV2TV(model_config=model_config)
58     if pretrained:
59         model = xnn.utils.load_weights(model, pretrained)
60     return model
61 #
62 #alias
63 mobilenetv2_x1 = mobilenetv2_tv_x1
66 def mobilenetv2_tv_x2_t2(model_config, pretrained=None):
67     model_config = mobilenetv2.get_config().merge_from(model_config)
68     model_config.width_mult = 2.0
69     model_config.expand_ratio = 2.0
70     model = mobilenetv2.MobileNetV2TV(model_config=model_config)
71     if pretrained:
72         model = xnn.utils.load_weights(model, pretrained)
73     return model
76 #####################################################################
77 def mobilenetv2_tv_gws_x1(model_config, pretrained=None):
78     model_config = mobilenetv2_gws_internal.get_config().merge_from(model_config)
79     model = mobilenetv2_gws_internal.MobileNetV2TVGWS(model_config=model_config)
80     if pretrained:
81         model = xnn.utils.load_weights(model, pretrained)
82     return model
85 #####################################################################
86 def mobilenetv2_ericsun_x1(model_config, pretrained=None):
87     model_config = mobilenetv2_ericsun_internal.get_config().merge_from(model_config)
88     model = mobilenetv2_ericsun_internal.MobileNetV2Ericsun(model_config=model_config)
89     if pretrained:
90         model = xnn.utils.load_weights(model, pretrained)
91     return model
94 def mobilenetv2_shicai_x1(model_config, pretrained=None):
95     model_config = mobilenetv2_shicai_internal.get_config().merge_from(model_config)
96     model = mobilenetv2_shicai_internal.mobilenetv2_shicai(model_config=model_config)
97     if pretrained:
98         model = xnn.utils.load_weights(model, pretrained)
99     return model
102 def flownetslite_base_x1(model_config, pretrained=None):
103     model_config = flownetbase_internal.get_config().merge_from(model_config)
104     model = flownetbase_internal.flownetslite_base(model_config, pretrained=pretrained)
105     if pretrained:
106         model = xnn.utils.load_weights(model, pretrained)
107     return model