From 5b48480e2932de0aaa1853b8abcc3e1f69e59cd5 Mon Sep 17 00:00:00 2001 From: Hao Zhang Date: Mon, 21 Dec 2020 15:48:00 -0500 Subject: [PATCH] [Collective][PR 3/6] Other collectives (#12864) --- python/ray/util/collective/__init__.py | 10 +- python/ray/util/collective/collective.py | 149 ++++++++-- .../collective_group/base_collective_group.py | 28 +- .../collective_group/nccl_collective_group.py | 175 +++++++++-- .../collective/collective_group/nccl_util.py | 103 ++++++- python/ray/util/collective/const.py | 3 +- .../tests/distributed_tests/__init__.py | 0 .../test_distributed_allgather.py | 133 +++++++++ .../test_distributed_allreduce.py | 139 +++++++++ .../test_distributed_basic_apis.py | 135 +++++++++ .../test_distributed_broadcast.py | 67 +++++ .../test_distributed_reduce.py | 119 ++++++++ .../test_distributed_reducescatter.py | 128 ++++++++ .../util/collective/tests/test_allgather.py | 131 +++++++++ .../util/collective/tests/test_allreduce.py | 143 +++++++++ .../util/collective/tests/test_basic_apis.py | 127 ++++++++ .../util/collective/tests/test_broadcast.py | 67 +++++ .../tests/test_collective_2_nodes_4_gpus.py | 276 ------------------ .../test_collective_single_node_2_gpus.py | 267 ----------------- .../ray/util/collective/tests/test_reduce.py | 143 +++++++++ .../collective/tests/test_reducescatter.py | 127 ++++++++ python/ray/util/collective/tests/util.py | 67 ++++- python/ray/util/collective/types.py | 37 ++- 23 files changed, 1968 insertions(+), 606 deletions(-) create mode 100644 python/ray/util/collective/tests/distributed_tests/__init__.py create mode 100644 python/ray/util/collective/tests/distributed_tests/test_distributed_allgather.py create mode 100644 python/ray/util/collective/tests/distributed_tests/test_distributed_allreduce.py create mode 100644 python/ray/util/collective/tests/distributed_tests/test_distributed_basic_apis.py create mode 100644 python/ray/util/collective/tests/distributed_tests/test_distributed_broadcast.py create mode 100644 python/ray/util/collective/tests/distributed_tests/test_distributed_reduce.py create mode 100644 python/ray/util/collective/tests/distributed_tests/test_distributed_reducescatter.py create mode 100644 python/ray/util/collective/tests/test_allgather.py create mode 100644 python/ray/util/collective/tests/test_allreduce.py create mode 100644 python/ray/util/collective/tests/test_basic_apis.py create mode 100644 python/ray/util/collective/tests/test_broadcast.py delete mode 100644 python/ray/util/collective/tests/test_collective_2_nodes_4_gpus.py delete mode 100644 python/ray/util/collective/tests/test_collective_single_node_2_gpus.py create mode 100644 python/ray/util/collective/tests/test_reduce.py create mode 100644 python/ray/util/collective/tests/test_reducescatter.py diff --git a/python/ray/util/collective/__init__.py b/python/ray/util/collective/__init__.py index 68fcb78d4..fcc879589 100644 --- a/python/ray/util/collective/__init__.py +++ b/python/ray/util/collective/__init__.py @@ -1,9 +1,11 @@ -from .collective import nccl_available, mpi_available, is_group_initialized, \ - init_collective_group, destroy_collective_group, get_rank, \ - get_world_size, allreduce, barrier +from ray.util.collective.collective import nccl_available, mpi_available, \ + is_group_initialized, init_collective_group, destroy_collective_group, \ + get_rank, get_world_size, allreduce, barrier, reduce, broadcast, \ + allgather, reducescatter __all__ = [ "nccl_available", "mpi_available", "is_group_initialized", "init_collective_group", "destroy_collective_group", "get_rank", - "get_world_size", "allreduce", "barrier" + "get_world_size", "allreduce", "barrier", "reduce", "broadcast", + "allgather", "reducescatter" ] diff --git a/python/ray/util/collective/collective.py b/python/ray/util/collective/collective.py index 343487e71..464b116a0 100644 --- a/python/ray/util/collective/collective.py +++ b/python/ray/util/collective/collective.py @@ -32,8 +32,7 @@ def mpi_available(): class GroupManager(object): - """ - Use this class to manage the collective groups we created so far. + """Use this class to manage the collective groups we created so far. Each process will have an instance of `GroupManager`. Each process could belong to multiple collective groups. The membership information @@ -45,8 +44,7 @@ class GroupManager(object): self._group_name_map = {} def create_collective_group(self, backend, world_size, rank, group_name): - """ - The entry to create new collective groups and register in the manager. + """The entry to create new collective groups in the manager. Put the registration and the group information into the manager metadata as well. @@ -120,8 +118,7 @@ def init_collective_group(world_size: int, rank: int, backend=types.Backend.NCCL, group_name: str = "default"): - """ - Initialize a collective group inside an actor process. + """Initialize a collective group inside an actor process. Args: world_size (int): the total number of processed in the group. @@ -158,8 +155,7 @@ def destroy_collective_group(group_name: str = "default") -> None: def get_rank(group_name: str = "default") -> int: - """ - Return the rank of this process in the given group. + """Return the rank of this process in the given group. Args: group_name (str): the name of the group to query @@ -176,9 +172,8 @@ def get_rank(group_name: str = "default") -> int: return g.rank -def get_world_size(group_name="default") -> int: - """ - Return the size of the collective gropu with the given name. +def get_world_size(group_name: str = "default") -> int: + """Return the size of the collective gropu with the given name. Args: group_name: the name of the group to query @@ -195,9 +190,8 @@ def get_world_size(group_name="default") -> int: return g.world_size -def allreduce(tensor, group_name: str, op=types.ReduceOp.SUM): - """ - Collective allreduce the tensor across the group with name group_name. +def allreduce(tensor, group_name: str = "default", op=types.ReduceOp.SUM): + """Collective allreduce the tensor across the group. Args: tensor: the tensor to be all-reduced on this process. @@ -214,9 +208,8 @@ def allreduce(tensor, group_name: str, op=types.ReduceOp.SUM): g.allreduce(tensor, opts) -def barrier(group_name): - """ - Barrier all processes in the collective group. +def barrier(group_name: str = "default"): + """Barrier all processes in the collective group. Args: group_name (str): the name of the group to barrier. @@ -228,6 +221,107 @@ def barrier(group_name): g.barrier() +def reduce(tensor, + dst_rank: int = 0, + group_name: str = "default", + op=types.ReduceOp.SUM): + """Reduce the tensor across the group to the destination rank. + + Args: + tensor: the tensor to be reduced on this process. + dst_rank: the rank of the destination process. + group_name: the collective group name to perform reduce. + op: The reduce operation. + + Returns: + None + """ + _check_single_tensor_input(tensor) + g = _check_and_get_group(group_name) + + # check dst rank + _check_rank_valid(g, dst_rank) + opts = types.ReduceOptions() + opts.reduceOp = op + opts.root_rank = dst_rank + g.reduce(tensor, opts) + + +def broadcast(tensor, src_rank: int = 0, group_name: str = "default"): + """Broadcast the tensor from a source process to all others. + + Args: + tensor: the tensor to be broadcasted (src) or received (destination). + src_rank: the rank of the source process. + group_name: he collective group name to perform broadcast. + + Returns: + None + """ + _check_single_tensor_input(tensor) + g = _check_and_get_group(group_name) + + # check src rank + _check_rank_valid(g, src_rank) + opts = types.BroadcastOptions() + opts.root_rank = src_rank + g.broadcast(tensor, opts) + + +def allgather(tensor_list: list, tensor, group_name: str = "default"): + """Allgather tensors from each process of the group into a list. + + Args: + tensor_list (list): the results, stored as a list of tensors. + tensor: the tensor (to be gathered) in the current process + group_name: the name of the collective group. + + Returns: + None + """ + _check_single_tensor_input(tensor) + _check_tensor_list_input(tensor_list) + g = _check_and_get_group(group_name) + if len(tensor_list) != g.world_size: + # Typically CLL lib requires len(tensor_list) >= world_size; + # Here we make it more strict: len(tensor_list) == world_size. + raise RuntimeError( + "The length of the tensor list operands to allgather " + "must not be equal to world_size.") + opts = types.AllGatherOptions() + g.allgather(tensor_list, tensor, opts) + + +def reducescatter(tensor, + tensor_list: list, + group_name: str = "default", + op=types.ReduceOp.SUM): + """Reducescatter a list of tensors across the group. + + Reduce the list of the tensors across each process in the group, then + scatter the reduced list of tensors -- one tensor for each process. + + Args: + tensor: the resulted tensor on this process. + tensor_list (list): The list of tensors to be reduced and scattered. + group_name (str): the name of the collective group. + op: The reduce operation. + + Returns: + None + """ + _check_single_tensor_input(tensor) + _check_tensor_list_input(tensor_list) + g = _check_and_get_group(group_name) + if len(tensor_list) != g.world_size: + raise RuntimeError( + "The length of the tensor list operands to reducescatter " + "must not be equal to world_size.") + opts = types.ReduceScatterOptions() + opts.reduceOp = op + g.reducescatter(tensor, tensor_list, opts) + + def _check_and_get_group(group_name): """Check the existence and return the group handle.""" _check_inside_actor() @@ -244,8 +338,6 @@ def _check_backend_availability(backend: types.Backend): if not mpi_available(): raise RuntimeError("MPI is not available.") elif backend == types.Backend.NCCL: - # expect some slowdown at the first call - # as I defer the import to invocation. if not nccl_available(): raise RuntimeError("NCCL is not available.") @@ -273,3 +365,22 @@ def _check_inside_actor(): else: raise RuntimeError("The collective APIs shall be only used inside " "a Ray actor or task.") + + +def _check_rank_valid(g, rank: int): + if rank < 0: + raise ValueError("rank '{}' is negative.".format(rank)) + if rank > g.world_size: + raise ValueError("rank '{}' is greater than world size " + "'{}'".format(rank, g.world_size)) + + +def _check_tensor_list_input(tensor_list): + """Check if the input is a list of supported tensor types.""" + if not isinstance(tensor_list, list): + raise RuntimeError("The input must be a list of tensors. " + "Got '{}'.".format(type(tensor_list))) + if not tensor_list: + raise RuntimeError("Got an empty list of tensors.") + for t in tensor_list: + _check_single_tensor_input(t) diff --git a/python/ray/util/collective/collective_group/base_collective_group.py b/python/ray/util/collective/collective_group/base_collective_group.py index a3f54fa26..81caf1a6b 100644 --- a/python/ray/util/collective/collective_group/base_collective_group.py +++ b/python/ray/util/collective/collective_group/base_collective_group.py @@ -2,13 +2,13 @@ from abc import ABCMeta from abc import abstractmethod -from ray.util.collective.types import AllReduceOptions, BarrierOptions +from ray.util.collective.types import AllReduceOptions, BarrierOptions, \ + ReduceOptions, AllGatherOptions, BroadcastOptions, ReduceScatterOptions class BaseGroup(metaclass=ABCMeta): def __init__(self, world_size, rank, group_name): - """ - Init the process group with basic information. + """Init the process group with basic information. Args: world_size (int): The total number of processes in the group. @@ -50,3 +50,25 @@ class BaseGroup(metaclass=ABCMeta): @abstractmethod def barrier(self, barrier_options=BarrierOptions()): raise NotImplementedError() + + @abstractmethod + def reduce(self, tensor, reduce_options=ReduceOptions()): + raise NotImplementedError() + + @abstractmethod + def allgather(self, + tensor_list, + tensor, + allgather_options=AllGatherOptions()): + raise NotImplementedError() + + @abstractmethod + def broadcast(self, tensor, broadcast_options=BroadcastOptions()): + raise NotImplementedError() + + @abstractmethod + def reducescatter(self, + tensor, + tensor_list, + reducescatter_options=ReduceScatterOptions()): + raise NotImplementedError() diff --git a/python/ray/util/collective/collective_group/nccl_collective_group.py b/python/ray/util/collective/collective_group/nccl_collective_group.py index 31412b5a4..4341f8e67 100644 --- a/python/ray/util/collective/collective_group/nccl_collective_group.py +++ b/python/ray/util/collective/collective_group/nccl_collective_group.py @@ -9,7 +9,8 @@ from ray.util.collective.collective_group import nccl_util from ray.util.collective.collective_group.base_collective_group \ import BaseGroup from ray.util.collective.types import AllReduceOptions, \ - BarrierOptions, Backend + BarrierOptions, Backend, ReduceOptions, BroadcastOptions, \ + AllGatherOptions, ReduceScatterOptions from ray.util.collective.const import get_nccl_store_name logger = logging.getLogger(__name__) @@ -21,8 +22,7 @@ logger = logging.getLogger(__name__) class Rendezvous: - """ - A rendezvous class for different actor/task processes to meet. + """A rendezvous class for different actor/task processes to meet. To initialize an NCCL collective communication group, different actors/tasks spawned in Ray in a collective group needs to meet @@ -42,8 +42,7 @@ class Rendezvous: self._store = None def meet(self, timeout_s=180): - """ - Meet at the named actor store. + """Meet at the named actor store. Args: timeout_s: timeout in seconds. @@ -80,8 +79,7 @@ class Rendezvous: return self._store def get_nccl_id(self, timeout_s=180): - """ - Get the NCCLUniqueID from the store through Ray. + """Get the NCCLUniqueID from the store through Ray. Args: timeout_s: timeout in seconds. @@ -132,10 +130,7 @@ class NCCLGroup(BaseGroup): self._barrier_tensor = cupy.array([1]) def _init_nccl_unique_id(self): - """ - Init the NCCL unique ID required for setting up NCCL communicator. - - """ + """Init the NCCLUniqueID required for creating NCCL communicators.""" self._nccl_uid = self._rendezvous.get_nccl_id() @property @@ -143,10 +138,7 @@ class NCCLGroup(BaseGroup): return self._nccl_uid def destroy_group(self): - """ - Destroy the group and release the NCCL communicators safely. - - """ + """Destroy the group and release the NCCL communicators safely.""" if self._nccl_comm is not None: self.barrier() # We also need a barrier call here. @@ -162,8 +154,7 @@ class NCCLGroup(BaseGroup): return Backend.NCCL def allreduce(self, tensor, allreduce_options=AllReduceOptions()): - """ - AllReduce a list of tensors following options. + """AllReduce the tensor across the collective group following options. Args: tensor: the tensor to be reduced, each tensor locates on a GPU @@ -186,8 +177,7 @@ class NCCLGroup(BaseGroup): comm.allReduce(ptr, ptr, n_elems, dtype, reduce_op, stream.ptr) def barrier(self, barrier_options=BarrierOptions()): - """ - Blocks until all processes reach this barrier. + """Blocks until all processes reach this barrier. Args: barrier_options: @@ -196,9 +186,108 @@ class NCCLGroup(BaseGroup): """ self.allreduce(self._barrier_tensor) - def _get_nccl_communicator(self): + def reduce(self, tensor, reduce_options=ReduceOptions()): + """Reduce tensor to a destination process following options. + + Args: + tensor: the tensor to be reduced. + reduce_options: reduce options + + Returns: + None """ - Create or use a cached NCCL communicator for the collective task. + comm = self._get_nccl_communicator() + stream = self._get_cuda_stream() + + dtype = nccl_util.get_nccl_tensor_dtype(tensor) + ptr = nccl_util.get_tensor_ptr(tensor) + n_elems = nccl_util.get_tensor_n_elements(tensor) + reduce_op = nccl_util.get_nccl_reduce_op(reduce_options.reduceOp) + + # in-place reduce + comm.reduce(ptr, ptr, n_elems, dtype, reduce_op, + reduce_options.root_rank, stream.ptr) + + def broadcast(self, tensor, broadcast_options=BroadcastOptions()): + """Broadcast tensor to all other processes following options. + + Args: + tensor: the tensor to be broadcasted. + broadcast_options: broadcast options. + + Returns: + None + """ + comm = self._get_nccl_communicator() + stream = self._get_cuda_stream() + + dtype = nccl_util.get_nccl_tensor_dtype(tensor) + ptr = nccl_util.get_tensor_ptr(tensor) + n_elems = nccl_util.get_tensor_n_elements(tensor) + # in-place broadcast + comm.broadcast(ptr, ptr, n_elems, dtype, broadcast_options.root_rank, + stream.ptr) + + def allgather(self, + tensor_list, + tensor, + allgather_options=AllGatherOptions()): + """Allgather tensors across the group into a list of tensors. + + Args: + tensor_list: the tensor list to store the results. + tensor: the tensor to be allgather-ed across the group. + allgather_options: allgather options. + + Returns: + None + """ + + _check_inputs_compatibility_for_scatter_gather(tensor, tensor_list) + comm = self._get_nccl_communicator() + stream = self._get_cuda_stream() + + dtype = nccl_util.get_nccl_tensor_dtype(tensor) + send_ptr = nccl_util.get_tensor_ptr(tensor) + n_elems = nccl_util.get_tensor_n_elements(tensor) + flattened = _flatten_for_scatter_gather(tensor_list, copy=False) + recv_ptr = nccl_util.get_tensor_ptr(flattened) + comm.allGather(send_ptr, recv_ptr, n_elems, dtype, stream.ptr) + for i, t in enumerate(tensor_list): + nccl_util.copy_tensor(t, flattened[i]) + + def reducescatter(self, + tensor, + tensor_list, + reducescatter_options=ReduceScatterOptions()): + """Reducescatter a list of tensors across the group. + + Args: + tensor: the output after reducescatter (could be unspecified). + tensor_list: the list of tensor to be reduce and scattered. + reducescatter_options: reducescatter options. + + Returns: + None + """ + _check_inputs_compatibility_for_scatter_gather(tensor, tensor_list) + + comm = self._get_nccl_communicator() + stream = self._get_cuda_stream() + dtype = nccl_util.get_nccl_tensor_dtype(tensor_list[0]) + n_elems = nccl_util.get_tensor_n_elements(tensor_list[0]) + reduce_op = nccl_util.get_nccl_reduce_op( + reducescatter_options.reduceOp) + + # get the send_ptr + flattened = _flatten_for_scatter_gather(tensor_list, copy=True) + send_ptr = nccl_util.get_tensor_ptr(flattened) + recv_ptr = nccl_util.get_tensor_ptr(tensor) + comm.reduceScatter(send_ptr, recv_ptr, n_elems, dtype, reduce_op, + stream.ptr) + + def _get_nccl_communicator(self): + """Create or use a cached NCCL communicator for the collective task. """ # TODO(Hao): later change this to use device keys and query from cache. @@ -217,3 +306,47 @@ class NCCLGroup(BaseGroup): # def _collective_call(self, *args): # """Private method to encapsulate all collective calls""" # pass + + +def _flatten_for_scatter_gather(tensor_list, copy=False): + """Flatten the tensor for gather/scatter operations. + + Args: + tensor_list: the list of tensors to be scattered/gathered. + copy: whether the copy the tensors in tensor_list into the buffer. + + Returns: + The flattened tensor buffer. + """ + if not tensor_list: + raise RuntimeError("Received an empty list.") + t = tensor_list[0] + # note we need a cupy dtype here. + dtype = nccl_util.get_cupy_tensor_dtype(t) + buffer_shape = [len(tensor_list)] + nccl_util.get_tensor_shape(t) + buffer = cupy.empty(buffer_shape, dtype=dtype) + if copy: + for i, tensor in enumerate(tensor_list): + nccl_util.copy_tensor(buffer[i], tensor) + return buffer + + +def _check_inputs_compatibility_for_scatter_gather(tensor, tensor_list): + """Check the compatibility between tensor input and tensor list inputs.""" + if not tensor_list: + raise RuntimeError("Got empty list of tensors.") + dtype = nccl_util.get_nccl_tensor_dtype(tensor) + shape = nccl_util.get_tensor_shape(tensor) + for t in tensor_list: + # check dtype + dt = nccl_util.get_nccl_tensor_dtype(t) + if dt != dtype: + raise RuntimeError("All tensor operands to scatter/gather must " + "have the same dtype. Got '{}' and '{}'" + "".format(dt, dtype)) + # Note: typically CCL libraries only requires they have the same + # number of elements; + # Here we make it more strict -- we require exact shape match. + if nccl_util.get_tensor_shape(t) != shape: + raise RuntimeError("All tensor operands to scatter/gather must " + "have the same shape.") diff --git a/python/ray/util/collective/collective_group/nccl_util.py b/python/ray/util/collective/collective_group/nccl_util.py index 4d2fc456f..da9ced35a 100644 --- a/python/ray/util/collective/collective_group/nccl_util.py +++ b/python/ray/util/collective/collective_group/nccl_util.py @@ -28,6 +28,7 @@ NUMPY_NCCL_DTYPE_MAP = { if torch_available(): import torch + import torch.utils.dlpack TORCH_NCCL_DTYPE_MAP = { torch.uint8: nccl.NCCL_UINT8, torch.float16: nccl.NCCL_FLOAT16, @@ -35,6 +36,13 @@ if torch_available(): torch.float64: nccl.NCCL_FLOAT64, } + TORCH_NUMPY_DTYPE_MAP = { + torch.uint8: numpy.uint8, + torch.float16: numpy.float16, + torch.float32: numpy.float32, + torch.float64: numpy.float64, + } + def get_nccl_build_version(): return get_build_version() @@ -49,8 +57,7 @@ def get_nccl_unique_id(): def create_nccl_communicator(world_size, nccl_unique_id, rank): - """ - Create an NCCL communicator using NCCL APIs. + """Create an NCCL communicator using NCCL APIs. Args: world_size (int): the number of processes of this communcator group. @@ -66,8 +73,7 @@ def create_nccl_communicator(world_size, nccl_unique_id, rank): def get_nccl_reduce_op(reduce_op): - """ - Map the reduce op to NCCL reduce op type. + """Map the reduce op to NCCL reduce op type. Args: reduce_op (ReduceOp): ReduceOp Enum (SUM/PRODUCT/MIN/MAX). @@ -87,8 +93,21 @@ def get_nccl_tensor_dtype(tensor): if torch_available(): if isinstance(tensor, torch.Tensor): return TORCH_NCCL_DTYPE_MAP[tensor.dtype] - raise ValueError("Unsupported tensor type. " - "Got: {}.".format(type(tensor))) + raise ValueError("Unsupported tensor type. Got: {}. Supported " + "GPU tensor types are: torch.Tensor, " + "cupy.ndarray.".format(type(tensor))) + + +def get_cupy_tensor_dtype(tensor): + """Return the corresponded Cupy dtype given a tensor.""" + if isinstance(tensor, cupy.ndarray): + return tensor.dtype.type + if torch_available(): + if isinstance(tensor, torch.Tensor): + return TORCH_NUMPY_DTYPE_MAP[tensor.dtype] + raise ValueError("Unsupported tensor type. Got: {}. Supported " + "GPU tensor types are: torch.Tensor, " + "cupy.ndarray.".format(type(tensor))) def get_tensor_ptr(tensor): @@ -102,8 +121,9 @@ def get_tensor_ptr(tensor): if not tensor.is_cuda: raise RuntimeError("torch tensor must be on gpu.") return tensor.data_ptr() - raise ValueError("Unsupported tensor type. " - "Got: {}.".format(type(tensor))) + raise ValueError("Unsupported tensor type. Got: {}. Supported " + "GPU tensor types are: torch.Tensor, " + "cupy.ndarray.".format(type(tensor))) def get_tensor_n_elements(tensor): @@ -113,5 +133,68 @@ def get_tensor_n_elements(tensor): if torch_available(): if isinstance(tensor, torch.Tensor): return torch.numel(tensor) - raise ValueError("Unsupported tensor type. " - "Got: {}.".format(type(tensor))) + raise ValueError("Unsupported tensor type. Got: {}. Supported " + "GPU tensor types are: torch.Tensor, " + "cupy.ndarray.".format(type(tensor))) + + +def get_tensor_shape(tensor): + """Return the shape of the tensor as a list.""" + if isinstance(tensor, cupy.ndarray): + return list(tensor.shape) + if torch_available(): + if isinstance(tensor, torch.Tensor): + return list(tensor.size()) + raise ValueError("Unsupported tensor type. Got: {}. Supported " + "GPU tensor types are: torch.Tensor, " + "cupy.ndarray.".format(type(tensor))) + + +def get_tensor_strides(tensor): + """Return the strides of the tensor as a list.""" + if isinstance(tensor, cupy.ndarray): + return [ + int(stride / tensor.dtype.itemsize) for stride in tensor.strides + ] + if torch_available(): + if isinstance(tensor, torch.Tensor): + return list(tensor.stride()) + raise ValueError("Unsupported tensor type. Got: {}. Supported " + "GPU tensor types are: torch.Tensor, " + "cupy.ndarray.".format(type(tensor))) + + +def copy_tensor(dst_tensor, src_tensor): + """Copy the content from src_tensor to dst_tensor. + + Args: + dst_tensor: the tensor to copy from. + src_tensor: the tensor to copy to. + + Returns: + None + """ + copied = True + if isinstance(dst_tensor, cupy.ndarray) \ + and isinstance(src_tensor, cupy.ndarray): + cupy.copyto(dst_tensor, src_tensor) + elif torch_available(): + if isinstance(dst_tensor, torch.Tensor) and isinstance( + src_tensor, torch.Tensor): + dst_tensor.copy_(src_tensor) + elif isinstance(dst_tensor, torch.Tensor) and isinstance( + src_tensor, cupy.ndarray): + t = torch.utils.dlpack.from_dlpack(src_tensor.toDlpack()) + dst_tensor.copy_(t) + elif isinstance(dst_tensor, cupy.ndarray) and isinstance( + src_tensor, torch.Tensor): + t = cupy.fromDlpack(torch.utils.dlpack.to_dlpack(src_tensor)) + cupy.copyto(dst_tensor, t) + else: + copied = False + else: + copied = False + if not copied: + raise ValueError("Unsupported tensor type. Got: {} and {}. Supported " + "GPU tensor types are: torch.Tensor, cupy.ndarray." + .format(type(dst_tensor), type(src_tensor))) diff --git a/python/ray/util/collective/const.py b/python/ray/util/collective/const.py index 6eded9c51..ebc48982d 100644 --- a/python/ray/util/collective/const.py +++ b/python/ray/util/collective/const.py @@ -7,8 +7,7 @@ import hashlib def get_nccl_store_name(group_name): - """ - Generate the unique name for the NCCLUniqueID store (named actor). + """Generate the unique name for the NCCLUniqueID store (named actor). Args: group_name (str): unique user name for the store. diff --git a/python/ray/util/collective/tests/distributed_tests/__init__.py b/python/ray/util/collective/tests/distributed_tests/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/python/ray/util/collective/tests/distributed_tests/test_distributed_allgather.py b/python/ray/util/collective/tests/distributed_tests/test_distributed_allgather.py new file mode 100644 index 000000000..5a369c852 --- /dev/null +++ b/python/ray/util/collective/tests/distributed_tests/test_distributed_allgather.py @@ -0,0 +1,133 @@ +"""Test the allgather API on a distributed Ray cluster.""" +import pytest +import ray + +import cupy as cp +import torch + +from ray.util.collective.tests.util import create_collective_workers, \ + init_tensors_for_gather_scatter + + +@pytest.mark.parametrize("tensor_backend", ["cupy", "torch"]) +@pytest.mark.parametrize("array_size", + [2, 2**5, 2**10, 2**15, 2**20, [2, 2], [5, 5, 5]]) +def test_allgather_different_array_size(ray_start_distributed_2_nodes_4_gpus, + array_size, tensor_backend): + world_size = 4 + actors, _ = create_collective_workers(world_size) + init_tensors_for_gather_scatter( + actors, array_size=array_size, tensor_backend=tensor_backend) + results = ray.get([a.do_allgather.remote() for a in actors]) + for i in range(world_size): + for j in range(world_size): + if tensor_backend == "cupy": + assert (results[i][j] == cp.ones(array_size, dtype=cp.float32) + * (j + 1)).all() + else: + assert (results[i][j] == torch.ones( + array_size, dtype=torch.float32).cuda() * (j + 1)).all() + + +@pytest.mark.parametrize("dtype", + [cp.uint8, cp.float16, cp.float32, cp.float64]) +def test_allgather_different_dtype(ray_start_distributed_2_nodes_4_gpus, + dtype): + world_size = 4 + actors, _ = create_collective_workers(world_size) + init_tensors_for_gather_scatter(actors, dtype=dtype) + results = ray.get([a.do_allgather.remote() for a in actors]) + for i in range(world_size): + for j in range(world_size): + assert (results[i][j] == cp.ones(10, dtype=dtype) * (j + 1)).all() + + +@pytest.mark.parametrize("length", [0, 1, 3, 4, 7, 8]) +def test_unmatched_tensor_list_length(ray_start_distributed_2_nodes_4_gpus, + length): + world_size = 4 + actors, _ = create_collective_workers(world_size) + list_buffer = [cp.ones(10, dtype=cp.float32) for _ in range(length)] + ray.wait([a.set_list_buffer.remote(list_buffer) for a in actors]) + if length != world_size: + with pytest.raises(RuntimeError): + ray.get([a.do_allgather.remote() for a in actors]) + else: + ray.get([a.do_allgather.remote() for a in actors]) + + +@pytest.mark.parametrize("shape", [10, 20, [4, 5], [1, 3, 5, 7]]) +def test_unmatched_tensor_shape(ray_start_distributed_2_nodes_4_gpus, shape): + world_size = 4 + actors, _ = create_collective_workers(world_size) + init_tensors_for_gather_scatter(actors, array_size=10) + list_buffer = [cp.ones(shape, dtype=cp.float32) for _ in range(world_size)] + ray.get([a.set_list_buffer.remote(list_buffer) for a in actors]) + if shape != 10: + with pytest.raises(RuntimeError): + ray.get([a.do_allgather.remote() for a in actors]) + else: + ray.get([a.do_allgather.remote() for a in actors]) + + +def test_allgather_torch_cupy(ray_start_distributed_2_nodes_4_gpus): + world_size = 4 + shape = [10, 10] + actors, _ = create_collective_workers(world_size) + + # tensor is pytorch, list is cupy + for i, a in enumerate(actors): + t = torch.ones(shape, dtype=torch.float32).cuda() * (i + 1) + ray.wait([a.set_buffer.remote(t)]) + list_buffer = [ + cp.ones(shape, dtype=cp.float32) for _ in range(world_size) + ] + ray.wait([a.set_list_buffer.remote(list_buffer)]) + results = ray.get([a.do_allgather.remote() for a in actors]) + for i in range(world_size): + for j in range(world_size): + assert (results[i][j] == cp.ones(shape, dtype=cp.float32) * + (j + 1)).all() + + # tensor is cupy, list is pytorch + for i, a in enumerate(actors): + t = cp.ones(shape, dtype=cp.float32) * (i + 1) + ray.wait([a.set_buffer.remote(t)]) + list_buffer = [ + torch.ones(shape, dtype=torch.float32).cuda() + for _ in range(world_size) + ] + ray.wait([a.set_list_buffer.remote(list_buffer)]) + results = ray.get([a.do_allgather.remote() for a in actors]) + for i in range(world_size): + for j in range(world_size): + assert (results[i][j] == torch.ones( + shape, dtype=torch.float32).cuda() * (j + 1)).all() + + # some tensors in the list are pytorch, some are cupy + for i, a in enumerate(actors): + t = cp.ones(shape, dtype=cp.float32) * (i + 1) + ray.wait([a.set_buffer.remote(t)]) + list_buffer = [] + for j in range(world_size): + if j % 2 == 0: + list_buffer.append( + torch.ones(shape, dtype=torch.float32).cuda()) + else: + list_buffer.append(cp.ones(shape, dtype=cp.float32)) + ray.wait([a.set_list_buffer.remote(list_buffer)]) + results = ray.get([a.do_allgather.remote() for a in actors]) + for i in range(world_size): + for j in range(world_size): + if j % 2 == 0: + assert (results[i][j] == torch.ones( + shape, dtype=torch.float32).cuda() * (j + 1)).all() + else: + assert (results[i][j] == cp.ones(shape, dtype=cp.float32) * + (j + 1)).all() + + +if __name__ == "__main__": + import pytest + import sys + sys.exit(pytest.main(["-v", "-x", __file__])) diff --git a/python/ray/util/collective/tests/distributed_tests/test_distributed_allreduce.py b/python/ray/util/collective/tests/distributed_tests/test_distributed_allreduce.py new file mode 100644 index 000000000..35aae35b2 --- /dev/null +++ b/python/ray/util/collective/tests/distributed_tests/test_distributed_allreduce.py @@ -0,0 +1,139 @@ +"""Test the collective allreduice API on a distributed Ray cluster.""" +import pytest +import ray +from ray.util.collective.types import ReduceOp + +import cupy as cp +import torch + +from ray.util.collective.tests.util import create_collective_workers + + +@pytest.mark.parametrize("group_name", ["default", "test", "123?34!"]) +@pytest.mark.parametrize("world_size", [2, 3, 4]) +def test_allreduce_different_name(ray_start_distributed_2_nodes_4_gpus, + group_name, world_size): + actors, _ = create_collective_workers( + num_workers=world_size, group_name=group_name) + results = ray.get([a.do_allreduce.remote(group_name) for a in actors]) + assert (results[0] == cp.ones((10, ), dtype=cp.float32) * world_size).all() + assert (results[1] == cp.ones((10, ), dtype=cp.float32) * world_size).all() + + +@pytest.mark.parametrize("array_size", [2, 2**5, 2**10, 2**15, 2**20]) +def test_allreduce_different_array_size(ray_start_distributed_2_nodes_4_gpus, + array_size): + world_size = 4 + actors, _ = create_collective_workers(world_size) + ray.wait([ + a.set_buffer.remote(cp.ones(array_size, dtype=cp.float32)) + for a in actors + ]) + results = ray.get([a.do_allreduce.remote() for a in actors]) + assert (results[0] == cp.ones( + (array_size, ), dtype=cp.float32) * world_size).all() + assert (results[1] == cp.ones( + (array_size, ), dtype=cp.float32) * world_size).all() + + +def test_allreduce_destroy(ray_start_distributed_2_nodes_4_gpus, + backend="nccl", + group_name="default"): + world_size = 4 + actors, _ = create_collective_workers(world_size) + + results = ray.get([a.do_allreduce.remote() for a in actors]) + assert (results[0] == cp.ones((10, ), dtype=cp.float32) * world_size).all() + assert (results[1] == cp.ones((10, ), dtype=cp.float32) * world_size).all() + + # destroy the group and try do work, should fail + ray.wait([a.destroy_group.remote() for a in actors]) + with pytest.raises(RuntimeError): + results = ray.get([a.do_allreduce.remote() for a in actors]) + + # reinit the same group and all reduce + ray.get([ + actor.init_group.remote(world_size, i, backend, group_name) + for i, actor in enumerate(actors) + ]) + results = ray.get([a.do_allreduce.remote() for a in actors]) + assert (results[0] == cp.ones( + (10, ), dtype=cp.float32) * world_size * world_size).all() + assert (results[1] == cp.ones( + (10, ), dtype=cp.float32) * world_size * world_size).all() + + +def test_allreduce_multiple_group(ray_start_distributed_2_nodes_4_gpus, + backend="nccl", + num_groups=5): + world_size = 4 + actors, _ = create_collective_workers(world_size) + for group_name in range(1, num_groups): + ray.get([ + actor.init_group.remote(world_size, i, backend, str(group_name)) + for i, actor in enumerate(actors) + ]) + for i in range(num_groups): + group_name = "default" if i == 0 else str(i) + results = ray.get([a.do_allreduce.remote(group_name) for a in actors]) + assert (results[0] == cp.ones( + (10, ), dtype=cp.float32) * (world_size**(i + 1))).all() + + +def test_allreduce_different_op(ray_start_distributed_2_nodes_4_gpus): + world_size = 4 + actors, _ = create_collective_workers(world_size) + + # check product + ray.wait([ + a.set_buffer.remote(cp.ones(10, dtype=cp.float32) * (i + 2)) + for i, a in enumerate(actors) + ]) + results = ray.get( + [a.do_allreduce.remote(op=ReduceOp.PRODUCT) for a in actors]) + assert (results[0] == cp.ones((10, ), dtype=cp.float32) * 120).all() + assert (results[1] == cp.ones((10, ), dtype=cp.float32) * 120).all() + + # check min + ray.wait([ + a.set_buffer.remote(cp.ones(10, dtype=cp.float32) * (i + 2)) + for i, a in enumerate(actors) + ]) + results = ray.get([a.do_allreduce.remote(op=ReduceOp.MIN) for a in actors]) + assert (results[0] == cp.ones((10, ), dtype=cp.float32) * 2).all() + assert (results[1] == cp.ones((10, ), dtype=cp.float32) * 2).all() + + # check max + ray.wait([ + a.set_buffer.remote(cp.ones(10, dtype=cp.float32) * (i + 2)) + for i, a in enumerate(actors) + ]) + results = ray.get([a.do_allreduce.remote(op=ReduceOp.MAX) for a in actors]) + assert (results[0] == cp.ones((10, ), dtype=cp.float32) * 5).all() + assert (results[1] == cp.ones((10, ), dtype=cp.float32) * 5).all() + + +@pytest.mark.parametrize("dtype", + [cp.uint8, cp.float16, cp.float32, cp.float64]) +def test_allreduce_different_dtype(ray_start_distributed_2_nodes_4_gpus, + dtype): + world_size = 4 + actors, _ = create_collective_workers(world_size) + ray.wait([a.set_buffer.remote(cp.ones(10, dtype=dtype)) for a in actors]) + results = ray.get([a.do_allreduce.remote() for a in actors]) + assert (results[0] == cp.ones((10, ), dtype=dtype) * world_size).all() + assert (results[1] == cp.ones((10, ), dtype=dtype) * world_size).all() + + +def test_allreduce_torch_cupy(ray_start_distributed_2_nodes_4_gpus): + # import torch + world_size = 4 + actors, _ = create_collective_workers(world_size) + ray.wait([actors[1].set_buffer.remote(torch.ones(10, ).cuda())]) + results = ray.get([a.do_allreduce.remote() for a in actors]) + assert (results[0] == cp.ones((10, )) * world_size).all() + + ray.wait([actors[0].set_buffer.remote(torch.ones(10, ))]) + ray.wait([actors[1].set_buffer.remote(cp.ones(10, ))]) + with pytest.raises(RuntimeError): + results = ray.get([a.do_allreduce.remote() for a in actors]) diff --git a/python/ray/util/collective/tests/distributed_tests/test_distributed_basic_apis.py b/python/ray/util/collective/tests/distributed_tests/test_distributed_basic_apis.py new file mode 100644 index 000000000..0f17b79ba --- /dev/null +++ b/python/ray/util/collective/tests/distributed_tests/test_distributed_basic_apis.py @@ -0,0 +1,135 @@ +"""Test the collective group APIs.""" +import pytest +import ray +from random import shuffle + +from ray.util.collective.tests.util import Worker, \ + create_collective_workers + + +@pytest.mark.parametrize("world_size", [2, 3, 4]) +@pytest.mark.parametrize("group_name", ["default", "test", "123?34!"]) +def test_init_two_actors(ray_start_distributed_2_nodes_4_gpus, world_size, + group_name): + actors, results = create_collective_workers(world_size, group_name) + for i in range(world_size): + assert (results[i]) + + +@pytest.mark.parametrize("world_size", [2, 3, 4]) +def test_init_multiple_groups(ray_start_distributed_2_nodes_4_gpus, + world_size): + num_groups = 1 + actors = [Worker.remote() for _ in range(world_size)] + for i in range(num_groups): + group_name = str(i) + init_results = ray.get([ + actor.init_group.remote(world_size, i, group_name=group_name) + for i, actor in enumerate(actors) + ]) + for j in range(world_size): + assert init_results[j] + + +@pytest.mark.parametrize("world_size", [2, 3, 4]) +def test_get_rank(ray_start_distributed_2_nodes_4_gpus, world_size): + actors, _ = create_collective_workers(world_size) + actor0_rank = ray.get(actors[0].report_rank.remote()) + assert actor0_rank == 0 + actor1_rank = ray.get(actors[1].report_rank.remote()) + assert actor1_rank == 1 + + # create a second group with a different name, and different + # orders of ranks. + new_group_name = "default2" + ranks = list(range(world_size)) + shuffle(ranks) + _ = ray.get([ + actor.init_group.remote( + world_size, ranks[i], group_name=new_group_name) + for i, actor in enumerate(actors) + ]) + actor0_rank = ray.get(actors[0].report_rank.remote(new_group_name)) + assert actor0_rank == ranks[0] + actor1_rank = ray.get(actors[1].report_rank.remote(new_group_name)) + assert actor1_rank == ranks[1] + + +@pytest.mark.parametrize("world_size", [2, 3, 4]) +def test_get_world_size(ray_start_distributed_2_nodes_4_gpus, world_size): + actors, _ = create_collective_workers(world_size) + actor0_world_size = ray.get(actors[0].report_world_size.remote()) + actor1_world_size = ray.get(actors[1].report_world_size.remote()) + assert actor0_world_size == actor1_world_size == world_size + + +def test_availability(ray_start_distributed_2_nodes_4_gpus): + world_size = 4 + actors, _ = create_collective_workers(world_size) + actor0_nccl_availability = ray.get( + actors[0].report_nccl_availability.remote()) + assert actor0_nccl_availability + actor0_mpi_availability = ray.get( + actors[0].report_mpi_availability.remote()) + assert not actor0_mpi_availability + + +def test_is_group_initialized(ray_start_distributed_2_nodes_4_gpus): + world_size = 4 + actors, _ = create_collective_workers(world_size) + # check group is_init + actor0_is_init = ray.get(actors[0].report_is_group_initialized.remote()) + assert actor0_is_init + actor0_is_init = ray.get( + actors[0].report_is_group_initialized.remote("random")) + assert not actor0_is_init + actor0_is_init = ray.get( + actors[0].report_is_group_initialized.remote("123")) + assert not actor0_is_init + actor1_is_init = ray.get(actors[0].report_is_group_initialized.remote()) + assert actor1_is_init + actor1_is_init = ray.get( + actors[0].report_is_group_initialized.remote("456")) + assert not actor1_is_init + + +def test_destroy_group(ray_start_distributed_2_nodes_4_gpus): + world_size = 4 + actors, _ = create_collective_workers(world_size) + # Now destroy the group at actor0 + ray.wait([actors[0].destroy_group.remote()]) + actor0_is_init = ray.get(actors[0].report_is_group_initialized.remote()) + assert not actor0_is_init + + # should go well as the group `random` does not exist at all + ray.wait([actors[0].destroy_group.remote("random")]) + + actor1_is_init = ray.get(actors[1].report_is_group_initialized.remote()) + assert actor1_is_init + ray.wait([actors[1].destroy_group.remote("random")]) + actor1_is_init = ray.get(actors[1].report_is_group_initialized.remote()) + assert actor1_is_init + ray.wait([actors[1].destroy_group.remote("default")]) + actor1_is_init = ray.get(actors[1].report_is_group_initialized.remote()) + assert not actor1_is_init + for i in [2, 3]: + ray.wait([actors[i].destroy_group.remote("default")]) + + # Now reconstruct the group using the same name + init_results = ray.get([ + actor.init_group.remote(world_size, i) + for i, actor in enumerate(actors) + ]) + for i in range(world_size): + assert init_results[i] + actor0_is_init = ray.get(actors[0].report_is_group_initialized.remote()) + assert actor0_is_init + actor1_is_init = ray.get(actors[0].report_is_group_initialized.remote()) + assert actor1_is_init + + +if __name__ == "__main__": + import pytest + import sys + + sys.exit(pytest.main(["-v", "-x", __file__])) diff --git a/python/ray/util/collective/tests/distributed_tests/test_distributed_broadcast.py b/python/ray/util/collective/tests/distributed_tests/test_distributed_broadcast.py new file mode 100644 index 000000000..408ebce76 --- /dev/null +++ b/python/ray/util/collective/tests/distributed_tests/test_distributed_broadcast.py @@ -0,0 +1,67 @@ +"""Test the broadcast API.""" +import pytest +import cupy as cp +import ray + +from ray.util.collective.tests.util import create_collective_workers + + +@pytest.mark.parametrize("group_name", ["default", "test", "123?34!"]) +@pytest.mark.parametrize("src_rank", [0, 1, 2, 3]) +def test_broadcast_different_name(ray_start_distributed_2_nodes_4_gpus, + group_name, src_rank): + world_size = 4 + actors, _ = create_collective_workers( + num_workers=world_size, group_name=group_name) + ray.wait([ + a.set_buffer.remote(cp.ones((10, ), dtype=cp.float32) * (i + 2)) + for i, a in enumerate(actors) + ]) + results = ray.get([ + a.do_broadcast.remote(group_name=group_name, src_rank=src_rank) + for a in actors + ]) + for i in range(world_size): + assert (results[i] == cp.ones( + (10, ), dtype=cp.float32) * (src_rank + 2)).all() + + +@pytest.mark.parametrize("array_size", [2, 2**5, 2**10, 2**15, 2**20]) +@pytest.mark.parametrize("src_rank", [0, 1, 2, 3]) +def test_broadcast_different_array_size(ray_start_distributed_2_nodes_4_gpus, + array_size, src_rank): + world_size = 4 + actors, _ = create_collective_workers(world_size) + ray.wait([ + a.set_buffer.remote(cp.ones(array_size, dtype=cp.float32) * (i + 2)) + for i, a in enumerate(actors) + ]) + results = ray.get( + [a.do_broadcast.remote(src_rank=src_rank) for a in actors]) + for i in range(world_size): + assert (results[i] == cp.ones( + (array_size, ), dtype=cp.float32) * (src_rank + 2)).all() + + +@pytest.mark.parametrize("src_rank", [0, 1]) +def test_broadcast_torch_cupy(ray_start_distributed_2_nodes_4_gpus, src_rank): + import torch + world_size = 4 + actors, _ = create_collective_workers(world_size) + ray.wait( + [actors[1].set_buffer.remote(torch.ones(10, ).cuda() * world_size)]) + results = ray.get( + [a.do_broadcast.remote(src_rank=src_rank) for a in actors]) + if src_rank == 0: + assert (results[0] == cp.ones((10, ))).all() + assert (results[1] == torch.ones((10, )).cuda()).all() + else: + assert (results[0] == cp.ones((10, )) * world_size).all() + assert (results[1] == torch.ones((10, )).cuda() * world_size).all() + + +def test_broadcast_invalid_rank(ray_start_single_node_2_gpus, src_rank=3): + world_size = 2 + actors, _ = create_collective_workers(world_size) + with pytest.raises(ValueError): + _ = ray.get([a.do_broadcast.remote(src_rank=src_rank) for a in actors]) diff --git a/python/ray/util/collective/tests/distributed_tests/test_distributed_reduce.py b/python/ray/util/collective/tests/distributed_tests/test_distributed_reduce.py new file mode 100644 index 000000000..9646f8d12 --- /dev/null +++ b/python/ray/util/collective/tests/distributed_tests/test_distributed_reduce.py @@ -0,0 +1,119 @@ +"""Test the reduce API.""" +import pytest +import cupy as cp +import ray +from ray.util.collective.types import ReduceOp + +from ray.util.collective.tests.util import create_collective_workers + + +@pytest.mark.parametrize("group_name", ["default", "test", "123?34!"]) +@pytest.mark.parametrize("dst_rank", [0, 1, 2, 3]) +def test_reduce_different_name(ray_start_distributed_2_nodes_4_gpus, + group_name, dst_rank): + world_size = 4 + actors, _ = create_collective_workers( + num_workers=world_size, group_name=group_name) + results = ray.get( + [a.do_reduce.remote(group_name, dst_rank) for a in actors]) + for i in range(world_size): + if i == dst_rank: + assert (results[i] == cp.ones( + (10, ), dtype=cp.float32) * world_size).all() + else: + assert (results[i] == cp.ones((10, ), dtype=cp.float32)).all() + + +@pytest.mark.parametrize("array_size", [2, 2**5, 2**10, 2**15, 2**20]) +@pytest.mark.parametrize("dst_rank", [0, 1, 2, 3]) +def test_reduce_different_array_size(ray_start_distributed_2_nodes_4_gpus, + array_size, dst_rank): + world_size = 4 + actors, _ = create_collective_workers(world_size) + ray.wait([ + a.set_buffer.remote(cp.ones(array_size, dtype=cp.float32)) + for a in actors + ]) + results = ray.get([a.do_reduce.remote(dst_rank=dst_rank) for a in actors]) + for i in range(world_size): + if i == dst_rank: + assert (results[i] == cp.ones( + (array_size, ), dtype=cp.float32) * world_size).all() + else: + assert (results[i] == cp.ones((array_size, ), + dtype=cp.float32)).all() + + +@pytest.mark.parametrize("dst_rank", [0, 1, 2, 3]) +def test_reduce_different_op(ray_start_distributed_2_nodes_4_gpus, dst_rank): + world_size = 4 + actors, _ = create_collective_workers(world_size) + + # check product + ray.wait([ + a.set_buffer.remote(cp.ones(10, dtype=cp.float32) * (i + 2)) + for i, a in enumerate(actors) + ]) + results = ray.get([ + a.do_reduce.remote(dst_rank=dst_rank, op=ReduceOp.PRODUCT) + for a in actors + ]) + for i in range(world_size): + if i == dst_rank: + assert (results[i] == cp.ones( + (10, ), dtype=cp.float32) * 120).all() + else: + assert (results[i] == cp.ones( + (10, ), dtype=cp.float32) * (i + 2)).all() + + # check min + ray.wait([ + a.set_buffer.remote(cp.ones(10, dtype=cp.float32) * (i + 2)) + for i, a in enumerate(actors) + ]) + results = ray.get([ + a.do_reduce.remote(dst_rank=dst_rank, op=ReduceOp.MIN) for a in actors + ]) + for i in range(world_size): + if i == dst_rank: + assert (results[i] == cp.ones((10, ), dtype=cp.float32) * 2).all() + else: + assert (results[i] == cp.ones( + (10, ), dtype=cp.float32) * (i + 2)).all() + + # check max + ray.wait([ + a.set_buffer.remote(cp.ones(10, dtype=cp.float32) * (i + 2)) + for i, a in enumerate(actors) + ]) + results = ray.get([ + a.do_reduce.remote(dst_rank=dst_rank, op=ReduceOp.MAX) for a in actors + ]) + for i in range(world_size): + if i == dst_rank: + assert (results[i] == cp.ones((10, ), dtype=cp.float32) * 5).all() + else: + assert (results[i] == cp.ones( + (10, ), dtype=cp.float32) * (i + 2)).all() + + +@pytest.mark.parametrize("dst_rank", [0, 1]) +def test_reduce_torch_cupy(ray_start_distributed_2_nodes_4_gpus, dst_rank): + import torch + world_size = 4 + actors, _ = create_collective_workers(world_size) + ray.wait([actors[1].set_buffer.remote(torch.ones(10, ).cuda())]) + results = ray.get([a.do_reduce.remote(dst_rank=dst_rank) for a in actors]) + if dst_rank == 0: + assert (results[0] == cp.ones((10, )) * world_size).all() + assert (results[1] == torch.ones((10, )).cuda()).all() + else: + assert (results[0] == cp.ones((10, ))).all() + assert (results[1] == torch.ones((10, )).cuda() * world_size).all() + + +def test_reduce_invalid_rank(ray_start_distributed_2_nodes_4_gpus, dst_rank=7): + world_size = 4 + actors, _ = create_collective_workers(world_size) + with pytest.raises(ValueError): + _ = ray.get([a.do_reduce.remote(dst_rank=dst_rank) for a in actors]) diff --git a/python/ray/util/collective/tests/distributed_tests/test_distributed_reducescatter.py b/python/ray/util/collective/tests/distributed_tests/test_distributed_reducescatter.py new file mode 100644 index 000000000..63230402a --- /dev/null +++ b/python/ray/util/collective/tests/distributed_tests/test_distributed_reducescatter.py @@ -0,0 +1,128 @@ +"""Test the collective reducescatter API on a distributed Ray cluster.""" +import pytest +import ray + +import cupy as cp +import torch + +from ray.util.collective.tests.util import create_collective_workers, \ + init_tensors_for_gather_scatter + + +@pytest.mark.parametrize("tensor_backend", ["cupy", "torch"]) +@pytest.mark.parametrize("array_size", + [2, 2**5, 2**10, 2**15, 2**20, [2, 2], [5, 5, 5]]) +def test_reducescatter_different_array_size( + ray_start_distributed_2_nodes_4_gpus, array_size, tensor_backend): + world_size = 4 + actors, _ = create_collective_workers(world_size) + init_tensors_for_gather_scatter( + actors, array_size=array_size, tensor_backend=tensor_backend) + results = ray.get([a.do_reducescatter.remote() for a in actors]) + for i in range(world_size): + if tensor_backend == "cupy": + assert (results[i] == cp.ones(array_size, dtype=cp.float32) * + world_size).all() + else: + assert (results[i] == torch.ones( + array_size, dtype=torch.float32).cuda() * world_size).all() + + +@pytest.mark.parametrize("dtype", + [cp.uint8, cp.float16, cp.float32, cp.float64]) +def test_reducescatter_different_dtype(ray_start_distributed_2_nodes_4_gpus, + dtype): + world_size = 4 + actors, _ = create_collective_workers(world_size) + init_tensors_for_gather_scatter(actors, dtype=dtype) + results = ray.get([a.do_reducescatter.remote() for a in actors]) + for i in range(world_size): + for j in range(world_size): + assert (results[i] == cp.ones(10, dtype=dtype) * world_size).all() + + +def test_reducescatter_torch_cupy(ray_start_distributed_2_nodes_4_gpus): + world_size = 4 + shape = [10, 10] + actors, _ = create_collective_workers(world_size) + + # tensor is pytorch, list is cupy + for i, a in enumerate(actors): + t = torch.ones(shape, dtype=torch.float32).cuda() * (i + 1) + ray.wait([a.set_buffer.remote(t)]) + list_buffer = [ + cp.ones(shape, dtype=cp.float32) for _ in range(world_size) + ] + ray.wait([a.set_list_buffer.remote(list_buffer)]) + results = ray.get([a.do_reducescatter.remote() for a in actors]) + for i in range(world_size): + assert (results[i] == torch.ones(shape, dtype=torch.float32).cuda() * + world_size).all() + + # tensor is cupy, list is pytorch + for i, a in enumerate(actors): + t = cp.ones(shape, dtype=cp.float32) * (i + 1) + ray.wait([a.set_buffer.remote(t)]) + list_buffer = [ + torch.ones(shape, dtype=torch.float32).cuda() + for _ in range(world_size) + ] + ray.wait([a.set_list_buffer.remote(list_buffer)]) + results = ray.get([a.do_reducescatter.remote() for a in actors]) + for i in range(world_size): + assert ( + results[i] == cp.ones(shape, dtype=cp.float32) * world_size).all() + + # some tensors in the list are pytorch, some are cupy + for i, a in enumerate(actors): + if i % 2 == 0: + t = torch.ones(shape, dtype=torch.float32).cuda() * (i + 1) + else: + t = cp.ones(shape, dtype=cp.float32) * (i + 1) + ray.wait([a.set_buffer.remote(t)]) + list_buffer = [] + for j in range(world_size): + if j % 2 == 0: + list_buffer.append( + torch.ones(shape, dtype=torch.float32).cuda()) + else: + list_buffer.append(cp.ones(shape, dtype=cp.float32)) + ray.wait([a.set_list_buffer.remote(list_buffer)]) + results = ray.get([a.do_reducescatter.remote() for a in actors]) + for i in range(world_size): + if i % 2 == 0: + assert (results[i] == torch.ones( + shape, dtype=torch.float32).cuda() * world_size).all() + else: + assert (results[i] == cp.ones(shape, dtype=cp.float32) * + world_size).all() + + # mixed case + for i, a in enumerate(actors): + if i % 2 == 0: + t = torch.ones(shape, dtype=torch.float32).cuda() * (i + 1) + else: + t = cp.ones(shape, dtype=cp.float32) * (i + 1) + ray.wait([a.set_buffer.remote(t)]) + list_buffer = [] + for j in range(world_size): + if j % 2 == 0: + list_buffer.append(cp.ones(shape, dtype=cp.float32)) + else: + list_buffer.append( + torch.ones(shape, dtype=torch.float32).cuda()) + ray.wait([a.set_list_buffer.remote(list_buffer)]) + results = ray.get([a.do_reducescatter.remote() for a in actors]) + for i in range(world_size): + if i % 2 == 0: + assert (results[i] == torch.ones( + shape, dtype=torch.float32).cuda() * world_size).all() + else: + assert (results[i] == cp.ones(shape, dtype=cp.float32) * + world_size).all() + + +if __name__ == "__main__": + import pytest + import sys + sys.exit(pytest.main(["-v", "-x", __file__])) diff --git a/python/ray/util/collective/tests/test_allgather.py b/python/ray/util/collective/tests/test_allgather.py new file mode 100644 index 000000000..33cf9a6d0 --- /dev/null +++ b/python/ray/util/collective/tests/test_allgather.py @@ -0,0 +1,131 @@ +"""Test the collective allgather API.""" +import pytest +import ray + +import cupy as cp +import torch + +from ray.util.collective.tests.util import create_collective_workers, \ + init_tensors_for_gather_scatter + + +@pytest.mark.parametrize("tensor_backend", ["cupy", "torch"]) +@pytest.mark.parametrize("array_size", + [2, 2**5, 2**10, 2**15, 2**20, [2, 2], [5, 5, 5]]) +def test_allgather_different_array_size(ray_start_single_node_2_gpus, + array_size, tensor_backend): + world_size = 2 + actors, _ = create_collective_workers(world_size) + init_tensors_for_gather_scatter( + actors, array_size=array_size, tensor_backend=tensor_backend) + results = ray.get([a.do_allgather.remote() for a in actors]) + for i in range(world_size): + for j in range(world_size): + if tensor_backend == "cupy": + assert (results[i][j] == cp.ones(array_size, dtype=cp.float32) + * (j + 1)).all() + else: + assert (results[i][j] == torch.ones( + array_size, dtype=torch.float32).cuda() * (j + 1)).all() + + +@pytest.mark.parametrize("dtype", + [cp.uint8, cp.float16, cp.float32, cp.float64]) +def test_allgather_different_dtype(ray_start_single_node_2_gpus, dtype): + world_size = 2 + actors, _ = create_collective_workers(world_size) + init_tensors_for_gather_scatter(actors, dtype=dtype) + results = ray.get([a.do_allgather.remote() for a in actors]) + for i in range(world_size): + for j in range(world_size): + assert (results[i][j] == cp.ones(10, dtype=dtype) * (j + 1)).all() + + +@pytest.mark.parametrize("length", [0, 1, 2, 3]) +def test_unmatched_tensor_list_length(ray_start_single_node_2_gpus, length): + world_size = 2 + actors, _ = create_collective_workers(world_size) + list_buffer = [cp.ones(10, dtype=cp.float32) for _ in range(length)] + ray.wait([a.set_list_buffer.remote(list_buffer) for a in actors]) + if length != world_size: + with pytest.raises(RuntimeError): + ray.get([a.do_allgather.remote() for a in actors]) + else: + ray.get([a.do_allgather.remote() for a in actors]) + + +@pytest.mark.parametrize("shape", [10, 20, [4, 5], [1, 3, 5, 7]]) +def test_unmatched_tensor_shape(ray_start_single_node_2_gpus, shape): + world_size = 2 + actors, _ = create_collective_workers(world_size) + init_tensors_for_gather_scatter(actors, array_size=10) + list_buffer = [cp.ones(shape, dtype=cp.float32) for _ in range(world_size)] + ray.get([a.set_list_buffer.remote(list_buffer) for a in actors]) + if shape != 10: + with pytest.raises(RuntimeError): + ray.get([a.do_allgather.remote() for a in actors]) + else: + ray.get([a.do_allgather.remote() for a in actors]) + + +def test_allgather_torch_cupy(ray_start_single_node_2_gpus): + world_size = 2 + shape = [10, 10] + actors, _ = create_collective_workers(world_size) + + # tensor is pytorch, list is cupy + for i, a in enumerate(actors): + t = torch.ones(shape, dtype=torch.float32).cuda() * (i + 1) + ray.wait([a.set_buffer.remote(t)]) + list_buffer = [ + cp.ones(shape, dtype=cp.float32) for _ in range(world_size) + ] + ray.wait([a.set_list_buffer.remote(list_buffer)]) + results = ray.get([a.do_allgather.remote() for a in actors]) + for i in range(world_size): + for j in range(world_size): + assert (results[i][j] == cp.ones(shape, dtype=cp.float32) * + (j + 1)).all() + + # tensor is cupy, list is pytorch + for i, a in enumerate(actors): + t = cp.ones(shape, dtype=cp.float32) * (i + 1) + ray.wait([a.set_buffer.remote(t)]) + list_buffer = [ + torch.ones(shape, dtype=torch.float32).cuda() + for _ in range(world_size) + ] + ray.wait([a.set_list_buffer.remote(list_buffer)]) + results = ray.get([a.do_allgather.remote() for a in actors]) + for i in range(world_size): + for j in range(world_size): + assert (results[i][j] == torch.ones( + shape, dtype=torch.float32).cuda() * (j + 1)).all() + + # some tensors in the list are pytorch, some are cupy + for i, a in enumerate(actors): + t = cp.ones(shape, dtype=cp.float32) * (i + 1) + ray.wait([a.set_buffer.remote(t)]) + list_buffer = [] + for j in range(world_size): + if j % 2 == 0: + list_buffer.append( + torch.ones(shape, dtype=torch.float32).cuda()) + else: + list_buffer.append(cp.ones(shape, dtype=cp.float32)) + ray.wait([a.set_list_buffer.remote(list_buffer)]) + results = ray.get([a.do_allgather.remote() for a in actors]) + for i in range(world_size): + for j in range(world_size): + if j % 2 == 0: + assert (results[i][j] == torch.ones( + shape, dtype=torch.float32).cuda() * (j + 1)).all() + else: + assert (results[i][j] == cp.ones(shape, dtype=cp.float32) * + (j + 1)).all() + + +if __name__ == "__main__": + import pytest + import sys + sys.exit(pytest.main(["-v", "-x", __file__])) diff --git a/python/ray/util/collective/tests/test_allreduce.py b/python/ray/util/collective/tests/test_allreduce.py new file mode 100644 index 000000000..1fbdf526b --- /dev/null +++ b/python/ray/util/collective/tests/test_allreduce.py @@ -0,0 +1,143 @@ +"""Test the collective allreduice API.""" +import pytest +import ray +from ray.util.collective.types import ReduceOp + +import cupy as cp +import torch + +from ray.util.collective.tests.util import create_collective_workers + + +@pytest.mark.parametrize("group_name", ["default", "test", "123?34!"]) +def test_allreduce_different_name(ray_start_single_node_2_gpus, group_name): + world_size = 2 + actors, _ = create_collective_workers( + num_workers=world_size, group_name=group_name) + results = ray.get([a.do_allreduce.remote(group_name) for a in actors]) + assert (results[0] == cp.ones((10, ), dtype=cp.float32) * world_size).all() + assert (results[1] == cp.ones((10, ), dtype=cp.float32) * world_size).all() + + +@pytest.mark.parametrize("array_size", [2, 2**5, 2**10, 2**15, 2**20]) +def test_allreduce_different_array_size(ray_start_single_node_2_gpus, + array_size): + world_size = 2 + actors, _ = create_collective_workers(world_size) + ray.wait([ + a.set_buffer.remote(cp.ones(array_size, dtype=cp.float32)) + for a in actors + ]) + results = ray.get([a.do_allreduce.remote() for a in actors]) + assert (results[0] == cp.ones( + (array_size, ), dtype=cp.float32) * world_size).all() + assert (results[1] == cp.ones( + (array_size, ), dtype=cp.float32) * world_size).all() + + +def test_allreduce_destroy(ray_start_single_node_2_gpus, + backend="nccl", + group_name="default"): + world_size = 2 + actors, _ = create_collective_workers(world_size) + + results = ray.get([a.do_allreduce.remote() for a in actors]) + assert (results[0] == cp.ones((10, ), dtype=cp.float32) * world_size).all() + assert (results[1] == cp.ones((10, ), dtype=cp.float32) * world_size).all() + + # destroy the group and try do work, should fail + ray.wait([a.destroy_group.remote() for a in actors]) + with pytest.raises(RuntimeError): + results = ray.get([a.do_allreduce.remote() for a in actors]) + + # reinit the same group and all reduce + ray.get([ + actor.init_group.remote(world_size, i, backend, group_name) + for i, actor in enumerate(actors) + ]) + results = ray.get([a.do_allreduce.remote() for a in actors]) + assert (results[0] == cp.ones( + (10, ), dtype=cp.float32) * world_size * 2).all() + assert (results[1] == cp.ones( + (10, ), dtype=cp.float32) * world_size * 2).all() + + +def test_allreduce_multiple_group(ray_start_single_node_2_gpus, + backend="nccl", + num_groups=5): + world_size = 2 + actors, _ = create_collective_workers(world_size) + for group_name in range(1, num_groups): + ray.get([ + actor.init_group.remote(world_size, i, backend, str(group_name)) + for i, actor in enumerate(actors) + ]) + for i in range(num_groups): + group_name = "default" if i == 0 else str(i) + results = ray.get([a.do_allreduce.remote(group_name) for a in actors]) + assert (results[0] == cp.ones( + (10, ), dtype=cp.float32) * (world_size**(i + 1))).all() + + +def test_allreduce_different_op(ray_start_single_node_2_gpus): + world_size = 2 + actors, _ = create_collective_workers(world_size) + + # check product + ray.wait([ + a.set_buffer.remote(cp.ones(10, dtype=cp.float32) * (i + 2)) + for i, a in enumerate(actors) + ]) + results = ray.get( + [a.do_allreduce.remote(op=ReduceOp.PRODUCT) for a in actors]) + assert (results[0] == cp.ones((10, ), dtype=cp.float32) * 6).all() + assert (results[1] == cp.ones((10, ), dtype=cp.float32) * 6).all() + + # check min + ray.wait([ + a.set_buffer.remote(cp.ones(10, dtype=cp.float32) * (i + 2)) + for i, a in enumerate(actors) + ]) + results = ray.get([a.do_allreduce.remote(op=ReduceOp.MIN) for a in actors]) + assert (results[0] == cp.ones((10, ), dtype=cp.float32) * 2).all() + assert (results[1] == cp.ones((10, ), dtype=cp.float32) * 2).all() + + # check max + ray.wait([ + a.set_buffer.remote(cp.ones(10, dtype=cp.float32) * (i + 2)) + for i, a in enumerate(actors) + ]) + results = ray.get([a.do_allreduce.remote(op=ReduceOp.MAX) for a in actors]) + assert (results[0] == cp.ones((10, ), dtype=cp.float32) * 3).all() + assert (results[1] == cp.ones((10, ), dtype=cp.float32) * 3).all() + + +@pytest.mark.parametrize("dtype", + [cp.uint8, cp.float16, cp.float32, cp.float64]) +def test_allreduce_different_dtype(ray_start_single_node_2_gpus, dtype): + world_size = 2 + actors, _ = create_collective_workers(world_size) + ray.wait([a.set_buffer.remote(cp.ones(10, dtype=dtype)) for a in actors]) + results = ray.get([a.do_allreduce.remote() for a in actors]) + assert (results[0] == cp.ones((10, ), dtype=dtype) * world_size).all() + assert (results[1] == cp.ones((10, ), dtype=dtype) * world_size).all() + + +def test_allreduce_torch_cupy(ray_start_single_node_2_gpus): + # import torch + world_size = 2 + actors, _ = create_collective_workers(world_size) + ray.wait([actors[1].set_buffer.remote(torch.ones(10, ).cuda())]) + results = ray.get([a.do_allreduce.remote() for a in actors]) + assert (results[0] == cp.ones((10, )) * world_size).all() + + ray.wait([actors[0].set_buffer.remote(torch.ones(10, ))]) + ray.wait([actors[1].set_buffer.remote(cp.ones(10, ))]) + with pytest.raises(RuntimeError): + results = ray.get([a.do_allreduce.remote() for a in actors]) + + +if __name__ == "__main__": + import pytest + import sys + sys.exit(pytest.main(["-v", "-x", __file__])) diff --git a/python/ray/util/collective/tests/test_basic_apis.py b/python/ray/util/collective/tests/test_basic_apis.py new file mode 100644 index 000000000..8c23442a3 --- /dev/null +++ b/python/ray/util/collective/tests/test_basic_apis.py @@ -0,0 +1,127 @@ +"""Test the collective group APIs.""" +import pytest +import ray + +from ray.util.collective.tests.util import Worker, \ + create_collective_workers + + +@pytest.mark.parametrize("group_name", ["default", "test", "123?34!"]) +def test_init_two_actors(ray_start_single_node_2_gpus, group_name): + world_size = 2 + actors, results = create_collective_workers(world_size, group_name) + for i in range(world_size): + assert (results[i]) + + +def test_init_multiple_groups(ray_start_single_node_2_gpus): + world_size = 2 + num_groups = 10 + actors = [Worker.remote() for i in range(world_size)] + for i in range(num_groups): + group_name = str(i) + init_results = ray.get([ + actor.init_group.remote(world_size, i, group_name=group_name) + for i, actor in enumerate(actors) + ]) + for j in range(world_size): + assert init_results[j] + + +def test_get_rank(ray_start_single_node_2_gpus): + world_size = 2 + actors, _ = create_collective_workers(world_size) + actor0_rank = ray.get(actors[0].report_rank.remote()) + assert actor0_rank == 0 + actor1_rank = ray.get(actors[1].report_rank.remote()) + assert actor1_rank == 1 + + # create a second group with a different name, + # and different order of ranks. + new_group_name = "default2" + _ = ray.get([ + actor.init_group.remote( + world_size, world_size - 1 - i, group_name=new_group_name) + for i, actor in enumerate(actors) + ]) + actor0_rank = ray.get(actors[0].report_rank.remote(new_group_name)) + assert actor0_rank == 1 + actor1_rank = ray.get(actors[1].report_rank.remote(new_group_name)) + assert actor1_rank == 0 + + +def test_get_world_size(ray_start_single_node_2_gpus): + world_size = 2 + actors, _ = create_collective_workers(world_size) + actor0_world_size = ray.get(actors[0].report_world_size.remote()) + actor1_world_size = ray.get(actors[1].report_world_size.remote()) + assert actor0_world_size == actor1_world_size == world_size + + +def test_availability(ray_start_single_node_2_gpus): + world_size = 2 + actors, _ = create_collective_workers(world_size) + actor0_nccl_availability = ray.get( + actors[0].report_nccl_availability.remote()) + assert actor0_nccl_availability + actor0_mpi_availability = ray.get( + actors[0].report_mpi_availability.remote()) + assert not actor0_mpi_availability + + +def test_is_group_initialized(ray_start_single_node_2_gpus): + world_size = 2 + actors, _ = create_collective_workers(world_size) + # check group is_init + actor0_is_init = ray.get(actors[0].report_is_group_initialized.remote()) + assert actor0_is_init + actor0_is_init = ray.get( + actors[0].report_is_group_initialized.remote("random")) + assert not actor0_is_init + actor0_is_init = ray.get( + actors[0].report_is_group_initialized.remote("123")) + assert not actor0_is_init + actor1_is_init = ray.get(actors[0].report_is_group_initialized.remote()) + assert actor1_is_init + actor1_is_init = ray.get( + actors[0].report_is_group_initialized.remote("456")) + assert not actor1_is_init + + +def test_destroy_group(ray_start_single_node_2_gpus): + world_size = 2 + actors, _ = create_collective_workers(world_size) + # Now destroy the group at actor0 + ray.wait([actors[0].destroy_group.remote()]) + actor0_is_init = ray.get(actors[0].report_is_group_initialized.remote()) + assert not actor0_is_init + + # should go well as the group `random` does not exist at all + ray.wait([actors[0].destroy_group.remote("random")]) + + actor1_is_init = ray.get(actors[1].report_is_group_initialized.remote()) + assert actor1_is_init + ray.wait([actors[1].destroy_group.remote("random")]) + actor1_is_init = ray.get(actors[1].report_is_group_initialized.remote()) + assert actor1_is_init + ray.wait([actors[1].destroy_group.remote("default")]) + actor1_is_init = ray.get(actors[1].report_is_group_initialized.remote()) + assert not actor1_is_init + + # Now reconstruct the group using the same name + init_results = ray.get([ + actor.init_group.remote(world_size, i) + for i, actor in enumerate(actors) + ]) + for i in range(world_size): + assert init_results[i] + actor0_is_init = ray.get(actors[0].report_is_group_initialized.remote()) + assert actor0_is_init + actor1_is_init = ray.get(actors[0].report_is_group_initialized.remote()) + assert actor1_is_init + + +if __name__ == "__main__": + import pytest + import sys + sys.exit(pytest.main(["-v", "-x", __file__])) diff --git a/python/ray/util/collective/tests/test_broadcast.py b/python/ray/util/collective/tests/test_broadcast.py new file mode 100644 index 000000000..3d62b6d2e --- /dev/null +++ b/python/ray/util/collective/tests/test_broadcast.py @@ -0,0 +1,67 @@ +"""Test the broadcast API.""" +import pytest +import cupy as cp +import ray + +from ray.util.collective.tests.util import create_collective_workers + + +@pytest.mark.parametrize("group_name", ["default", "test", "123?34!"]) +@pytest.mark.parametrize("src_rank", [0, 1]) +def test_broadcast_different_name(ray_start_single_node_2_gpus, group_name, + src_rank): + world_size = 2 + actors, _ = create_collective_workers( + num_workers=world_size, group_name=group_name) + ray.wait([ + a.set_buffer.remote(cp.ones((10, ), dtype=cp.float32) * (i + 2)) + for i, a in enumerate(actors) + ]) + results = ray.get([ + a.do_broadcast.remote(group_name=group_name, src_rank=src_rank) + for a in actors + ]) + for i in range(world_size): + assert (results[i] == cp.ones( + (10, ), dtype=cp.float32) * (src_rank + 2)).all() + + +@pytest.mark.parametrize("array_size", [2, 2**5, 2**10, 2**15, 2**20]) +@pytest.mark.parametrize("src_rank", [0, 1]) +def test_broadcast_different_array_size(ray_start_single_node_2_gpus, + array_size, src_rank): + world_size = 2 + actors, _ = create_collective_workers(world_size) + ray.wait([ + a.set_buffer.remote(cp.ones(array_size, dtype=cp.float32) * (i + 2)) + for i, a in enumerate(actors) + ]) + results = ray.get( + [a.do_broadcast.remote(src_rank=src_rank) for a in actors]) + for i in range(world_size): + assert (results[i] == cp.ones( + (array_size, ), dtype=cp.float32) * (src_rank + 2)).all() + + +@pytest.mark.parametrize("src_rank", [0, 1]) +def test_broadcast_torch_cupy(ray_start_single_node_2_gpus, src_rank): + import torch + world_size = 2 + actors, _ = create_collective_workers(world_size) + ray.wait( + [actors[1].set_buffer.remote(torch.ones(10, ).cuda() * world_size)]) + results = ray.get( + [a.do_broadcast.remote(src_rank=src_rank) for a in actors]) + if src_rank == 0: + assert (results[0] == cp.ones((10, ))).all() + assert (results[1] == torch.ones((10, )).cuda()).all() + else: + assert (results[0] == cp.ones((10, )) * world_size).all() + assert (results[1] == torch.ones((10, )).cuda() * world_size).all() + + +def test_broadcast_invalid_rank(ray_start_single_node_2_gpus, src_rank=3): + world_size = 2 + actors, _ = create_collective_workers(world_size) + with pytest.raises(ValueError): + _ = ray.get([a.do_broadcast.remote(src_rank=src_rank) for a in actors]) diff --git a/python/ray/util/collective/tests/test_collective_2_nodes_4_gpus.py b/python/ray/util/collective/tests/test_collective_2_nodes_4_gpus.py deleted file mode 100644 index c35e48b9a..000000000 --- a/python/ray/util/collective/tests/test_collective_2_nodes_4_gpus.py +++ /dev/null @@ -1,276 +0,0 @@ -"""Test the collective group APIs.""" -from random import shuffle -import pytest -import ray -from ray.util.collective.types import ReduceOp - -import cupy as cp -import torch - -from .util import Worker - - -def get_actors_group(num_workers=2, group_name="default", backend="nccl"): - actors = [Worker.remote() for i in range(num_workers)] - world_size = num_workers - init_results = ray.get([ - actor.init_group.remote(world_size, i, backend, group_name) - for i, actor in enumerate(actors) - ]) - return actors, init_results - - -@pytest.mark.parametrize("world_size", [2, 3, 4]) -@pytest.mark.parametrize("group_name", ["default", "test", "123?34!"]) -def test_init_two_actors(ray_start_distributed_2_nodes_4_gpus, world_size, - group_name): - actors, results = get_actors_group(world_size, group_name) - for i in range(world_size): - assert (results[i]) - - -@pytest.mark.parametrize("world_size", [2, 3, 4]) -def test_init_multiple_groups(ray_start_distributed_2_nodes_4_gpus, - world_size): - num_groups = 1 - actors = [Worker.remote() for _ in range(world_size)] - for i in range(num_groups): - group_name = str(i) - init_results = ray.get([ - actor.init_group.remote(world_size, i, group_name=group_name) - for i, actor in enumerate(actors) - ]) - for j in range(world_size): - assert init_results[j] - - -@pytest.mark.parametrize("world_size", [2, 3, 4]) -def test_get_rank(ray_start_distributed_2_nodes_4_gpus, world_size): - actors, _ = get_actors_group(world_size) - actor0_rank = ray.get(actors[0].report_rank.remote()) - assert actor0_rank == 0 - actor1_rank = ray.get(actors[1].report_rank.remote()) - assert actor1_rank == 1 - - # create a second group with a different name, and different - # orders of ranks. - new_group_name = "default2" - ranks = list(range(world_size)) - shuffle(ranks) - _ = ray.get([ - actor.init_group.remote( - world_size, ranks[i], group_name=new_group_name) - for i, actor in enumerate(actors) - ]) - actor0_rank = ray.get(actors[0].report_rank.remote(new_group_name)) - assert actor0_rank == ranks[0] - actor1_rank = ray.get(actors[1].report_rank.remote(new_group_name)) - assert actor1_rank == ranks[1] - - -@pytest.mark.parametrize("world_size", [2, 3, 4]) -def test_get_world_size(ray_start_distributed_2_nodes_4_gpus, world_size): - actors, _ = get_actors_group(world_size) - actor0_world_size = ray.get(actors[0].report_world_size.remote()) - actor1_world_size = ray.get(actors[1].report_world_size.remote()) - assert actor0_world_size == actor1_world_size == world_size - - -def test_availability(ray_start_distributed_2_nodes_4_gpus): - world_size = 4 - actors, _ = get_actors_group(world_size) - actor0_nccl_availability = ray.get( - actors[0].report_nccl_availability.remote()) - assert actor0_nccl_availability - actor0_mpi_availability = ray.get( - actors[0].report_mpi_availability.remote()) - assert not actor0_mpi_availability - - -def test_is_group_initialized(ray_start_distributed_2_nodes_4_gpus): - world_size = 4 - actors, _ = get_actors_group(world_size) - # check group is_init - actor0_is_init = ray.get(actors[0].report_is_group_initialized.remote()) - assert actor0_is_init - actor0_is_init = ray.get( - actors[0].report_is_group_initialized.remote("random")) - assert not actor0_is_init - actor0_is_init = ray.get( - actors[0].report_is_group_initialized.remote("123")) - assert not actor0_is_init - actor1_is_init = ray.get(actors[0].report_is_group_initialized.remote()) - assert actor1_is_init - actor1_is_init = ray.get( - actors[0].report_is_group_initialized.remote("456")) - assert not actor1_is_init - - -def test_destroy_group(ray_start_distributed_2_nodes_4_gpus): - world_size = 4 - actors, _ = get_actors_group(world_size) - # Now destroy the group at actor0 - ray.wait([actors[0].destroy_group.remote()]) - actor0_is_init = ray.get(actors[0].report_is_group_initialized.remote()) - assert not actor0_is_init - - # should go well as the group `random` does not exist at all - ray.wait([actors[0].destroy_group.remote("random")]) - - actor1_is_init = ray.get(actors[1].report_is_group_initialized.remote()) - assert actor1_is_init - ray.wait([actors[1].destroy_group.remote("random")]) - actor1_is_init = ray.get(actors[1].report_is_group_initialized.remote()) - assert actor1_is_init - ray.wait([actors[1].destroy_group.remote("default")]) - actor1_is_init = ray.get(actors[1].report_is_group_initialized.remote()) - assert not actor1_is_init - for i in [2, 3]: - ray.wait([actors[i].destroy_group.remote("default")]) - - # Now reconstruct the group using the same name - init_results = ray.get([ - actor.init_group.remote(world_size, i) - for i, actor in enumerate(actors) - ]) - for i in range(world_size): - assert init_results[i] - actor0_is_init = ray.get(actors[0].report_is_group_initialized.remote()) - assert actor0_is_init - actor1_is_init = ray.get(actors[0].report_is_group_initialized.remote()) - assert actor1_is_init - - -@pytest.mark.parametrize("group_name", ["default", "test", "123?34!"]) -@pytest.mark.parametrize("world_size", [2, 3, 4]) -def test_allreduce_different_name(ray_start_distributed_2_nodes_4_gpus, - group_name, world_size): - actors, _ = get_actors_group(num_workers=world_size, group_name=group_name) - results = ray.get([a.do_work.remote(group_name) for a in actors]) - assert (results[0] == cp.ones((10, ), dtype=cp.float32) * world_size).all() - assert (results[1] == cp.ones((10, ), dtype=cp.float32) * world_size).all() - - -@pytest.mark.parametrize("array_size", [2, 2**5, 2**10, 2**15, 2**20]) -def test_allreduce_different_array_size(ray_start_distributed_2_nodes_4_gpus, - array_size): - world_size = 4 - actors, _ = get_actors_group(world_size) - ray.wait([ - a.set_buffer.remote(cp.ones(array_size, dtype=cp.float32)) - for a in actors - ]) - results = ray.get([a.do_work.remote() for a in actors]) - assert (results[0] == cp.ones( - (array_size, ), dtype=cp.float32) * world_size).all() - assert (results[1] == cp.ones( - (array_size, ), dtype=cp.float32) * world_size).all() - - -def test_allreduce_destroy(ray_start_distributed_2_nodes_4_gpus, - backend="nccl", - group_name="default"): - world_size = 4 - actors, _ = get_actors_group(world_size) - - results = ray.get([a.do_work.remote() for a in actors]) - assert (results[0] == cp.ones((10, ), dtype=cp.float32) * world_size).all() - assert (results[1] == cp.ones((10, ), dtype=cp.float32) * world_size).all() - - # destroy the group and try do work, should fail - ray.wait([a.destroy_group.remote() for a in actors]) - with pytest.raises(RuntimeError): - results = ray.get([a.do_work.remote() for a in actors]) - - # reinit the same group and all reduce - ray.get([ - actor.init_group.remote(world_size, i, backend, group_name) - for i, actor in enumerate(actors) - ]) - results = ray.get([a.do_work.remote() for a in actors]) - assert (results[0] == cp.ones( - (10, ), dtype=cp.float32) * world_size * world_size).all() - assert (results[1] == cp.ones( - (10, ), dtype=cp.float32) * world_size * world_size).all() - - -def test_allreduce_multiple_group(ray_start_distributed_2_nodes_4_gpus, - backend="nccl", - num_groups=5): - world_size = 4 - actors, _ = get_actors_group(world_size) - for group_name in range(1, num_groups): - ray.get([ - actor.init_group.remote(world_size, i, backend, str(group_name)) - for i, actor in enumerate(actors) - ]) - for i in range(num_groups): - group_name = "default" if i == 0 else str(i) - results = ray.get([a.do_work.remote(group_name) for a in actors]) - assert (results[0] == cp.ones( - (10, ), dtype=cp.float32) * (world_size**(i + 1))).all() - - -def test_allreduce_different_op(ray_start_distributed_2_nodes_4_gpus): - world_size = 4 - actors, _ = get_actors_group(world_size) - - # check product - ray.wait([ - a.set_buffer.remote(cp.ones(10, dtype=cp.float32) * (i + 2)) - for i, a in enumerate(actors) - ]) - results = ray.get([a.do_work.remote(op=ReduceOp.PRODUCT) for a in actors]) - assert (results[0] == cp.ones((10, ), dtype=cp.float32) * 120).all() - assert (results[1] == cp.ones((10, ), dtype=cp.float32) * 120).all() - - # check min - ray.wait([ - a.set_buffer.remote(cp.ones(10, dtype=cp.float32) * (i + 2)) - for i, a in enumerate(actors) - ]) - results = ray.get([a.do_work.remote(op=ReduceOp.MIN) for a in actors]) - assert (results[0] == cp.ones((10, ), dtype=cp.float32) * 2).all() - assert (results[1] == cp.ones((10, ), dtype=cp.float32) * 2).all() - - # check max - ray.wait([ - a.set_buffer.remote(cp.ones(10, dtype=cp.float32) * (i + 2)) - for i, a in enumerate(actors) - ]) - results = ray.get([a.do_work.remote(op=ReduceOp.MAX) for a in actors]) - assert (results[0] == cp.ones((10, ), dtype=cp.float32) * 5).all() - assert (results[1] == cp.ones((10, ), dtype=cp.float32) * 5).all() - - -@pytest.mark.parametrize("dtype", - [cp.uint8, cp.float16, cp.float32, cp.float64]) -def test_allreduce_different_dtype(ray_start_distributed_2_nodes_4_gpus, - dtype): - world_size = 4 - actors, _ = get_actors_group(world_size) - ray.wait([a.set_buffer.remote(cp.ones(10, dtype=dtype)) for a in actors]) - results = ray.get([a.do_work.remote() for a in actors]) - assert (results[0] == cp.ones((10, ), dtype=dtype) * world_size).all() - assert (results[1] == cp.ones((10, ), dtype=dtype) * world_size).all() - - -def test_allreduce_torch_cupy(ray_start_distributed_2_nodes_4_gpus): - # import torch - world_size = 4 - actors, _ = get_actors_group(world_size) - ray.wait([actors[1].set_buffer.remote(torch.ones(10, ).cuda())]) - results = ray.get([a.do_work.remote() for a in actors]) - assert (results[0] == cp.ones((10, )) * world_size).all() - - ray.wait([actors[0].set_buffer.remote(torch.ones(10, ))]) - ray.wait([actors[1].set_buffer.remote(cp.ones(10, ))]) - with pytest.raises(RuntimeError): - results = ray.get([a.do_work.remote() for a in actors]) - - -if __name__ == "__main__": - import pytest - import sys - - sys.exit(pytest.main(["-v", "-x", __file__])) diff --git a/python/ray/util/collective/tests/test_collective_single_node_2_gpus.py b/python/ray/util/collective/tests/test_collective_single_node_2_gpus.py deleted file mode 100644 index 267375e29..000000000 --- a/python/ray/util/collective/tests/test_collective_single_node_2_gpus.py +++ /dev/null @@ -1,267 +0,0 @@ -"""Test the collective group APIs.""" -import pytest -import ray -from ray.util.collective.types import ReduceOp - -import cupy as cp -import torch - -from .util import Worker - - -def get_actors_group(num_workers=2, group_name="default", backend="nccl"): - actors = [Worker.remote() for _ in range(num_workers)] - world_size = num_workers - init_results = ray.get([ - actor.init_group.remote(world_size, i, backend, group_name) - for i, actor in enumerate(actors) - ]) - return actors, init_results - - -@pytest.mark.parametrize("group_name", ["default", "test", "123?34!"]) -def test_init_two_actors(ray_start_single_node_2_gpus, group_name): - world_size = 2 - actors, results = get_actors_group(world_size, group_name) - for i in range(world_size): - assert (results[i]) - - -def test_init_multiple_groups(ray_start_single_node_2_gpus): - world_size = 2 - num_groups = 10 - actors = [Worker.remote() for i in range(world_size)] - for i in range(num_groups): - group_name = str(i) - init_results = ray.get([ - actor.init_group.remote(world_size, i, group_name=group_name) - for i, actor in enumerate(actors) - ]) - for j in range(world_size): - assert init_results[j] - - -def test_get_rank(ray_start_single_node_2_gpus): - world_size = 2 - actors, _ = get_actors_group(world_size) - actor0_rank = ray.get(actors[0].report_rank.remote()) - assert actor0_rank == 0 - actor1_rank = ray.get(actors[1].report_rank.remote()) - assert actor1_rank == 1 - - # create a second group with a different name, - # and different order of ranks. - new_group_name = "default2" - _ = ray.get([ - actor.init_group.remote( - world_size, world_size - 1 - i, group_name=new_group_name) - for i, actor in enumerate(actors) - ]) - actor0_rank = ray.get(actors[0].report_rank.remote(new_group_name)) - assert actor0_rank == 1 - actor1_rank = ray.get(actors[1].report_rank.remote(new_group_name)) - assert actor1_rank == 0 - - -def test_get_world_size(ray_start_single_node_2_gpus): - world_size = 2 - actors, _ = get_actors_group(world_size) - actor0_world_size = ray.get(actors[0].report_world_size.remote()) - actor1_world_size = ray.get(actors[1].report_world_size.remote()) - assert actor0_world_size == actor1_world_size == world_size - - -def test_availability(ray_start_single_node_2_gpus): - world_size = 2 - actors, _ = get_actors_group(world_size) - actor0_nccl_availability = ray.get( - actors[0].report_nccl_availability.remote()) - assert actor0_nccl_availability - actor0_mpi_availability = ray.get( - actors[0].report_mpi_availability.remote()) - assert not actor0_mpi_availability - - -def test_is_group_initialized(ray_start_single_node_2_gpus): - world_size = 2 - actors, _ = get_actors_group(world_size) - # check group is_init - actor0_is_init = ray.get(actors[0].report_is_group_initialized.remote()) - assert actor0_is_init - actor0_is_init = ray.get( - actors[0].report_is_group_initialized.remote("random")) - assert not actor0_is_init - actor0_is_init = ray.get( - actors[0].report_is_group_initialized.remote("123")) - assert not actor0_is_init - actor1_is_init = ray.get(actors[0].report_is_group_initialized.remote()) - assert actor1_is_init - actor1_is_init = ray.get( - actors[0].report_is_group_initialized.remote("456")) - assert not actor1_is_init - - -def test_destroy_group(ray_start_single_node_2_gpus): - world_size = 2 - actors, _ = get_actors_group(world_size) - # Now destroy the group at actor0 - ray.wait([actors[0].destroy_group.remote()]) - actor0_is_init = ray.get(actors[0].report_is_group_initialized.remote()) - assert not actor0_is_init - - # should go well as the group `random` does not exist at all - ray.wait([actors[0].destroy_group.remote("random")]) - - actor1_is_init = ray.get(actors[1].report_is_group_initialized.remote()) - assert actor1_is_init - ray.wait([actors[1].destroy_group.remote("random")]) - actor1_is_init = ray.get(actors[1].report_is_group_initialized.remote()) - assert actor1_is_init - ray.wait([actors[1].destroy_group.remote("default")]) - actor1_is_init = ray.get(actors[1].report_is_group_initialized.remote()) - assert not actor1_is_init - - # Now reconstruct the group using the same name - init_results = ray.get([ - actor.init_group.remote(world_size, i) - for i, actor in enumerate(actors) - ]) - for i in range(world_size): - assert init_results[i] - actor0_is_init = ray.get(actors[0].report_is_group_initialized.remote()) - assert actor0_is_init - actor1_is_init = ray.get(actors[0].report_is_group_initialized.remote()) - assert actor1_is_init - - -@pytest.mark.parametrize("group_name", ["default", "test", "123?34!"]) -# @pytest.mark.parametrize("group_name", ['123?34!']) -def test_allreduce_different_name(ray_start_single_node_2_gpus, group_name): - world_size = 2 - actors, _ = get_actors_group(num_workers=world_size, group_name=group_name) - results = ray.get([a.do_work.remote(group_name) for a in actors]) - assert (results[0] == cp.ones((10, ), dtype=cp.float32) * world_size).all() - assert (results[1] == cp.ones((10, ), dtype=cp.float32) * world_size).all() - - -@pytest.mark.parametrize("array_size", [2, 2**5, 2**10, 2**15, 2**20]) -def test_allreduce_different_array_size(ray_start_single_node_2_gpus, - array_size): - world_size = 2 - actors, _ = get_actors_group(world_size) - ray.wait([ - a.set_buffer.remote(cp.ones(array_size, dtype=cp.float32)) - for a in actors - ]) - results = ray.get([a.do_work.remote() for a in actors]) - assert (results[0] == cp.ones( - (array_size, ), dtype=cp.float32) * world_size).all() - assert (results[1] == cp.ones( - (array_size, ), dtype=cp.float32) * world_size).all() - - -def test_allreduce_destroy(ray_start_single_node_2_gpus, - backend="nccl", - group_name="default"): - world_size = 2 - actors, _ = get_actors_group(world_size) - - results = ray.get([a.do_work.remote() for a in actors]) - assert (results[0] == cp.ones((10, ), dtype=cp.float32) * world_size).all() - assert (results[1] == cp.ones((10, ), dtype=cp.float32) * world_size).all() - - # destroy the group and try do work, should fail - ray.wait([a.destroy_group.remote() for a in actors]) - with pytest.raises(RuntimeError): - results = ray.get([a.do_work.remote() for a in actors]) - - # reinit the same group and all reduce - ray.get([ - actor.init_group.remote(world_size, i, backend, group_name) - for i, actor in enumerate(actors) - ]) - results = ray.get([a.do_work.remote() for a in actors]) - assert (results[0] == cp.ones( - (10, ), dtype=cp.float32) * world_size * 2).all() - assert (results[1] == cp.ones( - (10, ), dtype=cp.float32) * world_size * 2).all() - - -def test_allreduce_multiple_group(ray_start_single_node_2_gpus, - backend="nccl", - num_groups=5): - world_size = 2 - actors, _ = get_actors_group(world_size) - for group_name in range(1, num_groups): - ray.get([ - actor.init_group.remote(world_size, i, backend, str(group_name)) - for i, actor in enumerate(actors) - ]) - for i in range(num_groups): - group_name = "default" if i == 0 else str(i) - results = ray.get([a.do_work.remote(group_name) for a in actors]) - assert (results[0] == cp.ones( - (10, ), dtype=cp.float32) * (world_size**(i + 1))).all() - - -def test_allreduce_different_op(ray_start_single_node_2_gpus): - world_size = 2 - actors, _ = get_actors_group(world_size) - - # check product - ray.wait([ - a.set_buffer.remote(cp.ones(10, dtype=cp.float32) * (i + 2)) - for i, a in enumerate(actors) - ]) - results = ray.get([a.do_work.remote(op=ReduceOp.PRODUCT) for a in actors]) - assert (results[0] == cp.ones((10, ), dtype=cp.float32) * 6).all() - assert (results[1] == cp.ones((10, ), dtype=cp.float32) * 6).all() - - # check min - ray.wait([ - a.set_buffer.remote(cp.ones(10, dtype=cp.float32) * (i + 2)) - for i, a in enumerate(actors) - ]) - results = ray.get([a.do_work.remote(op=ReduceOp.MIN) for a in actors]) - assert (results[0] == cp.ones((10, ), dtype=cp.float32) * 2).all() - assert (results[1] == cp.ones((10, ), dtype=cp.float32) * 2).all() - - # check max - ray.wait([ - a.set_buffer.remote(cp.ones(10, dtype=cp.float32) * (i + 2)) - for i, a in enumerate(actors) - ]) - results = ray.get([a.do_work.remote(op=ReduceOp.MAX) for a in actors]) - assert (results[0] == cp.ones((10, ), dtype=cp.float32) * 3).all() - assert (results[1] == cp.ones((10, ), dtype=cp.float32) * 3).all() - - -@pytest.mark.parametrize("dtype", - [cp.uint8, cp.float16, cp.float32, cp.float64]) -def test_allreduce_different_dtype(ray_start_single_node_2_gpus, dtype): - world_size = 2 - actors, _ = get_actors_group(world_size) - ray.wait([a.set_buffer.remote(cp.ones(10, dtype=dtype)) for a in actors]) - results = ray.get([a.do_work.remote() for a in actors]) - assert (results[0] == cp.ones((10, ), dtype=dtype) * world_size).all() - assert (results[1] == cp.ones((10, ), dtype=dtype) * world_size).all() - - -def test_allreduce_torch_cupy(ray_start_single_node_2_gpus): - # import torch - world_size = 2 - actors, _ = get_actors_group(world_size) - ray.wait([actors[1].set_buffer.remote(torch.ones(10, ).cuda())]) - results = ray.get([a.do_work.remote() for a in actors]) - assert (results[0] == cp.ones((10, )) * world_size).all() - - ray.wait([actors[0].set_buffer.remote(torch.ones(10, ))]) - ray.wait([actors[1].set_buffer.remote(cp.ones(10, ))]) - with pytest.raises(RuntimeError): - results = ray.get([a.do_work.remote() for a in actors]) - - -if __name__ == "__main__": - import pytest - import sys - sys.exit(pytest.main(["-v", "-x", __file__])) diff --git a/python/ray/util/collective/tests/test_reduce.py b/python/ray/util/collective/tests/test_reduce.py new file mode 100644 index 000000000..89063620c --- /dev/null +++ b/python/ray/util/collective/tests/test_reduce.py @@ -0,0 +1,143 @@ +"""Test the reduce API.""" +import pytest +import cupy as cp +import ray +from ray.util.collective.types import ReduceOp + +from ray.util.collective.tests.util import create_collective_workers + + +@pytest.mark.parametrize("group_name", ["default", "test", "123?34!"]) +@pytest.mark.parametrize("dst_rank", [0, 1]) +def test_reduce_different_name(ray_start_single_node_2_gpus, group_name, + dst_rank): + world_size = 2 + actors, _ = create_collective_workers( + num_workers=world_size, group_name=group_name) + results = ray.get( + [a.do_reduce.remote(group_name, dst_rank) for a in actors]) + for i in range(world_size): + if i == dst_rank: + assert (results[i] == cp.ones( + (10, ), dtype=cp.float32) * world_size).all() + else: + assert (results[i] == cp.ones((10, ), dtype=cp.float32)).all() + + +@pytest.mark.parametrize("array_size", [2, 2**5, 2**10, 2**15, 2**20]) +@pytest.mark.parametrize("dst_rank", [0, 1]) +def test_reduce_different_array_size(ray_start_single_node_2_gpus, array_size, + dst_rank): + world_size = 2 + actors, _ = create_collective_workers(world_size) + ray.wait([ + a.set_buffer.remote(cp.ones(array_size, dtype=cp.float32)) + for a in actors + ]) + results = ray.get([a.do_reduce.remote(dst_rank=dst_rank) for a in actors]) + for i in range(world_size): + if i == dst_rank: + assert (results[i] == cp.ones( + (array_size, ), dtype=cp.float32) * world_size).all() + else: + assert (results[i] == cp.ones((array_size, ), + dtype=cp.float32)).all() + + +@pytest.mark.parametrize("dst_rank", [0, 1]) +def test_reduce_multiple_group(ray_start_single_node_2_gpus, + dst_rank, + num_groups=5): + world_size = 2 + actors, _ = create_collective_workers(world_size) + for group_name in range(1, num_groups): + ray.get([ + actor.init_group.remote(world_size, i, "nccl", str(group_name)) + for i, actor in enumerate(actors) + ]) + for i in range(num_groups): + group_name = "default" if i == 0 else str(i) + results = ray.get([ + a.do_reduce.remote(dst_rank=dst_rank, group_name=group_name) + for a in actors + ]) + for j in range(world_size): + if j == dst_rank: + assert (results[j] == cp.ones( + (10, ), dtype=cp.float32) * (i + 2)).all() + else: + assert (results[j] == cp.ones((10, ), dtype=cp.float32)).all() + + +@pytest.mark.parametrize("dst_rank", [0, 1]) +def test_reduce_different_op(ray_start_single_node_2_gpus, dst_rank): + world_size = 2 + actors, _ = create_collective_workers(world_size) + + # check product + ray.wait([ + a.set_buffer.remote(cp.ones(10, dtype=cp.float32) * (i + 2)) + for i, a in enumerate(actors) + ]) + results = ray.get([ + a.do_reduce.remote(dst_rank=dst_rank, op=ReduceOp.PRODUCT) + for a in actors + ]) + for i in range(world_size): + if i == dst_rank: + assert (results[i] == cp.ones((10, ), dtype=cp.float32) * 6).all() + else: + assert (results[i] == cp.ones( + (10, ), dtype=cp.float32) * (i + 2)).all() + + # check min + ray.wait([ + a.set_buffer.remote(cp.ones(10, dtype=cp.float32) * (i + 2)) + for i, a in enumerate(actors) + ]) + results = ray.get([ + a.do_reduce.remote(dst_rank=dst_rank, op=ReduceOp.MIN) for a in actors + ]) + for i in range(world_size): + if i == dst_rank: + assert (results[i] == cp.ones((10, ), dtype=cp.float32) * 2).all() + else: + assert (results[i] == cp.ones( + (10, ), dtype=cp.float32) * (i + 2)).all() + + # check max + ray.wait([ + a.set_buffer.remote(cp.ones(10, dtype=cp.float32) * (i + 2)) + for i, a in enumerate(actors) + ]) + results = ray.get([ + a.do_reduce.remote(dst_rank=dst_rank, op=ReduceOp.MAX) for a in actors + ]) + for i in range(world_size): + if i == dst_rank: + assert (results[i] == cp.ones((10, ), dtype=cp.float32) * 3).all() + else: + assert (results[i] == cp.ones( + (10, ), dtype=cp.float32) * (i + 2)).all() + + +@pytest.mark.parametrize("dst_rank", [0, 1]) +def test_reduce_torch_cupy(ray_start_single_node_2_gpus, dst_rank): + import torch + world_size = 2 + actors, _ = create_collective_workers(world_size) + ray.wait([actors[1].set_buffer.remote(torch.ones(10, ).cuda())]) + results = ray.get([a.do_reduce.remote(dst_rank=dst_rank) for a in actors]) + if dst_rank == 0: + assert (results[0] == cp.ones((10, )) * world_size).all() + assert (results[1] == torch.ones((10, )).cuda()).all() + else: + assert (results[0] == cp.ones((10, ))).all() + assert (results[1] == torch.ones((10, )).cuda() * world_size).all() + + +def test_reduce_invalid_rank(ray_start_single_node_2_gpus, dst_rank=3): + world_size = 2 + actors, _ = create_collective_workers(world_size) + with pytest.raises(ValueError): + _ = ray.get([a.do_reduce.remote(dst_rank=dst_rank) for a in actors]) diff --git a/python/ray/util/collective/tests/test_reducescatter.py b/python/ray/util/collective/tests/test_reducescatter.py new file mode 100644 index 000000000..4b1322ed4 --- /dev/null +++ b/python/ray/util/collective/tests/test_reducescatter.py @@ -0,0 +1,127 @@ +"""Test the collective reducescatter API.""" +import pytest +import ray + +import cupy as cp +import torch + +from ray.util.collective.tests.util import create_collective_workers, \ + init_tensors_for_gather_scatter + + +@pytest.mark.parametrize("tensor_backend", ["cupy", "torch"]) +@pytest.mark.parametrize("array_size", + [2, 2**5, 2**10, 2**15, 2**20, [2, 2], [5, 5, 5]]) +def test_reducescatter_different_array_size(ray_start_single_node_2_gpus, + array_size, tensor_backend): + world_size = 2 + actors, _ = create_collective_workers(world_size) + init_tensors_for_gather_scatter( + actors, array_size=array_size, tensor_backend=tensor_backend) + results = ray.get([a.do_reducescatter.remote() for a in actors]) + for i in range(world_size): + if tensor_backend == "cupy": + assert (results[i] == cp.ones(array_size, dtype=cp.float32) * + world_size).all() + else: + assert (results[i] == torch.ones( + array_size, dtype=torch.float32).cuda() * world_size).all() + + +@pytest.mark.parametrize("dtype", + [cp.uint8, cp.float16, cp.float32, cp.float64]) +def test_reducescatter_different_dtype(ray_start_single_node_2_gpus, dtype): + world_size = 2 + actors, _ = create_collective_workers(world_size) + init_tensors_for_gather_scatter(actors, dtype=dtype) + results = ray.get([a.do_reducescatter.remote() for a in actors]) + for i in range(world_size): + for j in range(world_size): + assert (results[i] == cp.ones(10, dtype=dtype) * world_size).all() + + +def test_reducescatter_torch_cupy(ray_start_single_node_2_gpus): + world_size = 2 + shape = [10, 10] + actors, _ = create_collective_workers(world_size) + + # tensor is pytorch, list is cupy + for i, a in enumerate(actors): + t = torch.ones(shape, dtype=torch.float32).cuda() * (i + 1) + ray.wait([a.set_buffer.remote(t)]) + list_buffer = [ + cp.ones(shape, dtype=cp.float32) for _ in range(world_size) + ] + ray.wait([a.set_list_buffer.remote(list_buffer)]) + results = ray.get([a.do_reducescatter.remote() for a in actors]) + for i in range(world_size): + assert (results[i] == torch.ones(shape, dtype=torch.float32).cuda() * + world_size).all() + + # tensor is cupy, list is pytorch + for i, a in enumerate(actors): + t = cp.ones(shape, dtype=cp.float32) * (i + 1) + ray.wait([a.set_buffer.remote(t)]) + list_buffer = [ + torch.ones(shape, dtype=torch.float32).cuda() + for _ in range(world_size) + ] + ray.wait([a.set_list_buffer.remote(list_buffer)]) + results = ray.get([a.do_reducescatter.remote() for a in actors]) + for i in range(world_size): + assert ( + results[i] == cp.ones(shape, dtype=cp.float32) * world_size).all() + + # some tensors in the list are pytorch, some are cupy + for i, a in enumerate(actors): + if i % 2 == 0: + t = torch.ones(shape, dtype=torch.float32).cuda() * (i + 1) + else: + t = cp.ones(shape, dtype=cp.float32) * (i + 1) + ray.wait([a.set_buffer.remote(t)]) + list_buffer = [] + for j in range(world_size): + if j % 2 == 0: + list_buffer.append( + torch.ones(shape, dtype=torch.float32).cuda()) + else: + list_buffer.append(cp.ones(shape, dtype=cp.float32)) + ray.wait([a.set_list_buffer.remote(list_buffer)]) + results = ray.get([a.do_reducescatter.remote() for a in actors]) + for i in range(world_size): + if i % 2 == 0: + assert (results[i] == torch.ones( + shape, dtype=torch.float32).cuda() * world_size).all() + else: + assert (results[i] == cp.ones(shape, dtype=cp.float32) * + world_size).all() + + # mixed case + for i, a in enumerate(actors): + if i % 2 == 0: + t = torch.ones(shape, dtype=torch.float32).cuda() * (i + 1) + else: + t = cp.ones(shape, dtype=cp.float32) * (i + 1) + ray.wait([a.set_buffer.remote(t)]) + list_buffer = [] + for j in range(world_size): + if j % 2 == 0: + list_buffer.append(cp.ones(shape, dtype=cp.float32)) + else: + list_buffer.append( + torch.ones(shape, dtype=torch.float32).cuda()) + ray.wait([a.set_list_buffer.remote(list_buffer)]) + results = ray.get([a.do_reducescatter.remote() for a in actors]) + for i in range(world_size): + if i % 2 == 0: + assert (results[i] == torch.ones( + shape, dtype=torch.float32).cuda() * world_size).all() + else: + assert (results[i] == cp.ones(shape, dtype=cp.float32) * + world_size).all() + + +if __name__ == "__main__": + import pytest + import sys + sys.exit(pytest.main(["-v", "-x", __file__])) diff --git a/python/ray/util/collective/tests/util.py b/python/ray/util/collective/tests/util.py index d59294d3f..3cee4de59 100644 --- a/python/ray/util/collective/tests/util.py +++ b/python/ray/util/collective/tests/util.py @@ -4,11 +4,17 @@ import ray import ray.util.collective as col from ray.util.collective.types import Backend, ReduceOp +import torch + @ray.remote(num_gpus=1) class Worker: def __init__(self): self.buffer = cp.ones((10, ), dtype=cp.float32) + self.list_buffer = [ + cp.ones((10, ), dtype=cp.float32), + cp.ones((10, ), dtype=cp.float32) + ] def init_group(self, world_size, @@ -22,10 +28,30 @@ class Worker: self.buffer = data return self.buffer - def do_work(self, group_name="default", op=ReduceOp.SUM): + def set_list_buffer(self, list_of_arrays): + self.list_buffer = list_of_arrays + return self.list_buffer + + def do_allreduce(self, group_name="default", op=ReduceOp.SUM): col.allreduce(self.buffer, group_name, op) return self.buffer + def do_reduce(self, group_name="default", dst_rank=0, op=ReduceOp.SUM): + col.reduce(self.buffer, dst_rank, group_name, op) + return self.buffer + + def do_broadcast(self, group_name="default", src_rank=0): + col.broadcast(self.buffer, src_rank, group_name) + return self.buffer + + def do_allgather(self, group_name="default"): + col.allgather(self.list_buffer, self.buffer, group_name) + return self.list_buffer + + def do_reducescatter(self, group_name="default", op=ReduceOp.SUM): + col.reducescatter(self.buffer, self.list_buffer, group_name, op) + return self.buffer + def destroy_group(self, group_name="default"): col.destroy_collective_group(group_name) return True @@ -49,3 +75,42 @@ class Worker: def report_is_group_initialized(self, group_name="default"): is_init = col.is_group_initialized(group_name) return is_init + + +def create_collective_workers(num_workers=2, + group_name="default", + backend="nccl"): + actors = [Worker.remote() for _ in range(num_workers)] + world_size = num_workers + init_results = ray.get([ + actor.init_group.remote(world_size, i, backend, group_name) + for i, actor in enumerate(actors) + ]) + return actors, init_results + + +def init_tensors_for_gather_scatter(actors, + array_size=10, + dtype=cp.float32, + tensor_backend="cupy"): + world_size = len(actors) + for i, a in enumerate(actors): + if tensor_backend == "cupy": + t = cp.ones(array_size, dtype=dtype) * (i + 1) + elif tensor_backend == "torch": + t = torch.ones(array_size, dtype=torch.float32).cuda() * (i + 1) + else: + raise RuntimeError("Unsupported tensor backend.") + ray.wait([a.set_buffer.remote(t)]) + if tensor_backend == "cupy": + list_buffer = [ + cp.ones(array_size, dtype=dtype) for _ in range(world_size) + ] + elif tensor_backend == "torch": + list_buffer = [ + torch.ones(array_size, dtype=torch.float32).cuda() + for _ in range(world_size) + ] + else: + raise RuntimeError("Unsupported tensor backend.") + ray.get([a.set_list_buffer.remote(list_buffer) for a in actors]) diff --git a/python/ray/util/collective/types.py b/python/ray/util/collective/types.py index ef037373a..be92b98f2 100644 --- a/python/ray/util/collective/types.py +++ b/python/ray/util/collective/types.py @@ -50,15 +50,46 @@ class ReduceOp(Enum): MAX = 3 -unset_timeout = timedelta(milliseconds=-1) +unset_timeout_ms = timedelta(milliseconds=-1) @dataclass class AllReduceOptions: reduceOp = ReduceOp.SUM - timeout = unset_timeout + timeout_ms = unset_timeout_ms @dataclass class BarrierOptions: - timeout = unset_timeout + timeout_ms = unset_timeout_ms + + +@dataclass +class ReduceOptions: + reduceOp = ReduceOp.SUM + root_rank = 0 + timeout_ms = unset_timeout_ms + + +@dataclass +class AllGatherOptions: + timeout_ms = unset_timeout_ms + + +# +# @dataclass +# class GatherOptions: +# root_rank = 0 +# timeout = unset_timeout + + +@dataclass +class BroadcastOptions: + root_rank = 0 + timeout_ms = unset_timeout_ms + + +@dataclass +class ReduceScatterOptions: + reduceOp = ReduceOp.SUM + timeout_ms = unset_timeout_ms