Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add attentive layer to Jepa #927

Merged
merged 40 commits into from
Dec 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
40 commits
Select commit Hold shift + click to select a range
7fca3ca
add init function to the builders
Dec 19, 2024
7623e1b
add builder skeleton for the AttentivePooler
Dec 19, 2024
7b959fe
refactor init_module function
Dec 19, 2024
cef6687
refactor init_module function
Dec 19, 2024
c7a9021
update cross attention layer
Dec 19, 2024
ab93be9
update cross attn layer
Dec 19, 2024
f745bea
Cosmetic updates
cbalioglu Dec 19, 2024
60ec4c1
add forward() function
Dec 19, 2024
854b68e
Can's comments
Dec 19, 2024
328f8ca
fix git rebase
Dec 19, 2024
f57108d
fix git rebase
Dec 19, 2024
2d27bab
lint
Dec 19, 2024
8d7dfaf
lint
Dec 19, 2024
4b69946
rebase
Dec 19, 2024
e86afaa
flake8
Dec 19, 2024
2a76edd
remove commits remnant
Dec 19, 2024
250f1ee
black
Dec 19, 2024
f4aaf33
black
Dec 19, 2024
e4d9a0a
add builder func
Dec 20, 2024
b8cfd50
revert remnant codes
Dec 20, 2024
a6987ab
revert remnant codes
Dec 20, 2024
837cc6c
revert remnant codes
Dec 20, 2024
25eb3f0
lint
Dec 20, 2024
be3b6e2
lint
Dec 20, 2024
8f90974
rebase
Dec 22, 2024
2f685e8
nit import clean
Dec 22, 2024
919b305
nit rename layers
Dec 22, 2024
c771767
update factory
Dec 22, 2024
a43f683
lint
Dec 22, 2024
b76bbe5
fix typo
Dec 24, 2024
88b2d95
allow unstricted model loading
Dec 26, 2024
57987da
Feedback commit
cbalioglu Dec 23, 2024
4690e6b
update cross_attn build func
Dec 24, 2024
a345c4c
lint
Dec 26, 2024
f96344b
update AttentivePooler param names
Dec 26, 2024
75472ab
decouple #938
Dec 26, 2024
8993698
lint
Dec 26, 2024
4394a0a
lint
Dec 26, 2024
05c979f
lint
Dec 26, 2024
fc3b5c9
lint
Dec 26, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 29 additions & 0 deletions src/fairseq2/models/jepa/classifier/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from __future__ import annotations

import fairseq2.models.jepa.classifier.archs # Register architectures
from fairseq2.models.jepa.classifier.factory import (
JEPA_CLASSIFIER_FAMILY as JEPA_CLASSIFIER_FAMILY,
)
from fairseq2.models.jepa.classifier.factory import (
JepaClassifierBuilder as JepaClassifierBuilder,
)
from fairseq2.models.jepa.classifier.factory import (
JepaClassifierConfig as JepaClassifierConfig,
)
from fairseq2.models.jepa.classifier.factory import (
create_jepa_classifier_model as create_jepa_classifier_model,
)
from fairseq2.models.jepa.classifier.factory import (
jepa_classifier_archs as jepa_classifier_archs,
)
from fairseq2.models.jepa.classifier.model import (
JepaClassifierModel as JepaClassifierModel,
)

# isort: split
33 changes: 33 additions & 0 deletions src/fairseq2/models/jepa/classifier/archs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from __future__ import annotations

from fairseq2.models.jepa.archs import base as jepa_base
from fairseq2.models.jepa.archs import huge as jepa_huge
from fairseq2.models.jepa.archs import large as jepa_large
from fairseq2.models.jepa.classifier.factory import (
JepaClassifierConfig,
jepa_classifier_arch,
)


@jepa_classifier_arch("base")
def base() -> JepaClassifierConfig:
pretrain_config = jepa_base()
return JepaClassifierConfig(encoder_config=pretrain_config.encoder_config)


@jepa_classifier_arch("large")
def large() -> JepaClassifierConfig:
pretrain_config = jepa_large()
return JepaClassifierConfig(encoder_config=pretrain_config.encoder_config)


@jepa_classifier_arch("huge")
def huge() -> JepaClassifierConfig:
pretrain_config = jepa_huge()
return JepaClassifierConfig(encoder_config=pretrain_config.encoder_config)
184 changes: 184 additions & 0 deletions src/fairseq2/models/jepa/classifier/factory.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,184 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from dataclasses import dataclass, field
from typing import final

from fairseq2.config_registry import ConfigRegistry
from fairseq2.models.factory import model_factories
from fairseq2.models.jepa import JepaEncoderBuilder, JepaEncoderConfig
from fairseq2.models.jepa.classifier.model import (
AttentivePooler,
CrossAttentionDecoderLayer,
JepaClassifierModel,
)
from fairseq2.nn.projection import IdentityProjection, Linear, Projection
from fairseq2.nn.transformer import (
MultiheadAttention,
StandardMultiheadAttention,
create_default_sdpa,
)
from fairseq2.typing import DataType, Device

JEPA_CLASSIFIER_FAMILY = "jepa_classifier"


@dataclass(kw_only=True)
class JepaClassifierConfig:
encoder_config: JepaEncoderConfig = field(
default_factory=lambda: JepaEncoderConfig()
)
"""The configuration of the vision encoder."""

pool_depth: int = 1
"""The pool depth (minimum 1 decoder layer)"""

decoder_projection: bool = True
"""If True, the decoder will have a linear layer on top"""

num_queries: int = 1
"""Number of query tokens in the attention pool layer"""

num_classes: int = 1000
"""Size of classification logits"""


jepa_classifier_archs = ConfigRegistry[JepaClassifierConfig]()

jepa_classifier_arch = jepa_classifier_archs.decorator


@final
class JepaClassifierBuilder:
"""Build a JEPA model fine-tuned for classification"""

_config: JepaClassifierConfig
_encoder_builder: JepaEncoderBuilder
_device: Device | None
_dtype: DataType | None

def __init__(
self,
config: JepaClassifierConfig,
*,
device: Device | None = None,
dtype: DataType | None = None,
) -> None:
self._config = config

self._encoder_builder = JepaEncoderBuilder(
config.encoder_config, device=device, dtype=dtype
)

self._device, self._dtype = device, dtype

def build_model(self) -> JepaClassifierModel:
encoder_frontend = self._encoder_builder.build_frontend()
encoder = self._encoder_builder.build_encoder()
pooler = self.build_pooler()
head = self.build_head()

return JepaClassifierModel(encoder_frontend, encoder, pooler, head)

def build_pooler(self) -> AttentivePooler:
config = self._config

if config.pool_depth > 1:
encoder = self._encoder_builder.build_encoder(config.pool_depth)
else:
encoder = None

decoder = self.build_decoder_layer()

return AttentivePooler(
decoder=decoder,
encoder=encoder,
num_queries=config.num_queries,
init_std=config.encoder_config.init_std,
device=self._device,
dtype=self._dtype,
)

def build_head(self) -> Projection:
config = self._config
return Linear(
config.encoder_config.model_dim,
config.num_classes,
device=self._device,
dtype=self._dtype,
bias=True,
)

def build_decoder_layer(self) -> CrossAttentionDecoderLayer:
config = self._config

cross_attn = self.build_cross_attention()

ffn = self._encoder_builder.build_ffn(config.pool_depth)

return CrossAttentionDecoderLayer(
cross_attn,
ffn,
layer_norm_factory=self._encoder_builder.build_layer_norm,
device=self._device,
dtype=self._dtype,
)

def build_cross_attention(self) -> MultiheadAttention:
config = self._config.encoder_config

model_dim = config.model_dim

sdpa = create_default_sdpa(attn_dropout_p=config.attn_dropout_p)

output_proj = self.build_cross_attn_output_projection()

return StandardMultiheadAttention(
model_dim,
config.num_encoder_attn_heads,
sdpa=sdpa,
bias=config.qkv_bias,
output_proj=output_proj,
device=self._device,
dtype=self._dtype,
)

def build_cross_attn_output_projection(self) -> Projection:
config = self._config

model_dim = config.encoder_config.model_dim

if config.decoder_projection:
return Linear(
model_dim,
model_dim,
bias=True,
device=self._device,
dtype=self._dtype,
)
else:
return IdentityProjection(model_dim, model_dim)


def create_jepa_classifier_model(
config: JepaClassifierConfig,
*,
device: Device | None = None,
dtype: DataType | None = None,
) -> JepaClassifierModel:
return JepaClassifierBuilder(
config,
device=device,
dtype=dtype,
).build_model()


model_factories.register(
JEPA_CLASSIFIER_FAMILY,
create_jepa_classifier_model,
JepaClassifierConfig,
jepa_classifier_archs,
)
Loading
Loading