Skip to content

Commit

Permalink
Add more pali(2) weights. Switch rest of models adapting open_clip we…
Browse files Browse the repository at this point in the history
…ights to their own weight instances.
  • Loading branch information
rwightman committed Dec 27, 2024
1 parent 01cf0f7 commit 790decc
Show file tree
Hide file tree
Showing 4 changed files with 178 additions and 107 deletions.
32 changes: 8 additions & 24 deletions timm/models/byobnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -2282,107 +2282,91 @@ def _cfgr(url='', **kwargs):
# original attention pool head variants
'resnet50_clip.openai': _cfgr(
hf_hub_id='timm/',
hf_hub_filename='open_clip_pytorch_model.bin',
num_classes=1024, mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD,
fixed_input_size=True, input_size=(3, 224, 224), pool_size=(7, 7),
classifier='head.proj',
),
'resnet101_clip.openai': _cfgr(
hf_hub_id='timm/',
hf_hub_filename='open_clip_pytorch_model.bin',
num_classes=512, mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD,
fixed_input_size=True, input_size=(3, 224, 224), pool_size=(7, 7),
classifier='head.proj',
),
'resnet50x4_clip.openai': _cfgr(
hf_hub_id='timm/',
hf_hub_filename='open_clip_pytorch_model.bin',
num_classes=640, mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD,
fixed_input_size=True, input_size=(3, 288, 288), pool_size=(9, 9),
classifier='head.proj',
),
'resnet50x16_clip.openai': _cfgr(
hf_hub_id='timm/',
hf_hub_filename='open_clip_pytorch_model.bin',
num_classes=768, mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD,
fixed_input_size=True, input_size=(3, 384, 384), pool_size=(12, 12),
classifier='head.proj',
),
'resnet50x64_clip.openai': _cfgr(
hf_hub_id='timm/',
hf_hub_filename='open_clip_pytorch_model.bin',
num_classes=1024, mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD,
fixed_input_size=True, input_size=(3, 448, 448), pool_size=(14, 14),
classifier='head.proj',
),
'resnet50_clip.cc12m': _cfgr(
hf_hub_id='timm/',
hf_hub_filename='open_clip_pytorch_model.bin',
num_classes=1024, mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD,
fixed_input_size=True, input_size=(3, 224, 224), pool_size=(7, 7),
classifier='head.proj',
),
'resnet50_clip.yfcc15m': _cfgr(
hf_hub_id='timm/',
hf_hub_filename='open_clip_pytorch_model.bin',
num_classes=1024, mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD,
fixed_input_size=True, input_size=(3, 224, 224), pool_size=(7, 7),
classifier='head.proj',
),
'resnet101_clip.yfcc15m': _cfgr(
hf_hub_id='timm/',
hf_hub_filename='open_clip_pytorch_model.bin',
num_classes=512, mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD,
fixed_input_size=True, input_size=(3, 224, 224), pool_size=(7, 7),
classifier='head.proj',
),

# avg-pool w/ optional standard classifier head variants
'resnet50_clip_gap.openai': _cfgr(
hf_hub_id='timm/resnet50_clip.openai',
hf_hub_filename='open_clip_pytorch_model.bin',
hf_hub_id='timm/',
num_classes=0, mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD,
input_size=(3, 224, 224), pool_size=(7, 7),
),
'resnet101_clip_gap.openai': _cfgr(
hf_hub_id='timm/resnet101_clip.openai',
hf_hub_filename='open_clip_pytorch_model.bin',
hf_hub_id='timm/',
num_classes=0, mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD,
input_size=(3, 224, 224), pool_size=(7, 7),
),
'resnet50x4_clip_gap.openai': _cfgr(
hf_hub_id='timm/resnet50x4_clip.openai',
hf_hub_filename='open_clip_pytorch_model.bin',
hf_hub_id='timm/',
num_classes=0, mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD,
input_size=(3, 288, 288), pool_size=(9, 9),
),
'resnet50x16_clip_gap.openai': _cfgr(
hf_hub_id='timm/resnet50x16_clip.openai',
hf_hub_filename='open_clip_pytorch_model.bin',
hf_hub_id='timm/',
num_classes=0, mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD,
input_size=(3, 384, 384), pool_size=(12, 12),
),
'resnet50x64_clip_gap.openai': _cfgr(
hf_hub_id='timm/resnet50x64_clip.openai',
hf_hub_filename='open_clip_pytorch_model.bin',
hf_hub_id='timm/',
num_classes=0, mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD,
input_size=(3, 448, 448), pool_size=(14, 14),
),
'resnet50_clip_gap.cc12m': _cfgr(
hf_hub_id='timm/resnet50_clip.cc12m',
hf_hub_filename='open_clip_pytorch_model.bin',
hf_hub_id='timm/',
num_classes=0, mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD,
input_size=(3, 224, 224), pool_size=(7, 7),
),
'resnet50_clip_gap.yfcc15m': _cfgr(
hf_hub_id='timm/resnet50_clip.yfcc15m',
hf_hub_filename='open_clip_pytorch_model.bin',
hf_hub_id='timm/',
num_classes=0, mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD,
input_size=(3, 224, 224), pool_size=(7, 7),
),
'resnet101_clip_gap.yfcc15m': _cfgr(
hf_hub_id='timm/resnet101_clip.yfcc15m',
hf_hub_filename='open_clip_pytorch_model.bin',
hf_hub_id='timm/',
num_classes=0, mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD,
input_size=(3, 224, 224), pool_size=(7, 7),
),
Expand Down
35 changes: 21 additions & 14 deletions timm/models/eva.py
Original file line number Diff line number Diff line change
Expand Up @@ -912,45 +912,52 @@ def _cfg(url='', **kwargs):
# EVA01 and EVA02 CLIP image towers
'eva_giant_patch14_clip_224.laion400m': _cfg(
# hf_hub_id='QuanSun/EVA-CLIP', hf_hub_filename='EVA01_CLIP_g_14_plus_psz14_s11B.pt',
hf_hub_id='timm/eva_giant_patch14_clip_224.laion400m_s11b_b41k', # float16 weights
hf_hub_filename='open_clip_pytorch_model.bin',
# hf_hub_id='timm/eva_giant_patch14_clip_224.laion400m_s11b_b41k', # float16 weights
# hf_hub_filename='open_clip_pytorch_model.bin',
hf_hub_id='timm/',
num_classes=1024,
),
'eva_giant_patch14_clip_224.merged2b': _cfg(
# hf_hub_id='QuanSun/EVA-CLIP', hf_hub_filename='EVA01_CLIP_g_14_plus_psz14_s11B.pt',
hf_hub_id='timm/eva_giant_patch14_plus_clip_224.merged2b_s11b_b114k', # float16 weights
hf_hub_filename='open_clip_pytorch_model.bin',
# hf_hub_id='timm/eva_giant_patch14_plus_clip_224.merged2b_s11b_b114k', # float16 weights
# hf_hub_filename='open_clip_pytorch_model.bin',
hf_hub_id='timm/',
num_classes=1024,
),
'eva02_base_patch16_clip_224.merged2b': _cfg(
# hf_hub_id='QuanSun/EVA-CLIP', hf_hub_filename='EVA02_CLIP_L_psz14_s4B.pt',
hf_hub_id='timm/eva02_base_patch16_clip_224.merged2b_s8b_b131k', # float16 weights
hf_hub_filename='open_clip_pytorch_model.bin',
# hf_hub_id='timm/eva02_base_patch16_clip_224.merged2b_s8b_b131k', # float16 weights
# hf_hub_filename='open_clip_pytorch_model.bin',
hf_hub_id='timm/',
num_classes=512,
),
'eva02_large_patch14_clip_224.merged2b': _cfg(
# hf_hub_id='QuanSun/EVA-CLIP', hf_hub_filename='EVA02_CLIP_L_psz14_s4B.pt',
hf_hub_id='timm/eva02_large_patch14_clip_224.merged2b_s4b_b131k', # float16 weights
hf_hub_filename='open_clip_pytorch_model.bin',
# hf_hub_id='timm/eva02_large_patch14_clip_224.merged2b_s4b_b131k', # float16 weights
# hf_hub_filename='open_clip_pytorch_model.bin',
hf_hub_id='timm/',
num_classes=768,
),
'eva02_large_patch14_clip_336.merged2b': _cfg(
# hf_hub_id='QuanSun/EVA-CLIP', hf_hub_filename='EVA02_CLIP_L_psz14_s4B.pt',
hf_hub_id='timm/eva02_large_patch14_clip_336.merged2b_s6b_b61k', # float16 weights
hf_hub_filename='open_clip_pytorch_model.bin',
# hf_hub_id='timm/eva02_large_patch14_clip_336.merged2b_s6b_b61k', # float16 weights
# hf_hub_filename='open_clip_pytorch_model.bin',
hf_hub_id='timm/',
input_size=(3, 336, 336), crop_pct=1.0,
num_classes=768,
),
'eva02_enormous_patch14_clip_224.laion2b': _cfg(
# hf_hub_id='QuanSun/EVA-CLIP', hf_hub_filename='EVA02_CLIP_E_psz14_plus_s9B.pt',
hf_hub_id='timm/eva02_enormous_patch14_clip_224.laion2b_s4b_b115k', # float16 weights
hf_hub_filename='open_clip_pytorch_model.bin',
# hf_hub_id='timm/eva02_enormous_patch14_clip_224.laion2b_s4b_b115k', # float16 weights
# hf_hub_filename='open_clip_pytorch_model.bin',
hf_hub_id='timm/',
num_classes=1024,
),
'eva02_enormous_patch14_clip_224.laion2b_plus': _cfg(
# hf_hub_id='QuanSun/EVA-CLIP', hf_hub_filename='EVA02_CLIP_E_psz14_plus_s9B.pt',
hf_hub_id='timm/eva02_enormous_patch14_plus_clip_224.laion2b_s9b_b144k', # bfloat16 weights
hf_hub_filename='open_clip_pytorch_model.bin',
# hf_hub_id='timm/eva02_enormous_patch14_plus_clip_224.laion2b_s9b_b144k', # bfloat16 weights
# hf_hub_filename='open_clip_pytorch_model.bin',
hf_hub_id='timm/',
num_classes=1024,
),
'eva02_enormous_patch14_clip_224.pretrain': _cfg(
Expand Down
65 changes: 43 additions & 22 deletions timm/models/hieradet_sam2.py
Original file line number Diff line number Diff line change
Expand Up @@ -530,26 +530,47 @@ def _cfg(url='', **kwargs):


default_cfgs = generate_default_cfgs({
"sam2_hiera_tiny.r224": _cfg(
hf_hub_id='facebook/sam2-hiera-tiny',
hf_hub_filename='sam2_hiera_tiny.pt',
input_size=(3, 224, 224), pool_size=(7, 7),
), # FIXME reduced res for testing
"sam2_hiera_tiny.r896": _cfg(
hf_hub_id='facebook/sam2-hiera-tiny',
hf_hub_filename='sam2_hiera_tiny.pt',
"sam2_hiera_tiny.fb_r896": _cfg(
# hf_hub_id='facebook/sam2-hiera-tiny',
# hf_hub_filename='sam2_hiera_tiny.pt',
hf_hub_id='timm/',
),
"sam2_hiera_small": _cfg(
hf_hub_id='facebook/sam2-hiera-small',
hf_hub_filename='sam2_hiera_small.pt',
"sam2_hiera_tiny.fb_r896_2pt1": _cfg(
# hf_hub_id='facebook/sam2.1-hiera-tiny',
# hf_hub_filename='sam2.1_hiera_tiny.pt',
hf_hub_id='timm/',
),
"sam2_hiera_base_plus": _cfg(
hf_hub_id='facebook/sam2-hiera-base-plus',
hf_hub_filename='sam2_hiera_base_plus.pt',
"sam2_hiera_small.fb_r896": _cfg(
# hf_hub_id='facebook/sam2-hiera-small',
# hf_hub_filename='sam2_hiera_small.pt',
hf_hub_id='timm/',
),
"sam2_hiera_large": _cfg(
hf_hub_id='facebook/sam2-hiera-large',
hf_hub_filename='sam2_hiera_large.pt',
"sam2_hiera_small.fb_r896_2pt1": _cfg(
# hf_hub_id='facebook/sam2.1-hiera-small',
# hf_hub_filename='sam2.1_hiera_small.pt',
hf_hub_id='timm/',
),
"sam2_hiera_base_plus.fb_r896": _cfg(
# hf_hub_id='facebook/sam2-hiera-base-plus',
# hf_hub_filename='sam2_hiera_base_plus.pt',
hf_hub_id='timm/',
),
"sam2_hiera_base_plus.fb_r896_2pt1": _cfg(
# hf_hub_id='facebook/sam2.1-hiera-base-plus',
# hf_hub_filename='sam2.1_hiera_base_plus.pt',
hf_hub_id='timm/',
),
"sam2_hiera_large.fb_r1024": _cfg(
# hf_hub_id='facebook/sam2-hiera-large',
# hf_hub_filename='sam2_hiera_large.pt',
hf_hub_id='timm/',
min_input_size=(3, 256, 256),
input_size=(3, 1024, 1024), pool_size=(32, 32),
),
"sam2_hiera_large.fb_r1024_2pt1": _cfg(
# hf_hub_id='facebook/sam2.1-hiera-large',
# hf_hub_filename='sam2.1_hiera_large.pt',
hf_hub_id='timm/',
min_input_size=(3, 256, 256),
input_size=(3, 1024, 1024), pool_size=(32, 32),
),
Expand Down Expand Up @@ -578,11 +599,11 @@ def checkpoint_filter_fn(state_dict, model=None, prefix=''):
def _create_hiera_det(variant: str, pretrained: bool = False, **kwargs) -> HieraDet:
out_indices = kwargs.pop('out_indices', 4)
checkpoint_prefix = ''
if 'sam2' in variant:
# SAM2 pretrained weights have no classifier or final norm-layer (`head.norm`)
# This is workaround loading with num_classes=0 w/o removing norm-layer.
kwargs.setdefault('pretrained_strict', False)
checkpoint_prefix = 'image_encoder.trunk.'
# if 'sam2' in variant:
# # SAM2 pretrained weights have no classifier or final norm-layer (`head.norm`)
# # This is workaround loading with num_classes=0 w/o removing norm-layer.
# kwargs.setdefault('pretrained_strict', False)
# checkpoint_prefix = 'image_encoder.trunk.'
return build_model_with_cfg(
HieraDet,
variant,
Expand Down
Loading

0 comments on commit 790decc

Please sign in to comment.