mirror of
https://github.com/wassname/ray.git
synced 2026-06-27 21:38:18 +08:00
[Collective][PR 3.5/6] Send/Recv calls and some initial code for communicator caching (#12935)
* other collectives all work * auto-linting * mannual linting #1 * mannual linting 2 * bugfix * add send/recv point-to-point calls * add some initial code for communicator caching * auto linting * optimize imports * minor fix * fix unpassed tests * support more dtypes * rerun some distributed tests for send/recv * linting
This commit is contained in:
@@ -1,11 +1,11 @@
|
||||
from ray.util.collective.collective import nccl_available, mpi_available, \
|
||||
is_group_initialized, init_collective_group, destroy_collective_group, \
|
||||
get_rank, get_world_size, allreduce, barrier, reduce, broadcast, \
|
||||
allgather, reducescatter
|
||||
allgather, reducescatter, send, recv
|
||||
|
||||
__all__ = [
|
||||
"nccl_available", "mpi_available", "is_group_initialized",
|
||||
"init_collective_group", "destroy_collective_group", "get_rank",
|
||||
"get_world_size", "allreduce", "barrier", "reduce", "broadcast",
|
||||
"allgather", "reducescatter"
|
||||
"allgather", "reducescatter", "send", "recv"
|
||||
]
|
||||
|
||||
@@ -4,7 +4,6 @@ import logging
|
||||
import numpy as np
|
||||
import ray
|
||||
from ray.util.collective import types
|
||||
from ray.util.collective.const import get_nccl_store_name
|
||||
|
||||
_MPI_AVAILABLE = False
|
||||
_NCCL_AVAILABLE = True
|
||||
@@ -16,7 +15,6 @@ _NCCL_AVAILABLE = True
|
||||
# _MPI_AVAILABLE = False
|
||||
try:
|
||||
from ray.util.collective.collective_group import NCCLGroup
|
||||
from ray.util.collective.collective_group import nccl_util
|
||||
except ImportError:
|
||||
_NCCL_AVAILABLE = False
|
||||
|
||||
@@ -53,17 +51,6 @@ class GroupManager(object):
|
||||
if backend == types.Backend.MPI:
|
||||
raise NotImplementedError()
|
||||
elif backend == types.Backend.NCCL:
|
||||
# create the ncclUniqueID
|
||||
if rank == 0:
|
||||
# availability has been checked before entering here.
|
||||
group_uid = nccl_util.get_nccl_unique_id()
|
||||
store_name = get_nccl_store_name(group_name)
|
||||
# Avoid a potential circular dependency in ray/actor.py
|
||||
from ray.util.collective.util import NCCLUniqueIDStore
|
||||
store = NCCLUniqueIDStore.options(
|
||||
name=store_name, lifetime="detached").remote(store_name)
|
||||
ray.wait([store.set_id.remote(group_uid)])
|
||||
|
||||
logger.debug("creating NCCL group: '{}'".format(group_name))
|
||||
g = NCCLGroup(world_size, rank, group_name)
|
||||
self._name_group_map[group_name] = g
|
||||
@@ -89,19 +76,9 @@ class GroupManager(object):
|
||||
|
||||
# release the collective group resource
|
||||
g = self._name_group_map[group_name]
|
||||
rank = g.rank
|
||||
backend = g.backend()
|
||||
|
||||
# clean up the dicts
|
||||
del self._group_name_map[g]
|
||||
del self._name_group_map[group_name]
|
||||
if backend == types.Backend.NCCL:
|
||||
# release the named actor
|
||||
if rank == 0:
|
||||
store_name = get_nccl_store_name(group_name)
|
||||
store = ray.get_actor(store_name)
|
||||
ray.wait([store.__ray_terminate__.remote()])
|
||||
ray.kill(store)
|
||||
# Release the communicator resources
|
||||
g.destroy_group()
|
||||
|
||||
@@ -322,6 +299,46 @@ def reducescatter(tensor,
|
||||
g.reducescatter(tensor, tensor_list, opts)
|
||||
|
||||
|
||||
def send(tensor, dst_rank: int, group_name: str = "default"):
|
||||
"""Send a tensor to a remote processes synchronously.
|
||||
|
||||
Args:
|
||||
tensor: the tensor to send.
|
||||
dst_rank (int): the rank of the destination process.
|
||||
group_name (str): the name of the collective group.
|
||||
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
_check_single_tensor_input(tensor)
|
||||
g = _check_and_get_group(group_name)
|
||||
_check_rank_valid(g, dst_rank)
|
||||
if dst_rank == g.rank:
|
||||
raise RuntimeError(
|
||||
"The destination rank '{}' is self.".format(dst_rank))
|
||||
g.send(tensor, dst_rank)
|
||||
|
||||
|
||||
def recv(tensor, src_rank: int, group_name: str = "default"):
|
||||
"""Receive a tensor from a remote process synchronously.
|
||||
|
||||
Args:
|
||||
tensor: the received tensor.
|
||||
src_rank (int): the rank of the source process.
|
||||
group_name (str): the name of the collective group.
|
||||
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
_check_single_tensor_input(tensor)
|
||||
g = _check_and_get_group(group_name)
|
||||
_check_rank_valid(g, src_rank)
|
||||
if src_rank == g.rank:
|
||||
raise RuntimeError(
|
||||
"The destination rank '{}' is self.".format(src_rank))
|
||||
g.recv(tensor, src_rank)
|
||||
|
||||
|
||||
def _check_and_get_group(group_name):
|
||||
"""Check the existence and return the group handle."""
|
||||
_check_inside_actor()
|
||||
@@ -368,6 +385,7 @@ def _check_inside_actor():
|
||||
|
||||
|
||||
def _check_rank_valid(g, rank: int):
|
||||
"""Check the rank: 0 <= rank < world_size."""
|
||||
if rank < 0:
|
||||
raise ValueError("rank '{}' is negative.".format(rank))
|
||||
if rank > g.world_size:
|
||||
|
||||
@@ -72,3 +72,11 @@ class BaseGroup(metaclass=ABCMeta):
|
||||
tensor_list,
|
||||
reducescatter_options=ReduceScatterOptions()):
|
||||
raise NotImplementedError()
|
||||
|
||||
@abstractmethod
|
||||
def send(self, tensor, dst_rank):
|
||||
raise NotImplementedError()
|
||||
|
||||
@abstractmethod
|
||||
def recv(self, tensor, src_rank):
|
||||
raise NotImplementedError()
|
||||
|
||||
@@ -8,10 +8,10 @@ import cupy
|
||||
from ray.util.collective.collective_group import nccl_util
|
||||
from ray.util.collective.collective_group.base_collective_group \
|
||||
import BaseGroup
|
||||
from ray.util.collective.const import get_nccl_store_name
|
||||
from ray.util.collective.types import AllReduceOptions, \
|
||||
BarrierOptions, Backend, ReduceOptions, BroadcastOptions, \
|
||||
AllGatherOptions, ReduceScatterOptions
|
||||
from ray.util.collective.const import get_nccl_store_name
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -109,10 +109,10 @@ class NCCLGroup(BaseGroup):
|
||||
def __init__(self, world_size, rank, group_name):
|
||||
"""Init an NCCL collective group."""
|
||||
super(NCCLGroup, self).__init__(world_size, rank, group_name)
|
||||
self._nccl_uid = None
|
||||
|
||||
# TODO(Hao): change this to a be a cache
|
||||
self._nccl_comm = None
|
||||
self._collective_comm_cache = None
|
||||
self._p2p_comm_cache = {}
|
||||
|
||||
if nccl_util.get_nccl_build_version() < 2000:
|
||||
raise RuntimeError("NCCL in Ray requires NCCL >= 2.0.")
|
||||
@@ -120,33 +120,34 @@ class NCCLGroup(BaseGroup):
|
||||
if nccl_util.get_nccl_runtime_version() < 2704:
|
||||
logger.warning("NCCL send/recv calls requires NCCL>=2.7.4")
|
||||
|
||||
self._rendezvous = Rendezvous(self.group_name)
|
||||
self._rendezvous.meet()
|
||||
|
||||
# Setup the nccl uid using the store
|
||||
self._init_nccl_unique_id()
|
||||
|
||||
# Setup a tensor for barrier calls
|
||||
self._barrier_tensor = cupy.array([1])
|
||||
|
||||
def _init_nccl_unique_id(self):
|
||||
"""Init the NCCLUniqueID required for creating NCCL communicators."""
|
||||
self._nccl_uid = self._rendezvous.get_nccl_id()
|
||||
|
||||
@property
|
||||
def nccl_uid(self):
|
||||
return self._nccl_uid
|
||||
|
||||
def destroy_group(self):
|
||||
"""Destroy the group and release the NCCL communicators safely."""
|
||||
if self._nccl_comm is not None:
|
||||
"""Destroy the group and release NCCL communicators."""
|
||||
if self._collective_comm_cache:
|
||||
self.barrier()
|
||||
# We also need a barrier call here.
|
||||
stream = self._get_cuda_stream()
|
||||
stream.synchronize()
|
||||
# destroy the communicator
|
||||
self._nccl_comm.destroy()
|
||||
self._nccl_comm = None
|
||||
self._collective_comm_cache.destroy()
|
||||
self._collective_comm_cache = None
|
||||
|
||||
if self.rank == 0:
|
||||
self._destroy_store(self.group_name)
|
||||
|
||||
if self._p2p_comm_cache:
|
||||
for key, comm in self._p2p_comm_cache.items():
|
||||
comm.destroy()
|
||||
min_rank, max_rank = self._parse_p2p_group_key(key)
|
||||
if self.rank == min_rank:
|
||||
self._destroy_store(key)
|
||||
self._p2p_comm_cache[key] = None
|
||||
for key in list(self._p2p_comm_cache.keys()):
|
||||
del self._p2p_comm_cache[key]
|
||||
self._p2p_comm_cache = None
|
||||
|
||||
super(NCCLGroup, self).destroy_group()
|
||||
|
||||
@classmethod
|
||||
@@ -163,7 +164,7 @@ class NCCLGroup(BaseGroup):
|
||||
Returns:
|
||||
"""
|
||||
# obtain the communicator
|
||||
comm = self._get_nccl_communicator()
|
||||
comm = self._get_nccl_collective_communicator()
|
||||
# obtain the stream: using default stream by now
|
||||
# TODO(Hao): implement a simple stream manager here
|
||||
stream = self._get_cuda_stream()
|
||||
@@ -196,7 +197,7 @@ class NCCLGroup(BaseGroup):
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
comm = self._get_nccl_communicator()
|
||||
comm = self._get_nccl_collective_communicator()
|
||||
stream = self._get_cuda_stream()
|
||||
|
||||
dtype = nccl_util.get_nccl_tensor_dtype(tensor)
|
||||
@@ -218,7 +219,7 @@ class NCCLGroup(BaseGroup):
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
comm = self._get_nccl_communicator()
|
||||
comm = self._get_nccl_collective_communicator()
|
||||
stream = self._get_cuda_stream()
|
||||
|
||||
dtype = nccl_util.get_nccl_tensor_dtype(tensor)
|
||||
@@ -232,7 +233,7 @@ class NCCLGroup(BaseGroup):
|
||||
tensor_list,
|
||||
tensor,
|
||||
allgather_options=AllGatherOptions()):
|
||||
"""Allgather tensors across the group into a list of tensors.
|
||||
"""Allgather tensors across the group into a list of tensors.
|
||||
|
||||
Args:
|
||||
tensor_list: the tensor list to store the results.
|
||||
@@ -244,7 +245,7 @@ class NCCLGroup(BaseGroup):
|
||||
"""
|
||||
|
||||
_check_inputs_compatibility_for_scatter_gather(tensor, tensor_list)
|
||||
comm = self._get_nccl_communicator()
|
||||
comm = self._get_nccl_collective_communicator()
|
||||
stream = self._get_cuda_stream()
|
||||
|
||||
dtype = nccl_util.get_nccl_tensor_dtype(tensor)
|
||||
@@ -272,7 +273,7 @@ class NCCLGroup(BaseGroup):
|
||||
"""
|
||||
_check_inputs_compatibility_for_scatter_gather(tensor, tensor_list)
|
||||
|
||||
comm = self._get_nccl_communicator()
|
||||
comm = self._get_nccl_collective_communicator()
|
||||
stream = self._get_cuda_stream()
|
||||
dtype = nccl_util.get_nccl_tensor_dtype(tensor_list[0])
|
||||
n_elems = nccl_util.get_tensor_n_elements(tensor_list[0])
|
||||
@@ -286,16 +287,134 @@ class NCCLGroup(BaseGroup):
|
||||
comm.reduceScatter(send_ptr, recv_ptr, n_elems, dtype, reduce_op,
|
||||
stream.ptr)
|
||||
|
||||
def _get_nccl_communicator(self):
|
||||
"""Create or use a cached NCCL communicator for the collective task.
|
||||
def send(self, tensor, dst_rank):
|
||||
"""Send tensor to a destination process in the group.
|
||||
|
||||
Args:
|
||||
tensor: the tensor to send.
|
||||
dst_rank: the rank of the destination process.
|
||||
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
|
||||
# check whether send/recv is available
|
||||
if nccl_util.get_nccl_runtime_version() < 2704:
|
||||
raise RuntimeError("send is not available requires NCCL >= 2.7.4. "
|
||||
"Got '{}'.".format(
|
||||
nccl_util.get_nccl_runtime_version()))
|
||||
|
||||
peer_p2p_rank = 0 if self.rank > dst_rank else 1
|
||||
comm = self._get_nccl_p2p_communicator(self.rank, dst_rank)
|
||||
stream = self._get_cuda_stream()
|
||||
|
||||
dtype = nccl_util.get_nccl_tensor_dtype(tensor)
|
||||
ptr = nccl_util.get_tensor_ptr(tensor)
|
||||
n_elems = nccl_util.get_tensor_n_elements(tensor)
|
||||
comm.send(ptr, n_elems, dtype, peer_p2p_rank, stream.ptr)
|
||||
|
||||
def recv(self, tensor, src_rank):
|
||||
"""Receive tensor from a source process in the group.
|
||||
|
||||
Args:
|
||||
tensor: the received tensor.
|
||||
src_rank: the rank of the source process.
|
||||
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
if nccl_util.get_nccl_runtime_version() < 2704:
|
||||
raise RuntimeError("recv is not available requires NCCL >= 2.7.4. "
|
||||
"Got '{}'.".format(
|
||||
nccl_util.get_nccl_runtime_version()))
|
||||
peer_p2p_rank = 0 if self.rank > src_rank else 1
|
||||
comm = self._get_nccl_p2p_communicator(src_rank, self.rank)
|
||||
stream = self._get_cuda_stream()
|
||||
|
||||
dtype = nccl_util.get_nccl_tensor_dtype(tensor)
|
||||
ptr = nccl_util.get_tensor_ptr(tensor)
|
||||
n_elems = nccl_util.get_tensor_n_elements(tensor)
|
||||
comm.recv(ptr, n_elems, dtype, peer_p2p_rank, stream.ptr)
|
||||
|
||||
def _get_nccl_collective_communicator(self):
|
||||
"""Create or retrieve a cached NCCL communicator.
|
||||
|
||||
Returns:
|
||||
communicator
|
||||
"""
|
||||
# TODO(Hao): later change this to use device keys and query from cache.
|
||||
# TODO(Hao): implement a thin wrapper
|
||||
if not self._nccl_comm:
|
||||
self._nccl_comm = nccl_util.create_nccl_communicator(
|
||||
self.world_size, self.nccl_uid, self.rank)
|
||||
return self._nccl_comm
|
||||
if not self._collective_comm_cache:
|
||||
# create the communicator
|
||||
if self.rank == 0:
|
||||
group_uid = self._generate_nccl_uid(self.group_name)
|
||||
else:
|
||||
rendezvous = Rendezvous(self.group_name)
|
||||
rendezvous.meet()
|
||||
group_uid = rendezvous.get_nccl_id()
|
||||
self._collective_comm_cache = \
|
||||
nccl_util.create_nccl_communicator(self.world_size,
|
||||
group_uid,
|
||||
self.rank)
|
||||
return self._collective_comm_cache
|
||||
|
||||
def _get_nccl_p2p_communicator(self, src_rank, dst_rank):
|
||||
"""Create or retrieve an NCCL communicator for p2p tasks.
|
||||
|
||||
Args:
|
||||
src_rank (int): source rank.
|
||||
dst_rank (int): destination rank.
|
||||
|
||||
Returns:
|
||||
communicator
|
||||
"""
|
||||
min_rank = min(src_rank, dst_rank)
|
||||
max_rank = max(src_rank, dst_rank)
|
||||
my_rank = 0 if self.rank == min_rank else 1
|
||||
p2p_group_key = self._generate_p2p_group_key(min_rank, max_rank)
|
||||
comm = self._p2p_comm_cache.get(p2p_group_key)
|
||||
if not comm:
|
||||
if self.rank == min_rank:
|
||||
group_uid = self._generate_nccl_uid(p2p_group_key)
|
||||
else:
|
||||
rendezvous = Rendezvous(p2p_group_key)
|
||||
rendezvous.meet()
|
||||
group_uid = rendezvous.get_nccl_id()
|
||||
comm = nccl_util.create_nccl_communicator(2, group_uid, my_rank)
|
||||
self._p2p_comm_cache[p2p_group_key] = comm
|
||||
return comm
|
||||
|
||||
def _generate_p2p_group_key(self, min_rank, max_rank):
|
||||
return self.group_name + "_" + str(min_rank) + "_" + str(max_rank)
|
||||
|
||||
@staticmethod
|
||||
def _parse_p2p_group_key(key):
|
||||
strs = key.split("_")
|
||||
return int(strs[-2]), int(strs[-1])
|
||||
|
||||
@staticmethod
|
||||
def _destroy_store(group_name):
|
||||
store_name = get_nccl_store_name(group_name)
|
||||
store = ray.get_actor(store_name)
|
||||
# ray.get([store.__ray_terminate__.remote()])
|
||||
ray.kill(store)
|
||||
|
||||
def _generate_nccl_uid(self, name):
|
||||
"""Generate an NCCL UID by calling the NCCL API.
|
||||
|
||||
Args:
|
||||
name: the name of the collective group.
|
||||
|
||||
Returns:
|
||||
str: NCCL uid.
|
||||
"""
|
||||
group_uid = nccl_util.get_nccl_unique_id()
|
||||
store_name = get_nccl_store_name(name)
|
||||
# Avoid a potential circular dependency in ray/actor.py
|
||||
from ray.util.collective.util import NCCLUniqueIDStore
|
||||
store = NCCLUniqueIDStore.options(
|
||||
name=store_name, lifetime="detached").remote(store_name)
|
||||
ray.wait([store.set_id.remote(group_uid)])
|
||||
return group_uid
|
||||
|
||||
@staticmethod
|
||||
def _get_cuda_stream():
|
||||
@@ -303,6 +422,7 @@ class NCCLGroup(BaseGroup):
|
||||
# TODO: implement a simple stream manager.
|
||||
return cupy.cuda.Stream.null
|
||||
|
||||
# Note(Hao): too many bipolate code -- make some abstraction.
|
||||
# def _collective_call(self, *args):
|
||||
# """Private method to encapsulate all collective calls"""
|
||||
# pass
|
||||
|
||||
@@ -20,24 +20,54 @@ NCCL_REDUCE_OP_MAP = {
|
||||
|
||||
# cupy types are the same with numpy types
|
||||
NUMPY_NCCL_DTYPE_MAP = {
|
||||
# INT types
|
||||
numpy.int: nccl.NCCL_INT,
|
||||
numpy.uint8: nccl.NCCL_UINT8,
|
||||
numpy.uint32: nccl.NCCL_UINT32,
|
||||
numpy.uint64: nccl.NCCL_UINT64,
|
||||
numpy.int8: nccl.NCCL_INT8,
|
||||
numpy.int32: nccl.NCCL_INT32,
|
||||
numpy.int64: nccl.NCCL_INT64,
|
||||
# FLOAT types
|
||||
numpy.half: nccl.NCCL_HALF,
|
||||
numpy.float: nccl.NCCL_FLOAT,
|
||||
numpy.float16: nccl.NCCL_FLOAT16,
|
||||
numpy.float32: nccl.NCCL_FLOAT32,
|
||||
numpy.float64: nccl.NCCL_FLOAT64,
|
||||
numpy.double: nccl.NCCL_DOUBLE
|
||||
}
|
||||
|
||||
if torch_available():
|
||||
import torch
|
||||
import torch.utils.dlpack
|
||||
TORCH_NCCL_DTYPE_MAP = {
|
||||
# INT types
|
||||
torch.int: nccl.NCCL_INT,
|
||||
torch.uint8: nccl.NCCL_UINT8,
|
||||
torch.int8: nccl.NCCL_INT8,
|
||||
torch.int32: nccl.NCCL_INT32,
|
||||
torch.int64: nccl.NCCL_INT64,
|
||||
torch.long: nccl.NCCL_INT64,
|
||||
# FLOAT types
|
||||
torch.half: nccl.NCCL_HALF,
|
||||
torch.float: nccl.NCCL_FLOAT,
|
||||
torch.float16: nccl.NCCL_FLOAT16,
|
||||
torch.float32: nccl.NCCL_FLOAT32,
|
||||
torch.float64: nccl.NCCL_FLOAT64,
|
||||
torch.double: nccl.NCCL_DOUBLE,
|
||||
}
|
||||
|
||||
TORCH_NUMPY_DTYPE_MAP = {
|
||||
# INT types
|
||||
torch.int: numpy.int,
|
||||
torch.uint8: numpy.uint8,
|
||||
torch.int8: numpy.int8,
|
||||
torch.int32: numpy.int32,
|
||||
torch.int64: numpy.int64,
|
||||
torch.long: numpy.int64,
|
||||
# FLOAT types
|
||||
torch.half: numpy.half,
|
||||
torch.float: numpy.float,
|
||||
torch.float16: numpy.float16,
|
||||
torch.float32: numpy.float32,
|
||||
torch.float64: numpy.float64,
|
||||
|
||||
@@ -5,12 +5,22 @@ import ray
|
||||
from ray.util.collective.const import get_nccl_store_name
|
||||
|
||||
|
||||
# TODO (Hao): remove this clean_up function as it sometimes crashes Ray.
|
||||
def clean_up():
|
||||
group_names = ["default", "test", "123?34!", "default2", "random"]
|
||||
group_names.extend([str(i) for i in range(10)])
|
||||
for group_name in group_names:
|
||||
max_world_size = 4
|
||||
p2p_group_names = []
|
||||
for name in group_names:
|
||||
for i in range(max_world_size):
|
||||
for j in range(max_world_size):
|
||||
if i <= j:
|
||||
p2p_group_name = name + "_" + str(i) + "_" + str(j)
|
||||
p2p_group_names.append(p2p_group_name)
|
||||
all_names = group_names + p2p_group_names
|
||||
for group_name in all_names:
|
||||
store_name = get_nccl_store_name(group_name)
|
||||
try:
|
||||
store_name = get_nccl_store_name(group_name)
|
||||
actor = ray.get_actor(store_name)
|
||||
except ValueError:
|
||||
actor = None
|
||||
|
||||
@@ -47,7 +47,7 @@ def test_allreduce_destroy(ray_start_distributed_2_nodes_4_gpus,
|
||||
assert (results[1] == cp.ones((10, ), dtype=cp.float32) * world_size).all()
|
||||
|
||||
# destroy the group and try do work, should fail
|
||||
ray.wait([a.destroy_group.remote() for a in actors])
|
||||
ray.get([a.destroy_group.remote() for a in actors])
|
||||
with pytest.raises(RuntimeError):
|
||||
results = ray.get([a.do_allreduce.remote() for a in actors])
|
||||
|
||||
|
||||
@@ -0,0 +1,35 @@
|
||||
"""Test the send/recv API."""
|
||||
import cupy as cp
|
||||
import pytest
|
||||
import ray
|
||||
|
||||
from ray.util.collective.tests.util import create_collective_workers
|
||||
|
||||
|
||||
@pytest.mark.parametrize("group_name", ["default", "test", "123?34!"])
|
||||
@pytest.mark.parametrize("dst_rank", [0, 1, 2, 3])
|
||||
@pytest.mark.parametrize("src_rank", [0, 1, 2, 3])
|
||||
@pytest.mark.parametrize("array_size",
|
||||
[2**10, 2**15, 2**20, [2, 2], [5, 9, 10, 85]])
|
||||
def test_sendrecv(ray_start_distributed_2_nodes_4_gpus, group_name, array_size,
|
||||
src_rank, dst_rank):
|
||||
if src_rank == dst_rank:
|
||||
return
|
||||
world_size = 4
|
||||
actors, _ = create_collective_workers(
|
||||
num_workers=world_size, group_name=group_name)
|
||||
ray.get([
|
||||
a.set_buffer.remote(cp.ones(array_size, dtype=cp.float32) * (i + 1))
|
||||
for i, a in enumerate(actors)
|
||||
])
|
||||
refs = []
|
||||
for i in range(world_size):
|
||||
refs.append(actors[i].get_buffer.remote())
|
||||
refs[src_rank] = actors[src_rank].do_send.remote(group_name, dst_rank)
|
||||
refs[dst_rank] = actors[dst_rank].do_recv.remote(group_name, src_rank)
|
||||
results = ray.get(refs)
|
||||
assert (results[src_rank] == cp.ones(array_size, dtype=cp.float32) *
|
||||
(src_rank + 1)).all()
|
||||
assert (results[dst_rank] == cp.ones(array_size, dtype=cp.float32) *
|
||||
(src_rank + 1)).all()
|
||||
ray.get([a.destroy_group.remote(group_name) for a in actors])
|
||||
@@ -46,7 +46,7 @@ def test_allreduce_destroy(ray_start_single_node_2_gpus,
|
||||
assert (results[1] == cp.ones((10, ), dtype=cp.float32) * world_size).all()
|
||||
|
||||
# destroy the group and try do work, should fail
|
||||
ray.wait([a.destroy_group.remote() for a in actors])
|
||||
ray.get([a.destroy_group.remote() for a in actors])
|
||||
with pytest.raises(RuntimeError):
|
||||
results = ray.get([a.do_allreduce.remote() for a in actors])
|
||||
|
||||
|
||||
@@ -0,0 +1,64 @@
|
||||
"""Test the send/recv API."""
|
||||
import pytest
|
||||
import cupy as cp
|
||||
import ray
|
||||
|
||||
from ray.util.collective.tests.util import create_collective_workers
|
||||
|
||||
|
||||
@pytest.mark.parametrize("group_name", ["default", "test", "123?34!"])
|
||||
@pytest.mark.parametrize("dst_rank", [0, 1])
|
||||
@pytest.mark.parametrize(
|
||||
"array_size", [2, 2**5, 2**10, 2**15, 2**20, [2, 2], [5, 9, 10, 85]])
|
||||
def test_reduce_different_name(ray_start_single_node_2_gpus, group_name,
|
||||
array_size, dst_rank):
|
||||
world_size = 2
|
||||
actors, _ = create_collective_workers(
|
||||
num_workers=world_size, group_name=group_name)
|
||||
ray.wait([
|
||||
a.set_buffer.remote(cp.ones(array_size, dtype=cp.float32) * (i + 1))
|
||||
for i, a in enumerate(actors)
|
||||
])
|
||||
src_rank = 1 - dst_rank
|
||||
refs = []
|
||||
for i, actor in enumerate(actors):
|
||||
if i != dst_rank:
|
||||
ref = actor.do_send.remote(group_name, dst_rank)
|
||||
else:
|
||||
ref = actor.do_recv.remote(group_name, src_rank)
|
||||
refs.append(ref)
|
||||
results = ray.get(refs)
|
||||
for i in range(world_size):
|
||||
assert (results[i] == cp.ones(array_size, dtype=cp.float32) *
|
||||
(src_rank + 1)).all()
|
||||
|
||||
|
||||
@pytest.mark.parametrize("dst_rank", [0, 1])
|
||||
def test_sendrecv_torch_cupy(ray_start_single_node_2_gpus, dst_rank):
|
||||
import torch
|
||||
world_size = 2
|
||||
actors, _ = create_collective_workers(world_size)
|
||||
ray.wait([actors[1].set_buffer.remote(torch.ones(10, ).cuda() * 2)])
|
||||
src_rank = 1 - dst_rank
|
||||
|
||||
refs = []
|
||||
for i, actor in enumerate(actors):
|
||||
if i != dst_rank:
|
||||
ref = actor.do_send.remote(dst_rank=dst_rank)
|
||||
else:
|
||||
ref = actor.do_recv.remote(src_rank=src_rank)
|
||||
refs.append(ref)
|
||||
results = ray.get(refs)
|
||||
if dst_rank == 0:
|
||||
assert (results[0] == cp.ones((10, )) * 2).all()
|
||||
assert (results[1] == torch.ones((10, )).cuda() * 2).all()
|
||||
else:
|
||||
assert (results[0] == cp.ones((10, ))).all()
|
||||
assert (results[1] == torch.ones((10, )).cuda()).all()
|
||||
|
||||
|
||||
def test_sendrecv_invalid_rank(ray_start_single_node_2_gpus, dst_rank=3):
|
||||
world_size = 2
|
||||
actors, _ = create_collective_workers(world_size)
|
||||
with pytest.raises(ValueError):
|
||||
_ = ray.get([a.do_send.remote(dst_rank=dst_rank) for a in actors])
|
||||
@@ -28,6 +28,9 @@ class Worker:
|
||||
self.buffer = data
|
||||
return self.buffer
|
||||
|
||||
def get_buffer(self):
|
||||
return self.buffer
|
||||
|
||||
def set_list_buffer(self, list_of_arrays):
|
||||
self.list_buffer = list_of_arrays
|
||||
return self.list_buffer
|
||||
@@ -52,6 +55,14 @@ class Worker:
|
||||
col.reducescatter(self.buffer, self.list_buffer, group_name, op)
|
||||
return self.buffer
|
||||
|
||||
def do_send(self, group_name="default", dst_rank=0):
|
||||
col.send(self.buffer, dst_rank, group_name)
|
||||
return self.buffer
|
||||
|
||||
def do_recv(self, group_name="default", src_rank=0):
|
||||
col.recv(self.buffer, src_rank, group_name)
|
||||
return self.buffer
|
||||
|
||||
def destroy_group(self, group_name="default"):
|
||||
col.destroy_collective_group(group_name)
|
||||
return True
|
||||
|
||||
Reference in New Issue
Block a user