diff --git a/python/ray/experimental/dask/__init__.py b/python/ray/experimental/dask/__init__.py new file mode 100644 index 000000000..42dbf2010 --- /dev/null +++ b/python/ray/experimental/dask/__init__.py @@ -0,0 +1,3 @@ +from .scheduler import ray_dask_get, ray_dask_get_sync + +__all__ = ["ray_dask_get", "ray_dask_get_sync"] diff --git a/python/ray/experimental/dask/common.py b/python/ray/experimental/dask/common.py new file mode 100644 index 000000000..4ceb81e03 --- /dev/null +++ b/python/ray/experimental/dask/common.py @@ -0,0 +1,77 @@ +from collections import OrderedDict +from collections.abc import Iterator +from operator import getitem +import uuid + +import ray + +from dask.base import quote +from dask.core import get as get_sync +from dask.compatibility import apply, is_dataclass, dataclass_fields + + +def unpack_object_refs(*args): + """ + Extract `ray.ObjectRef`s from a set of potentially arbitrarily nested + Python objects. + + Intended use is to find all Ray object references in a set of (possibly + nested) Python objects, do something to them (get(), wait(), etc.), then + repackage them into equivalent Python objects. + + Args: + *args: One or more (potentially nested) Python objects that contain + Ray object references. + + Returns: + A 2-tuple of a flat list of all contained Ray object references, and a + function that, when given the corresponding flat list of concrete + values, will return a set of Python objects equivalent to that which + was given in *args, but with all Ray object references replaced with + their corresponding concrete values. + """ + object_refs = [] + repack_dsk = {} + + object_refs_token = uuid.uuid4().hex + + def _unpack(expr): + if isinstance(expr, ray.ObjectRef): + token = expr.hex() + repack_dsk[token] = (getitem, object_refs_token, len(object_refs)) + object_refs.append(expr) + return token + + token = uuid.uuid4().hex + # Treat iterators like lists + typ = list if isinstance(expr, Iterator) else type(expr) + if typ in (list, tuple, set): + repack_task = (typ, [_unpack(i) for i in expr]) + elif typ in (dict, OrderedDict): + repack_task = (typ, + [[_unpack(k), _unpack(v)] for k, v in expr.items()]) + elif is_dataclass(expr): + repack_task = ( + apply, + typ, + (), + ( + dict, + [[f.name, _unpack(getattr(expr, f.name))] + for f in dataclass_fields(expr)], + ), + ) + else: + return expr + repack_dsk[token] = repack_task + return token + + out = uuid.uuid4().hex + repack_dsk[out] = (tuple, [_unpack(i) for i in args]) + + def repack(results): + dsk = repack_dsk.copy() + dsk[object_refs_token] = quote(results) + return get_sync(dsk, out) + + return object_refs, repack diff --git a/python/ray/experimental/dask/scheduler.py b/python/ray/experimental/dask/scheduler.py new file mode 100644 index 000000000..b15164525 --- /dev/null +++ b/python/ray/experimental/dask/scheduler.py @@ -0,0 +1,297 @@ +import atexit +from collections import defaultdict +from multiprocessing.pool import ThreadPool +import threading + +import ray + +from dask.core import istask, ishashable, _execute_task +from dask.local import get_async, apply_sync +from dask.system import CPU_COUNT +from dask.threaded import pack_exception, _thread_get_id + +from .common import unpack_object_refs + +main_thread = threading.current_thread() +default_pool = None +pools = defaultdict(dict) +pools_lock = threading.Lock() + + +def ray_dask_get(dsk, keys, **kwargs): + """ + A Dask-Ray scheduler. This scheduler will send top-level (non-inlined) Dask + tasks to a Ray cluster for execution. The scheduler will wait for the + tasks to finish executing, fetch the results, and repackage them into the + appropriate Dask collections. This particular scheduler uses a threadpool + to submit Ray tasks. + + This can be passed directly to `dask.compute()`, as the scheduler: + + >>> dask.compute(obj, scheduler=ray_dask_get) + + You can override the number of threads to use when submitting the + Ray tasks, or the threadpool used to submit Ray tasks: + + >>> dask.compute( + obj, + scheduler=ray_dask_get, + num_workers=8, + pool=some_cool_pool, + ) + + Args: + dsk (Dict): Dask graph, represented as a task DAG dictionary. + keys (List[str]): List of Dask graph keys whose values we wish to + compute and return. + num_workers (Optional[int]): The number of worker threads to use in + the Ray task submission traversal of the Dask graph. + pool (Optional[ThreadPool]): A multiprocessing threadpool to use to + submit Ray tasks. + + Returns: + Computed values corresponding to the provided keys. + """ + num_workers = kwargs.pop("num_workers", None) + pool = kwargs.pop("pool", None) + # We attempt to reuse any other thread pools that have been created within + # this thread and with the given number of workers. We reuse a global + # thread pool if num_workers is not given and we're in the main thread. + global default_pool + thread = threading.current_thread() + if pool is None: + with pools_lock: + if num_workers is None and thread is main_thread: + if default_pool is None: + default_pool = ThreadPool(CPU_COUNT) + atexit.register(default_pool.close) + pool = default_pool + elif thread in pools and num_workers in pools[thread]: + pool = pools[thread][num_workers] + else: + pool = ThreadPool(num_workers) + atexit.register(pool.close) + pools[thread][num_workers] = pool + + # NOTE: We hijack Dask's `get_async` function, injecting a different task + # executor. + object_refs = get_async( + _apply_async_wrapper(pool.apply_async, _rayify_task_wrapper), + len(pool._pool), + dsk, + keys, + get_id=_thread_get_id, + pack_exception=pack_exception, + **kwargs, + ) + # NOTE: We explicitly delete the Dask graph here so object references + # are garbage-collected before this function returns, i.e. before all Ray + # tasks are done. Otherwise, no intermediate objects will be cleaned up + # until all Ray tasks are done. + del dsk + result = ray_get_unpack(object_refs) + + # cleanup pools associated with dead threads. + with pools_lock: + active_threads = set(threading.enumerate()) + if thread is not main_thread: + for t in list(pools): + if t not in active_threads: + for p in pools.pop(t).values(): + p.close() + return result + + +def _apply_async_wrapper(apply_async, real_func, *extra_args, **extra_kwargs): + """ + Wraps the given pool `apply_async` function, hotswapping `real_func` in as + the function to be applied and adding `extra_args` and `extra_kwargs` to + `real_func`'s call. + + Args: + apply_async (callable): The pool function to be wrapped. + real_func (callable): The real function that we wish the pool apply + function to execute. + *extra_args: Extra positional arguments to pass to the `real_func`. + **extra_kwargs: Extra keyword arguments to pass to the `real_func`. + + Returns: + A wrapper function that will ignore it's first `func` argument and + pass `real_func` in its place. To be passed to `dask.local.get_async`. + """ + + def wrapper(func, args=(), kwds={}, callback=None): # noqa: M511 + return apply_async( + real_func, + args=args + extra_args, + kwds=dict(kwds, **extra_kwargs), + callback=callback, + ) + + return wrapper + + +def _rayify_task_wrapper( + key, + task_info, + dumps, + loads, + get_id, + pack_exception, +): + """ + The core Ray-Dask task execution wrapper, to be given to the thread pool's + `apply_async` function. Exactly the same as `execute_task`, except that it + calls `_rayify_task` on the task instead of `_execute_task`. + + Args: + key (str): The Dask graph key whose corresponding task we wish to + execute. + task_info: The task to execute and its dependencies. + dumps (callable): A result serializing function. + loads (callable): A task_info deserializing function. + get_id (callable): An ID generating function. + pack_exception (callable): An exception serializing function. + + Returns: + A 3-tuple of the task's key, a literal or a Ray object reference for a + Ray task's result, and whether the Ray task submission failed. + """ + try: + task, deps = loads(task_info) + result = _rayify_task(task, key, deps) + id = get_id() + result = dumps((result, id)) + failed = False + except BaseException as e: + result = pack_exception(e, dumps) + failed = True + return key, result, failed + + +def _rayify_task(task, key, deps): + """ + Rayifies the given task, submitting it as a Ray task to the Ray cluster. + + Args: + task: A Dask graph value, being either a literal, dependency key, Dask + task, or a list thereof. + key: The Dask graph key for the given task. + deps: The dependencies of this task. + + Returns: + A literal, a Ray object reference representing a submitted task, or a + list thereof. + """ + if isinstance(task, list): + # Recursively rayify this list. This will still bottom out at the first + # actual task encountered, inlining any tasks in that task's arguments. + return [_rayify_task(t, deps) for t in task] + elif istask(task): + # Unpacks and repacks Ray object references and submits the task to the + # Ray cluster for execution. + func, args = task[0], task[1:] + # If the function's arguments contain nested object references, we must + # unpack said object references into a flat set of arguments so that + # Ray properly tracks the object dependencies between Ray tasks. + object_refs, repack = unpack_object_refs(args, deps) + # Submit the task using a wrapper function. + return dask_task_wrapper.remote(func, repack, *object_refs) + elif not ishashable(task): + return task + elif task in deps: + return deps[task] + else: + return task + + +@ray.remote +def dask_task_wrapper(func, repack, *args): + """ + A Ray remote function acting as a Dask task wrapper. This function will + repackage the given flat `args` into its original data structures using + `repack`, execute any Dask subtasks within the repackaged arguments + (inlined by Dask's optimization pass), and then pass the concrete task + arguments to the provide Dask task function, `func`. + + Args: + func (callable): The Dask task function to execute. + repack (callable): A function that repackages the provided args into + the original (possibly nested) Python objects. + *args (ObjectRef): Ray object references representing the Dask task's + arguments. + + Returns: + The output of the Dask task. In the context of Ray, a + dask_task_wrapper.remote() invocation will return a Ray object + reference representing the Ray task's result. + """ + repacked_args, repacked_deps = repack(args) + # Recursively execute Dask-inlined tasks. + actual_args = [_execute_task(a, repacked_deps) for a in repacked_args] + # Execute the actual underlying Dask task. + return func(*actual_args) + + +def ray_get_unpack(object_refs): + """ + Unpacks object references, gets the object references, and repacks. + Traverses arbitrary data structures. + + Args: + object_refs: A (potentially nested) Python object containing Ray object + references. + + Returns: + The input Python object with all contained Ray object references + resolved with their concrete values. + """ + if isinstance(object_refs, + (tuple, list)) and any(not isinstance(x, ray.ObjectRef) + for x in object_refs): + # We flatten the object references before calling ray.get(), since Dask + # loves to nest collections in nested tuples and Ray expects a flat + # list of object references. We repack the results after ray.get() + # completes. + object_refs, repack = unpack_object_refs(*object_refs) + computed_result = ray.get(object_refs) + return repack(computed_result) + else: + return ray.get(object_refs) + + +def ray_dask_get_sync(dsk, keys, **kwargs): + """ + A synchronous Dask-Ray scheduler. This scheduler will send top-level + (non-inlined) Dask tasks to a Ray cluster for execution. The scheduler will + wait for the tasks to finish executing, fetch the results, and repackage + them into the appropriate Dask collections. This particular scheduler + submits Ray tasks synchronously, which can be useful for debugging. + + This can be passed directly to `dask.compute()`, as the scheduler: + + >>> dask.compute(obj, scheduler=ray_dask_get_sync) + + Args: + dsk (Dict): Dask graph, represented as a task DAG dictionary. + keys (List[str]): List of Dask graph keys whose values we wish to + compute and return. + + Returns: + Computed values corresponding to the provided keys. + """ + # NOTE: We hijack Dask's `get_async` function, injecting a different task + # executor. + object_refs = get_async( + _apply_async_wrapper(apply_sync, _rayify_task_wrapper), + 1, + dsk, + keys, + **kwargs, + ) + # NOTE: We explicitly delete the Dask graph here so object references + # are garbage-collected before this function returns, i.e. before all Ray + # tasks are done. Otherwise, no intermediate objects will be cleaned up + # until all Ray tasks are done. + del dsk + return ray_get_unpack(object_refs) diff --git a/python/ray/tests/BUILD b/python/ray/tests/BUILD index 81c989ff5..0408ec920 100644 --- a/python/ray/tests/BUILD +++ b/python/ray/tests/BUILD @@ -38,6 +38,14 @@ py_test( deps = ["//:ray_lib"], ) +py_test( + name = "test_dask_scheduler", + size = "small", + srcs = SRCS + ["test_dask_scheduler.py"], + tags = ["exclusive"], + deps = ["//:ray_lib"], +) + py_test( name = "test_iter", size = "medium", diff --git a/python/ray/tests/test_dask_scheduler.py b/python/ray/tests/test_dask_scheduler.py new file mode 100644 index 000000000..9f55fa4ac --- /dev/null +++ b/python/ray/tests/test_dask_scheduler.py @@ -0,0 +1,35 @@ +import dask +import pytest + +import ray +from ray.experimental.dask import ray_dask_get + + +def test_ray_dask_basic(ray_start_regular_shared): + @ray.remote + def stringify(x): + return "The answer is {}".format(x) + + zero_id = ray.put(0) + + def add(x, y): + # Can retrieve ray objects from inside Dask. + zero = ray.get(zero_id) + # Can call Ray methods from inside Dask. + return ray.get(stringify.remote(x + y + zero)) + + add = dask.delayed(add) + + @ray.remote + def call_add(): + z = add(2, 4) + # Can call Dask graphs from inside Ray. + return z.compute(scheduler=ray_dask_get) + + ans = ray.get(call_add.remote()) + assert ans == "The answer is 6", ans + + +if __name__ == "__main__": + import sys + sys.exit(pytest.main(["-v", __file__])) diff --git a/python/requirements.txt b/python/requirements.txt index f5289741b..45b3c711a 100644 --- a/python/requirements.txt +++ b/python/requirements.txt @@ -43,6 +43,7 @@ blist; platform_system != "Windows" boto3 cython==0.29.0 dataclasses +dask[complete] feather-format gym kubernetes