From 36e1f20e9cfa1d8f3fb969b7deb2086eac883b73 Mon Sep 17 00:00:00 2001 From: Clark Zinzow Date: Tue, 8 Sep 2020 14:00:58 -0600 Subject: [PATCH] Add Dask-Ray scheduler callbacks. (#10519) Improve Dask-on-Ray documentation. Move to RayCallback(s) namedtuples, and use top-level CBS tuple as source-of-truth for callback methods. --- doc/source/dask-on-ray.rst | 225 +++++++++++++++++++++--- python/ray/tests/BUILD | 1 + python/ray/tests/test_dask_callback.py | 234 +++++++++++++++++++++++++ python/ray/util/dask/__init__.py | 13 +- python/ray/util/dask/callbacks.py | 211 ++++++++++++++++++++++ python/ray/util/dask/scheduler.py | 226 +++++++++++++++++++----- 6 files changed, 837 insertions(+), 73 deletions(-) create mode 100644 python/ray/tests/test_dask_callback.py create mode 100644 python/ray/util/dask/callbacks.py diff --git a/doc/source/dask-on-ray.rst b/doc/source/dask-on-ray.rst index 7c5626543..bb9e58101 100644 --- a/doc/source/dask-on-ray.rst +++ b/doc/source/dask-on-ray.rst @@ -1,42 +1,213 @@ +*********** Dask on Ray -=========== +*********** -Ray offers a scheduler backend for Dask. With this plugin, you can use familiar Dask APIs such as Dask DataFrames, and the computation will be executed by the Ray system. +Ray offers an experimental scheduler for Dask, allowing you to build data +analyses using the familiar Dask collections (dataframes, arrays) and execute +the underlying computations on a Ray cluster. Using this Dask scheduler, the +entire Dask ecosystem can be executed on top of Ray. -The Ray plugin can be used with any Dask `.compute() `__ call. -Note that for execution on a Ray cluster, you should *not* use the `Dask.distributed `__ client. -Just follow the instructions for :ref:`using Ray on a cluster ` to modify the ``ray.init()`` call. +========= +Scheduler +========= + +The Dask-Ray scheduler can execute any valid Dask graph, and can be used with +any Dask `.compute() `__ +call. Here's an example: .. code-block:: python - import ray - from ray.util.dask import ray_dask_get - import dask.delayed - from time import sleep + import ray + from ray.util.dask import ray_dask_get + import dask.delayed + import time - # Start Ray. - # Tip: If you're connecting to an existing cluster, use ray.init(address="auto"). - ray.init() + # Start Ray. + # Tip: If you're connecting to an existing cluster, use ray.init(address="auto"). + ray.init() - def inc(x): - sleep(1) - return x + 1 + @dask.delayed + def inc(x): + time.sleep(1) + return x + 1 - def add(x, y): - sleep(1) - return x + y + @dask.delayed + def add(x, y): + time.sleep(3) + return x + y - x = dask.delayed(inc)(1) - y = dask.delayed(inc)(2) - z = dask.delayed(add)(x, y) - # The Dask scheduler submits the recorded task graph to Ray. - z.compute(scheduler=ray_dask_get) + x = inc(1) + y = inc(2) + z = add(x, y) + # The Dask scheduler submits the underlying task graph to Ray. + z.compute(scheduler=ray_dask_get) -Why use this feature? +Why use Dask on Ray? - 1. If you'd like to use Dask and Ray libraries in the same application. - 2. To take advantage of Ray-specific features such as the :ref:`cluster launcher ` and :ref:`shared-memory store `. + 1. If you'd like to create data analyses using the familiar NumPy and Pandas + APIs provided by Dask and execute them on a production-ready distributed + task execution system like Ray. + 2. If you'd like to use Dask and Ray libraries in the same application + without having two different task execution backends. + 3. To take advantage of Ray-specific features such as the + :ref:`cluster launcher ` and + :ref:`shared-memory store `. -Note that Dask-on-Ray is an ongoing project and is not expected to achieve the same performance as using Ray directly. +Note that for execution on a Ray cluster, you should *not* use the +`Dask.distributed `__ +client; simply use plain Dask and its collections, and pass ``ray_dask_get`` +to ``.compute()`` calls. Follow the instructions for +:ref:`using Ray on a cluster ` to modify the +``ray.init()`` call. + +Dask-on-Ray is an ongoing project and is not expected to achieve the same performance as using Ray directly. + +========= +Callbacks +========= + +Dask's `custom callback abstraction `__ +is extended with Ray-specific callbacks, allowing the user to hook into the +Ray task submission and execution lifecycles. +With these hooks, implementing Dask-level scheduler and task introspection, +such as progress reporting, diagnostics, caching, etc., is simple. + +Here's an example that measures and logs the execution time of each task using +the ``ray_pretask`` and ``ray_posttask`` hooks: + +.. code-block:: python + + from ray.util.dask import RayDaskCallback + from timeit import default_timer as timer + + + class MyTimerCallback(RayDaskCallback): + def _ray_pretask(self, key, object_refs): + # Executed at the start of the Ray task. + start_time = timer() + return start_time + + def _ray_posttask(self, key, result, pre_state): + # Executed at the end of the Ray task. + execution_time = timer() - pre_state + print(f"Execution time for task {key}: {execution_time}s") + + + with MyTimerCallback(): + # Any .compute() calls within this context will get MyTimerCallback() + # as a Dask-Ray callback. + z.compute(scheduler=ray_dask_get) + +The following Ray-specific callbacks are provided: + + 1. :code:`ray_presubmit(task, key, deps)`: Run before submitting a Ray + task. If this callback returns a non-`None` value, a Ray task will _not_ + be created and this value will be used as the would-be task's result + value. + 2. :code:`ray_postsubmit(task, key, deps, object_ref)`: Run after submitting + a Ray task. + 3. :code:`ray_pretask(key, object_refs)`: Run before executing a Dask task + within a Ray task. This executes after the task has been submitted, + within a Ray worker. The return value of this task will be passed to the + ray_posttask callback, if provided. + 4. :code:`ray_posttask(key, result, pre_state)`: Run after executing a Dask + task within a Ray task. This executes within a Ray worker. This callback + receives the return value of the ray_pretask callback, if provided. + 5. :code:`ray_postsubmit_all(object_refs, dsk)`: Run after all Ray tasks + have been submitted. + 6. :code:`ray_finish(result)`: Run after all Ray tasks have finished + executing and the final result has been returned. + +See the docstring for +:meth:`RayDaskCallback.__init__() .__init__` +for further details about these callbacks, their arguments, and their return +values. + +When creating your own callbacks, you can use +:class:`RayDaskCallback ` +directly, passing the callback functions as constructor arguments: + +.. code-block:: python + + def my_presubmit_cb(task, key, deps): + print(f"About to submit task {key}!") + + with RayDaskCallback(ray_presubmit=my_presubmit_cb): + z.compute(scheduler=ray_dask_get) + +or you can subclass it, implementing the callback methods that you need: + +.. code-block:: python + + class MyPresubmitCallback(RayDaskCallback): + def _ray_presubmit(self, task, key, deps): + print(f"About to submit task {key}!") + + with MyPresubmitCallback(): + z.compute(scheduler=ray_dask_get) + +You can also specify multiple callbacks: + +.. code-block:: python + + # The hooks for both MyTimerCallback and MyPresubmitCallback will be + # called. + with MyTimerCallback(), MyPresubmitCallback(): + z.compute(scheduler=ray_dask_get) + +Combining Dask callbacks with an actor yields simple patterns for stateful data +aggregation, such as capturing task execution statistics and caching results. +Here is an example that does both, caching the result of a task if its +execution time exceeds some user-defined threshold: + +.. code-block:: python + + @ray.remote + class SimpleCacheActor: + def __init__(self): + self.cache = {} + + def get(self, key): + # Raises KeyError if key isn't in cache. + return self.cache[key] + + def put(self, key, value): + self.cache[key] = value + + + class SimpleCacheCallback(RayDaskCallback): + def __init__(self, cache_actor_handle, put_threshold=10): + self.cache_actor = cache_actor_handle + self.put_threshold = put_threshold + + def _ray_presubmit(self, task, key, deps): + try: + return ray.get(self.cache_actor.get.remote(str(key))) + except KeyError: + return None + + def _ray_pretask(self, key, object_refs): + start_time = timer() + return start_time + + def _ray_posttask(self, key, result, pre_state): + execution_time = timer() - pre_state + if execution_time > self.put_threshold: + self.cache_actor.put.remote(str(key), result) + + + cache_actor = SimpleCacheActor.remote() + cache_callback = SimpleCacheCallback(cache_actor, put_threshold=2) + with cache_callback: + z.compute(scheduler=ray_dask_get) + +Note that the existing Dask scheduler callbacks (``start``, ``start_state``, +``pretask``, ``posttask``, ``finish``) are also available, which can be used to +introspect the Dask task to Ray task conversion process, but that ``pretask`` +and ``posttask`` are executed before and after the Ray task is *submitted*, not +executed, and that ``finish`` is executed after all Ray tasks have been +*submitted*, not executed. + +This callback API is currently unstable and subject to change. diff --git a/python/ray/tests/BUILD b/python/ray/tests/BUILD index d79a18359..ddd186324 100644 --- a/python/ray/tests/BUILD +++ b/python/ray/tests/BUILD @@ -79,6 +79,7 @@ py_test_module_list( "test_command_runner.py", "test_coordinator_server.py", "test_dask_scheduler.py", + "test_dask_callback.py", "test_debug_tools.py", "test_global_state.py", "test_job.py", diff --git a/python/ray/tests/test_dask_callback.py b/python/ray/tests/test_dask_callback.py new file mode 100644 index 000000000..58c720cea --- /dev/null +++ b/python/ray/tests/test_dask_callback.py @@ -0,0 +1,234 @@ +import dask +import pytest + +import ray +from ray.util.dask import ray_dask_get, RayDaskCallback + + +@dask.delayed +def add(x, y): + return x + y + + +def test_callback_active(): + """Test that callbacks are active within context""" + assert not RayDaskCallback.ray_active + + with RayDaskCallback(): + assert RayDaskCallback.ray_active + + assert not RayDaskCallback.ray_active + + +def test_presubmit_shortcircuit(ray_start_regular_shared): + """ + Test that presubmit return short-circuits task submission, and that task's + result is set to the presubmit return value. + """ + + class PresubmitShortcircuitCallback(RayDaskCallback): + def _ray_presubmit(self, task, key, deps): + return 0 + + def _ray_postsubmit(self, task, key, deps, object_ref): + pytest.fail("_ray_postsubmit shouldn't be called when " + "_ray_presubmit returns a value") + + with PresubmitShortcircuitCallback(): + z = add(2, 3) + result = z.compute(scheduler=ray_dask_get) + + assert result == 0 + + +def test_pretask_posttask_shared_state(ray_start_regular_shared): + """ + Test that pretask return value is passed to corresponding posttask + callback. + """ + + class PretaskPosttaskCallback(RayDaskCallback): + def _ray_pretask(self, key, object_refs): + return key + + def _ray_posttask(self, key, result, pre_state): + assert pre_state == key + + with PretaskPosttaskCallback(): + z = add(2, 3) + result = z.compute(scheduler=ray_dask_get) + + assert result == 5 + + +def test_postsubmit(ray_start_regular_shared): + """ + Test that postsubmit is called after each task. + """ + + class PostsubmitCallback(RayDaskCallback): + def __init__(self, postsubmit_actor): + self.postsubmit_actor = postsubmit_actor + + def _ray_postsubmit(self, task, key, deps, object_ref): + self.postsubmit_actor.postsubmit.remote(task, key, deps, + object_ref) + + @ray.remote + class PostsubmitActor: + def __init__(self): + self.postsubmit_counter = 0 + + def postsubmit(self, task, key, deps, object_ref): + self.postsubmit_counter += 1 + + def get_postsubmit_counter(self): + return self.postsubmit_counter + + postsubmit_actor = PostsubmitActor.remote() + with PostsubmitCallback(postsubmit_actor): + z = add(2, 3) + result = z.compute(scheduler=ray_dask_get) + + assert ray.get(postsubmit_actor.get_postsubmit_counter.remote()) == 1 + assert result == 5 + + +def test_postsubmit_all(ray_start_regular_shared): + """ + Test that postsubmit_all is called once. + """ + + class PostsubmitAllCallback(RayDaskCallback): + def __init__(self, postsubmit_all_actor): + self.postsubmit_all_actor = postsubmit_all_actor + + def _ray_postsubmit_all(self, object_refs, dsk): + self.postsubmit_all_actor.postsubmit_all.remote(object_refs, dsk) + + @ray.remote + class PostsubmitAllActor: + def __init__(self): + self.postsubmit_all_called = False + + def postsubmit_all(self, object_refs, dsk): + self.postsubmit_all_called = True + + def get_postsubmit_all_called(self): + return self.postsubmit_all_called + + postsubmit_all_actor = PostsubmitAllActor.remote() + with PostsubmitAllCallback(postsubmit_all_actor): + z = add(2, 3) + result = z.compute(scheduler=ray_dask_get) + + assert ray.get(postsubmit_all_actor.get_postsubmit_all_called.remote()) + assert result == 5 + + +def test_finish(ray_start_regular_shared): + """ + Test that finish callback is called once. + """ + + class FinishCallback(RayDaskCallback): + def __init__(self, finish_actor): + self.finish_actor = finish_actor + + def _ray_finish(self, result): + self.finish_actor.finish.remote(result) + + @ray.remote + class FinishActor: + def __init__(self): + self.finish_called = False + + def finish(self, result): + self.finish_called = True + + def get_finish_called(self): + return self.finish_called + + finish_actor = FinishActor.remote() + with FinishCallback(finish_actor): + z = add(2, 3) + result = z.compute(scheduler=ray_dask_get) + + assert ray.get(finish_actor.get_finish_called.remote()) + assert result == 5 + + +def test_multiple_callbacks(ray_start_regular_shared): + """ + Test that multiple callbacks are supported. + """ + + class PostsubmitCallback(RayDaskCallback): + def __init__(self, postsubmit_actor): + self.postsubmit_actor = postsubmit_actor + + def _ray_postsubmit(self, task, key, deps, object_ref): + self.postsubmit_actor.postsubmit.remote(task, key, deps, + object_ref) + + @ray.remote + class PostsubmitActor: + def __init__(self): + self.postsubmit_counter = 0 + + def postsubmit(self, task, key, deps, object_ref): + self.postsubmit_counter += 1 + + def get_postsubmit_counter(self): + return self.postsubmit_counter + + postsubmit_actor = PostsubmitActor.remote() + cb1 = PostsubmitCallback(postsubmit_actor) + cb2 = PostsubmitCallback(postsubmit_actor) + cb3 = PostsubmitCallback(postsubmit_actor) + with cb1, cb2, cb3: + z = add(2, 3) + result = z.compute(scheduler=ray_dask_get) + + assert ray.get(postsubmit_actor.get_postsubmit_counter.remote()) == 3 + assert result == 5 + + +def test_pretask_posttask_shared_state_multi(ray_start_regular_shared): + """ + Test that pretask return values are passed to the correct corresponding + posttask callbacks when multiple callbacks are given. + """ + + class PretaskPosttaskCallback(RayDaskCallback): + def __init__(self, suffix): + self.suffix = suffix + + def _ray_pretask(self, key, object_refs): + return key + self.suffix + + def _ray_posttask(self, key, result, pre_state): + assert pre_state == key + self.suffix + + class PretaskOnlyCallback(RayDaskCallback): + def _ray_pretask(self, key, object_refs): + return "baz" + + class PosttaskOnlyCallback(RayDaskCallback): + def _ray_posttask(self, key, result, pre_state): + assert pre_state is None + + cb1 = PretaskPosttaskCallback("foo") + cb2 = PretaskOnlyCallback() + cb3 = PosttaskOnlyCallback() + cb4 = PretaskPosttaskCallback("bar") + with cb1, cb2, cb3, cb4: + z = add(2, 3) + result = z.compute(scheduler=ray_dask_get) + + assert result == 5 + + +if __name__ == "__main__": + import sys + sys.exit(pytest.main(["-v", __file__])) diff --git a/python/ray/util/dask/__init__.py b/python/ray/util/dask/__init__.py index 42dbf2010..bfe28571a 100644 --- a/python/ray/util/dask/__init__.py +++ b/python/ray/util/dask/__init__.py @@ -1,3 +1,14 @@ from .scheduler import ray_dask_get, ray_dask_get_sync +from .callbacks import ( + RayDaskCallback, + local_ray_callbacks, + unpack_ray_callbacks, +) -__all__ = ["ray_dask_get", "ray_dask_get_sync"] +__all__ = [ + "ray_dask_get", + "ray_dask_get_sync", + "RayDaskCallback", + "local_ray_callbacks", + "unpack_ray_callbacks", +] diff --git a/python/ray/util/dask/callbacks.py b/python/ray/util/dask/callbacks.py new file mode 100644 index 000000000..8b4e8039b --- /dev/null +++ b/python/ray/util/dask/callbacks.py @@ -0,0 +1,211 @@ +import contextlib +from collections import namedtuple + +from dask.callbacks import Callback + +# The names of the Ray-specific callbacks. These are the kwarg names that +# RayDaskCallback will accept on construction, and is considered the +# source-of-truth for what Ray-specific callbacks exist. +CBS = ( + "ray_presubmit", + "ray_postsubmit", + "ray_pretask", + "ray_posttask", + "ray_postsubmit_all", + "ray_finish", +) +# The Ray-specific callback method names for RayDaskCallback. +CB_FIELDS = tuple("_" + field for field in CBS) +# The Ray-specific callbacks that we do _not_ wish to drop from RayCallbacks +# if not given on a RayDaskCallback instance (will be filled with None +# instead). +CBS_DONT_DROP = {"ray_pretask", "ray_posttask"} + +# The Ray-specific callbacks for a single RayDaskCallback. +RayCallback = namedtuple("RayCallback", " ".join(CBS)) + +# The Ray-specific callbacks for one or more RayDaskCallbacks. +RayCallbacks = namedtuple("RayCallbacks", + " ".join([field + "_cbs" for field in CBS])) + + +class RayDaskCallback(Callback): + """ + Extends Dask's `Callback` class with Ray-specific hooks. When instantiating + or subclassing this class, both the normal Dask hooks (e.g. pretask, + posttask, etc.) and the Ray-specific hooks can be provided. + + See `dask.callbacks.Callback` for usage. + + Caveats: Any Dask-Ray scheduler must bring the Ray-specific callbacks into + context using the `local_ray_callbacks` context manager, since the built-in + `local_callbacks` context manager provided by Dask isn't aware of this + class. + """ + + # Set of active Ray-specific callbacks. + ray_active = set() + + def __init__(self, **kwargs): + """ + Ray-specific callbacks: + - def _ray_presubmit(task, key, deps): + Run before submitting a Ray task. If this callback returns a + non-`None` value, a Ray task will _not_ be created and this + value will be used as the would-be task's result value. + + Args: + task (tuple): A Dask task, where the first tuple item is + the task function, and the remaining tuple items are + the task arguments (either the actual argument values, + or Dask keys into the deps dictionary whose + corresponding values are the argument values). + key (str): The Dask graph key for the given task. + deps (dict): The dependencies of this task. + + Returns: + Either None, in which case a Ray task will be submitted, or + a non-None value, in which case a Ray task will not be + submitted and this return value will be used as the + would-be task result value. + + - def _ray_postsubmit(task, key, deps, object_ref): + Run after submitting a Ray task. + + Args: + task (tuple): A Dask task, where the first tuple item is + the task function, and the remaining tuple items are + the task arguments (either the actual argument values, + or Dask keys into the deps dictionary whose + corresponding values are the argument values). + key (str): The Dask graph key for the given task. + deps (dict): The dependencies of this task. + object_ref (ray.ObjectRef): The object reference for the + return value of the Ray task. + + - def _ray_pretask(key, object_refs): + Run before executing a Dask task within a Ray task. This + executes after the task has been submitted, within a Ray + worker. The return value of this task will be passed to the + _ray_posttask callback, if provided. + + Args: + key (str): The Dask graph key for the Dask task. + object_refs (List[ray.ObjectRef]): The object references + for the arguments of the Ray task. + + Returns: + A value that will be passed to the corresponding + _ray_posttask callback, if said callback is defined. + + - def _ray_posttask(key, result, pre_state): + Run after executing a Dask task within a Ray task. This + executes within a Ray worker. This callback receives the return + value of the _ray_pretask callback, if provided. + + Args: + key (str): The Dask graph key for the Dask task. + result (object): The task result value. + pre_state (object): The return value of the corresponding + _ray_pretask callback, if said callback is defined. + + - def _ray_postsubmit_all(object_refs, dsk): + Run after all Ray tasks have been submitted. + + Args: + object_refs (List[ray.ObjectRef]): The object references + for the output (leaf) Ray tasks of the task graph. + dsk (dict): The Dask graph. + + - def _ray_finish(result): + Run after all Ray tasks have finished executing and the final + result has been returned. + + Args: + result (object): The final result (output) of the Dask + computation, before any repackaging is done by + Dask collection-specific post-compute callbacks. + """ + for cb in CBS: + cb_func = kwargs.pop(cb, None) + if cb_func is not None: + setattr(self, "_" + cb, cb_func) + + super().__init__(**kwargs) + + @property + def _ray_callback(self): + return RayCallback( + *[getattr(self, field, None) for field in CB_FIELDS]) + + def __enter__(self): + self._ray_cm = add_ray_callbacks(self) + self._ray_cm.__enter__() + super().__enter__() + return self + + def __exit__(self, *args): + super().__exit__(*args) + self._ray_cm.__exit__(*args) + + def register(self): + type(self).ray_active.add(self._ray_callback) + super().register() + + def unregister(self): + type(self).ray_active.remove(self._ray_callback) + super().unregister() + + +class add_ray_callbacks: + def __init__(self, *callbacks): + self.callbacks = [normalize_ray_callback(c) for c in callbacks] + RayDaskCallback.ray_active.update(self.callbacks) + + def __enter__(self): + return self + + def __exit__(self, *args): + for c in self.callbacks: + RayDaskCallback.ray_active.discard(c) + + +def normalize_ray_callback(cb): + if isinstance(cb, RayDaskCallback): + return cb._ray_callback + elif isinstance(cb, RayCallback): + return cb + else: + raise TypeError( + "Callbacks must be either 'RayDaskCallback' or 'RayCallback' " + "namedtuple") + + +def unpack_ray_callbacks(cbs): + """Take an iterable of callbacks, return a list of each callback.""" + if cbs: + # Only drop callback methods that aren't in CBS_DONT_DROP. + return RayCallbacks(*( + [cb for cb in cbs_ if cb or CBS[idx] in CBS_DONT_DROP] or None + for idx, cbs_ in enumerate(zip(*cbs)))) + else: + return RayCallbacks(*([()] * len(CBS))) + + +@contextlib.contextmanager +def local_ray_callbacks(callbacks=None): + """ + Allows Dask-Ray callbacks to work with nested schedulers. + + Callbacks will only be used by the first started scheduler they encounter. + This means that only the outermost scheduler will use global callbacks. + """ + global_callbacks = callbacks is None + if global_callbacks: + callbacks, RayDaskCallback.ray_active = (RayDaskCallback.ray_active, + set()) + try: + yield callbacks or () + finally: + if global_callbacks: + RayDaskCallback.ray_active = callbacks diff --git a/python/ray/util/dask/scheduler.py b/python/ray/util/dask/scheduler.py index 6d948cf25..aeded7bca 100644 --- a/python/ray/util/dask/scheduler.py +++ b/python/ray/util/dask/scheduler.py @@ -10,6 +10,7 @@ from dask.local import get_async, apply_sync from dask.system import CPU_COUNT from dask.threaded import pack_exception, _thread_get_id +from .callbacks import local_ray_callbacks, unpack_ray_callbacks from .common import unpack_object_refs main_thread = threading.current_thread() @@ -30,12 +31,14 @@ def ray_dask_get(dsk, keys, **kwargs): >>> dask.compute(obj, scheduler=ray_dask_get) - You can override the number of threads to use when submitting the - Ray tasks, or the threadpool used to submit Ray tasks: + You can override the currently active global Dask-Ray callbacks (e.g. + supplied via a context manager), the number of threads to use when + submitting the Ray tasks, or the threadpool used to submit Ray tasks: >>> dask.compute( obj, scheduler=ray_dask_get, + ray_callbacks=some_ray_dask_callbacks, num_workers=8, pool=some_cool_pool, ) @@ -44,6 +47,7 @@ def ray_dask_get(dsk, keys, **kwargs): dsk (Dict): Dask graph, represented as a task DAG dictionary. keys (List[str]): List of Dask graph keys whose values we wish to compute and return. + ray_callbacks (Optional[list[callable]]): Dask-Ray callbacks. num_workers (Optional[int]): The number of worker threads to use in the Ray task submission traversal of the Dask graph. pool (Optional[ThreadPool]): A multiprocessing threadpool to use to @@ -73,23 +77,48 @@ def ray_dask_get(dsk, keys, **kwargs): atexit.register(pool.close) pools[thread][num_workers] = pool - # NOTE: We hijack Dask's `get_async` function, injecting a different task - # executor. - object_refs = get_async( - _apply_async_wrapper(pool.apply_async, _rayify_task_wrapper), - len(pool._pool), - dsk, - keys, - get_id=_thread_get_id, - pack_exception=pack_exception, - **kwargs, - ) - # NOTE: We explicitly delete the Dask graph here so object references - # are garbage-collected before this function returns, i.e. before all Ray - # tasks are done. Otherwise, no intermediate objects will be cleaned up - # until all Ray tasks are done. - del dsk - result = ray_get_unpack(object_refs) + ray_callbacks = kwargs.pop("ray_callbacks", None) + + with local_ray_callbacks(ray_callbacks) as ray_callbacks: + # Unpack the Ray-specific callbacks. + ( + ray_presubmit_cbs, + ray_postsubmit_cbs, + ray_pretask_cbs, + ray_posttask_cbs, + ray_postsubmit_all_cbs, + ray_finish_cbs, + ) = unpack_ray_callbacks(ray_callbacks) + # NOTE: We hijack Dask's `get_async` function, injecting a different + # task executor. + object_refs = get_async( + _apply_async_wrapper( + pool.apply_async, + _rayify_task_wrapper, + ray_presubmit_cbs, + ray_postsubmit_cbs, + ray_pretask_cbs, + ray_posttask_cbs, + ), + len(pool._pool), + dsk, + keys, + get_id=_thread_get_id, + pack_exception=pack_exception, + **kwargs, + ) + if ray_postsubmit_all_cbs is not None: + for cb in ray_postsubmit_all_cbs: + cb(object_refs, dsk) + # NOTE: We explicitly delete the Dask graph here so object references + # are garbage-collected before this function returns, i.e. before all + # Ray tasks are done. Otherwise, no intermediate objects will be + # cleaned up until all Ray tasks are done. + del dsk + result = ray_get_unpack(object_refs) + if ray_finish_cbs is not None: + for cb in ray_finish_cbs: + cb(result) # cleanup pools associated with dead threads. with pools_lock: @@ -138,6 +167,10 @@ def _rayify_task_wrapper( loads, get_id, pack_exception, + ray_presubmit_cbs, + ray_postsubmit_cbs, + ray_pretask_cbs, + ray_posttask_cbs, ): """ The core Ray-Dask task execution wrapper, to be given to the thread pool's @@ -152,6 +185,10 @@ def _rayify_task_wrapper( loads (callable): A task_info deserializing function. get_id (callable): An ID generating function. pack_exception (callable): An exception serializing function. + ray_presubmit_cbs (callable): Pre-task submission callbacks. + ray_postsubmit_cbs (callable): Post-task submission callbacks. + ray_pretask_cbs (callable): Pre-task execution callbacks. + ray_posttask_cbs (callable): Post-task execution callbacks. Returns: A 3-tuple of the task's key, a literal or a Ray object reference for a @@ -159,7 +196,15 @@ def _rayify_task_wrapper( """ try: task, deps = loads(task_info) - result = _rayify_task(task, key, deps) + result = _rayify_task( + task, + key, + deps, + ray_presubmit_cbs, + ray_postsubmit_cbs, + ray_pretask_cbs, + ray_posttask_cbs, + ) id = get_id() result = dumps((result, id)) failed = False @@ -169,15 +214,27 @@ def _rayify_task_wrapper( return key, result, failed -def _rayify_task(task, key, deps): +def _rayify_task( + task, + key, + deps, + ray_presubmit_cbs, + ray_postsubmit_cbs, + ray_pretask_cbs, + ray_posttask_cbs, +): """ Rayifies the given task, submitting it as a Ray task to the Ray cluster. Args: - task: A Dask graph value, being either a literal, dependency key, Dask - task, or a list thereof. - key: The Dask graph key for the given task. - deps: The dependencies of this task. + task (tuple): A Dask graph value, being either a literal, dependency + key, Dask task, or a list thereof. + key (str): The Dask graph key for the given task. + deps (dict): The dependencies of this task. + ray_presubmit_cbs (callable): Pre-task submission callbacks. + ray_postsubmit_cbs (callable): Post-task submission callbacks. + ray_pretask_cbs (callable): Pre-task execution callbacks. + ray_posttask_cbs (callable): Post-task execution callbacks. Returns: A literal, a Ray object reference representing a submitted task, or a @@ -186,18 +243,46 @@ def _rayify_task(task, key, deps): if isinstance(task, list): # Recursively rayify this list. This will still bottom out at the first # actual task encountered, inlining any tasks in that task's arguments. - return [_rayify_task(t, deps) for t in task] + return [ + _rayify_task( + t, + key, + deps, + ray_presubmit_cbs, + ray_postsubmit_cbs, + ray_pretask_cbs, + ray_posttask_cbs, + ) for t in task + ] elif istask(task): # Unpacks and repacks Ray object references and submits the task to the # Ray cluster for execution. + if ray_presubmit_cbs is not None: + alternate_returns = [ + cb(task, key, deps) for cb in ray_presubmit_cbs + ] + for alternate_return in alternate_returns: + # We don't submit a Ray task if a presubmit callback returns + # a non-`None` value, instead we return said value. + # NOTE: This returns the first non-None presubmit callback + # return value. + if alternate_return is not None: + return alternate_return + func, args = task[0], task[1:] # If the function's arguments contain nested object references, we must # unpack said object references into a flat set of arguments so that # Ray properly tracks the object dependencies between Ray tasks. object_refs, repack = unpack_object_refs(args, deps) # Submit the task using a wrapper function. - return dask_task_wrapper.options(name=f"dask:{key!s}").remote( - func, repack, *object_refs) + object_ref = dask_task_wrapper.options(name=f"dask:{key!s}").remote( + func, repack, key, ray_pretask_cbs, ray_posttask_cbs, *object_refs) + + if ray_postsubmit_cbs is not None: + for cb in ray_postsubmit_cbs: + cb(task, key, deps, object_ref) + + return object_ref elif not ishashable(task): return task elif task in deps: @@ -207,7 +292,8 @@ def _rayify_task(task, key, deps): @ray.remote -def dask_task_wrapper(func, repack, *args): +def dask_task_wrapper(func, repack, key, ray_pretask_cbs, ray_posttask_cbs, + *args): """ A Ray remote function acting as a Dask task wrapper. This function will repackage the given flat `args` into its original data structures using @@ -219,6 +305,9 @@ def dask_task_wrapper(func, repack, *args): func (callable): The Dask task function to execute. repack (callable): A function that repackages the provided args into the original (possibly nested) Python objects. + key (str): The Dask key for this task. + ray_pretask_cbs (callable): Pre-task execution callbacks. + ray_posttask_cbs (callable): Post-task execution callback. *args (ObjectRef): Ray object references representing the Dask task's arguments. @@ -227,11 +316,21 @@ def dask_task_wrapper(func, repack, *args): dask_task_wrapper.remote() invocation will return a Ray object reference representing the Ray task's result. """ + if ray_pretask_cbs is not None: + pre_states = [ + cb(key, args) if cb is not None else None for cb in ray_pretask_cbs + ] repacked_args, repacked_deps = repack(args) # Recursively execute Dask-inlined tasks. actual_args = [_execute_task(a, repacked_deps) for a in repacked_args] # Execute the actual underlying Dask task. - return func(*actual_args) + result = func(*actual_args) + if ray_posttask_cbs is not None: + for cb, pre_state in zip(ray_posttask_cbs, pre_states): + if cb is not None: + cb(key, result, pre_state) + + return result def ray_get_unpack(object_refs): @@ -273,6 +372,15 @@ def ray_dask_get_sync(dsk, keys, **kwargs): >>> dask.compute(obj, scheduler=ray_dask_get_sync) + You can override the currently active global Dask-Ray callbacks (e.g. + supplied via a context manager): + + >>> dask.compute( + obj, + scheduler=ray_dask_get_sync, + ray_callbacks=some_ray_dask_callbacks, + ) + Args: dsk (Dict): Dask graph, represented as a task DAG dictionary. keys (List[str]): List of Dask graph keys whose values we wish to @@ -281,18 +389,46 @@ def ray_dask_get_sync(dsk, keys, **kwargs): Returns: Computed values corresponding to the provided keys. """ - # NOTE: We hijack Dask's `get_async` function, injecting a different task - # executor. - object_refs = get_async( - _apply_async_wrapper(apply_sync, _rayify_task_wrapper), - 1, - dsk, - keys, - **kwargs, - ) - # NOTE: We explicitly delete the Dask graph here so object references - # are garbage-collected before this function returns, i.e. before all Ray - # tasks are done. Otherwise, no intermediate objects will be cleaned up - # until all Ray tasks are done. - del dsk - return ray_get_unpack(object_refs) + + ray_callbacks = kwargs.pop("ray_callbacks", None) + + with local_ray_callbacks(ray_callbacks) as ray_callbacks: + # Unpack the Ray-specific callbacks. + ( + ray_presubmit_cbs, + ray_postsubmit_cbs, + ray_pretask_cbs, + ray_posttask_cbs, + ray_postsubmit_all_cbs, + ray_finish_cbs, + ) = unpack_ray_callbacks(ray_callbacks) + # NOTE: We hijack Dask's `get_async` function, injecting a different + # task executor. + object_refs = get_async( + _apply_async_wrapper( + apply_sync, + _rayify_task_wrapper, + ray_presubmit_cbs, + ray_postsubmit_cbs, + ray_pretask_cbs, + ray_posttask_cbs, + ), + 1, + dsk, + keys, + **kwargs, + ) + if ray_postsubmit_all_cbs is not None: + for cb in ray_postsubmit_all_cbs: + cb(object_refs, dsk) + # NOTE: We explicitly delete the Dask graph here so object references + # are garbage-collected before this function returns, i.e. before all + # Ray tasks are done. Otherwise, no intermediate objects will be + # cleaned up until all Ray tasks are done. + del dsk + result = ray_get_unpack(object_refs) + if ray_finish_cbs is not None: + for cb in ray_finish_cbs: + cb(result) + + return result