diff --git a/src/transformers/models/modernbert/modeling_modernbert.py b/src/transformers/models/modernbert/modeling_modernbert.py index 237fba6f645fa5..00d529824f8d16 100644 --- a/src/transformers/models/modernbert/modeling_modernbert.py +++ b/src/transformers/models/modernbert/modeling_modernbert.py @@ -307,7 +307,7 @@ def eager_attention_forward( dim: int, output_attentions: Optional[bool] = False, **_kwargs, -) -> Tuple[torch.Tensor, torch.Tensor] | Tuple[torch.Tensor]: +) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]: # qkv: [batch_size, seqlen, 3, nheads, headdim] cos, sin = module.rotary_emb(qkv, position_ids=position_ids) query, key, value = qkv.transpose(3, 1).unbind(dim=2) @@ -1125,6 +1125,12 @@ def __init__(self, config: ModernBertConfig): # Initialize weights and apply final processing self.post_init() + def get_input_embeddings(self): + return self.model.embeddings.tok_embeddings + + def set_input_embeddings(self, value): + self.model.embeddings.tok_embeddings = value + @add_start_docstrings_to_model_forward(MODERNBERT_INPUTS_DOCSTRING) @add_code_sample_docstrings( checkpoint=_CHECKPOINT_FOR_DOC, @@ -1236,6 +1242,12 @@ def __init__(self, config: ModernBertConfig): # Initialize weights and apply final processing self.post_init() + def get_input_embeddings(self): + return self.model.embeddings.tok_embeddings + + def set_input_embeddings(self, value): + self.model.embeddings.tok_embeddings = value + @add_start_docstrings_to_model_forward(MODERNBERT_INPUTS_DOCSTRING) @add_code_sample_docstrings( checkpoint=_CHECKPOINT_FOR_DOC, diff --git a/src/transformers/models/modernbert/modular_modernbert.py b/src/transformers/models/modernbert/modular_modernbert.py index dac356146f3015..fb65d1d1eae3d9 100644 --- a/src/transformers/models/modernbert/modular_modernbert.py +++ b/src/transformers/models/modernbert/modular_modernbert.py @@ -532,7 +532,7 @@ def eager_attention_forward( dim: int, output_attentions: Optional[bool] = False, **_kwargs, -) -> Tuple[torch.Tensor, torch.Tensor] | Tuple[torch.Tensor]: +) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]: # qkv: [batch_size, seqlen, 3, nheads, headdim] cos, sin = module.rotary_emb(qkv, position_ids=position_ids) query, key, value = qkv.transpose(3, 1).unbind(dim=2) @@ -1278,6 +1278,12 @@ def __init__(self, config: ModernBertConfig): # Initialize weights and apply final processing self.post_init() + def get_input_embeddings(self): + return self.model.embeddings.tok_embeddings + + def set_input_embeddings(self, value): + self.model.embeddings.tok_embeddings = value + @add_start_docstrings_to_model_forward(MODERNBERT_INPUTS_DOCSTRING) @add_code_sample_docstrings( checkpoint=_CHECKPOINT_FOR_DOC, @@ -1389,6 +1395,12 @@ def __init__(self, config: ModernBertConfig): # Initialize weights and apply final processing self.post_init() + def get_input_embeddings(self): + return self.model.embeddings.tok_embeddings + + def set_input_embeddings(self, value): + self.model.embeddings.tok_embeddings = value + @add_start_docstrings_to_model_forward(MODERNBERT_INPUTS_DOCSTRING) @add_code_sample_docstrings( checkpoint=_CHECKPOINT_FOR_DOC,