Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Perceiver #768

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 8 additions & 7 deletions tests/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
# transformer models don't support many of the spatial / feature based model functionalities
NON_STD_FILTERS = [
'vit_*', 'tnt_*', 'pit_*', 'swin_*', 'coat_*', 'cait_*', '*mixer_*', 'gmlp_*', 'resmlp_*', 'twins_*',
'convit_*', 'levit*', 'visformer*', 'deit*', 'jx_nest_*', 'nest_*', 'xcit_*']
'convit_*', 'levit*', 'visformer*', 'deit*', 'jx_nest_*', 'nest_*', 'xcit_*', 'perceiver*']
NUM_NON_STD = len(NON_STD_FILTERS)

# exclude models that cause specific test failures
Expand All @@ -26,7 +26,7 @@
EXCLUDE_FILTERS = [
'*efficientnet_l2*', '*resnext101_32x48d', '*in21k', '*152x4_bitm', '*101x3_bitm', '*50x3_bitm',
'*nfnet_f3*', '*nfnet_f4*', '*nfnet_f5*', '*nfnet_f6*', '*nfnet_f7*',
'*resnetrs350*', '*resnetrs420*', 'xcit_large_24_p8*']
'*resnetrs350*', '*resnetrs420*', 'xcit_large_24_p8*', 'perceiver_l*']
else:
EXCLUDE_FILTERS = []

Expand Down Expand Up @@ -218,11 +218,12 @@ def test_model_default_cfgs_non_std(model_name, batch_size):

# check first conv(s) names match default_cfg
first_conv = cfg['first_conv']
if isinstance(first_conv, str):
first_conv = (first_conv,)
assert isinstance(first_conv, (tuple, list))
for fc in first_conv:
assert fc + ".weight" in state_dict.keys(), f'{fc} not in model params'
if first_conv is not None:
if isinstance(first_conv, str):
first_conv = (first_conv,)
assert isinstance(first_conv, (tuple, list))
for fc in first_conv:
assert fc + ".weight" in state_dict.keys(), f'{fc} not in model params'


if 'GITHUB_ACTIONS' not in os.environ:
Expand Down
3 changes: 2 additions & 1 deletion timm/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from .nasnet import *
from .nest import *
from .nfnet import *
from .perceiver import *
from .pit import *
from .pnasnet import *
from .regnet import *
Expand All @@ -36,6 +37,7 @@
from .swin_transformer import *
from .tnt import *
from .tresnet import *
from .twins import *
from .vgg import *
from .visformer import *
from .vision_transformer import *
Expand All @@ -44,7 +46,6 @@
from .xception import *
from .xception_aligned import *
from .xcit import *
from .twins import *

from .factory import create_model, split_model_name, safe_model_name
from .helpers import load_checkpoint, resume_checkpoint, model_parameters
Expand Down
6 changes: 0 additions & 6 deletions timm/models/nfnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,6 @@

Official Deepmind JAX code: https://github.com/deepmind/deepmind-research/tree/master/nfnets

Status:
* These models are a work in progress, experiments ongoing.
* Pretrained weights for two models so far, more to come.
* Model details updated to closer match official JAX code now that it's released
* NF-ResNet, NF-RegNet-B, and NFNet-F models supported

Hacked together by / copyright Ross Wightman, 2021.
"""
import math
Expand Down
Loading