diff --git a/python/ray/util/collective/__init__.py b/python/ray/util/collective/__init__.py index fcc879589..4ae886607 100644 --- a/python/ray/util/collective/__init__.py +++ b/python/ray/util/collective/__init__.py @@ -1,11 +1,11 @@ from ray.util.collective.collective import nccl_available, mpi_available, \ is_group_initialized, init_collective_group, destroy_collective_group, \ get_rank, get_world_size, allreduce, barrier, reduce, broadcast, \ - allgather, reducescatter + allgather, reducescatter, send, recv __all__ = [ "nccl_available", "mpi_available", "is_group_initialized", "init_collective_group", "destroy_collective_group", "get_rank", "get_world_size", "allreduce", "barrier", "reduce", "broadcast", - "allgather", "reducescatter" + "allgather", "reducescatter", "send", "recv" ] diff --git a/python/ray/util/collective/collective.py b/python/ray/util/collective/collective.py index 464b116a0..e2263648b 100644 --- a/python/ray/util/collective/collective.py +++ b/python/ray/util/collective/collective.py @@ -4,7 +4,6 @@ import logging import numpy as np import ray from ray.util.collective import types -from ray.util.collective.const import get_nccl_store_name _MPI_AVAILABLE = False _NCCL_AVAILABLE = True @@ -16,7 +15,6 @@ _NCCL_AVAILABLE = True # _MPI_AVAILABLE = False try: from ray.util.collective.collective_group import NCCLGroup - from ray.util.collective.collective_group import nccl_util except ImportError: _NCCL_AVAILABLE = False @@ -53,17 +51,6 @@ class GroupManager(object): if backend == types.Backend.MPI: raise NotImplementedError() elif backend == types.Backend.NCCL: - # create the ncclUniqueID - if rank == 0: - # availability has been checked before entering here. - group_uid = nccl_util.get_nccl_unique_id() - store_name = get_nccl_store_name(group_name) - # Avoid a potential circular dependency in ray/actor.py - from ray.util.collective.util import NCCLUniqueIDStore - store = NCCLUniqueIDStore.options( - name=store_name, lifetime="detached").remote(store_name) - ray.wait([store.set_id.remote(group_uid)]) - logger.debug("creating NCCL group: '{}'".format(group_name)) g = NCCLGroup(world_size, rank, group_name) self._name_group_map[group_name] = g @@ -89,19 +76,9 @@ class GroupManager(object): # release the collective group resource g = self._name_group_map[group_name] - rank = g.rank - backend = g.backend() - # clean up the dicts del self._group_name_map[g] del self._name_group_map[group_name] - if backend == types.Backend.NCCL: - # release the named actor - if rank == 0: - store_name = get_nccl_store_name(group_name) - store = ray.get_actor(store_name) - ray.wait([store.__ray_terminate__.remote()]) - ray.kill(store) # Release the communicator resources g.destroy_group() @@ -322,6 +299,46 @@ def reducescatter(tensor, g.reducescatter(tensor, tensor_list, opts) +def send(tensor, dst_rank: int, group_name: str = "default"): + """Send a tensor to a remote processes synchronously. + + Args: + tensor: the tensor to send. + dst_rank (int): the rank of the destination process. + group_name (str): the name of the collective group. + + Returns: + None + """ + _check_single_tensor_input(tensor) + g = _check_and_get_group(group_name) + _check_rank_valid(g, dst_rank) + if dst_rank == g.rank: + raise RuntimeError( + "The destination rank '{}' is self.".format(dst_rank)) + g.send(tensor, dst_rank) + + +def recv(tensor, src_rank: int, group_name: str = "default"): + """Receive a tensor from a remote process synchronously. + + Args: + tensor: the received tensor. + src_rank (int): the rank of the source process. + group_name (str): the name of the collective group. + + Returns: + None + """ + _check_single_tensor_input(tensor) + g = _check_and_get_group(group_name) + _check_rank_valid(g, src_rank) + if src_rank == g.rank: + raise RuntimeError( + "The destination rank '{}' is self.".format(src_rank)) + g.recv(tensor, src_rank) + + def _check_and_get_group(group_name): """Check the existence and return the group handle.""" _check_inside_actor() @@ -368,6 +385,7 @@ def _check_inside_actor(): def _check_rank_valid(g, rank: int): + """Check the rank: 0 <= rank < world_size.""" if rank < 0: raise ValueError("rank '{}' is negative.".format(rank)) if rank > g.world_size: diff --git a/python/ray/util/collective/collective_group/base_collective_group.py b/python/ray/util/collective/collective_group/base_collective_group.py index 81caf1a6b..5289c562f 100644 --- a/python/ray/util/collective/collective_group/base_collective_group.py +++ b/python/ray/util/collective/collective_group/base_collective_group.py @@ -72,3 +72,11 @@ class BaseGroup(metaclass=ABCMeta): tensor_list, reducescatter_options=ReduceScatterOptions()): raise NotImplementedError() + + @abstractmethod + def send(self, tensor, dst_rank): + raise NotImplementedError() + + @abstractmethod + def recv(self, tensor, src_rank): + raise NotImplementedError() diff --git a/python/ray/util/collective/collective_group/nccl_collective_group.py b/python/ray/util/collective/collective_group/nccl_collective_group.py index 4341f8e67..2c5e09196 100644 --- a/python/ray/util/collective/collective_group/nccl_collective_group.py +++ b/python/ray/util/collective/collective_group/nccl_collective_group.py @@ -8,10 +8,10 @@ import cupy from ray.util.collective.collective_group import nccl_util from ray.util.collective.collective_group.base_collective_group \ import BaseGroup +from ray.util.collective.const import get_nccl_store_name from ray.util.collective.types import AllReduceOptions, \ BarrierOptions, Backend, ReduceOptions, BroadcastOptions, \ AllGatherOptions, ReduceScatterOptions -from ray.util.collective.const import get_nccl_store_name logger = logging.getLogger(__name__) @@ -109,10 +109,10 @@ class NCCLGroup(BaseGroup): def __init__(self, world_size, rank, group_name): """Init an NCCL collective group.""" super(NCCLGroup, self).__init__(world_size, rank, group_name) - self._nccl_uid = None # TODO(Hao): change this to a be a cache - self._nccl_comm = None + self._collective_comm_cache = None + self._p2p_comm_cache = {} if nccl_util.get_nccl_build_version() < 2000: raise RuntimeError("NCCL in Ray requires NCCL >= 2.0.") @@ -120,33 +120,34 @@ class NCCLGroup(BaseGroup): if nccl_util.get_nccl_runtime_version() < 2704: logger.warning("NCCL send/recv calls requires NCCL>=2.7.4") - self._rendezvous = Rendezvous(self.group_name) - self._rendezvous.meet() - - # Setup the nccl uid using the store - self._init_nccl_unique_id() - # Setup a tensor for barrier calls self._barrier_tensor = cupy.array([1]) - def _init_nccl_unique_id(self): - """Init the NCCLUniqueID required for creating NCCL communicators.""" - self._nccl_uid = self._rendezvous.get_nccl_id() - - @property - def nccl_uid(self): - return self._nccl_uid - def destroy_group(self): - """Destroy the group and release the NCCL communicators safely.""" - if self._nccl_comm is not None: + """Destroy the group and release NCCL communicators.""" + if self._collective_comm_cache: self.barrier() # We also need a barrier call here. stream = self._get_cuda_stream() stream.synchronize() # destroy the communicator - self._nccl_comm.destroy() - self._nccl_comm = None + self._collective_comm_cache.destroy() + self._collective_comm_cache = None + + if self.rank == 0: + self._destroy_store(self.group_name) + + if self._p2p_comm_cache: + for key, comm in self._p2p_comm_cache.items(): + comm.destroy() + min_rank, max_rank = self._parse_p2p_group_key(key) + if self.rank == min_rank: + self._destroy_store(key) + self._p2p_comm_cache[key] = None + for key in list(self._p2p_comm_cache.keys()): + del self._p2p_comm_cache[key] + self._p2p_comm_cache = None + super(NCCLGroup, self).destroy_group() @classmethod @@ -163,7 +164,7 @@ class NCCLGroup(BaseGroup): Returns: """ # obtain the communicator - comm = self._get_nccl_communicator() + comm = self._get_nccl_collective_communicator() # obtain the stream: using default stream by now # TODO(Hao): implement a simple stream manager here stream = self._get_cuda_stream() @@ -196,7 +197,7 @@ class NCCLGroup(BaseGroup): Returns: None """ - comm = self._get_nccl_communicator() + comm = self._get_nccl_collective_communicator() stream = self._get_cuda_stream() dtype = nccl_util.get_nccl_tensor_dtype(tensor) @@ -218,7 +219,7 @@ class NCCLGroup(BaseGroup): Returns: None """ - comm = self._get_nccl_communicator() + comm = self._get_nccl_collective_communicator() stream = self._get_cuda_stream() dtype = nccl_util.get_nccl_tensor_dtype(tensor) @@ -232,7 +233,7 @@ class NCCLGroup(BaseGroup): tensor_list, tensor, allgather_options=AllGatherOptions()): - """Allgather tensors across the group into a list of tensors. + """Allgather tensors across the group into a list of tensors. Args: tensor_list: the tensor list to store the results. @@ -244,7 +245,7 @@ class NCCLGroup(BaseGroup): """ _check_inputs_compatibility_for_scatter_gather(tensor, tensor_list) - comm = self._get_nccl_communicator() + comm = self._get_nccl_collective_communicator() stream = self._get_cuda_stream() dtype = nccl_util.get_nccl_tensor_dtype(tensor) @@ -272,7 +273,7 @@ class NCCLGroup(BaseGroup): """ _check_inputs_compatibility_for_scatter_gather(tensor, tensor_list) - comm = self._get_nccl_communicator() + comm = self._get_nccl_collective_communicator() stream = self._get_cuda_stream() dtype = nccl_util.get_nccl_tensor_dtype(tensor_list[0]) n_elems = nccl_util.get_tensor_n_elements(tensor_list[0]) @@ -286,16 +287,134 @@ class NCCLGroup(BaseGroup): comm.reduceScatter(send_ptr, recv_ptr, n_elems, dtype, reduce_op, stream.ptr) - def _get_nccl_communicator(self): - """Create or use a cached NCCL communicator for the collective task. + def send(self, tensor, dst_rank): + """Send tensor to a destination process in the group. + Args: + tensor: the tensor to send. + dst_rank: the rank of the destination process. + + Returns: + None + """ + + # check whether send/recv is available + if nccl_util.get_nccl_runtime_version() < 2704: + raise RuntimeError("send is not available requires NCCL >= 2.7.4. " + "Got '{}'.".format( + nccl_util.get_nccl_runtime_version())) + + peer_p2p_rank = 0 if self.rank > dst_rank else 1 + comm = self._get_nccl_p2p_communicator(self.rank, dst_rank) + stream = self._get_cuda_stream() + + dtype = nccl_util.get_nccl_tensor_dtype(tensor) + ptr = nccl_util.get_tensor_ptr(tensor) + n_elems = nccl_util.get_tensor_n_elements(tensor) + comm.send(ptr, n_elems, dtype, peer_p2p_rank, stream.ptr) + + def recv(self, tensor, src_rank): + """Receive tensor from a source process in the group. + + Args: + tensor: the received tensor. + src_rank: the rank of the source process. + + Returns: + None + """ + if nccl_util.get_nccl_runtime_version() < 2704: + raise RuntimeError("recv is not available requires NCCL >= 2.7.4. " + "Got '{}'.".format( + nccl_util.get_nccl_runtime_version())) + peer_p2p_rank = 0 if self.rank > src_rank else 1 + comm = self._get_nccl_p2p_communicator(src_rank, self.rank) + stream = self._get_cuda_stream() + + dtype = nccl_util.get_nccl_tensor_dtype(tensor) + ptr = nccl_util.get_tensor_ptr(tensor) + n_elems = nccl_util.get_tensor_n_elements(tensor) + comm.recv(ptr, n_elems, dtype, peer_p2p_rank, stream.ptr) + + def _get_nccl_collective_communicator(self): + """Create or retrieve a cached NCCL communicator. + + Returns: + communicator """ - # TODO(Hao): later change this to use device keys and query from cache. # TODO(Hao): implement a thin wrapper - if not self._nccl_comm: - self._nccl_comm = nccl_util.create_nccl_communicator( - self.world_size, self.nccl_uid, self.rank) - return self._nccl_comm + if not self._collective_comm_cache: + # create the communicator + if self.rank == 0: + group_uid = self._generate_nccl_uid(self.group_name) + else: + rendezvous = Rendezvous(self.group_name) + rendezvous.meet() + group_uid = rendezvous.get_nccl_id() + self._collective_comm_cache = \ + nccl_util.create_nccl_communicator(self.world_size, + group_uid, + self.rank) + return self._collective_comm_cache + + def _get_nccl_p2p_communicator(self, src_rank, dst_rank): + """Create or retrieve an NCCL communicator for p2p tasks. + + Args: + src_rank (int): source rank. + dst_rank (int): destination rank. + + Returns: + communicator + """ + min_rank = min(src_rank, dst_rank) + max_rank = max(src_rank, dst_rank) + my_rank = 0 if self.rank == min_rank else 1 + p2p_group_key = self._generate_p2p_group_key(min_rank, max_rank) + comm = self._p2p_comm_cache.get(p2p_group_key) + if not comm: + if self.rank == min_rank: + group_uid = self._generate_nccl_uid(p2p_group_key) + else: + rendezvous = Rendezvous(p2p_group_key) + rendezvous.meet() + group_uid = rendezvous.get_nccl_id() + comm = nccl_util.create_nccl_communicator(2, group_uid, my_rank) + self._p2p_comm_cache[p2p_group_key] = comm + return comm + + def _generate_p2p_group_key(self, min_rank, max_rank): + return self.group_name + "_" + str(min_rank) + "_" + str(max_rank) + + @staticmethod + def _parse_p2p_group_key(key): + strs = key.split("_") + return int(strs[-2]), int(strs[-1]) + + @staticmethod + def _destroy_store(group_name): + store_name = get_nccl_store_name(group_name) + store = ray.get_actor(store_name) + # ray.get([store.__ray_terminate__.remote()]) + ray.kill(store) + + def _generate_nccl_uid(self, name): + """Generate an NCCL UID by calling the NCCL API. + + Args: + name: the name of the collective group. + + Returns: + str: NCCL uid. + """ + group_uid = nccl_util.get_nccl_unique_id() + store_name = get_nccl_store_name(name) + # Avoid a potential circular dependency in ray/actor.py + from ray.util.collective.util import NCCLUniqueIDStore + store = NCCLUniqueIDStore.options( + name=store_name, lifetime="detached").remote(store_name) + ray.wait([store.set_id.remote(group_uid)]) + return group_uid @staticmethod def _get_cuda_stream(): @@ -303,6 +422,7 @@ class NCCLGroup(BaseGroup): # TODO: implement a simple stream manager. return cupy.cuda.Stream.null + # Note(Hao): too many bipolate code -- make some abstraction. # def _collective_call(self, *args): # """Private method to encapsulate all collective calls""" # pass diff --git a/python/ray/util/collective/collective_group/nccl_util.py b/python/ray/util/collective/collective_group/nccl_util.py index da9ced35a..889c8c443 100644 --- a/python/ray/util/collective/collective_group/nccl_util.py +++ b/python/ray/util/collective/collective_group/nccl_util.py @@ -20,24 +20,54 @@ NCCL_REDUCE_OP_MAP = { # cupy types are the same with numpy types NUMPY_NCCL_DTYPE_MAP = { + # INT types + numpy.int: nccl.NCCL_INT, numpy.uint8: nccl.NCCL_UINT8, + numpy.uint32: nccl.NCCL_UINT32, + numpy.uint64: nccl.NCCL_UINT64, + numpy.int8: nccl.NCCL_INT8, + numpy.int32: nccl.NCCL_INT32, + numpy.int64: nccl.NCCL_INT64, + # FLOAT types + numpy.half: nccl.NCCL_HALF, + numpy.float: nccl.NCCL_FLOAT, numpy.float16: nccl.NCCL_FLOAT16, numpy.float32: nccl.NCCL_FLOAT32, numpy.float64: nccl.NCCL_FLOAT64, + numpy.double: nccl.NCCL_DOUBLE } if torch_available(): import torch import torch.utils.dlpack TORCH_NCCL_DTYPE_MAP = { + # INT types + torch.int: nccl.NCCL_INT, torch.uint8: nccl.NCCL_UINT8, + torch.int8: nccl.NCCL_INT8, + torch.int32: nccl.NCCL_INT32, + torch.int64: nccl.NCCL_INT64, + torch.long: nccl.NCCL_INT64, + # FLOAT types + torch.half: nccl.NCCL_HALF, + torch.float: nccl.NCCL_FLOAT, torch.float16: nccl.NCCL_FLOAT16, torch.float32: nccl.NCCL_FLOAT32, torch.float64: nccl.NCCL_FLOAT64, + torch.double: nccl.NCCL_DOUBLE, } TORCH_NUMPY_DTYPE_MAP = { + # INT types + torch.int: numpy.int, torch.uint8: numpy.uint8, + torch.int8: numpy.int8, + torch.int32: numpy.int32, + torch.int64: numpy.int64, + torch.long: numpy.int64, + # FLOAT types + torch.half: numpy.half, + torch.float: numpy.float, torch.float16: numpy.float16, torch.float32: numpy.float32, torch.float64: numpy.float64, diff --git a/python/ray/util/collective/tests/conftest.py b/python/ray/util/collective/tests/conftest.py index b84a01742..ab5b3765d 100644 --- a/python/ray/util/collective/tests/conftest.py +++ b/python/ray/util/collective/tests/conftest.py @@ -5,12 +5,22 @@ import ray from ray.util.collective.const import get_nccl_store_name +# TODO (Hao): remove this clean_up function as it sometimes crashes Ray. def clean_up(): group_names = ["default", "test", "123?34!", "default2", "random"] group_names.extend([str(i) for i in range(10)]) - for group_name in group_names: + max_world_size = 4 + p2p_group_names = [] + for name in group_names: + for i in range(max_world_size): + for j in range(max_world_size): + if i <= j: + p2p_group_name = name + "_" + str(i) + "_" + str(j) + p2p_group_names.append(p2p_group_name) + all_names = group_names + p2p_group_names + for group_name in all_names: + store_name = get_nccl_store_name(group_name) try: - store_name = get_nccl_store_name(group_name) actor = ray.get_actor(store_name) except ValueError: actor = None diff --git a/python/ray/util/collective/tests/distributed_tests/test_distributed_allreduce.py b/python/ray/util/collective/tests/distributed_tests/test_distributed_allreduce.py index 35aae35b2..78587e4e5 100644 --- a/python/ray/util/collective/tests/distributed_tests/test_distributed_allreduce.py +++ b/python/ray/util/collective/tests/distributed_tests/test_distributed_allreduce.py @@ -47,7 +47,7 @@ def test_allreduce_destroy(ray_start_distributed_2_nodes_4_gpus, assert (results[1] == cp.ones((10, ), dtype=cp.float32) * world_size).all() # destroy the group and try do work, should fail - ray.wait([a.destroy_group.remote() for a in actors]) + ray.get([a.destroy_group.remote() for a in actors]) with pytest.raises(RuntimeError): results = ray.get([a.do_allreduce.remote() for a in actors]) diff --git a/python/ray/util/collective/tests/distributed_tests/test_distributed_sendrecv.py b/python/ray/util/collective/tests/distributed_tests/test_distributed_sendrecv.py new file mode 100644 index 000000000..55e2664d5 --- /dev/null +++ b/python/ray/util/collective/tests/distributed_tests/test_distributed_sendrecv.py @@ -0,0 +1,35 @@ +"""Test the send/recv API.""" +import cupy as cp +import pytest +import ray + +from ray.util.collective.tests.util import create_collective_workers + + +@pytest.mark.parametrize("group_name", ["default", "test", "123?34!"]) +@pytest.mark.parametrize("dst_rank", [0, 1, 2, 3]) +@pytest.mark.parametrize("src_rank", [0, 1, 2, 3]) +@pytest.mark.parametrize("array_size", + [2**10, 2**15, 2**20, [2, 2], [5, 9, 10, 85]]) +def test_sendrecv(ray_start_distributed_2_nodes_4_gpus, group_name, array_size, + src_rank, dst_rank): + if src_rank == dst_rank: + return + world_size = 4 + actors, _ = create_collective_workers( + num_workers=world_size, group_name=group_name) + ray.get([ + a.set_buffer.remote(cp.ones(array_size, dtype=cp.float32) * (i + 1)) + for i, a in enumerate(actors) + ]) + refs = [] + for i in range(world_size): + refs.append(actors[i].get_buffer.remote()) + refs[src_rank] = actors[src_rank].do_send.remote(group_name, dst_rank) + refs[dst_rank] = actors[dst_rank].do_recv.remote(group_name, src_rank) + results = ray.get(refs) + assert (results[src_rank] == cp.ones(array_size, dtype=cp.float32) * + (src_rank + 1)).all() + assert (results[dst_rank] == cp.ones(array_size, dtype=cp.float32) * + (src_rank + 1)).all() + ray.get([a.destroy_group.remote(group_name) for a in actors]) diff --git a/python/ray/util/collective/tests/test_allreduce.py b/python/ray/util/collective/tests/test_allreduce.py index 1fbdf526b..e8f7e4237 100644 --- a/python/ray/util/collective/tests/test_allreduce.py +++ b/python/ray/util/collective/tests/test_allreduce.py @@ -46,7 +46,7 @@ def test_allreduce_destroy(ray_start_single_node_2_gpus, assert (results[1] == cp.ones((10, ), dtype=cp.float32) * world_size).all() # destroy the group and try do work, should fail - ray.wait([a.destroy_group.remote() for a in actors]) + ray.get([a.destroy_group.remote() for a in actors]) with pytest.raises(RuntimeError): results = ray.get([a.do_allreduce.remote() for a in actors]) diff --git a/python/ray/util/collective/tests/test_sendrecv.py b/python/ray/util/collective/tests/test_sendrecv.py new file mode 100644 index 000000000..91e9aeab1 --- /dev/null +++ b/python/ray/util/collective/tests/test_sendrecv.py @@ -0,0 +1,64 @@ +"""Test the send/recv API.""" +import pytest +import cupy as cp +import ray + +from ray.util.collective.tests.util import create_collective_workers + + +@pytest.mark.parametrize("group_name", ["default", "test", "123?34!"]) +@pytest.mark.parametrize("dst_rank", [0, 1]) +@pytest.mark.parametrize( + "array_size", [2, 2**5, 2**10, 2**15, 2**20, [2, 2], [5, 9, 10, 85]]) +def test_reduce_different_name(ray_start_single_node_2_gpus, group_name, + array_size, dst_rank): + world_size = 2 + actors, _ = create_collective_workers( + num_workers=world_size, group_name=group_name) + ray.wait([ + a.set_buffer.remote(cp.ones(array_size, dtype=cp.float32) * (i + 1)) + for i, a in enumerate(actors) + ]) + src_rank = 1 - dst_rank + refs = [] + for i, actor in enumerate(actors): + if i != dst_rank: + ref = actor.do_send.remote(group_name, dst_rank) + else: + ref = actor.do_recv.remote(group_name, src_rank) + refs.append(ref) + results = ray.get(refs) + for i in range(world_size): + assert (results[i] == cp.ones(array_size, dtype=cp.float32) * + (src_rank + 1)).all() + + +@pytest.mark.parametrize("dst_rank", [0, 1]) +def test_sendrecv_torch_cupy(ray_start_single_node_2_gpus, dst_rank): + import torch + world_size = 2 + actors, _ = create_collective_workers(world_size) + ray.wait([actors[1].set_buffer.remote(torch.ones(10, ).cuda() * 2)]) + src_rank = 1 - dst_rank + + refs = [] + for i, actor in enumerate(actors): + if i != dst_rank: + ref = actor.do_send.remote(dst_rank=dst_rank) + else: + ref = actor.do_recv.remote(src_rank=src_rank) + refs.append(ref) + results = ray.get(refs) + if dst_rank == 0: + assert (results[0] == cp.ones((10, )) * 2).all() + assert (results[1] == torch.ones((10, )).cuda() * 2).all() + else: + assert (results[0] == cp.ones((10, ))).all() + assert (results[1] == torch.ones((10, )).cuda()).all() + + +def test_sendrecv_invalid_rank(ray_start_single_node_2_gpus, dst_rank=3): + world_size = 2 + actors, _ = create_collective_workers(world_size) + with pytest.raises(ValueError): + _ = ray.get([a.do_send.remote(dst_rank=dst_rank) for a in actors]) diff --git a/python/ray/util/collective/tests/util.py b/python/ray/util/collective/tests/util.py index 3cee4de59..259ee24c9 100644 --- a/python/ray/util/collective/tests/util.py +++ b/python/ray/util/collective/tests/util.py @@ -28,6 +28,9 @@ class Worker: self.buffer = data return self.buffer + def get_buffer(self): + return self.buffer + def set_list_buffer(self, list_of_arrays): self.list_buffer = list_of_arrays return self.list_buffer @@ -52,6 +55,14 @@ class Worker: col.reducescatter(self.buffer, self.list_buffer, group_name, op) return self.buffer + def do_send(self, group_name="default", dst_rank=0): + col.send(self.buffer, dst_rank, group_name) + return self.buffer + + def do_recv(self, group_name="default", src_rank=0): + col.recv(self.buffer, src_rank, group_name) + return self.buffer + def destroy_group(self, group_name="default"): col.destroy_collective_group(group_name) return True