Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Predict for Python-Backend #450

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions backends/core/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,15 @@ pub enum ModelType {
Embedding(Pool),
}

impl fmt::Display for ModelType {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match self {
ModelType::Classifier => write!(f, "classifier"),
ModelType::Embedding(_) => write!(f, "embedding"),
}
}
}

#[derive(Debug, PartialEq, Clone)]
#[cfg_attr(feature = "clap", derive(ValueEnum))]
pub enum Pool {
Expand Down
21 changes: 21 additions & 0 deletions backends/grpc-client/src/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -64,4 +64,25 @@ impl Client {
let response = self.stub.embed(request).await?.into_inner();
Ok(response.embeddings)
}

#[instrument(skip_all)]
pub async fn predict(
&mut self,
input_ids: Vec<u32>,
token_type_ids: Vec<u32>,
position_ids: Vec<u32>,
cu_seq_lengths: Vec<u32>,
max_length: u32,
) -> Result<Vec<Score>> {
let request = tonic::Request::new(EmbedRequest {
input_ids,
token_type_ids,
position_ids,
max_length,
cu_seq_lengths,
})
.inject_context();
let response = self.stub.predict(request).await?.into_inner();
Ok(response.scores)
}
}
10 changes: 10 additions & 0 deletions backends/proto/embed.proto
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@ service EmbeddingService {
rpc Embed (EmbedRequest) returns (EmbedResponse);
/// Health check
rpc Health (HealthRequest) returns (HealthResponse);
/// Predict
rpc Predict (EmbedRequest) returns (PredictResponse);
}

message HealthRequest {}
Expand All @@ -28,3 +30,11 @@ message Embedding {
message EmbedResponse {
repeated Embedding embeddings = 1;
}

message Score {
repeated float values = 1;
}

message PredictResponse {
repeated Score scores = 1;
}
5 changes: 4 additions & 1 deletion backends/python/server/Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,10 @@ install: gen-server
pip install -e .

run-dev:
python text_embeddings_server/cli.py serve BAAI/bge-small-en
python text_embeddings_server/cli.py BAAI/bge-small-en

run-reranker-dev:
python text_embeddings_server/cli.py mixedbread-ai/mxbai-rerank-xsmall-v1 --model-type classifier

export-requirements:
poetry export -o requirements.txt --without-hashes
6 changes: 5 additions & 1 deletion backends/python/server/text_embeddings_server/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,9 @@ class Dtype(str, Enum):
float16 = "float16"
bloat16 = "bfloat16"

class ModelType(str, Enum):
embedding = "embedding"
classifier = "classifier"

@app.command()
def serve(
Expand All @@ -25,6 +28,7 @@ def serve(
otlp_endpoint: Optional[str] = None,
otlp_service_name: str = "text-embeddings-inference.server",
pool: str = "cls",
model_type: ModelType = "embedding",
):
# Remove default handler
logger.remove()
Expand All @@ -49,7 +53,7 @@ def serve(
# Downgrade enum into str for easier management later on
dtype = None if dtype is None else dtype.value

server.serve(model_path, dtype, uds_path, pool)
server.serve(model_path, dtype, uds_path, pool, model_type.value)


if __name__ == "__main__":
Expand Down
16 changes: 13 additions & 3 deletions backends/python/server/text_embeddings_server/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,11 @@
from typing import Optional
from transformers import AutoConfig
from transformers.models.bert import BertConfig
from typing import TYPE_CHECKING

from text_embeddings_server.models.model import Model
from text_embeddings_server.models.default_model import DefaultModel
from text_embeddings_server.models.predict_model import PredictModel

__all__ = ["Model"]

Expand All @@ -25,7 +27,7 @@
__all__.append(FlashBert)


def get_model(model_path: Path, dtype: Optional[str], pool: str):
def get_model(model_path: Path, dtype: Optional[str], pool: str, model_type: str):
if dtype == "float32":
dtype = torch.float32
elif dtype == "float16":
Expand Down Expand Up @@ -54,6 +56,14 @@ def get_model(model_path: Path, dtype: Optional[str], pool: str):
raise ValueError("FlashBert only supports cls pooling")
return FlashBert(model_path, device, dtype)
else:
return DefaultModel(model_path, device, dtype, pool)

return DefaultModel(model_path, device, dtype)
# predict
if model_type == "classifier" and config.model_type in [
"roberta",
"xlm-roberta",
"bert",
"deberta-v2"
]:
return PredictModel(model_path, device, dtype)
breakpoint()
raise NotImplementedError
6 changes: 4 additions & 2 deletions backends/python/server/text_embeddings_server/models/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from abc import ABC, abstractmethod
from typing import List, TypeVar, Type

from text_embeddings_server.models.types import Batch, Embedding
from text_embeddings_server.models.types import Batch, Embedding, Score

B = TypeVar("B", bound=Batch)

Expand All @@ -24,6 +24,8 @@ def __init__(
def batch_type(self) -> Type[B]:
raise NotImplementedError

@abstractmethod
def embed(self, batch: B) -> List[Embedding]:
raise NotImplementedError

def predict(self, batch: B) -> List[Score]:
raise NotImplementedError
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
import inspect
import torch

from pathlib import Path
from typing import Type, List
from transformers import AutoModelForSequenceClassification, AutoTokenizer
from opentelemetry import trace

from text_embeddings_server.models import Model
from text_embeddings_server.models.types import PaddedBatch, Embedding, Score

tracer = trace.get_tracer(__name__)


class PredictModel(Model):
def __init__(
self, model_path: Path, device: torch.device, dtype: torch.dtype
):
model = (
AutoModelForSequenceClassification.from_pretrained(model_path)
.to(dtype)
.to(device)
)

self.has_position_ids = (
inspect.signature(model.forward).parameters.get("position_ids", None)
is not None
)
self.has_token_type_ids = (
inspect.signature(model.forward).parameters.get("token_type_ids", None)
is not None
)

super(PredictModel, self).__init__(model=model, dtype=dtype, device=device)

@property
def batch_type(self) -> Type[PaddedBatch]:
return PaddedBatch

@tracer.start_as_current_span("predict")
def predict(self, batch: PaddedBatch) -> List[Embedding]:
kwargs = {"input_ids": batch.input_ids, "attention_mask": batch.attention_mask}
if self.has_token_type_ids:
kwargs["token_type_ids"] = batch.token_type_ids
if self.has_position_ids:
kwargs["position_ids"] = batch.position_ids

logits = self.model(**kwargs).logits

cpu_results = logits.cpu().tolist()

return [Score(values=cpu_results[i]) for i in range(len(batch))]
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,9 @@
from opentelemetry import trace

from text_embeddings_server.pb import embed_pb2
from text_embeddings_server.pb.embed_pb2 import Embedding
from text_embeddings_server.pb.embed_pb2 import Embedding, Score

__all__ = ["Embedding", "Score"]

tracer = trace.get_tracer(__name__)

Expand Down
15 changes: 12 additions & 3 deletions backends/python/server/text_embeddings_server/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,21 +31,30 @@ async def Embed(self, request, context):

return embed_pb2.EmbedResponse(embeddings=embeddings)

async def Predict(self, request, context):
batch = self.model.batch_type.from_pb(request, self.model.device)

scores = self.model.predict(batch)

return embed_pb2.PredictResponse(scores=scores)


def serve(
model_path: Path,
dtype: Optional[str],
uds_path: Path,
pool: str,
model_type: str
):
async def serve_inner(
model_path: Path,
dtype: Optional[str] = None,
dtype: Optional[str],
model_type: str,
):
unix_socket = f"unix://{uds_path}"

try:
model = get_model(model_path, dtype, pool)
model = get_model(model_path, dtype, pool, model_type)
except Exception:
logger.exception("Error when initializing model")
raise
Expand Down Expand Up @@ -76,4 +85,4 @@ async def serve_inner(
logger.info("Signal received. Shutting down")
await server.stop(0)

asyncio.run(serve_inner(model_path, dtype))
asyncio.run(serve_inner(model_path, dtype, model_type))
42 changes: 31 additions & 11 deletions backends/python/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,13 +24,9 @@ impl PythonBackend {
otlp_endpoint: Option<String>,
otlp_service_name: String,
) -> Result<Self, BackendError> {
let pool = match model_type {
ModelType::Classifier => {
return Err(BackendError::Start(
"`classifier` model type is not supported".to_string(),
))
}
ModelType::Embedding(pool) => pool,
let pool = match model_type.clone() {
ModelType::Classifier => None,
ModelType::Embedding(pool) => Some(pool),
};

let backend_process = management::BackendProcess::new(
Expand All @@ -40,6 +36,7 @@ impl PythonBackend {
otlp_endpoint,
otlp_service_name,
pool,
model_type,
)?;
let tokio_runtime = tokio::runtime::Builder::new_current_thread()
.enable_all()
Expand Down Expand Up @@ -105,9 +102,32 @@ impl Backend for PythonBackend {
Ok(embeddings)
}

fn predict(&self, _batch: Batch) -> Result<Predictions, BackendError> {
Err(BackendError::Inference(
"`predict` is not implemented".to_string(),
))
fn predict(&self, batch: Batch) -> Result<Predictions, BackendError> {
if !batch.raw_indices.is_empty() {
return Err(BackendError::Inference(
"raw embeddings are not supported for the Python backend.".to_string(),
));
}
let batch_size = batch.len();
let results = self
.tokio_runtime
.block_on(self.backend_client.clone().predict(
batch.input_ids,
batch.token_type_ids,
batch.position_ids,
batch.cumulative_seq_lengths,
batch.max_length,
))
.map_err(|err| BackendError::Inference(err.to_string()))?;
let raw_results: Vec<Vec<f32>> = results.into_iter().map(|r| r.values).collect();

let mut predictions =
HashMap::with_capacity_and_hasher(batch_size, BuildNoHashHasher::default());

for (i, r) in raw_results.into_iter().enumerate() {
predictions.insert(i, r);
}

Ok(predictions)
}
}
25 changes: 16 additions & 9 deletions backends/python/src/management.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ use std::sync::mpsc;
use std::thread::sleep;
use std::time::{Duration, Instant};
use std::{env, fs, io, thread};
use text_embeddings_backend_core::{BackendError, Pool};
use text_embeddings_backend_core::{BackendError, ModelType, Pool};

#[derive(Debug)]
pub(crate) struct BackendProcess {
Expand All @@ -22,7 +22,8 @@ impl BackendProcess {
uds_path: &str,
otlp_endpoint: Option<String>,
otlp_service_name: String,
pool: Pool,
pool: Option<Pool>,
modeltype: ModelType,
) -> Result<Self, BackendError> {
// Get UDS path
let uds = Path::new(uds_path);
Expand All @@ -32,13 +33,14 @@ impl BackendProcess {
fs::remove_file(uds).expect("could not remove UDS file");
}

let pool = match pool {
Pool::Cls => "cls",
Pool::Mean => "mean",
Pool::LastToken => "lasttoken",
Pool::Splade => {
let pool_str = match pool {
Some(Pool::Cls) => "cls",
Some(Pool::Mean) => "mean",
Some(Pool::LastToken) => "lasttoken",
Some(Pool::Splade) => {
return Err(BackendError::Start(format!("{pool:?} is not supported")));
}
None => "",
};

// Process args
Expand All @@ -51,9 +53,14 @@ impl BackendProcess {
"--logger-level".to_owned(),
"INFO".to_owned(),
"--json-output".to_owned(),
"--pool".to_owned(),
pool.to_owned(),
"--model-type".to_owned(),
modeltype.to_string(),
];
// Add `--pool` argument only if `pool` is not `None`
if !pool_str.is_empty() {
python_server_args.push("--pool".to_owned());
python_server_args.push(pool_str.to_owned());
}

// OpenTelemetry
if let Some(otlp_endpoint) = otlp_endpoint {
Expand Down