Skip to content

Commit

Permalink
Add Alibaba-NLP/gte-multilingual-reranker-base model support (#37)
Browse files Browse the repository at this point in the history
Signed-off-by: kaixuanliu <[email protected]>
  • Loading branch information
kaixuanliu authored Dec 19, 2024
1 parent 8b1adeb commit 6ebbec2
Show file tree
Hide file tree
Showing 3 changed files with 52 additions and 21 deletions.
54 changes: 38 additions & 16 deletions backends/python/server/text_embeddings_server/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,6 @@
from typing import Optional
from transformers import AutoConfig, BertForMaskedLM
from transformers.models.bert import BertConfig
from transformers.models.auto.modeling_auto import (
MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES,
)

from text_embeddings_server.models.model import Model, B
from text_embeddings_server.models.default_model import DefaultModel
Expand All @@ -18,6 +15,10 @@

HTCORE_AVAILABLE = True
TRUST_REMOTE_CODE = os.getenv("TRUST_REMOTE_CODE", "false").lower() in ["true", "1"]
DISABLE_TENSOR_CACHE = os.getenv("DISABLE_TENSOR_CACHE", "false").lower() in [
"true",
"1",
]

try:
import habana_frameworks.torch.core as htcore
Expand Down Expand Up @@ -72,29 +73,50 @@ def get_model(model_path: Path, dtype: Optional[str], pool: str):
raise ValueError("FlashBert only supports cls pooling")
return FlashBert(model_path, device, dtype)
else:
if (
config.architectures[0]
in MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES.values()
):
return ClassificationModel(model_path, device, dtype)
if config.architectures[0].endswith("Classification"):
return ClassificationModel(
model_path,
device,
dtype,
disable_tensor_cache=DISABLE_TENSOR_CACHE,
trust_remote=TRUST_REMOTE_CODE,
)
elif config.architectures[0] == "BertForMaskedLM":
return DefaultModel(
model_path, device, dtype, pool, trust_remote=TRUST_REMOTE_CODE, model_class=BertForMaskedLM
model_path,
device,
dtype,
pool,
trust_remote=TRUST_REMOTE_CODE,
model_class=BertForMaskedLM,
)
else:
return DefaultModel(
model_path, device, dtype, pool, trust_remote=TRUST_REMOTE_CODE
model_path,
device,
dtype,
pool,
disable_tensor_cache=DISABLE_TENSOR_CACHE,
trust_remote=TRUST_REMOTE_CODE,
)
else:
try:
if (
config.architectures[0]
in MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES.values()
):
return ClassificationModel(model_path, device, dtype)
if config.architectures[0].endswith("Classification"):
return ClassificationModel(
model_path,
device,
dtype,
disable_tensor_cache=DISABLE_TENSOR_CACHE,
trust_remote=TRUST_REMOTE_CODE,
)
else:
return DefaultModel(
model_path, device, dtype, pool, trust_remote=TRUST_REMOTE_CODE
model_path,
device,
dtype,
pool,
disable_tensor_cache=DISABLE_TENSOR_CACHE,
trust_remote=TRUST_REMOTE_CODE,
)
except:
raise RuntimeError(f"Unsupported model_type {config.model_type}")
Original file line number Diff line number Diff line change
Expand Up @@ -17,15 +17,24 @@


class ClassificationModel(Model):
def __init__(self, model_path: Path, device: torch.device, dtype: torch.dtype):
def __init__(
self,
model_path: Path,
device: torch.device,
dtype: torch.dtype,
disable_tensor_cache: bool = False,
trust_remote: bool = False,
):
if device == torch.device("hpu"):
adapt_transformers_to_gaudi()

model = AutoModelForSequenceClassification.from_pretrained(model_path)
model = AutoModelForSequenceClassification.from_pretrained(
model_path, trust_remote_code=trust_remote
)
model = model.to(dtype).to(device)
if device == torch.device("hpu"):
logger.info("Use graph mode for HPU")
model = wrap_in_hpu_graph(model, disable_tensor_cache=True)
model = wrap_in_hpu_graph(model, disable_tensor_cache=disable_tensor_cache)

self.hidden_size = model.config.hidden_size
position_offset = 0
Expand Down Expand Up @@ -68,7 +77,6 @@ def predict(self, batch: PaddedBatch) -> List[Score]:
kwargs["token_type_ids"] = batch.token_type_ids
if self.has_position_ids:
kwargs["position_ids"] = batch.position_ids

output = self.model(**kwargs, return_dict=True)
all_scores = output.logits.tolist()
return [Score(values=scores) for scores in all_scores]
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ def __init__(
device: torch.device,
dtype: torch.dtype,
pool: str = "cls",
disable_tensor_cache: bool = False,
trust_remote: bool = False,
model_class: type[PreTrainedModel] = AutoModel, # type: ignore
):
Expand All @@ -37,7 +38,7 @@ def __init__(

if device == torch.device("hpu"):
logger.info("Use graph mode for HPU")
model = wrap_in_hpu_graph(model, disable_tensor_cache=True)
model = wrap_in_hpu_graph(model, disable_tensor_cache=disable_tensor_cache)
self.hidden_size = model.config.hidden_size
self.vocab_size = model.config.vocab_size
self.pooling_mode = pool
Expand Down

0 comments on commit 6ebbec2

Please sign in to comment.