Skip to content

Commit

Permalink
Splade enabling (#35)
Browse files Browse the repository at this point in the history
Signed-off-by: Daniel Huang <[email protected]>
  • Loading branch information
pi314ever authored Nov 25, 2024
1 parent 1871d33 commit 4b38c43
Show file tree
Hide file tree
Showing 7 changed files with 84 additions and 36 deletions.
19 changes: 10 additions & 9 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -38,15 +38,16 @@ Not all features of TEI are currently supported as this is still a work in progr

## Validated Models

| Architecture | Model Type | Models |
|--------------|------------|--------|
| BERT | Embedding | <li>[BAAI/bge-large-en-v1.5](https://huggingface.co/BAAI/bge-large-en-v1.5)</li><li>[sentence-transformers/all-MiniLM-L6-v2](https://huggingface.co/sentence-transformers/all-MiniLM-L6-v2)</li><li>[sentence-transformers/all-MiniLM-L12-v2](https://huggingface.co/sentence-transformers/all-MiniLM-L12-v2)</li><li>[sentence-transformers/multi-qa-MiniLM-L6-cos-v1](https://huggingface.co/sentence-transformers/multi-qa-MiniLM-L6-cos-v1)</li><li>[sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2](https://huggingface.co/sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2)</li><li>[sentence-transformers/paraphrase-MiniLM-L3-v2](https://huggingface.co/sentence-transformers/paraphrase-MiniLM-L3-v2)</li> |
| MPNet | Embedding | <li>[sentence-transformers/all-mpnet-base-v2](https://huggingface.co/sentence-transformers/all-mpnet-base-v2)</li><li>[sentence-transformers/paraphrase-multilingual-mpnet-base-v2](https://huggingface.co/sentence-transformers/paraphrase-multilingual-mpnet-base-v2)</li><li>[sentence-transformers/multi-qa-mpnet-base-dot-v1](https://huggingface.co/sentence-transformers/multi-qa-mpnet-base-dot-v1)</li> |
| ALBERT | Embedding | <li>[sentence-transformers/paraphrase-albert-small-v2](https://huggingface.co/sentence-transformers/paraphrase-albert-small-v2)</li> |
| Mistral | Embedding | <li>[intfloat/e5-mistral-7b-instruct](https://huggingface.co/intfloat/e5-mistral-7b-instruct)</li><li>[Salesforce/SFR-Embedding-2_R](https://huggingface.co/Salesforce/SFR-Embedding-2_R)</li> |
| GTE | Embedding | <li>[Alibaba-NLP/gte-large-en-v1.5](https://huggingface.co/Alibaba-NLP/gte-large-en-v1.5)</li> |
| JinaBERT | Embedding | <li>[jinaai/jina-embeddings-v2-base-en](https://huggingface.co/jinaai/jina-embeddings-v2-base-en)</li> |
| Roberta | Sequence Classification | <li>[SamLowe/roberta-base-go_emotions](https://huggingface.co/SamLowe/roberta-base-go_emotions)</li> |
| Architecture | Model Type | Pooling | Models |
|--------------|------------|---------|--------|
| BERT | Embedding | Cls | <li>[BAAI/bge-large-en-v1.5](https://huggingface.co/BAAI/bge-large-en-v1.5)</li><li>[sentence-transformers/all-MiniLM-L6-v2](https://huggingface.co/sentence-transformers/all-MiniLM-L6-v2)</li><li>[sentence-transformers/all-MiniLM-L12-v2](https://huggingface.co/sentence-transformers/all-MiniLM-L12-v2)</li><li>[sentence-transformers/multi-qa-MiniLM-L6-cos-v1](https://huggingface.co/sentence-transformers/multi-qa-MiniLM-L6-cos-v1)</li><li>[sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2](https://huggingface.co/sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2)</li><li>[sentence-transformers/paraphrase-MiniLM-L3-v2](https://huggingface.co/sentence-transformers/paraphrase-MiniLM-L3-v2)</li> |
| BERT | Embedding | Splade | <li>[naver/efficient-splade-VI-BT-large-query](https://huggingface.co/naver/efficient-splade-VI-BT-large-query)</li> |
| MPNet | Embedding | Mean | <li>[sentence-transformers/all-mpnet-base-v2](https://huggingface.co/sentence-transformers/all-mpnet-base-v2)</li><li>[sentence-transformers/paraphrase-multilingual-mpnet-base-v2](https://huggingface.co/sentence-transformers/paraphrase-multilingual-mpnet-base-v2)</li><li>[sentence-transformers/multi-qa-mpnet-base-dot-v1](https://huggingface.co/sentence-transformers/multi-qa-mpnet-base-dot-v1)</li> |
| ALBERT | Embedding | Mean | <li>[sentence-transformers/paraphrase-albert-small-v2](https://huggingface.co/sentence-transformers/paraphrase-albert-small-v2)</li> |
| Mistral | Embedding | Last token | <li>[intfloat/e5-mistral-7b-instruct](https://huggingface.co/intfloat/e5-mistral-7b-instruct)</li><li>[Salesforce/SFR-Embedding-2_R](https://huggingface.co/Salesforce/SFR-Embedding-2_R)</li> |
| GTE | Embedding | Cls | <li>[Alibaba-NLP/gte-large-en-v1.5](https://huggingface.co/Alibaba-NLP/gte-large-en-v1.5)</li> |
| JinaBERT | Embedding | Mean | <li>[jinaai/jina-embeddings-v2-base-en](https://huggingface.co/jinaai/jina-embeddings-v2-base-en)</li> |
| Roberta | Sequence Classification | N/A | <li>[SamLowe/roberta-base-go_emotions](https://huggingface.co/SamLowe/roberta-base-go_emotions)</li> |

> The license to use TEI on Habana Gaudi is the one of TEI: https://github.com/huggingface/text-embeddings-inference/blob/main/LICENSE
>
Expand Down
2 changes: 1 addition & 1 deletion backends/python/server/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ safetensors==0.4.1 ; python_version >= "3.9" and python_version < "3.13"
sentencepiece==0.2.0 ; python_version >= "3.9" and python_version < "3.13"
setuptools==68.2.0 ; python_version >= "3.9" and python_version < "3.13"
six==1.16.0 ; python_version >= "3.9" and python_version < "3.13"
sympy==1.12 ; python_version >= "3.9" and python_version < "3.13"
sympy==1.12.1 ; python_version >= "3.9" and python_version < "3.13"
tokenizers==0.20.1 ; python_version >= "3.9" and python_version < "3.13"
tqdm==4.66.2 ; python_version >= "3.9" and python_version < "3.13"
transformers==4.45.2 ; python_version >= "3.9" and python_version < "3.13"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from loguru import logger
from pathlib import Path
from typing import Optional
from transformers import AutoConfig
from transformers import AutoConfig, BertForMaskedLM
from transformers.models.bert import BertConfig
from transformers.models.auto.modeling_auto import (
MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES,
Expand Down Expand Up @@ -77,6 +77,10 @@ def get_model(model_path: Path, dtype: Optional[str], pool: str):
in MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES.values()
):
return ClassificationModel(model_path, device, dtype)
elif config.architectures[0] == "BertForMaskedLM":
return DefaultModel(
model_path, device, dtype, pool, trust_remote=TRUST_REMOTE_CODE, model_class=BertForMaskedLM
)
else:
return DefaultModel(
model_path, device, dtype, pool, trust_remote=TRUST_REMOTE_CODE
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,14 @@
from loguru import logger
from pathlib import Path
from typing import Type, List
from transformers import AutoModel
from sentence_transformers.models import Pooling
from transformers import AutoModel, PreTrainedModel
from opentelemetry import trace

from habana_frameworks.torch.hpu import wrap_in_hpu_graph
from optimum.habana.transformers.modeling_utils import adapt_transformers_to_gaudi

from text_embeddings_server.models import Model
from text_embeddings_server.models.pooling import DefaultPooling, SpladePooling
from text_embeddings_server.models.types import PaddedBatch, Embedding

tracer = trace.get_tracer(__name__)
Expand All @@ -25,19 +25,26 @@ def __init__(
dtype: torch.dtype,
pool: str = "cls",
trust_remote: bool = False,
model_class: type[PreTrainedModel] = AutoModel, # type: ignore
):
if device == torch.device("hpu"):
adapt_transformers_to_gaudi()
model = (
AutoModel.from_pretrained(model_path, trust_remote_code=trust_remote)
.to(dtype)
.to(device)
model_class.from_pretrained(model_path, trust_remote_code=trust_remote) # type: ignore
.to(dtype=dtype)
.to(device=device)
)

if device == torch.device("hpu"):
logger.info("Use graph mode for HPU")
model = wrap_in_hpu_graph(model, disable_tensor_cache=True)
self.hidden_size = model.config.hidden_size
self.pooling = Pooling(self.hidden_size, pooling_mode=pool)
self.vocab_size = model.config.vocab_size
self.pooling_mode = pool
if pool == "splade":
self.pooling = SpladePooling()
else:
self.pooling = DefaultPooling(self.hidden_size, pooling_mode=pool)
position_offset = 0
model_type = model.config.model_type
if model_type in ["xlm-roberta", "camembert", "roberta"]:
Expand Down Expand Up @@ -72,17 +79,19 @@ def embed(self, batch: PaddedBatch) -> List[Embedding]:
kwargs["position_ids"] = batch.position_ids

output = self.model(**kwargs)
pooling_features = {
"token_embeddings": output[0],
"attention_mask": batch.attention_mask,
}
embedding = self.pooling.forward(pooling_features)["sentence_embedding"]
embedding = self.pooling.forward(output, batch.attention_mask)
cpu_results = embedding.reshape(-1).tolist()

step_size = embedding.shape[-1]
if self.pooling_mode == "splade":
assert (
step_size == self.vocab_size
), f"Step size for splade pooling expected vocab size ({self.vocab_size}) but got {step_size}. Check splade pooling implementation"
else:
assert (
step_size == self.hidden_size
), f"Step size expected hidden size ({self.hidden_size}) but got {step_size}. Please check model outputs."
return [
Embedding(
values=cpu_results[i * self.hidden_size : (i + 1) * self.hidden_size]
)
Embedding(values=cpu_results[i * step_size : (i + 1) * step_size])
for i in range(len(batch))
]

Expand Down
40 changes: 40 additions & 0 deletions backends/python/server/text_embeddings_server/models/pooling.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
from abc import ABC, abstractmethod

import torch
from opentelemetry import trace
from sentence_transformers.models import Pooling
from torch import Tensor

tracer = trace.get_tracer(__name__)


class _Pooling(ABC):
@abstractmethod
def forward(self, model_output, attention_mask) -> Tensor:
pass


class DefaultPooling(_Pooling):
def __init__(self, hidden_size, pooling_mode) -> None:
assert (
pooling_mode != "splade"
), "Splade pooling is not supported for DefaultPooling"
self.pooling = Pooling(hidden_size, pooling_mode=pooling_mode)

@tracer.start_as_current_span("pooling")
def forward(self, model_output, attention_mask) -> Tensor:
pooling_features = {
"token_embeddings": model_output[0],
"attention_mask": attention_mask,
}
return self.pooling.forward(pooling_features)["sentence_embedding"]


class SpladePooling(_Pooling):
@tracer.start_as_current_span("pooling")
def forward(self, model_output, attention_mask) -> Tensor:
# Implement Splade pooling
hidden_states = torch.relu(model_output[0])
hidden_states = (1 + hidden_states).log()
hidden_states = torch.mul(hidden_states, attention_mask.unsqueeze(-1))
return hidden_states.max(dim=1).values
12 changes: 3 additions & 9 deletions backends/python/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,15 +24,9 @@ impl PythonBackend {
otlp_endpoint: Option<String>,
otlp_service_name: String,
) -> Result<Self, BackendError> {
let mut pool_type = Pool::Cls;
match model_type {
ModelType::Classifier => {}
ModelType::Embedding(pool) => {
if pool == Pool::Splade {
return Err(BackendError::Start(format!("{pool:?} is not supported")));
}
pool_type = pool;
}
let pool_type = match model_type {
ModelType::Classifier => Pool::Cls,
ModelType::Embedding(pool) => pool,
};

let backend_process = management::BackendProcess::new(
Expand Down
2 changes: 1 addition & 1 deletion backends/python/src/management.rs
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ impl BackendProcess {
Pool::Mean => "mean",
Pool::LastToken => "lasttoken",
Pool::Splade => {
return Err(BackendError::Start(format!("{pool:?} is not supported")));
"splade"
}
};

Expand Down

0 comments on commit 4b38c43

Please sign in to comment.