mirror of
https://github.com/wassname/ray.git
synced 2026-06-27 20:22:39 +08:00
[Collective] Some necessary abstraction of collective calls before introducing stream management (#13162)
This commit is contained in:
@@ -153,20 +153,25 @@ def declare_collective_group(actors,
|
||||
pass
|
||||
|
||||
if len(ranks) != len(actors):
|
||||
raise RuntimeError("Each actor should correspond to one rank.")
|
||||
raise RuntimeError(
|
||||
"Each actor should correspond to one rank. Got '{}' "
|
||||
"ranks but '{}' actors".format(len(ranks), len(actors)))
|
||||
|
||||
if set(ranks) != set(range(len(ranks))):
|
||||
raise RuntimeError("Rank must be a permutation from 0 to len-1.")
|
||||
raise RuntimeError(
|
||||
"Ranks must be a permutation from 0 to '{}'. Got '{}'.".format(
|
||||
len(ranks), "".join([str(r) for r in ranks])))
|
||||
|
||||
assert world_size > 0
|
||||
assert all(ranks) >= 0 and all(ranks) < world_size
|
||||
|
||||
# avoid a circular dependency
|
||||
from ray.util.collective.util import Info
|
||||
# store the information into a NamedActor that can be accessed later/
|
||||
name = "info_" + group_name
|
||||
actors_id = [a._ray_actor_id for a in actors]
|
||||
info = Info.options(name=name, lifetime="detached").remote()
|
||||
ray.wait([info.set_info.remote(actors_id, world_size, ranks, backend)])
|
||||
ray.get([info.set_info.remote(actors_id, world_size, ranks, backend)])
|
||||
|
||||
|
||||
def destroy_collective_group(group_name: str = "default") -> None:
|
||||
|
||||
@@ -158,24 +158,23 @@ class NCCLGroup(BaseGroup):
|
||||
"""AllReduce the tensor across the collective group following options.
|
||||
|
||||
Args:
|
||||
tensor: the tensor to be reduced, each tensor locates on a GPU
|
||||
allreduce_options:
|
||||
tensor: the tensor to be reduced, each tensor locates on a GPU.
|
||||
allreduce_options: allreduce options.
|
||||
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
# obtain the communicator
|
||||
comm = self._get_nccl_collective_communicator()
|
||||
# obtain the stream: using default stream by now
|
||||
# TODO(Hao): implement a simple stream manager here
|
||||
stream = self._get_cuda_stream()
|
||||
|
||||
dtype = nccl_util.get_nccl_tensor_dtype(tensor)
|
||||
ptr = nccl_util.get_tensor_ptr(tensor)
|
||||
n_elems = nccl_util.get_tensor_n_elements(tensor)
|
||||
reduce_op = nccl_util.get_nccl_reduce_op(allreduce_options.reduceOp)
|
||||
def collective_fn(input_tensor, output_tensor, comm, stream):
|
||||
comm.allReduce(
|
||||
nccl_util.get_tensor_ptr(input_tensor),
|
||||
nccl_util.get_tensor_ptr(output_tensor),
|
||||
nccl_util.get_tensor_n_elements(input_tensor),
|
||||
nccl_util.get_nccl_tensor_dtype(input_tensor),
|
||||
nccl_util.get_nccl_reduce_op(allreduce_options.reduceOp),
|
||||
stream.ptr)
|
||||
|
||||
# in-place allreduce
|
||||
comm.allReduce(ptr, ptr, n_elems, dtype, reduce_op, stream.ptr)
|
||||
self._collective(tensor, tensor, collective_fn)
|
||||
|
||||
def barrier(self, barrier_options=BarrierOptions()):
|
||||
"""Blocks until all processes reach this barrier.
|
||||
@@ -184,6 +183,7 @@ class NCCLGroup(BaseGroup):
|
||||
barrier_options:
|
||||
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
self.allreduce(self._barrier_tensor)
|
||||
|
||||
@@ -197,17 +197,17 @@ class NCCLGroup(BaseGroup):
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
comm = self._get_nccl_collective_communicator()
|
||||
stream = self._get_cuda_stream()
|
||||
|
||||
dtype = nccl_util.get_nccl_tensor_dtype(tensor)
|
||||
ptr = nccl_util.get_tensor_ptr(tensor)
|
||||
n_elems = nccl_util.get_tensor_n_elements(tensor)
|
||||
reduce_op = nccl_util.get_nccl_reduce_op(reduce_options.reduceOp)
|
||||
def collective_fn(input_tensor, output_tensor, comm, stream):
|
||||
comm.reduce(
|
||||
nccl_util.get_tensor_ptr(input_tensor),
|
||||
nccl_util.get_tensor_ptr(output_tensor),
|
||||
nccl_util.get_tensor_n_elements(input_tensor),
|
||||
nccl_util.get_nccl_tensor_dtype(input_tensor),
|
||||
nccl_util.get_nccl_reduce_op(reduce_options.reduceOp),
|
||||
reduce_options.root_rank, stream.ptr)
|
||||
|
||||
# in-place reduce
|
||||
comm.reduce(ptr, ptr, n_elems, dtype, reduce_op,
|
||||
reduce_options.root_rank, stream.ptr)
|
||||
self._collective(tensor, tensor, collective_fn)
|
||||
|
||||
def broadcast(self, tensor, broadcast_options=BroadcastOptions()):
|
||||
"""Broadcast tensor to all other processes following options.
|
||||
@@ -219,15 +219,16 @@ class NCCLGroup(BaseGroup):
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
comm = self._get_nccl_collective_communicator()
|
||||
stream = self._get_cuda_stream()
|
||||
|
||||
dtype = nccl_util.get_nccl_tensor_dtype(tensor)
|
||||
ptr = nccl_util.get_tensor_ptr(tensor)
|
||||
n_elems = nccl_util.get_tensor_n_elements(tensor)
|
||||
# in-place broadcast
|
||||
comm.broadcast(ptr, ptr, n_elems, dtype, broadcast_options.root_rank,
|
||||
stream.ptr)
|
||||
def collective_fn(input_tensor, output_tensor, comm, stream):
|
||||
comm.broadcast(
|
||||
nccl_util.get_tensor_ptr(input_tensor),
|
||||
nccl_util.get_tensor_ptr(output_tensor),
|
||||
nccl_util.get_tensor_n_elements(input_tensor),
|
||||
nccl_util.get_nccl_tensor_dtype(input_tensor),
|
||||
broadcast_options.root_rank, stream.ptr)
|
||||
|
||||
self._collective(tensor, tensor, collective_fn)
|
||||
|
||||
def allgather(self,
|
||||
tensor_list,
|
||||
@@ -244,18 +245,26 @@ class NCCLGroup(BaseGroup):
|
||||
None
|
||||
"""
|
||||
|
||||
_check_inputs_compatibility_for_scatter_gather(tensor, tensor_list)
|
||||
comm = self._get_nccl_collective_communicator()
|
||||
stream = self._get_cuda_stream()
|
||||
def collective_fn(input_tensor, output_tensor, comm, stream):
|
||||
comm.allGather(
|
||||
nccl_util.get_tensor_ptr(input_tensor),
|
||||
nccl_util.get_tensor_ptr(output_tensor),
|
||||
nccl_util.get_tensor_n_elements(input_tensor),
|
||||
nccl_util.get_nccl_tensor_dtype(input_tensor), stream.ptr)
|
||||
|
||||
dtype = nccl_util.get_nccl_tensor_dtype(tensor)
|
||||
send_ptr = nccl_util.get_tensor_ptr(tensor)
|
||||
n_elems = nccl_util.get_tensor_n_elements(tensor)
|
||||
flattened = _flatten_for_scatter_gather(tensor_list, copy=False)
|
||||
recv_ptr = nccl_util.get_tensor_ptr(flattened)
|
||||
comm.allGather(send_ptr, recv_ptr, n_elems, dtype, stream.ptr)
|
||||
for i, t in enumerate(tensor_list):
|
||||
nccl_util.copy_tensor(t, flattened[i])
|
||||
_check_inputs_compatibility_for_scatter_gather(tensor, tensor_list)
|
||||
flattened_output_tensor = _flatten_for_scatter_gather(
|
||||
tensor_list, copy=False)
|
||||
|
||||
def postprocess_fn(stream):
|
||||
for i, tensor in enumerate(tensor_list):
|
||||
nccl_util.copy_tensor(tensor, flattened_output_tensor[i])
|
||||
|
||||
self._collective(
|
||||
tensor,
|
||||
flattened_output_tensor,
|
||||
collective_fn,
|
||||
postprocess_fn=postprocess_fn)
|
||||
|
||||
def reducescatter(self,
|
||||
tensor,
|
||||
@@ -264,28 +273,36 @@ class NCCLGroup(BaseGroup):
|
||||
"""Reducescatter a list of tensors across the group.
|
||||
|
||||
Args:
|
||||
tensor: the output after reducescatter (could be unspecified).
|
||||
tensor_list: the list of tensor to be reduce and scattered.
|
||||
tensor: the output tensor (could be unspecified).
|
||||
tensor_list: the list of tensor to be reduced then scattered.
|
||||
reducescatter_options: reducescatter options.
|
||||
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
|
||||
def collective_fn(input_tensor, output_tensor, comm, stream):
|
||||
comm.reduceScatter(
|
||||
nccl_util.get_tensor_ptr(input_tensor),
|
||||
nccl_util.get_tensor_ptr(output_tensor),
|
||||
nccl_util.get_tensor_n_elements(output_tensor),
|
||||
nccl_util.get_nccl_tensor_dtype(output_tensor),
|
||||
nccl_util.get_nccl_reduce_op(reducescatter_options.reduceOp),
|
||||
stream.ptr)
|
||||
|
||||
_check_inputs_compatibility_for_scatter_gather(tensor, tensor_list)
|
||||
flattened_input_tensor = _flatten_for_scatter_gather(
|
||||
tensor_list, copy=False)
|
||||
|
||||
comm = self._get_nccl_collective_communicator()
|
||||
stream = self._get_cuda_stream()
|
||||
dtype = nccl_util.get_nccl_tensor_dtype(tensor_list[0])
|
||||
n_elems = nccl_util.get_tensor_n_elements(tensor_list[0])
|
||||
reduce_op = nccl_util.get_nccl_reduce_op(
|
||||
reducescatter_options.reduceOp)
|
||||
def preprocess_fn(stream):
|
||||
for i, tensor in enumerate(tensor_list):
|
||||
nccl_util.copy_tensor(flattened_input_tensor[i], tensor)
|
||||
|
||||
# get the send_ptr
|
||||
flattened = _flatten_for_scatter_gather(tensor_list, copy=True)
|
||||
send_ptr = nccl_util.get_tensor_ptr(flattened)
|
||||
recv_ptr = nccl_util.get_tensor_ptr(tensor)
|
||||
comm.reduceScatter(send_ptr, recv_ptr, n_elems, dtype, reduce_op,
|
||||
stream.ptr)
|
||||
self._collective(
|
||||
flattened_input_tensor,
|
||||
tensor,
|
||||
collective_fn,
|
||||
preprocess_fn=preprocess_fn)
|
||||
|
||||
def send(self, tensor, dst_rank):
|
||||
"""Send tensor to a destination process in the group.
|
||||
@@ -298,20 +315,13 @@ class NCCLGroup(BaseGroup):
|
||||
None
|
||||
"""
|
||||
|
||||
# check whether send/recv is available
|
||||
if nccl_util.get_nccl_runtime_version() < 2704:
|
||||
raise RuntimeError("send is not available requires NCCL >= 2.7.4. "
|
||||
"Got '{}'.".format(
|
||||
nccl_util.get_nccl_runtime_version()))
|
||||
def p2p_fn(tensor, comm, stream, peer):
|
||||
comm.send(
|
||||
nccl_util.get_tensor_ptr(tensor),
|
||||
nccl_util.get_tensor_n_elements(tensor),
|
||||
nccl_util.get_nccl_tensor_dtype(tensor), peer, stream.ptr)
|
||||
|
||||
peer_p2p_rank = 0 if self.rank > dst_rank else 1
|
||||
comm = self._get_nccl_p2p_communicator(self.rank, dst_rank)
|
||||
stream = self._get_cuda_stream()
|
||||
|
||||
dtype = nccl_util.get_nccl_tensor_dtype(tensor)
|
||||
ptr = nccl_util.get_tensor_ptr(tensor)
|
||||
n_elems = nccl_util.get_tensor_n_elements(tensor)
|
||||
comm.send(ptr, n_elems, dtype, peer_p2p_rank, stream.ptr)
|
||||
self._point2point(tensor, p2p_fn, dst_rank)
|
||||
|
||||
def recv(self, tensor, src_rank):
|
||||
"""Receive tensor from a source process in the group.
|
||||
@@ -323,18 +333,14 @@ class NCCLGroup(BaseGroup):
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
if nccl_util.get_nccl_runtime_version() < 2704:
|
||||
raise RuntimeError("recv is not available requires NCCL >= 2.7.4. "
|
||||
"Got '{}'.".format(
|
||||
nccl_util.get_nccl_runtime_version()))
|
||||
peer_p2p_rank = 0 if self.rank > src_rank else 1
|
||||
comm = self._get_nccl_p2p_communicator(src_rank, self.rank)
|
||||
stream = self._get_cuda_stream()
|
||||
|
||||
dtype = nccl_util.get_nccl_tensor_dtype(tensor)
|
||||
ptr = nccl_util.get_tensor_ptr(tensor)
|
||||
n_elems = nccl_util.get_tensor_n_elements(tensor)
|
||||
comm.recv(ptr, n_elems, dtype, peer_p2p_rank, stream.ptr)
|
||||
def p2p_fn(tensor, comm, stream, peer):
|
||||
comm.recv(
|
||||
nccl_util.get_tensor_ptr(tensor),
|
||||
nccl_util.get_tensor_n_elements(tensor),
|
||||
nccl_util.get_nccl_tensor_dtype(tensor), peer, stream.ptr)
|
||||
|
||||
self._point2point(tensor, p2p_fn, src_rank)
|
||||
|
||||
def _get_nccl_collective_communicator(self):
|
||||
"""Create or retrieve a cached NCCL communicator.
|
||||
@@ -342,7 +348,6 @@ class NCCLGroup(BaseGroup):
|
||||
Returns:
|
||||
communicator
|
||||
"""
|
||||
# TODO(Hao): implement a thin wrapper
|
||||
if not self._collective_comm_cache:
|
||||
# create the communicator
|
||||
if self.rank == 0:
|
||||
@@ -357,18 +362,18 @@ class NCCLGroup(BaseGroup):
|
||||
self.rank)
|
||||
return self._collective_comm_cache
|
||||
|
||||
def _get_nccl_p2p_communicator(self, src_rank, dst_rank):
|
||||
def _get_nccl_p2p_communicator(self, rank1, rank2):
|
||||
"""Create or retrieve an NCCL communicator for p2p tasks.
|
||||
|
||||
Args:
|
||||
src_rank (int): source rank.
|
||||
dst_rank (int): destination rank.
|
||||
rank1 (int): source rank.
|
||||
rank2 (int): destination rank.
|
||||
|
||||
Returns:
|
||||
communicator
|
||||
"""
|
||||
min_rank = min(src_rank, dst_rank)
|
||||
max_rank = max(src_rank, dst_rank)
|
||||
min_rank = min(rank1, rank2)
|
||||
max_rank = max(rank1, rank2)
|
||||
my_rank = 0 if self.rank == min_rank else 1
|
||||
p2p_group_key = self._generate_p2p_group_key(min_rank, max_rank)
|
||||
comm = self._p2p_comm_cache.get(p2p_group_key)
|
||||
@@ -422,10 +427,57 @@ class NCCLGroup(BaseGroup):
|
||||
# TODO: implement a simple stream manager.
|
||||
return cupy.cuda.Stream.null
|
||||
|
||||
# Note(Hao): too many bipolate code -- make some abstraction.
|
||||
# def _collective_call(self, *args):
|
||||
# """Private method to encapsulate all collective calls"""
|
||||
# pass
|
||||
def _collective(self,
|
||||
input_tensor,
|
||||
output_tensor,
|
||||
collective_fn,
|
||||
preprocess_fn=None,
|
||||
postprocess_fn=None):
|
||||
"""A method to encapsulate all collective calls.
|
||||
|
||||
Args:
|
||||
input_tensor: the input tensor.
|
||||
output_tensor: the output tensor.
|
||||
collective_fn: the collective function call.
|
||||
preprocess_fn: preprocess function to call before collectives.
|
||||
postprocess_fn: postprocess function to call after collectives.
|
||||
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
comm = self._get_nccl_collective_communicator()
|
||||
stream = self._get_cuda_stream()
|
||||
|
||||
# Make the collective call
|
||||
if preprocess_fn:
|
||||
preprocess_fn(stream)
|
||||
collective_fn(input_tensor, output_tensor, comm, stream)
|
||||
if postprocess_fn:
|
||||
postprocess_fn(stream)
|
||||
|
||||
def _point2point(self, tensor, p2p_fn, peer_rank: int):
|
||||
"""A method to encapsulate all p2p calls.
|
||||
|
||||
Args:
|
||||
tensor: the tensor to be sent/received.
|
||||
p2p_fn: the p2p function call.
|
||||
peer_rank (int): the peer rank of the current process.
|
||||
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
# check send/recv availability.
|
||||
if nccl_util.get_nccl_runtime_version() < 2704:
|
||||
raise RuntimeError("P2p send/recv requires NCCL >= 2.7.4. "
|
||||
"Got '{}'.".format(
|
||||
nccl_util.get_nccl_runtime_version()))
|
||||
|
||||
# We have made sure that self.rank != peer_rank during API check.
|
||||
peer_p2p_rank = 0 if self.rank > peer_rank else 1
|
||||
comm = self._get_nccl_p2p_communicator(self.rank, peer_rank)
|
||||
stream = self._get_cuda_stream()
|
||||
# Make the p2p call:
|
||||
p2p_fn(tensor, comm, stream, peer_p2p_rank)
|
||||
|
||||
|
||||
def _flatten_for_scatter_gather(tensor_list, copy=False):
|
||||
|
||||
@@ -42,7 +42,6 @@ class Backend(object):
|
||||
return backend
|
||||
|
||||
|
||||
# TODO(Hao): extend this to support more MPI types
|
||||
class ReduceOp(Enum):
|
||||
SUM = 0
|
||||
PRODUCT = 1
|
||||
|
||||
Reference in New Issue
Block a user