Skip to content

Commit

Permalink
chore: pipeline parallel with Ray accelerated dag
Browse files Browse the repository at this point in the history
  • Loading branch information
AlpinDale committed Sep 1, 2024
1 parent 141672a commit 523ac99
Show file tree
Hide file tree
Showing 7 changed files with 135 additions and 54 deletions.
2 changes: 2 additions & 0 deletions Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ WORKDIR /workspace

# install build and runtime dependencies
COPY requirements-common.txt requirements-common.txt
COPY requirements-adag.txt requirements-adag.txt
COPY requirements-cuda.txt requirements-cuda.txt
RUN pip install packaging wheel
RUN --mount=type=cache,target=/root/.cache/pip \
Expand Down Expand Up @@ -68,6 +69,7 @@ COPY setup.py setup.py
COPY cmake cmake
COPY CMakeLists.txt CMakeLists.txt
COPY requirements-common.txt requirements-common.txt
COPY requirements-adag.txt requirements-adag.txt
COPY requirements-cuda.txt requirements-cuda.txt
COPY pyproject.toml pyproject.toml
COPY aphrodite aphrodite
Expand Down
5 changes: 4 additions & 1 deletion MANIFEST.in
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
include LICENSE
include requirements-adag.txt
include requirements-common.txt
include requirements-cuda.txt
include requirements-rocm.txt
include requirements-neuron.txt
include requirements-cpu.txt
include CMakeLists.txt
include aphrodite/endpoints/kobold/klite.embd

recursive-include kernels *
recursive-include kernels *
recursive-include cmake *
141 changes: 96 additions & 45 deletions aphrodite/executor/ray_gpu_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
from itertools import islice, repeat
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple

from loguru import logger

from aphrodite.common.sequence import ExecuteModelRequest, SamplerOutput
from aphrodite.common.utils import (_run_task_with_lock,
get_aphrodite_instance_id,
Expand All @@ -28,6 +30,9 @@
APHRODITE_USE_RAY_SPMD_WORKER = bool(
os.getenv("APHRODITE_USE_RAY_SPMD_WORKER", 0))

APHRODITE_USE_RAY_COMPILED_DAG_NCCL_CHANNEL = bool(
int(os.getenv("APHRODITE_USE_RAY_COMPILED_DAG_NCCL_CHANNEL", 1)))


class RayGPUExecutor(DistributedGPUExecutor):

Expand Down Expand Up @@ -111,12 +116,19 @@ def _init_workers_ray(self, placement_group: "PlacementGroup",
# The remaining workers are the actual ray actors.
self.workers: List[RayWorkerWrapper] = []

# Used in ray compiled DAG: indexed first by PP rank,
# and then TP rank. In other words, the inner list is
# the TP group of workers for a PP rank.
self.pp_tp_workers: List[List[RayWorkerWrapper]] = []

if self.parallel_config.ray_workers_use_nsight:
ray_remote_kwargs = self._configure_ray_workers_use_nsight(
ray_remote_kwargs)

logger.info(f"use_ray_spmd_worker: {self.use_ray_spmd_worker}")
# Create the workers.
driver_ip = get_ip()
logger.info(f"driver_ip: {driver_ip}")
worker_wrapper_kwargs = self._get_worker_wrapper_args()
for bundle_id, bundle in enumerate(placement_group.bundle_specs):
if not bundle.get("GPU", 0):
Expand Down Expand Up @@ -148,42 +160,49 @@ def _init_workers_ray(self, placement_group: "PlacementGroup",
# Else, added to the list of workers.
self.workers.append(worker)

logger.debug(f"workers: {self.workers}")
logger.debug(f"driver_dummy_worker: {self.driver_dummy_worker}")
if not self.use_ray_spmd_worker and self.driver_dummy_worker is None:
raise ValueError(
"Ray does not allocate any GPUs on the driver node. Consider "
"adjusting the Ray placement group or running the driver on a "
"GPU node.")

worker_ips = [
ray.get(worker.get_node_ip.remote()) # type: ignore[attr-defined]
for worker in self.workers
]
ip_counts: Dict[str, int] = {}
for ip in worker_ips:
ip_counts[ip] = ip_counts.get(ip, 0) + 1

def sort_by_driver_then_worker_ip(worker):
"""
Sort the workers based on 3 properties:
1. If the worker is on the same node as the driver (vllm engine),
it should be placed first.
2. Then, if the worker is on a node with fewer workers, it should
be placed first.
3. Finally, if the work is on a node with smaller IP address, it
should be placed first.
"""
ip = ray.get(worker.get_node_ip.remote())
return (ip != driver_ip, ip_counts[ip], ip)

# After sorting, the workers on the same node will be
# close to each other, and the workers on the driver
# node will be placed first.
self.workers = sorted(self.workers, key=sort_by_driver_then_worker_ip)

# Get the set of GPU IDs used on each node.
worker_node_and_gpu_ids = self._run_workers("get_node_and_gpu_ids",
use_dummy_driver=True)

# the order in `worker_node_and_gpu_ids` does not necessarily match
# the machine boundaries. We need to make sure that workers in the
# same node are assigned consecutive ranks.
# examples:
# [('852a09a13c7503ef126d7c828454c741494b1be33a8627a5206604d9', [0]), ('dfaad7adfdae57a694cc74490db45bd112c9f31243523e43ddc2e7f0', [0]), ('dfaad7adfdae57a694cc74490db45bd112c9f31243523e43ddc2e7f0', [1]), ('dfaad7adfdae57a694cc74490db45bd112c9f31243523e43ddc2e7f0', [2]), ('dfaad7adfdae57a694cc74490db45bd112c9f31243523e43ddc2e7f0', [3]), ('852a09a13c7503ef126d7c828454c741494b1be33a8627a5206604d9', [1]), ('852a09a13c7503ef126d7c828454c741494b1be33a8627a5206604d9', [2]), ('852a09a13c7503ef126d7c828454c741494b1be33a8627a5206604d9', [3])] # noqa

# initialize worker ranks with -1 (unassigned)
worker_ranks = [-1 for x in worker_node_and_gpu_ids]
current_rank = 0
while -1 in worker_ranks:
# whenever we find an unassigned worker, find the node
index = worker_ranks.index(-1)
current_node_id = worker_node_and_gpu_ids[index][0]
# assign ranks to all workers in the same node
for i, (node_id, _) in enumerate(worker_node_and_gpu_ids):
if node_id == current_node_id:
worker_ranks[i] = current_rank
current_rank += 1
# with the above example, worker_ranks will be [0, 4, 5, 6, 7, 1, 2, 3]

node_workers = defaultdict(list) # node id -> list of worker ranks
node_gpus = defaultdict(list) # node id -> list of gpu ids

for worker_rank, (node_id, gpu_ids) in zip(worker_ranks,
worker_node_and_gpu_ids):
node_workers[node_id].append(worker_rank)
for i, (node_id, gpu_ids) in enumerate(worker_node_and_gpu_ids):
node_workers[node_id].append(i)
# `gpu_ids` can be a list of strings or integers.
# convert them to integers for consistency.
# NOTE: gpu_ids can be larger than 9 (e.g. 16 GPUs),
Expand All @@ -207,16 +226,6 @@ def _init_workers_ray(self, placement_group: "PlacementGroup",
self._run_workers("update_environment_variables",
all_args=all_args_to_update_environment_variables)

if len(node_gpus) == 1:
# in single node case, we don't need to get the IP address.
# the loopback address is sufficient
# NOTE: a node may have several IP addresses, one for each
# network interface. `get_ip()` might return any of them,
# while they might not work for communication inside the node
# if the network setup is complicated. Using the loopback address
# solves this issue, as it always works for communication inside
# the node.
driver_ip = "127.0.0.1"
distributed_init_method = get_distributed_init_method(
driver_ip, get_open_port())

Expand All @@ -226,8 +235,7 @@ def _init_workers_ray(self, placement_group: "PlacementGroup",
local_rank=node_workers[node_id].index(rank),
rank=rank,
distributed_init_method=distributed_init_method,
) for rank, (node_id,
_) in zip(worker_ranks, worker_node_and_gpu_ids)
) for rank, (node_id, _) in enumerate(worker_node_and_gpu_ids)
]
self._run_workers("init_worker", all_kwargs=init_worker_all_kwargs)

Expand All @@ -236,6 +244,19 @@ def _init_workers_ray(self, placement_group: "PlacementGroup",
max_concurrent_workers=self.parallel_config.
max_parallel_loading_workers)

if self.use_ray_spmd_worker:
for pp_rank in range(self.parallel_config.pipeline_parallel_size):
self.pp_tp_workers.append([])
for tp_rank in range(
self.parallel_config.tensor_parallel_size):
# PP=2, TP=4
# pp_tp_workers = [[0, 1, 2, 3], [4, 5, 6, 7]]
rank = (pp_rank * self.parallel_config.tensor_parallel_size
) + tp_rank
assert len(self.pp_tp_workers[pp_rank]) == tp_rank
assert pp_rank < len(self.pp_tp_workers)
self.pp_tp_workers[pp_rank].append(self.workers[rank])

# This is the list of workers that are rank 0 of each TP group EXCEPT
# global rank 0. These are the workers that will broadcast to the
# rest of the workers.
Expand All @@ -246,9 +267,9 @@ def _init_workers_ray(self, placement_group: "PlacementGroup",
self.non_driver_workers: List[RayWorkerWrapper] = []

# Enforce rank order for correct rank to return final output.
for rank, worker in sorted(zip(worker_ranks[1:], self.workers)):
# We need to skip the driver worker, which we
# do by skipping worker_ranks[0] which is always 0.
for index, worker in enumerate(self.workers):
# The driver worker is rank 0 and not in self.workers.
rank = index + 1
if rank % self.parallel_config.tensor_parallel_size == 0:
self.tp_driver_workers.append(worker)
else:
Expand Down Expand Up @@ -381,16 +402,46 @@ def _compiled_ray_dag(self, enable_asyncio: bool):
raise ValueError(f"Ray version {required_version} or greater is "
f"required, but found {current_version}")

from ray.dag import InputNode, MultiOutputNode
assert self.parallel_config.use_ray
from ray.dag import InputNode, MultiOutputNode
from ray.experimental.channel.torch_tensor_type import TorchTensorType

# Right now, compiled DAG requires at least 1 arg. We send
# a dummy value for now. It will be fixed soon.
logger.info(f"APHRODITE_USE_RAY_COMPILED_DAG_NCCL_CHANNEL = "
f"{APHRODITE_USE_RAY_COMPILED_DAG_NCCL_CHANNEL}")
with InputNode() as input_data:
forward_dag = MultiOutputNode([
worker.execute_model_spmd.bind( # type: ignore[attr-defined]
input_data) for worker in self.workers
])
# Example DAG: PP=2, TP=4
# (ExecuteModelReq, None) -> 0 -> (ExecuteModelReq, IntermediateOutput) -> 4 -> SamplerOutput # noqa: E501
# -> 1 -> (ExecuteModelReq, IntermediateOutput) -> 5 -> SamplerOutput # noqa: E501
# -> 2 -> (ExecuteModelReq, IntermediateOutput) -> 6 -> SamplerOutput # noqa: E501
# -> 3 -> (ExecuteModelReq, IntermediateOutput) -> 7 -> SamplerOutput # noqa: E501

# All workers in the first TP group will take in the
# ExecuteModelRequest as input.
outputs = [input_data for _ in self.pp_tp_workers[0]]
for pp_rank, tp_group in enumerate(self.pp_tp_workers):
# Each PP worker takes in the output of the previous PP worker,
# and the TP group executes in SPMD fashion.
outputs = [
worker.execute_model_spmd.
bind( # type: ignore[attr-defined]
outputs[i]) for i, worker in enumerate(tp_group)
]

last_pp_rank = len(self.pp_tp_workers) - 1
if pp_rank < last_pp_rank:
# Specify how intermediate tensors should be passed
# between pp stages, no need to specify for the last
# pp stage.
transport = "nccl" \
if APHRODITE_USE_RAY_COMPILED_DAG_NCCL_CHANNEL \
else "auto"
outputs = [
output.with_type_hint(
TorchTensorType(transport=transport))
for output in outputs
]

forward_dag = MultiOutputNode(outputs)
return forward_dag.experimental_compile(enable_asyncio=enable_asyncio)

def __del__(self):
Expand Down
29 changes: 23 additions & 6 deletions aphrodite/executor/ray_utils.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from typing import List, Optional, Tuple
from typing import List, Optional, Tuple, Union

from aphrodite.common.config import ParallelConfig
from aphrodite.common.sequence import ExecuteModelRequest
from aphrodite.common.sequence import ExecuteModelRequest, IntermediateTensors
from aphrodite.common.utils import get_ip, is_hip, is_tpu, is_xpu
from aphrodite.task_handler.worker_base import WorkerWrapperBase

Expand All @@ -28,9 +28,16 @@ def get_node_and_gpu_ids(self) -> Tuple[str, List[int]]:
gpu_ids = ray.get_gpu_ids()
return node_id, gpu_ids

def execute_model_spmd(self, execute_model_req: ExecuteModelRequest):
"""Used only when SPMD worker and compiled DAG are both
enabled."""
def execute_model_spmd(
self, req_or_tuple: Union[ExecuteModelRequest,
Tuple[ExecuteModelRequest,
IntermediateTensors]]):
"""Execute model in SPMD fashion: used only when SPMD worker and
compiled DAG are both enabled.
Args:
req_or_tuple: The request to execute the model, or a tuple
containing the request and intermediate tensors.
"""
# TODO: This is needed right now because Ray DAG executes
# on a background thread, so we need to reset torch's current
# device.
Expand All @@ -39,7 +46,17 @@ def execute_model_spmd(self, execute_model_req: ExecuteModelRequest):
torch.cuda.set_device(self.worker.device)
self.compiled_dag_cuda_device_set = True

return self.worker._execute_model_spmd(execute_model_req)
if isinstance(req_or_tuple, tuple):
execute_model_req, intermediate_tensors = req_or_tuple
else:
execute_model_req = req_or_tuple
intermediate_tensors = None

output = self.worker._execute_model_spmd(execute_model_req,
intermediate_tensors)
if isinstance(output, IntermediateTensors):
return execute_model_req, output
return output

ray_import_err = None

Expand Down
6 changes: 4 additions & 2 deletions aphrodite/task_handler/worker_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -282,7 +282,9 @@ def execute_model(
return output

def _execute_model_spmd(
self, execute_model_req: ExecuteModelRequest
self,
execute_model_req: ExecuteModelRequest,
intermediate_tensors: Optional[IntermediateTensors] = None
) -> Optional[List[SamplerOutput]]:
"""
Execute model in Single Program Multiple Data (SPMD) fashion.
Expand All @@ -306,7 +308,7 @@ def _execute_model_spmd(

return self.model_runner.execute_model(
model_input, self.kv_cache[worker_input.virtual_engine]
if self.kv_cache is not None else None)
if self.kv_cache is not None else None, intermediate_tensors)


class WorkerWrapperBase:
Expand Down
3 changes: 3 additions & 0 deletions requirements-adag.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
# Dependencies for Ray accelerated DAG
cupy-cuda12x
ray >= 2.32
3 changes: 3 additions & 0 deletions requirements-test.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
# Needed for Ray accelerated DAG tests
-r requirements-adag.txt

# testing
pytest
tensorizer>=2.9.0
Expand Down

0 comments on commit 523ac99

Please sign in to comment.