From 7ff39fa75b9e4714ada2e50c4e66333e4874c72e Mon Sep 17 00:00:00 2001 From: Koichi Yasuoka Date: Fri, 27 Dec 2024 01:23:23 +0900 Subject: [PATCH 1/5] Update modular_modernbert.py --- src/transformers/models/modernbert/modular_modernbert.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/models/modernbert/modular_modernbert.py b/src/transformers/models/modernbert/modular_modernbert.py index dac356146f3015..4424e8b2fead5d 100644 --- a/src/transformers/models/modernbert/modular_modernbert.py +++ b/src/transformers/models/modernbert/modular_modernbert.py @@ -532,7 +532,7 @@ def eager_attention_forward( dim: int, output_attentions: Optional[bool] = False, **_kwargs, -) -> Tuple[torch.Tensor, torch.Tensor] | Tuple[torch.Tensor]: +) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]: # qkv: [batch_size, seqlen, 3, nheads, headdim] cos, sin = module.rotary_emb(qkv, position_ids=position_ids) query, key, value = qkv.transpose(3, 1).unbind(dim=2) From e51b48ab108aea43f9fca0f7dc35efd783c39128 Mon Sep 17 00:00:00 2001 From: Koichi Yasuoka Date: Fri, 27 Dec 2024 17:42:07 +0900 Subject: [PATCH 2/5] support {set,get}_input_embeddings --- .../models/modernbert/modular_modernbert.py | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/src/transformers/models/modernbert/modular_modernbert.py b/src/transformers/models/modernbert/modular_modernbert.py index 4424e8b2fead5d..90d15c44905a0e 100644 --- a/src/transformers/models/modernbert/modular_modernbert.py +++ b/src/transformers/models/modernbert/modular_modernbert.py @@ -1278,6 +1278,12 @@ def __init__(self, config: ModernBertConfig): # Initialize weights and apply final processing self.post_init() + def get_input_embeddings(self): + return self.embeddings.tok_embeddings + + def set_input_embeddings(self, value): + self.embeddings.tok_embeddings = value + @add_start_docstrings_to_model_forward(MODERNBERT_INPUTS_DOCSTRING) @add_code_sample_docstrings( checkpoint=_CHECKPOINT_FOR_DOC, @@ -1389,6 +1395,12 @@ def __init__(self, config: ModernBertConfig): # Initialize weights and apply final processing self.post_init() + def get_input_embeddings(self): + return self.embeddings.tok_embeddings + + def set_input_embeddings(self, value): + self.embeddings.tok_embeddings = value + @add_start_docstrings_to_model_forward(MODERNBERT_INPUTS_DOCSTRING) @add_code_sample_docstrings( checkpoint=_CHECKPOINT_FOR_DOC, From be15c358e07bf638167be8fbe3cfe42b1aea1260 Mon Sep 17 00:00:00 2001 From: Koichi Yasuoka Date: Sun, 29 Dec 2024 17:41:56 +0900 Subject: [PATCH 3/5] sync to modular_modernbert.py --- .../models/modernbert/modeling_modernbert.py | 14 +++++++++++++- 1 file changed, 13 insertions(+), 1 deletion(-) diff --git a/src/transformers/models/modernbert/modeling_modernbert.py b/src/transformers/models/modernbert/modeling_modernbert.py index 237fba6f645fa5..3cea820ac1fcf1 100644 --- a/src/transformers/models/modernbert/modeling_modernbert.py +++ b/src/transformers/models/modernbert/modeling_modernbert.py @@ -307,7 +307,7 @@ def eager_attention_forward( dim: int, output_attentions: Optional[bool] = False, **_kwargs, -) -> Tuple[torch.Tensor, torch.Tensor] | Tuple[torch.Tensor]: +) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]: # qkv: [batch_size, seqlen, 3, nheads, headdim] cos, sin = module.rotary_emb(qkv, position_ids=position_ids) query, key, value = qkv.transpose(3, 1).unbind(dim=2) @@ -1125,6 +1125,12 @@ def __init__(self, config: ModernBertConfig): # Initialize weights and apply final processing self.post_init() + def get_input_embeddings(self): + return self.embeddings.tok_embeddings + + def set_input_embeddings(self, value): + self.embeddings.tok_embeddings = value + @add_start_docstrings_to_model_forward(MODERNBERT_INPUTS_DOCSTRING) @add_code_sample_docstrings( checkpoint=_CHECKPOINT_FOR_DOC, @@ -1236,6 +1242,12 @@ def __init__(self, config: ModernBertConfig): # Initialize weights and apply final processing self.post_init() + def get_input_embeddings(self): + return self.embeddings.tok_embeddings + + def set_input_embeddings(self, value): + self.embeddings.tok_embeddings = value + @add_start_docstrings_to_model_forward(MODERNBERT_INPUTS_DOCSTRING) @add_code_sample_docstrings( checkpoint=_CHECKPOINT_FOR_DOC, From 098f872b8a305faef237db627441a345c5731326 Mon Sep 17 00:00:00 2001 From: Koichi Yasuoka Date: Tue, 31 Dec 2024 18:02:27 +0900 Subject: [PATCH 4/5] embeddings are in self.model --- src/transformers/models/modernbert/modular_modernbert.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/transformers/models/modernbert/modular_modernbert.py b/src/transformers/models/modernbert/modular_modernbert.py index 90d15c44905a0e..fb65d1d1eae3d9 100644 --- a/src/transformers/models/modernbert/modular_modernbert.py +++ b/src/transformers/models/modernbert/modular_modernbert.py @@ -1279,10 +1279,10 @@ def __init__(self, config: ModernBertConfig): self.post_init() def get_input_embeddings(self): - return self.embeddings.tok_embeddings + return self.model.embeddings.tok_embeddings def set_input_embeddings(self, value): - self.embeddings.tok_embeddings = value + self.model.embeddings.tok_embeddings = value @add_start_docstrings_to_model_forward(MODERNBERT_INPUTS_DOCSTRING) @add_code_sample_docstrings( @@ -1396,10 +1396,10 @@ def __init__(self, config: ModernBertConfig): self.post_init() def get_input_embeddings(self): - return self.embeddings.tok_embeddings + return self.model.embeddings.tok_embeddings def set_input_embeddings(self, value): - self.embeddings.tok_embeddings = value + self.model.embeddings.tok_embeddings = value @add_start_docstrings_to_model_forward(MODERNBERT_INPUTS_DOCSTRING) @add_code_sample_docstrings( From 2929b085a65bec7a9d3b973f02b1cfa53da80c59 Mon Sep 17 00:00:00 2001 From: Koichi Yasuoka Date: Tue, 31 Dec 2024 18:06:41 +0900 Subject: [PATCH 5/5] sync to modular_modernbert.py --- src/transformers/models/modernbert/modeling_modernbert.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/transformers/models/modernbert/modeling_modernbert.py b/src/transformers/models/modernbert/modeling_modernbert.py index 3cea820ac1fcf1..00d529824f8d16 100644 --- a/src/transformers/models/modernbert/modeling_modernbert.py +++ b/src/transformers/models/modernbert/modeling_modernbert.py @@ -1126,10 +1126,10 @@ def __init__(self, config: ModernBertConfig): self.post_init() def get_input_embeddings(self): - return self.embeddings.tok_embeddings + return self.model.embeddings.tok_embeddings def set_input_embeddings(self, value): - self.embeddings.tok_embeddings = value + self.model.embeddings.tok_embeddings = value @add_start_docstrings_to_model_forward(MODERNBERT_INPUTS_DOCSTRING) @add_code_sample_docstrings( @@ -1243,10 +1243,10 @@ def __init__(self, config: ModernBertConfig): self.post_init() def get_input_embeddings(self): - return self.embeddings.tok_embeddings + return self.model.embeddings.tok_embeddings def set_input_embeddings(self, value): - self.embeddings.tok_embeddings = value + self.model.embeddings.tok_embeddings = value @add_start_docstrings_to_model_forward(MODERNBERT_INPUTS_DOCSTRING) @add_code_sample_docstrings(