Skip to content

Commit

Permalink
feat: introduce MQAphroditeEngine
Browse files Browse the repository at this point in the history
  • Loading branch information
AlpinDale committed Dec 27, 2024
1 parent 9bdf8d5 commit b47a390
Show file tree
Hide file tree
Showing 20 changed files with 1,000 additions and 833 deletions.
6 changes: 3 additions & 3 deletions aphrodite/common/envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@
APHRODITE_DYNAMIC_ROPE_SCALING: bool = False
APHRODITE_TEST_FORCE_FP8_MARLIN: bool = False
APHRODITE_PLUGINS: Optional[List[str]] = None
APHRODITE_RPC_GET_DATA_TIMEOUT_MS: int = 5000
APHRODITE_RPC_TIMEOUT: int = 5000
APHRODITE_FORCE_SINGLE_USER_PREFIX_CACHE: bool = False
APHRODITE_TEST_DYNAMO_GRAPH_CAPTURE: int = 0
APHRODITE_TEST_DYNAMO_FULLGRAPH_CAPTURE: int = 0
Expand Down Expand Up @@ -383,8 +383,8 @@ def get_default_config_root():

# Time in ms for the zmq client to wait for a response from the backend
# server for simple data operations
"APHRODITE_RPC_GET_DATA_TIMEOUT_MS":
lambda: int(os.getenv("APHRODITE_RPC_GET_DATA_TIMEOUT_MS", "5000")),
"APHRODITE_RPC_TIMEOUT":
lambda: int(os.getenv("APHRODITE_RPC_TIMEOUT", "5000")),

# a list of plugin names to load, separated by commas.
# if this is not set, it means all plugins will be loaded
Expand Down
147 changes: 69 additions & 78 deletions aphrodite/endpoints/openai/api_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,6 @@
KAIGenerationInputSchema,
TokenizeRequest,
TokenizeResponse)
from aphrodite.endpoints.openai.rpc.client import AsyncEngineRPCClient
from aphrodite.endpoints.openai.rpc.server import run_rpc_server
# yapf: enable
from aphrodite.endpoints.openai.serving_chat import OpenAIServingChat
from aphrodite.endpoints.openai.serving_completions import (
Expand All @@ -59,7 +57,9 @@
OpenAIServingTokenization)
from aphrodite.engine.args_tools import AsyncEngineArgs
from aphrodite.engine.async_aphrodite import AsyncAphrodite
from aphrodite.engine.protocol import AsyncEngineClient
from aphrodite.engine.protocol import EngineClient
from aphrodite.engine.multiprocessing.client import MQAphroditeEngineClient
from aphrodite.engine.multiprocessing.engine import run_mp_engine
from aphrodite.modeling.model_loader.weight_utils import get_model_config_yaml
from aphrodite.server import serve_http
from aphrodite.transformers_utils.tokenizer import get_tokenizer
Expand All @@ -85,27 +85,16 @@
_running_tasks: Set[asyncio.Task] = set()


def model_is_embedding(model_name: str, trust_remote_code: bool,
quantization: Optional[str]) -> bool:
return ModelConfig(model=model_name,
tokenizer=model_name,
tokenizer_mode="auto",
trust_remote_code=trust_remote_code,
quantization=quantization,
seed=0,
dtype="auto").embedding_mode


@asynccontextmanager
async def lifespan(app: FastAPI):

try:
if app.state.log_stats:
async_engine_client = app.state.engine_client
engine_client: EngineClient = app.state.engine_client
async def _force_log():
while True:
await asyncio.sleep(10)
await async_engine_client.do_log_stats()
await asyncio.sleep(10.)
await engine_client.do_log_stats()
task = asyncio.create_task(_force_log())
_running_tasks.add(task)
task.add_done_callback(_running_tasks.remove)
Expand All @@ -122,36 +111,35 @@ async def _force_log():


@asynccontextmanager
async def build_async_engine_client(
args: Namespace) -> AsyncIterator[Optional[AsyncEngineClient]]:
async def build_engine_client(
args: Namespace) -> AsyncIterator[Optional[EngineClient]]:

# Context manager to handle async_engine_client lifecycle
# Context manager to handle engine_client lifecycle
# Ensures everything is shutdown and cleaned up on error/exit
engine_args = AsyncEngineArgs.from_cli_args(args)

async with build_async_engine_client_from_engine_args(
async with build_engine_client_from_engine_args(
engine_args, args.disable_frontend_multiprocessing) as engine:

yield engine


@asynccontextmanager
async def build_async_engine_client_from_engine_args(
async def build_engine_client_from_engine_args(
engine_args: AsyncEngineArgs,
disable_frontend_multiprocessing: bool = False,
) -> AsyncIterator[Optional[AsyncEngineClient]]:
) -> AsyncIterator[Optional[EngineClient]]:
"""
Create AsyncEngineClient, either:
Create EngineClient, either:
- in-process using the AsyncAphrodite Directly
- multiprocess using AsyncAphrodite RPC
Returns the Client or None if the creation failed.
"""

# If manually triggered or embedding model, use AsyncAphrodite in process.
# TODO: support embedding model via RPC.
if (model_is_embedding(engine_args.model, engine_args.trust_remote_code,
engine_args.quantization)
# Fall back
# TODO: fill out feature matrix.
if (MQAphroditeEngineClient.is_unsupported_config(engine_args)
or disable_frontend_multiprocessing):
engine_config = engine_args.create_engine_config()
uses_ray = getattr(AsyncAphrodite._get_executor_cls(engine_config),
Expand Down Expand Up @@ -186,63 +174,64 @@ async def build_async_engine_client_from_engine_args(
"and Aphrodite will properly handle cleanup.")

# Select random path for IPC.
rpc_path = get_open_zmq_ipc_path()
logger.info(f"Multiprocessing frontend to use {rpc_path} for RPC Path."
)

# Build RPCClient, which conforms to AsyncEngineClient Protocol.
# NOTE: Actually, this is not true yet. We still need to support
# embedding models via RPC (see TODO above)
rpc_client = AsyncEngineRPCClient(rpc_path)
ipc_path = get_open_zmq_ipc_path()
logger.info(
f"Multiprocessing frontend to use {ipc_path} for IPC Path.")

# Start RPCServer in separate process (holds the AsyncAphrodite).
context = multiprocessing.get_context("spawn")
# Start RPCServer in separate process (holds the LLMEngine).
# the current process might have CUDA context,
# so we need to spawn a new process
rpc_server_process = context.Process(
target=run_rpc_server,
args=(engine_args, rpc_path))
rpc_server_process.start()
logger.info(
f"Started engine process with PID {rpc_server_process.pid}")
context = multiprocessing.get_context("spawn")
engine_process = context.Process(target=run_mp_engine,
args=(engine_args,
ipc_path))
engine_process.start()
logger.info(f"Started engine process with PID {engine_process.pid}")
# Build RPCClient, which conforms to EngineClient Protocol.
# NOTE: Actually, this is not true yet. We still need to support
# embedding models via RPC (see TODO above)
engine_config = engine_args.create_engine_config()
mp_engine_client = MQAphroditeEngineClient(ipc_path, engine_config)

try:
while True:
try:
await rpc_client.setup()
await mp_engine_client.setup()
break
except TimeoutError:
if not rpc_server_process.is_alive():
logger.error(
"RPCServer process died before responding "
"to readiness probe")
if not engine_process.is_alive():
logger.error("Engine process died before responding "
"to readiness probe")
yield None
return

yield rpc_client # type: ignore[misc]
yield mp_engine_client # type: ignore[misc]
finally:
# Ensure rpc server process was terminated
rpc_server_process.terminate()
engine_process.terminate()

# Close all open connections to the backend
rpc_client.close()
mp_engine_client.close()

# Wait for server process to join
rpc_server_process.join()
# Wait for engine process to join
engine_process.join(4)
if engine_process.exitcode is None:
# Kill if taking longer than 5 seconds to stop
engine_process.kill()

# Lazy import for prometheus multiprocessing.
# We need to set PROMETHEUS_MULTIPROC_DIR environment variable
# before prometheus_client is imported.
# See https://prometheus.github.io/client_python/multiprocess/
from prometheus_client import multiprocess
multiprocess.mark_process_dead(rpc_server_process.pid)
multiprocess.mark_process_dead(engine_process.pid)


async def _maybe_switch_model(
request_model: str, app_state,
raw_request: Request) -> Optional[ErrorResponse]:
"""Switch to requested model if different from currently loaded one."""
global model_is_loaded, async_engine_client, engine_args, served_model_names
global model_is_loaded, engine_client, engine_args, served_model_names

if not model_is_loaded:
return None
Expand Down Expand Up @@ -334,8 +323,7 @@ async def _maybe_switch_model(
engine_args = AsyncEngineArgs(model=request_model)

# Create new engine client without context manager
if (model_is_embedding(engine_args.model, engine_args.trust_remote_code,
engine_args.quantization)
if (MQAphroditeEngineClient.is_unsupported_config(engine_args)
or args.disable_frontend_multiprocessing):
new_engine_client = AsyncAphrodite.from_engine_args(engine_args)
await new_engine_client.setup()
Expand Down Expand Up @@ -440,7 +428,7 @@ def tokenization(request: Request) -> OpenAIServingTokenization:
def embedding(request: Request) -> OpenAIServingEmbedding:
return request.app.state.openai_serving_embedding

def engine_client(request: Request) -> AsyncEngineClient:
def engine_client(request: Request) -> EngineClient:
return request.app.state.engine_client


Expand All @@ -450,14 +438,18 @@ async def unload_model(raw_request: Request):
logger.info("Received request to unload model.")

try:
args = raw_request.app.state.args
if not args.disable_frontend_multiprocessing:
await engine_client(raw_request).kill()
else:
await engine_client(raw_request).shutdown_background_loop()
current_client = engine_client(raw_request)

try:
await current_client.shutdown_background_loop()
finally:
global model_is_loaded
model_is_loaded = False

# Clean up the client reference from app state
if hasattr(raw_request.app.state, 'engine_client'):
del raw_request.app.state.engine_client

global model_is_loaded
model_is_loaded = False
return JSONResponse(
content={
"status": "success",
Expand All @@ -477,7 +469,7 @@ async def unload_model(raw_request: Request):
@router.post("/v1/model/load")
async def load_model(config_file: UploadFile, raw_request: Request):
"""Load a model using a YAML configuration file."""
global model_is_loaded, async_engine_client, engine_args
global model_is_loaded, engine_client, engine_args

if model_is_loaded:
return JSONResponse(
Expand Down Expand Up @@ -1123,7 +1115,7 @@ async def authentication(request: Request, call_next):


def init_app_state(
async_engine_client: AsyncEngineClient,
engine_client: EngineClient,
model_config: ModelConfig,
state: State,
args: Namespace,
Expand All @@ -1149,11 +1141,11 @@ def init_app_state(
else:
request_logger = RequestLogger(max_log_len=args.max_log_len)

state.engine_client = async_engine_client
state.engine_client = engine_client
state.log_stats = not args.disable_log_stats

state.openai_serving_chat = OpenAIServingChat(
async_engine_client,
engine_client,
model_config,
served_model_names,
args.response_role,
Expand All @@ -1166,7 +1158,7 @@ def init_app_state(
tool_parser=args.tool_call_parser
)
state.openai_serving_completion = OpenAIServingCompletion(
async_engine_client,
engine_client,
model_config,
served_model_names,
lora_modules=args.lora_modules,
Expand All @@ -1175,13 +1167,13 @@ def init_app_state(
return_tokens_as_token_ids=args.return_tokens_as_token_ids,
)
state.openai_serving_embedding = OpenAIServingEmbedding(
async_engine_client,
engine_client,
model_config,
served_model_names,
request_logger=request_logger,
)
state.openai_serving_tokenization = OpenAIServingTokenization(
async_engine_client,
engine_client,
model_config,
served_model_names,
lora_modules=args.lora_modules,
Expand All @@ -1207,13 +1199,13 @@ def signal_handler(*_) -> None:
raise KeyboardInterrupt("terminated")
signal.signal(signal.SIGTERM, signal_handler)

async with build_async_engine_client(args) as async_engine_client:
async with build_engine_client(args) as engine_client:
# If None, creation of the client failed and we exit.
if async_engine_client is None:
if engine_client is None:
return
app = build_app(args)
model_config = await async_engine_client.get_model_config()
init_app_state(async_engine_client, model_config, app.state, args)
model_config = await engine_client.get_model_config()
init_app_state(engine_client, model_config, app.state, args)

protocol = "https" if args.ssl_certfile else "http"
root_path = args.root_path.rstrip("/") if args.root_path else ""
Expand All @@ -1233,7 +1225,6 @@ def signal_handler(*_) -> None:

shutdown_task = await serve_http(
app,
limit_concurrency=async_engine_client.limit_concurrency,
host=args.host,
port=args.port,
log_level=args.uvicorn_log_level,
Expand Down
44 changes: 0 additions & 44 deletions aphrodite/endpoints/openai/rpc/__init__.py

This file was deleted.

Loading

0 comments on commit b47a390

Please sign in to comment.