[Collective] Some necessary abstraction of collective calls before introducing stream management (#13162)

This commit is contained in:
Hao Zhang
2021-01-05 19:20:12 -05:00
committed by GitHub
parent 4e569ee20b
commit 7e52351ae5
3 changed files with 149 additions and 93 deletions
+8 -3
View File
@@ -153,20 +153,25 @@ def declare_collective_group(actors,
pass
if len(ranks) != len(actors):
raise RuntimeError("Each actor should correspond to one rank.")
raise RuntimeError(
"Each actor should correspond to one rank. Got '{}' "
"ranks but '{}' actors".format(len(ranks), len(actors)))
if set(ranks) != set(range(len(ranks))):
raise RuntimeError("Rank must be a permutation from 0 to len-1.")
raise RuntimeError(
"Ranks must be a permutation from 0 to '{}'. Got '{}'.".format(
len(ranks), "".join([str(r) for r in ranks])))
assert world_size > 0
assert all(ranks) >= 0 and all(ranks) < world_size
# avoid a circular dependency
from ray.util.collective.util import Info
# store the information into a NamedActor that can be accessed later/
name = "info_" + group_name
actors_id = [a._ray_actor_id for a in actors]
info = Info.options(name=name, lifetime="detached").remote()
ray.wait([info.set_info.remote(actors_id, world_size, ranks, backend)])
ray.get([info.set_info.remote(actors_id, world_size, ranks, backend)])
def destroy_collective_group(group_name: str = "default") -> None:
@@ -158,24 +158,23 @@ class NCCLGroup(BaseGroup):
"""AllReduce the tensor across the collective group following options.
Args:
tensor: the tensor to be reduced, each tensor locates on a GPU
allreduce_options:
tensor: the tensor to be reduced, each tensor locates on a GPU.
allreduce_options: allreduce options.
Returns:
None
"""
# obtain the communicator
comm = self._get_nccl_collective_communicator()
# obtain the stream: using default stream by now
# TODO(Hao): implement a simple stream manager here
stream = self._get_cuda_stream()
dtype = nccl_util.get_nccl_tensor_dtype(tensor)
ptr = nccl_util.get_tensor_ptr(tensor)
n_elems = nccl_util.get_tensor_n_elements(tensor)
reduce_op = nccl_util.get_nccl_reduce_op(allreduce_options.reduceOp)
def collective_fn(input_tensor, output_tensor, comm, stream):
comm.allReduce(
nccl_util.get_tensor_ptr(input_tensor),
nccl_util.get_tensor_ptr(output_tensor),
nccl_util.get_tensor_n_elements(input_tensor),
nccl_util.get_nccl_tensor_dtype(input_tensor),
nccl_util.get_nccl_reduce_op(allreduce_options.reduceOp),
stream.ptr)
# in-place allreduce
comm.allReduce(ptr, ptr, n_elems, dtype, reduce_op, stream.ptr)
self._collective(tensor, tensor, collective_fn)
def barrier(self, barrier_options=BarrierOptions()):
"""Blocks until all processes reach this barrier.
@@ -184,6 +183,7 @@ class NCCLGroup(BaseGroup):
barrier_options:
Returns:
None
"""
self.allreduce(self._barrier_tensor)
@@ -197,17 +197,17 @@ class NCCLGroup(BaseGroup):
Returns:
None
"""
comm = self._get_nccl_collective_communicator()
stream = self._get_cuda_stream()
dtype = nccl_util.get_nccl_tensor_dtype(tensor)
ptr = nccl_util.get_tensor_ptr(tensor)
n_elems = nccl_util.get_tensor_n_elements(tensor)
reduce_op = nccl_util.get_nccl_reduce_op(reduce_options.reduceOp)
def collective_fn(input_tensor, output_tensor, comm, stream):
comm.reduce(
nccl_util.get_tensor_ptr(input_tensor),
nccl_util.get_tensor_ptr(output_tensor),
nccl_util.get_tensor_n_elements(input_tensor),
nccl_util.get_nccl_tensor_dtype(input_tensor),
nccl_util.get_nccl_reduce_op(reduce_options.reduceOp),
reduce_options.root_rank, stream.ptr)
# in-place reduce
comm.reduce(ptr, ptr, n_elems, dtype, reduce_op,
reduce_options.root_rank, stream.ptr)
self._collective(tensor, tensor, collective_fn)
def broadcast(self, tensor, broadcast_options=BroadcastOptions()):
"""Broadcast tensor to all other processes following options.
@@ -219,15 +219,16 @@ class NCCLGroup(BaseGroup):
Returns:
None
"""
comm = self._get_nccl_collective_communicator()
stream = self._get_cuda_stream()
dtype = nccl_util.get_nccl_tensor_dtype(tensor)
ptr = nccl_util.get_tensor_ptr(tensor)
n_elems = nccl_util.get_tensor_n_elements(tensor)
# in-place broadcast
comm.broadcast(ptr, ptr, n_elems, dtype, broadcast_options.root_rank,
stream.ptr)
def collective_fn(input_tensor, output_tensor, comm, stream):
comm.broadcast(
nccl_util.get_tensor_ptr(input_tensor),
nccl_util.get_tensor_ptr(output_tensor),
nccl_util.get_tensor_n_elements(input_tensor),
nccl_util.get_nccl_tensor_dtype(input_tensor),
broadcast_options.root_rank, stream.ptr)
self._collective(tensor, tensor, collective_fn)
def allgather(self,
tensor_list,
@@ -244,18 +245,26 @@ class NCCLGroup(BaseGroup):
None
"""
_check_inputs_compatibility_for_scatter_gather(tensor, tensor_list)
comm = self._get_nccl_collective_communicator()
stream = self._get_cuda_stream()
def collective_fn(input_tensor, output_tensor, comm, stream):
comm.allGather(
nccl_util.get_tensor_ptr(input_tensor),
nccl_util.get_tensor_ptr(output_tensor),
nccl_util.get_tensor_n_elements(input_tensor),
nccl_util.get_nccl_tensor_dtype(input_tensor), stream.ptr)
dtype = nccl_util.get_nccl_tensor_dtype(tensor)
send_ptr = nccl_util.get_tensor_ptr(tensor)
n_elems = nccl_util.get_tensor_n_elements(tensor)
flattened = _flatten_for_scatter_gather(tensor_list, copy=False)
recv_ptr = nccl_util.get_tensor_ptr(flattened)
comm.allGather(send_ptr, recv_ptr, n_elems, dtype, stream.ptr)
for i, t in enumerate(tensor_list):
nccl_util.copy_tensor(t, flattened[i])
_check_inputs_compatibility_for_scatter_gather(tensor, tensor_list)
flattened_output_tensor = _flatten_for_scatter_gather(
tensor_list, copy=False)
def postprocess_fn(stream):
for i, tensor in enumerate(tensor_list):
nccl_util.copy_tensor(tensor, flattened_output_tensor[i])
self._collective(
tensor,
flattened_output_tensor,
collective_fn,
postprocess_fn=postprocess_fn)
def reducescatter(self,
tensor,
@@ -264,28 +273,36 @@ class NCCLGroup(BaseGroup):
"""Reducescatter a list of tensors across the group.
Args:
tensor: the output after reducescatter (could be unspecified).
tensor_list: the list of tensor to be reduce and scattered.
tensor: the output tensor (could be unspecified).
tensor_list: the list of tensor to be reduced then scattered.
reducescatter_options: reducescatter options.
Returns:
None
"""
def collective_fn(input_tensor, output_tensor, comm, stream):
comm.reduceScatter(
nccl_util.get_tensor_ptr(input_tensor),
nccl_util.get_tensor_ptr(output_tensor),
nccl_util.get_tensor_n_elements(output_tensor),
nccl_util.get_nccl_tensor_dtype(output_tensor),
nccl_util.get_nccl_reduce_op(reducescatter_options.reduceOp),
stream.ptr)
_check_inputs_compatibility_for_scatter_gather(tensor, tensor_list)
flattened_input_tensor = _flatten_for_scatter_gather(
tensor_list, copy=False)
comm = self._get_nccl_collective_communicator()
stream = self._get_cuda_stream()
dtype = nccl_util.get_nccl_tensor_dtype(tensor_list[0])
n_elems = nccl_util.get_tensor_n_elements(tensor_list[0])
reduce_op = nccl_util.get_nccl_reduce_op(
reducescatter_options.reduceOp)
def preprocess_fn(stream):
for i, tensor in enumerate(tensor_list):
nccl_util.copy_tensor(flattened_input_tensor[i], tensor)
# get the send_ptr
flattened = _flatten_for_scatter_gather(tensor_list, copy=True)
send_ptr = nccl_util.get_tensor_ptr(flattened)
recv_ptr = nccl_util.get_tensor_ptr(tensor)
comm.reduceScatter(send_ptr, recv_ptr, n_elems, dtype, reduce_op,
stream.ptr)
self._collective(
flattened_input_tensor,
tensor,
collective_fn,
preprocess_fn=preprocess_fn)
def send(self, tensor, dst_rank):
"""Send tensor to a destination process in the group.
@@ -298,20 +315,13 @@ class NCCLGroup(BaseGroup):
None
"""
# check whether send/recv is available
if nccl_util.get_nccl_runtime_version() < 2704:
raise RuntimeError("send is not available requires NCCL >= 2.7.4. "
"Got '{}'.".format(
nccl_util.get_nccl_runtime_version()))
def p2p_fn(tensor, comm, stream, peer):
comm.send(
nccl_util.get_tensor_ptr(tensor),
nccl_util.get_tensor_n_elements(tensor),
nccl_util.get_nccl_tensor_dtype(tensor), peer, stream.ptr)
peer_p2p_rank = 0 if self.rank > dst_rank else 1
comm = self._get_nccl_p2p_communicator(self.rank, dst_rank)
stream = self._get_cuda_stream()
dtype = nccl_util.get_nccl_tensor_dtype(tensor)
ptr = nccl_util.get_tensor_ptr(tensor)
n_elems = nccl_util.get_tensor_n_elements(tensor)
comm.send(ptr, n_elems, dtype, peer_p2p_rank, stream.ptr)
self._point2point(tensor, p2p_fn, dst_rank)
def recv(self, tensor, src_rank):
"""Receive tensor from a source process in the group.
@@ -323,18 +333,14 @@ class NCCLGroup(BaseGroup):
Returns:
None
"""
if nccl_util.get_nccl_runtime_version() < 2704:
raise RuntimeError("recv is not available requires NCCL >= 2.7.4. "
"Got '{}'.".format(
nccl_util.get_nccl_runtime_version()))
peer_p2p_rank = 0 if self.rank > src_rank else 1
comm = self._get_nccl_p2p_communicator(src_rank, self.rank)
stream = self._get_cuda_stream()
dtype = nccl_util.get_nccl_tensor_dtype(tensor)
ptr = nccl_util.get_tensor_ptr(tensor)
n_elems = nccl_util.get_tensor_n_elements(tensor)
comm.recv(ptr, n_elems, dtype, peer_p2p_rank, stream.ptr)
def p2p_fn(tensor, comm, stream, peer):
comm.recv(
nccl_util.get_tensor_ptr(tensor),
nccl_util.get_tensor_n_elements(tensor),
nccl_util.get_nccl_tensor_dtype(tensor), peer, stream.ptr)
self._point2point(tensor, p2p_fn, src_rank)
def _get_nccl_collective_communicator(self):
"""Create or retrieve a cached NCCL communicator.
@@ -342,7 +348,6 @@ class NCCLGroup(BaseGroup):
Returns:
communicator
"""
# TODO(Hao): implement a thin wrapper
if not self._collective_comm_cache:
# create the communicator
if self.rank == 0:
@@ -357,18 +362,18 @@ class NCCLGroup(BaseGroup):
self.rank)
return self._collective_comm_cache
def _get_nccl_p2p_communicator(self, src_rank, dst_rank):
def _get_nccl_p2p_communicator(self, rank1, rank2):
"""Create or retrieve an NCCL communicator for p2p tasks.
Args:
src_rank (int): source rank.
dst_rank (int): destination rank.
rank1 (int): source rank.
rank2 (int): destination rank.
Returns:
communicator
"""
min_rank = min(src_rank, dst_rank)
max_rank = max(src_rank, dst_rank)
min_rank = min(rank1, rank2)
max_rank = max(rank1, rank2)
my_rank = 0 if self.rank == min_rank else 1
p2p_group_key = self._generate_p2p_group_key(min_rank, max_rank)
comm = self._p2p_comm_cache.get(p2p_group_key)
@@ -422,10 +427,57 @@ class NCCLGroup(BaseGroup):
# TODO: implement a simple stream manager.
return cupy.cuda.Stream.null
# Note(Hao): too many bipolate code -- make some abstraction.
# def _collective_call(self, *args):
# """Private method to encapsulate all collective calls"""
# pass
def _collective(self,
input_tensor,
output_tensor,
collective_fn,
preprocess_fn=None,
postprocess_fn=None):
"""A method to encapsulate all collective calls.
Args:
input_tensor: the input tensor.
output_tensor: the output tensor.
collective_fn: the collective function call.
preprocess_fn: preprocess function to call before collectives.
postprocess_fn: postprocess function to call after collectives.
Returns:
None
"""
comm = self._get_nccl_collective_communicator()
stream = self._get_cuda_stream()
# Make the collective call
if preprocess_fn:
preprocess_fn(stream)
collective_fn(input_tensor, output_tensor, comm, stream)
if postprocess_fn:
postprocess_fn(stream)
def _point2point(self, tensor, p2p_fn, peer_rank: int):
"""A method to encapsulate all p2p calls.
Args:
tensor: the tensor to be sent/received.
p2p_fn: the p2p function call.
peer_rank (int): the peer rank of the current process.
Returns:
None
"""
# check send/recv availability.
if nccl_util.get_nccl_runtime_version() < 2704:
raise RuntimeError("P2p send/recv requires NCCL >= 2.7.4. "
"Got '{}'.".format(
nccl_util.get_nccl_runtime_version()))
# We have made sure that self.rank != peer_rank during API check.
peer_p2p_rank = 0 if self.rank > peer_rank else 1
comm = self._get_nccl_p2p_communicator(self.rank, peer_rank)
stream = self._get_cuda_stream()
# Make the p2p call:
p2p_fn(tensor, comm, stream, peer_p2p_rank)
def _flatten_for_scatter_gather(tensor_list, copy=False):
-1
View File
@@ -42,7 +42,6 @@ class Backend(object):
return backend
# TODO(Hao): extend this to support more MPI types
class ReduceOp(Enum):
SUM = 0
PRODUCT = 1