Skip to content

Commit

Permalink
Add pool::mean support for python backend (#33)
Browse files Browse the repository at this point in the history
Signed-off-by: kaixuanliu <[email protected]>
  • Loading branch information
kaixuanliu authored Sep 27, 2024
1 parent 03a96b3 commit 6488b73
Show file tree
Hide file tree
Showing 8 changed files with 83 additions and 30 deletions.
3 changes: 2 additions & 1 deletion backends/python/server/text_embeddings_server/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ def serve(
json_output: bool = False,
otlp_endpoint: Optional[str] = None,
otlp_service_name: str = "text-embeddings-inference.server",
pool: str = "cls",
):
# Remove default handler
logger.remove()
Expand All @@ -48,7 +49,7 @@ def serve(
# Downgrade enum into str for easier management later on
dtype = None if dtype is None else dtype.value

server.serve(model_path, dtype, uds_path)
server.serve(model_path, dtype, uds_path, pool)


if __name__ == "__main__":
Expand Down
26 changes: 20 additions & 6 deletions backends/python/server/text_embeddings_server/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,9 @@
from typing import Optional
from transformers import AutoConfig
from transformers.models.bert import BertConfig
from transformers.models.auto.modeling_auto import MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES
from transformers.models.auto.modeling_auto import (
MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES,
)

from text_embeddings_server.models.model import Model, B
from text_embeddings_server.models.default_model import DefaultModel
Expand Down Expand Up @@ -37,7 +39,7 @@
__all__.append(FlashBert)


def get_model(model_path: Path, dtype: Optional[str]):
def get_model(model_path: Path, dtype: Optional[str], pool: str):
if dtype == "float32":
dtype = torch.float32
elif dtype == "float16":
Expand Down Expand Up @@ -66,17 +68,29 @@ def get_model(model_path: Path, dtype: Optional[str]):
and dtype in [torch.float16, torch.bfloat16]
and FLASH_ATTENTION
):
if pool != "cls":
raise ValueError("FlashBert only supports cls pooling")
return FlashBert(model_path, device, dtype)
else:
if config.architectures[0] in MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES.values():
if (
config.architectures[0]
in MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES.values()
):
return ClassificationModel(model_path, device, dtype)
else:
return DefaultModel(model_path, device, dtype, trust_remote=TRUST_REMOTE_CODE)
return DefaultModel(
model_path, device, dtype, pool, trust_remote=TRUST_REMOTE_CODE
)
else:
try:
if config.architectures[0] in MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES.values():
if (
config.architectures[0]
in MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES.values()
):
return ClassificationModel(model_path, device, dtype)
else:
return DefaultModel(model_path, device, dtype, trust_remote=TRUST_REMOTE_CODE)
return DefaultModel(
model_path, device, dtype, pool, trust_remote=TRUST_REMOTE_CODE
)
except:
raise RuntimeError(f"Unsupported model_type {config.model_type}")
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,9 @@ def batch_type(self) -> Type[PaddedBatch]:

@tracer.start_as_current_span("embed")
def embed(self, batch):
raise NotImplementedError(f"Embed is not a valid operation for model type {self.model.config.model_type}")
raise NotImplementedError(
f"Embed is not a valid operation for model type {self.model.config.model_type}"
)

@tracer.start_as_current_span("predict")
def predict(self, batch: PaddedBatch) -> List[Score]:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from pathlib import Path
from typing import Type, List
from transformers import AutoModel
from sentence_transformers.models import Pooling
from opentelemetry import trace

from habana_frameworks.torch.hpu import wrap_in_hpu_graph
Expand All @@ -17,18 +18,26 @@


class DefaultModel(Model):
def __init__(self,
model_path: Path,
device: torch.device,
dtype: torch.dtype,
trust_remote: bool=False):
def __init__(
self,
model_path: Path,
device: torch.device,
dtype: torch.dtype,
pool: str = "cls",
trust_remote: bool = False,
):
if device == torch.device("hpu"):
adapt_transformers_to_gaudi()
model = AutoModel.from_pretrained(model_path, trust_remote_code=trust_remote).to(dtype).to(device)
model = (
AutoModel.from_pretrained(model_path, trust_remote_code=trust_remote)
.to(dtype)
.to(device)
)
if device == torch.device("hpu"):
logger.info("Use graph mode for HPU")
model = wrap_in_hpu_graph(model, disable_tensor_cache=True)
self.hidden_size = model.config.hidden_size
self.pooling = Pooling(self.hidden_size, pooling_mode=pool)
position_offset = 0
model_type = model.config.model_type
if model_type in ["xlm-roberta", "camembert", "roberta"]:
Expand Down Expand Up @@ -63,7 +72,11 @@ def embed(self, batch: PaddedBatch) -> List[Embedding]:
kwargs["position_ids"] = batch.position_ids

output = self.model(**kwargs)
embedding = output[0][:, 0]
pooling_features = {
"token_embeddings": output[0],
"attention_mask": batch.attention_mask,
}
embedding = self.pooling.forward(pooling_features)["sentence_embedding"]
cpu_results = embedding.reshape(-1).tolist()

return [
Expand All @@ -75,4 +88,6 @@ def embed(self, batch: PaddedBatch) -> List[Embedding]:

@tracer.start_as_current_span("predict")
def predict(self, batch):
raise NotImplementedError(f"Predict is not a valid operation for model type {self.model.config.model_type}")
raise NotImplementedError(
f"Predict is not a valid operation for model type {self.model.config.model_type}"
)
19 changes: 10 additions & 9 deletions backends/python/server/text_embeddings_server/models/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,15 +9,19 @@
from text_embeddings_server.pb.embed_pb2 import Embedding, Score

tracer = trace.get_tracer(__name__)
PAD_SEQUENCE_TO_MULTIPLE_OF = int(os.environ.get('PAD_SEQUENCE_TO_MULTIPLE_OF', 128))
PAD_SEQUENCE_TO_MULTIPLE_OF = int(os.environ.get("PAD_SEQUENCE_TO_MULTIPLE_OF", 128))


def round_up(number, k):
return (number + k - 1) // k * k


class Batch(ABC):
@classmethod
@abstractmethod
def from_pb(cls, pb: embed_pb2.EmbedRequest, device: torch.device, *args, **kwargs) -> "Batch":
def from_pb(
cls, pb: embed_pb2.EmbedRequest, device: torch.device, *args, **kwargs
) -> "Batch":
raise NotImplementedError

@abstractmethod
Expand All @@ -34,10 +38,9 @@ class PaddedBatch(Batch):

@classmethod
@tracer.start_as_current_span("from_pb")
def from_pb(cls,
pb: embed_pb2.EmbedRequest,
device: torch.device,
max_input_length: int) -> "PaddedBatch":
def from_pb(
cls, pb: embed_pb2.EmbedRequest, device: torch.device, max_input_length: int
) -> "PaddedBatch":
if pb.max_length > max_input_length:
raise RuntimeError(f"input length exceeds model config's max_input_length")

Expand All @@ -46,9 +49,7 @@ def from_pb(cls,
batch_size = len(pb.cu_seq_lengths) - 1
new_bs = 2 ** math.ceil(math.log2(batch_size))
# Allocate padded tensors all at once
all_tensors = torch.zeros(
[4, new_bs, max_length], dtype=torch.int32
)
all_tensors = torch.zeros([4, new_bs, max_length], dtype=torch.int32)
for i, start_index in enumerate(pb.cu_seq_lengths[:-1]):
end_index = pb.cu_seq_lengths[i + 1]
input_length = end_index - start_index
Expand Down
11 changes: 8 additions & 3 deletions backends/python/server/text_embeddings_server/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,15 +27,19 @@ async def Health(self, request, context):

async def Embed(self, request, context):
max_input_length = self.model.max_input_length
batch = self.model.batch_type.from_pb(request, self.model.device, max_input_length)
batch = self.model.batch_type.from_pb(
request, self.model.device, max_input_length
)

embeddings = self.model.embed(batch)

return embed_pb2.EmbedResponse(embeddings=embeddings)

async def Predict(self, request, context):
max_input_length = self.model.max_input_length
batch = self.model.batch_type.from_pb(request, self.model.device, max_input_length)
batch = self.model.batch_type.from_pb(
request, self.model.device, max_input_length
)

scores = self.model.predict(batch)

Expand All @@ -46,6 +50,7 @@ def serve(
model_path: Path,
dtype: Optional[str],
uds_path: Path,
pool: str,
):
async def serve_inner(
model_path: Path,
Expand All @@ -54,7 +59,7 @@ async def serve_inner(
unix_socket = f"unix://{uds_path}"

try:
model = get_model(model_path, dtype)
model = get_model(model_path, dtype, pool)
except Exception:
logger.exception("Error when initializing model")
raise
Expand Down
5 changes: 4 additions & 1 deletion backends/python/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,12 +24,14 @@ impl PythonBackend {
otlp_endpoint: Option<String>,
otlp_service_name: String,
) -> Result<Self, BackendError> {
let mut pool_type = Pool::Cls;
match model_type {
ModelType::Classifier => {}
ModelType::Embedding(pool) => {
if pool != Pool::Cls {
if pool == Pool::Splade {
return Err(BackendError::Start(format!("{pool:?} is not supported")));
}
pool_type = pool;
}
};

Expand All @@ -39,6 +41,7 @@ impl PythonBackend {
&uds_path,
otlp_endpoint,
otlp_service_name,
pool_type,
)?;
let tokio_runtime = tokio::runtime::Builder::new_current_thread()
.enable_all()
Expand Down
14 changes: 13 additions & 1 deletion backends/python/src/management.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ use std::sync::mpsc;
use std::thread::sleep;
use std::time::{Duration, Instant};
use std::{env, fs, io, thread};
use text_embeddings_backend_core::BackendError;
use text_embeddings_backend_core::{BackendError, Pool};

#[derive(Debug)]
pub(crate) struct BackendProcess {
Expand All @@ -22,6 +22,7 @@ impl BackendProcess {
uds_path: &str,
otlp_endpoint: Option<String>,
otlp_service_name: String,
pool: Pool,
) -> Result<Self, BackendError> {
// Get UDS path
let uds = Path::new(uds_path);
Expand All @@ -31,6 +32,15 @@ impl BackendProcess {
fs::remove_file(uds).expect("could not remove UDS file");
}

let pool = match pool {
Pool::Cls => "cls",
Pool::Mean => "mean",
Pool::LastToken => "lasttoken",
Pool::Splade => {
return Err(BackendError::Start(format!("{pool:?} is not supported")));
}
};

// Process args
let mut python_server_args = vec![
model_path,
Expand All @@ -41,6 +51,8 @@ impl BackendProcess {
"--logger-level".to_owned(),
"INFO".to_owned(),
"--json-output".to_owned(),
"--pool".to_owned(),
pool.to_owned(),
];

// OpenTelemetry
Expand Down

0 comments on commit 6488b73

Please sign in to comment.