Files
ray/python/ray/util/collective/collective.py
T
Hao Zhang 18f5743416 [Collective][PR 3.5/6] Send/Recv calls and some initial code for communicator caching (#12935)
* other collectives all work

* auto-linting

* mannual linting #1

* mannual linting 2

* bugfix

* add send/recv point-to-point calls

* add some initial code for communicator caching

* auto linting

* optimize imports

* minor fix

* fix unpassed tests

* support more dtypes

* rerun some distributed tests for send/recv

* linting
2020-12-28 09:48:07 -08:00

405 lines
12 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""APIs exposed under the namespace ray.util.collective."""
import logging
import numpy as np
import ray
from ray.util.collective import types
_MPI_AVAILABLE = False
_NCCL_AVAILABLE = True
# try:
# from ray.util.collective.collective_group.mpi_collective_group \
# import MPIGroup
# except ImportError:
# _MPI_AVAILABLE = False
try:
from ray.util.collective.collective_group import NCCLGroup
except ImportError:
_NCCL_AVAILABLE = False
logger = logging.getLogger(__name__)
def nccl_available():
return _NCCL_AVAILABLE
def mpi_available():
return _MPI_AVAILABLE
class GroupManager(object):
"""Use this class to manage the collective groups we created so far.
Each process will have an instance of `GroupManager`. Each process
could belong to multiple collective groups. The membership information
and other metadata are stored in the global `_group_mgr` object.
"""
def __init__(self):
self._name_group_map = {}
self._group_name_map = {}
def create_collective_group(self, backend, world_size, rank, group_name):
"""The entry to create new collective groups in the manager.
Put the registration and the group information into the manager
metadata as well.
"""
backend = types.Backend(backend)
if backend == types.Backend.MPI:
raise NotImplementedError()
elif backend == types.Backend.NCCL:
logger.debug("creating NCCL group: '{}'".format(group_name))
g = NCCLGroup(world_size, rank, group_name)
self._name_group_map[group_name] = g
self._group_name_map[g] = group_name
return self._name_group_map[group_name]
def is_group_exist(self, group_name):
return group_name in self._name_group_map
def get_group_by_name(self, group_name):
"""Get the collective group handle by its name."""
if not self.is_group_exist(group_name):
logger.warning(
"The group '{}' is not initialized.".format(group_name))
return None
return self._name_group_map[group_name]
def destroy_collective_group(self, group_name):
"""Group destructor."""
if not self.is_group_exist(group_name):
logger.warning("The group '{}' does not exist.".format(group_name))
return
# release the collective group resource
g = self._name_group_map[group_name]
# clean up the dicts
del self._group_name_map[g]
del self._name_group_map[group_name]
# Release the communicator resources
g.destroy_group()
_group_mgr = GroupManager()
def is_group_initialized(group_name):
"""Check if the group is initialized in this process by the group name."""
return _group_mgr.is_group_exist(group_name)
def init_collective_group(world_size: int,
rank: int,
backend=types.Backend.NCCL,
group_name: str = "default"):
"""Initialize a collective group inside an actor process.
Args:
world_size (int): the total number of processed in the group.
rank (int): the rank of the current process.
backend: the CCL backend to use, NCCL or MPI.
group_name (str): the name of the collective group.
Returns:
None
"""
_check_inside_actor()
backend = types.Backend(backend)
_check_backend_availability(backend)
global _group_mgr
# TODO(Hao): implement a group auto-counter.
if not group_name:
raise ValueError("group_name '{}' needs to be a string."
.format(group_name))
if _group_mgr.is_group_exist(group_name):
raise RuntimeError("Trying to initialize a group twice.")
assert (world_size > 0)
assert (rank >= 0)
assert (rank < world_size)
_group_mgr.create_collective_group(backend, world_size, rank, group_name)
def destroy_collective_group(group_name: str = "default") -> None:
"""Destroy a collective group given its group name."""
_check_inside_actor()
global _group_mgr
_group_mgr.destroy_collective_group(group_name)
def get_rank(group_name: str = "default") -> int:
"""Return the rank of this process in the given group.
Args:
group_name (str): the name of the group to query
Returns:
the rank of this process in the named group,
-1 if the group does not exist or the process does
not belong to the group.
"""
_check_inside_actor()
if not is_group_initialized(group_name):
return -1
g = _group_mgr.get_group_by_name(group_name)
return g.rank
def get_world_size(group_name: str = "default") -> int:
"""Return the size of the collective gropu with the given name.
Args:
group_name: the name of the group to query
Returns:
The world size of the collective group
-1 if the group does not exist or the process does
not belong to the group.
"""
_check_inside_actor()
if not is_group_initialized(group_name):
return -1
g = _group_mgr.get_group_by_name(group_name)
return g.world_size
def allreduce(tensor, group_name: str = "default", op=types.ReduceOp.SUM):
"""Collective allreduce the tensor across the group.
Args:
tensor: the tensor to be all-reduced on this process.
group_name (str): the collective group name to perform allreduce.
op: The reduce operation.
Returns:
None
"""
_check_single_tensor_input(tensor)
g = _check_and_get_group(group_name)
opts = types.AllReduceOptions
opts.reduceOp = op
g.allreduce(tensor, opts)
def barrier(group_name: str = "default"):
"""Barrier all processes in the collective group.
Args:
group_name (str): the name of the group to barrier.
Returns:
None
"""
g = _check_and_get_group(group_name)
g.barrier()
def reduce(tensor,
dst_rank: int = 0,
group_name: str = "default",
op=types.ReduceOp.SUM):
"""Reduce the tensor across the group to the destination rank.
Args:
tensor: the tensor to be reduced on this process.
dst_rank: the rank of the destination process.
group_name: the collective group name to perform reduce.
op: The reduce operation.
Returns:
None
"""
_check_single_tensor_input(tensor)
g = _check_and_get_group(group_name)
# check dst rank
_check_rank_valid(g, dst_rank)
opts = types.ReduceOptions()
opts.reduceOp = op
opts.root_rank = dst_rank
g.reduce(tensor, opts)
def broadcast(tensor, src_rank: int = 0, group_name: str = "default"):
"""Broadcast the tensor from a source process to all others.
Args:
tensor: the tensor to be broadcasted (src) or received (destination).
src_rank: the rank of the source process.
group_name: he collective group name to perform broadcast.
Returns:
None
"""
_check_single_tensor_input(tensor)
g = _check_and_get_group(group_name)
# check src rank
_check_rank_valid(g, src_rank)
opts = types.BroadcastOptions()
opts.root_rank = src_rank
g.broadcast(tensor, opts)
def allgather(tensor_list: list, tensor, group_name: str = "default"):
"""Allgather tensors from each process of the group into a list.
Args:
tensor_list (list): the results, stored as a list of tensors.
tensor: the tensor (to be gathered) in the current process
group_name: the name of the collective group.
Returns:
None
"""
_check_single_tensor_input(tensor)
_check_tensor_list_input(tensor_list)
g = _check_and_get_group(group_name)
if len(tensor_list) != g.world_size:
# Typically CLL lib requires len(tensor_list) >= world_size;
# Here we make it more strict: len(tensor_list) == world_size.
raise RuntimeError(
"The length of the tensor list operands to allgather "
"must not be equal to world_size.")
opts = types.AllGatherOptions()
g.allgather(tensor_list, tensor, opts)
def reducescatter(tensor,
tensor_list: list,
group_name: str = "default",
op=types.ReduceOp.SUM):
"""Reducescatter a list of tensors across the group.
Reduce the list of the tensors across each process in the group, then
scatter the reduced list of tensors -- one tensor for each process.
Args:
tensor: the resulted tensor on this process.
tensor_list (list): The list of tensors to be reduced and scattered.
group_name (str): the name of the collective group.
op: The reduce operation.
Returns:
None
"""
_check_single_tensor_input(tensor)
_check_tensor_list_input(tensor_list)
g = _check_and_get_group(group_name)
if len(tensor_list) != g.world_size:
raise RuntimeError(
"The length of the tensor list operands to reducescatter "
"must not be equal to world_size.")
opts = types.ReduceScatterOptions()
opts.reduceOp = op
g.reducescatter(tensor, tensor_list, opts)
def send(tensor, dst_rank: int, group_name: str = "default"):
"""Send a tensor to a remote processes synchronously.
Args:
tensor: the tensor to send.
dst_rank (int): the rank of the destination process.
group_name (str): the name of the collective group.
Returns:
None
"""
_check_single_tensor_input(tensor)
g = _check_and_get_group(group_name)
_check_rank_valid(g, dst_rank)
if dst_rank == g.rank:
raise RuntimeError(
"The destination rank '{}' is self.".format(dst_rank))
g.send(tensor, dst_rank)
def recv(tensor, src_rank: int, group_name: str = "default"):
"""Receive a tensor from a remote process synchronously.
Args:
tensor: the received tensor.
src_rank (int): the rank of the source process.
group_name (str): the name of the collective group.
Returns:
None
"""
_check_single_tensor_input(tensor)
g = _check_and_get_group(group_name)
_check_rank_valid(g, src_rank)
if src_rank == g.rank:
raise RuntimeError(
"The destination rank '{}' is self.".format(src_rank))
g.recv(tensor, src_rank)
def _check_and_get_group(group_name):
"""Check the existence and return the group handle."""
_check_inside_actor()
if not is_group_initialized(group_name):
raise RuntimeError("The collective group '{}' is not "
"initialized in the process.".format(group_name))
g = _group_mgr.get_group_by_name(group_name)
return g
def _check_backend_availability(backend: types.Backend):
"""Check whether the backend is available."""
if backend == types.Backend.MPI:
if not mpi_available():
raise RuntimeError("MPI is not available.")
elif backend == types.Backend.NCCL:
if not nccl_available():
raise RuntimeError("NCCL is not available.")
def _check_single_tensor_input(tensor):
"""Check if the tensor is with a supported type."""
if isinstance(tensor, np.ndarray):
return
if types.cupy_available():
if isinstance(tensor, types.cp.ndarray):
return
if types.torch_available():
if isinstance(tensor, types.th.Tensor):
return
raise RuntimeError("Unrecognized tensor type '{}'. Supported types are: "
"np.ndarray, torch.Tensor, cupy.ndarray.".format(
type(tensor)))
def _check_inside_actor():
"""Check if currently it is inside a Ray actor/task."""
worker = ray.worker.global_worker
if worker.mode == ray.WORKER_MODE:
return
else:
raise RuntimeError("The collective APIs shall be only used inside "
"a Ray actor or task.")
def _check_rank_valid(g, rank: int):
"""Check the rank: 0 <= rank < world_size."""
if rank < 0:
raise ValueError("rank '{}' is negative.".format(rank))
if rank > g.world_size:
raise ValueError("rank '{}' is greater than world size "
"'{}'".format(rank, g.world_size))
def _check_tensor_list_input(tensor_list):
"""Check if the input is a list of supported tensor types."""
if not isinstance(tensor_list, list):
raise RuntimeError("The input must be a list of tensors. "
"Got '{}'.".format(type(tensor_list)))
if not tensor_list:
raise RuntimeError("Got an empty list of tensors.")
for t in tensor_list:
_check_single_tensor_input(t)