Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[core][compiled graphs] Support reduce scatter collective in compiled graph #49404

Open
wants to merge 9 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 15 additions & 17 deletions python/ray/dag/collective_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,12 @@
from ray.experimental.channel import ChannelContext
from ray.experimental.channel.torch_tensor_nccl_channel import _init_communicator
from ray.experimental.channel.torch_tensor_type import Communicator, TorchTensorType
from ray.experimental.util.types import _CollectiveOp, ReduceOp
from ray.experimental.util.types import (
_CollectiveOp,
ReduceOp,
AllReduceReduceOp,
ReduceScatterReduceOp,
anyadontfly marked this conversation as resolved.
Show resolved Hide resolved
)
from ray.util.annotations import DeveloperAPI


Expand All @@ -31,13 +36,9 @@ class _CollectiveOperation:
3. Actor handles match the custom NCCL group if specified.
"""

ALLREDUCE = "ar"
REDUCESCATTER = "rs"

def __init__(
self,
input_nodes: List[DAGNode],
comm_op: str,
op: _CollectiveOp,
transport: Optional[Union[str, Communicator]] = None,
):
Expand All @@ -64,12 +65,6 @@ def __init__(
f"{invalid_input_nodes}"
)

self.comm_op = comm_op
if self.comm_op not in [self.ALLREDUCE, self.REDUCESCATTER]:
raise NotImplementedError(
"Only all-reduce and reduce-scatter are implemented"
)

self._op = op
if not isinstance(self._op, ReduceOp):
raise NotImplementedError("Only ReduceOp is implemented")
Expand Down Expand Up @@ -134,21 +129,24 @@ def execute(self, send_buf: "torch.Tensor") -> "torch.Tensor":
if not isinstance(send_buf, torch.Tensor):
raise ValueError("Expected a torch tensor")
communicator = self.get_communicator()
if self.comm_op == self.ALLREDUCE:
if isinstance(self._op, AllReduceReduceOp):
recv_buf = torch.empty_like(send_buf)
communicator.allreduce(send_buf, recv_buf, self._op)
elif self.comm_op == self.REDUCESCATTER:
elif isinstance(self._op, ReduceScatterReduceOp):
world_size = len(self._actor_handles)
assert (
send_buf.shape[0] % world_size == 0
), "Input tensor's first dimension should be divisible by "
"the number of ators participated"
if not send_buf.shape[0] % world_size == 0:
raise ValueError(
"Input tensor's first dimension should be divisible by "
"the number of ators participated."
)
recv_buf = torch.empty(
(send_buf.shape[0] // world_size, *send_buf.shape[1:]),
dtype=send_buf.dtype,
device=send_buf.device,
)
communicator.reducescatter(send_buf, recv_buf, self._op)
else:
raise ValueError("Unsupported ReduceOp type")
return recv_buf


Expand Down
98 changes: 77 additions & 21 deletions python/ray/dag/tests/experimental/test_torch_tensor_dag.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
import time
from ray.air._internal import torch_utils
from ray.dag import InputNode
from ray.exceptions import RayChannelError
from ray.exceptions import RayChannelError, RayTaskError
from ray.dag.output_node import MultiOutputNode
from ray.experimental.channel.communicator import (
Communicator,
Expand All @@ -27,7 +27,10 @@

from ray.experimental.channel.torch_tensor_type import TorchTensorType
from ray.tests.conftest import * # noqa
from ray.experimental.util.types import ReduceOp
from ray.experimental.util.types import (
AllReduceReduceOp,
ReduceScatterReduceOp,
)

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -432,7 +435,7 @@ def allreduce(
self,
send_buf: "torch.Tensor",
recv_buf: "torch.Tensor",
op: ReduceOp = ReduceOp.SUM,
op: AllReduceReduceOp = AllReduceReduceOp.SUM,
) -> None:
self._inner.allreduce(send_buf, recv_buf, op)
recv_buf += 1
Expand All @@ -441,7 +444,7 @@ def reducescatter(
self,
send_buf: "torch.Tensor",
recv_buf: "torch.Tensor",
op: ReduceOp = ReduceOp.SUM,
op: ReduceScatterReduceOp = ReduceScatterReduceOp.SUM,
) -> None:
self._inner.reducescatter(send_buf, recv_buf, op)
recv_buf += 1
Expand Down Expand Up @@ -551,15 +554,15 @@ def allreduce(
self,
send_buf: "torch.Tensor",
recv_buf: "torch.Tensor",
op: ReduceOp,
op: AllReduceReduceOp = AllReduceReduceOp.SUM,
) -> None:
raise NotImplementedError

def reducescatter(
self,
send_buf: "torch.Tensor",
recv_buf: "torch.Tensor",
op: ReduceOp = ReduceOp.SUM,
op: ReduceScatterReduceOp = ReduceScatterReduceOp.SUM,
) -> None:
raise NotImplementedError

Expand Down Expand Up @@ -713,15 +716,15 @@ def allreduce(
self,
send_buf: "torch.Tensor",
recv_buf: "torch.Tensor",
op: ReduceOp,
op: AllReduceReduceOp = AllReduceReduceOp.SUM,
) -> None:
raise NotImplementedError

def reducescatter(
self,
send_buf: "torch.Tensor",
recv_buf: "torch.Tensor",
op: ReduceOp = ReduceOp.SUM,
op: ReduceScatterReduceOp = ReduceScatterReduceOp.SUM,
) -> None:
raise NotImplementedError

Expand Down Expand Up @@ -987,7 +990,7 @@ def test_torch_tensor_nccl_all_reduce(ray_start_regular):
worker.compute_with_tuple_args.bind(inp, i)
for i, worker in enumerate(workers)
]
collectives = collective.allreduce.bind(computes, ReduceOp.SUM)
collectives = collective.allreduce.bind(computes, AllReduceReduceOp.SUM)
recvs = [
worker.recv.bind(collective)
for worker, collective in zip(workers, collectives)
Expand Down Expand Up @@ -1033,7 +1036,7 @@ def test_torch_tensor_nccl_all_reduce_get_partial(ray_start_regular):
worker.compute_with_tuple_args.bind(inp, i)
for i, worker in enumerate(workers)
]
collectives = collective.allreduce.bind(computes, ReduceOp.SUM)
collectives = collective.allreduce.bind(computes, AllReduceReduceOp.SUM)
recv = workers[0].recv.bind(collectives[0])
tensor = workers[1].recv_tensor.bind(collectives[0])
dag = MultiOutputNode([recv, tensor, collectives[1]])
Expand Down Expand Up @@ -1077,7 +1080,7 @@ def test_torch_tensor_nccl_all_reduce_wrong_shape(ray_start_regular):
worker.compute_with_tuple_args.bind(inp, i)
for i, worker in enumerate(workers)
]
collectives = collective.allreduce.bind(computes, ReduceOp.SUM)
collectives = collective.allreduce.bind(computes, AllReduceReduceOp.SUM)
recvs = [
worker.recv.bind(collective)
for worker, collective in zip(workers, collectives)
Expand Down Expand Up @@ -1185,7 +1188,7 @@ def allreduce(
self,
send_buf: "torch.Tensor",
recv_buf: "torch.Tensor",
op: ReduceOp = ReduceOp.SUM,
op: AllReduceReduceOp = AllReduceReduceOp.SUM,
) -> None:
self._inner.allreduce(send_buf, recv_buf, op)
recv_buf += 1
Expand All @@ -1194,7 +1197,7 @@ def reducescatter(
self,
send_buf: "torch.Tensor",
recv_buf: "torch.Tensor",
op: ReduceOp = ReduceOp.SUM,
op: ReduceScatterReduceOp = ReduceScatterReduceOp.SUM,
) -> None:
self._inner.reducescatter(send_buf, recv_buf, op)
recv_buf += 1
Expand Down Expand Up @@ -1317,7 +1320,7 @@ def test_nccl_all_reduce_with_class_method_output_node(ray_start_regular):
with InputNode() as inp:
t1, t2 = workers[0].return_two_tensors.bind(inp[0], inp[1])
t3, t4 = workers[1].return_two_tensors.bind(inp[2], inp[3])
tensors = collective.allreduce.bind([t1, t4], ReduceOp.SUM)
tensors = collective.allreduce.bind([t1, t4], AllReduceReduceOp.SUM)
dag = MultiOutputNode(tensors + [t2, t3])

compiled_dag = dag.experimental_compile()
Expand Down Expand Up @@ -1356,7 +1359,7 @@ def test_torch_tensor_nccl_reduce_scatter(ray_start_regular):
worker.compute_with_tuple_args.bind(inp, i)
for i, worker in enumerate(workers)
]
collectives = collective.reducescatter.bind(computes, ReduceOp.SUM)
collectives = collective.reducescatter.bind(computes, ReduceScatterReduceOp.SUM)
recvs = [
worker.recv_tensor.bind(collective)
for worker, collective in zip(workers, collectives)
Expand Down Expand Up @@ -1402,7 +1405,7 @@ def test_torch_tensor_nccl_reduce_scatter_get_partial(ray_start_regular):
worker.compute_with_tuple_args.bind(inp, i)
for i, worker in enumerate(workers)
]
collectives = collective.reducescatter.bind(computes, ReduceOp.SUM)
collectives = collective.reducescatter.bind(computes, ReduceScatterReduceOp.SUM)
tensor1 = workers[0].recv_tensor.bind(collectives[0])
tensor2 = workers[1].recv_tensor.bind(collectives[0])
dag = MultiOutputNode([tensor1, tensor2, collectives[1]])
Expand All @@ -1427,7 +1430,7 @@ def test_torch_tensor_nccl_reduce_scatter_get_partial(ray_start_regular):


@pytest.mark.parametrize("ray_start_regular", [{"num_cpus": 4}], indirect=True)
def test_torch_tensor_nccl_reduce_scatter_different_shapes_among_participants(
def test_torch_tensor_nccl_reduce_scatter_wrong_shape(
ray_start_regular,
):
"""
Expand All @@ -1453,7 +1456,7 @@ def test_torch_tensor_nccl_reduce_scatter_different_shapes_among_participants(
worker.compute_with_tuple_args.bind(inp, i)
for i, worker in enumerate(workers)
]
collectives = collective.reducescatter.bind(computes, ReduceOp.SUM)
collectives = collective.reducescatter.bind(computes, ReduceScatterReduceOp.SUM)
recvs = [
worker.recv_tensor.bind(collective)
for worker, collective in zip(workers, collectives)
Expand Down Expand Up @@ -1487,6 +1490,59 @@ def test_torch_tensor_nccl_reduce_scatter_different_shapes_among_participants(
ref = compiled_dag.execute([((20,), dtype, 1) for _ in workers])


@pytest.mark.parametrize("ray_start_regular", [{"num_cpus": 4}], indirect=True)
def test_torch_tensor_nccl_reduce_scatter_tensor_wrong_first_dimension(
ray_start_regular,
):
"""
Test an error is thrown when input tensors' shape's firt dimension
anyadontfly marked this conversation as resolved.
Show resolved Hide resolved
is not divisible by the number of participants.
"""
if not USE_GPU:
pytest.skip("NCCL tests require GPUs")

assert (
sum(node["Resources"].get("GPU", 0) for node in ray.nodes()) > 1
), "This test requires at least 2 GPUs"

actor_cls = TorchTensorWorker.options(num_cpus=0, num_gpus=1)

num_workers = 2
workers = [actor_cls.remote() for _ in range(num_workers)]

dtype = torch.float16

with InputNode() as inp:
computes = [
worker.compute_with_tuple_args.bind(inp, i)
for i, worker in enumerate(workers)
]
collectives = collective.reducescatter.bind(computes, ReduceScatterReduceOp.SUM)
recvs = [
worker.recv_tensor.bind(collective)
for worker, collective in zip(workers, collectives)
]
dag = MultiOutputNode(recvs)

compiled_dag = dag.experimental_compile()

ref = compiled_dag.execute(
[((20, 10), dtype, 1 + idx) for idx in range(num_workers)]
)
result = ray.get(ref)
reduced_val = sum(1 + idx for idx in range(num_workers))
for tensor in result:
tensor = tensor.to("cpu")
expected_tensor_val = torch.ones((10, 10), dtype=dtype) * reduced_val
assert torch.equal(tensor, expected_tensor_val)

ref = compiled_dag.execute(
[((3, 10), dtype, 1 + idx) for idx in range(num_workers)]
)
with pytest.raises(RayTaskError):
ray.get(ref)


@pytest.mark.parametrize("ray_start_regular", [{"num_cpus": 4}], indirect=True)
def test_torch_tensor_nccl_reduce_scatter_custom_comm(ray_start_regular):
"""
Expand Down Expand Up @@ -1567,7 +1623,7 @@ def allreduce(
self,
send_buf: "torch.Tensor",
recv_buf: "torch.Tensor",
op: ReduceOp = ReduceOp.SUM,
op: AllReduceReduceOp = AllReduceReduceOp.SUM,
) -> None:
self._inner.allreduce(send_buf, recv_buf, op)
recv_buf += 1
Expand All @@ -1576,7 +1632,7 @@ def reducescatter(
self,
send_buf: "torch.Tensor",
recv_buf: "torch.Tensor",
op: ReduceOp = ReduceOp.SUM,
op: ReduceScatterReduceOp = ReduceScatterReduceOp.SUM,
) -> None:
self._inner.reducescatter(send_buf, recv_buf, op)
recv_buf += 1
Expand Down Expand Up @@ -1703,7 +1759,7 @@ def test_nccl_reduce_scatter_with_class_method_output_node(ray_start_regular):
with InputNode() as inp:
t1, t2 = workers[0].return_two_tensors.bind(inp[0], inp[1])
t3, t4 = workers[1].return_two_tensors.bind(inp[2], inp[3])
tensors = collective.reducescatter.bind([t1, t4], ReduceOp.SUM)
tensors = collective.reducescatter.bind([t1, t4], ReduceScatterReduceOp.SUM)
dag = MultiOutputNode(tensors + [t2, t3])

compiled_dag = dag.experimental_compile()
Expand Down
6 changes: 3 additions & 3 deletions python/ray/experimental/channel/communicator.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from typing import TYPE_CHECKING, Callable, List, Optional, Tuple

import ray
from ray.experimental.util.types import ReduceOp
from ray.experimental.util.types import AllReduceReduceOp, ReduceScatterReduceOp
from ray.util.annotations import DeveloperAPI

if TYPE_CHECKING:
Expand Down Expand Up @@ -127,7 +127,7 @@ def allreduce(
self,
send_buf: "torch.Tensor",
recv_buf: "torch.Tensor",
op: ReduceOp,
op: AllReduceReduceOp,
) -> None:
"""
Collectively allreduce the tensor across the group.
Expand All @@ -145,7 +145,7 @@ def reducescatter(
self,
send_buf: "torch.Tensor",
recv_buf: "torch.Tensor",
op: ReduceOp,
op: ReduceScatterReduceOp,
) -> None:
"""
Collectively reducescatter the tensor across the group.
Expand Down
Loading
Loading