mirror of
https://github.com/wassname/ray.git
synced 2026-06-28 20:40:09 +08:00
Add Dask-Ray scheduler callbacks. (#10519)
Improve Dask-on-Ray documentation. Move to RayCallback(s) namedtuples, and use top-level CBS tuple as source-of-truth for callback methods.
This commit is contained in:
@@ -79,6 +79,7 @@ py_test_module_list(
|
||||
"test_command_runner.py",
|
||||
"test_coordinator_server.py",
|
||||
"test_dask_scheduler.py",
|
||||
"test_dask_callback.py",
|
||||
"test_debug_tools.py",
|
||||
"test_global_state.py",
|
||||
"test_job.py",
|
||||
|
||||
@@ -0,0 +1,234 @@
|
||||
import dask
|
||||
import pytest
|
||||
|
||||
import ray
|
||||
from ray.util.dask import ray_dask_get, RayDaskCallback
|
||||
|
||||
|
||||
@dask.delayed
|
||||
def add(x, y):
|
||||
return x + y
|
||||
|
||||
|
||||
def test_callback_active():
|
||||
"""Test that callbacks are active within context"""
|
||||
assert not RayDaskCallback.ray_active
|
||||
|
||||
with RayDaskCallback():
|
||||
assert RayDaskCallback.ray_active
|
||||
|
||||
assert not RayDaskCallback.ray_active
|
||||
|
||||
|
||||
def test_presubmit_shortcircuit(ray_start_regular_shared):
|
||||
"""
|
||||
Test that presubmit return short-circuits task submission, and that task's
|
||||
result is set to the presubmit return value.
|
||||
"""
|
||||
|
||||
class PresubmitShortcircuitCallback(RayDaskCallback):
|
||||
def _ray_presubmit(self, task, key, deps):
|
||||
return 0
|
||||
|
||||
def _ray_postsubmit(self, task, key, deps, object_ref):
|
||||
pytest.fail("_ray_postsubmit shouldn't be called when "
|
||||
"_ray_presubmit returns a value")
|
||||
|
||||
with PresubmitShortcircuitCallback():
|
||||
z = add(2, 3)
|
||||
result = z.compute(scheduler=ray_dask_get)
|
||||
|
||||
assert result == 0
|
||||
|
||||
|
||||
def test_pretask_posttask_shared_state(ray_start_regular_shared):
|
||||
"""
|
||||
Test that pretask return value is passed to corresponding posttask
|
||||
callback.
|
||||
"""
|
||||
|
||||
class PretaskPosttaskCallback(RayDaskCallback):
|
||||
def _ray_pretask(self, key, object_refs):
|
||||
return key
|
||||
|
||||
def _ray_posttask(self, key, result, pre_state):
|
||||
assert pre_state == key
|
||||
|
||||
with PretaskPosttaskCallback():
|
||||
z = add(2, 3)
|
||||
result = z.compute(scheduler=ray_dask_get)
|
||||
|
||||
assert result == 5
|
||||
|
||||
|
||||
def test_postsubmit(ray_start_regular_shared):
|
||||
"""
|
||||
Test that postsubmit is called after each task.
|
||||
"""
|
||||
|
||||
class PostsubmitCallback(RayDaskCallback):
|
||||
def __init__(self, postsubmit_actor):
|
||||
self.postsubmit_actor = postsubmit_actor
|
||||
|
||||
def _ray_postsubmit(self, task, key, deps, object_ref):
|
||||
self.postsubmit_actor.postsubmit.remote(task, key, deps,
|
||||
object_ref)
|
||||
|
||||
@ray.remote
|
||||
class PostsubmitActor:
|
||||
def __init__(self):
|
||||
self.postsubmit_counter = 0
|
||||
|
||||
def postsubmit(self, task, key, deps, object_ref):
|
||||
self.postsubmit_counter += 1
|
||||
|
||||
def get_postsubmit_counter(self):
|
||||
return self.postsubmit_counter
|
||||
|
||||
postsubmit_actor = PostsubmitActor.remote()
|
||||
with PostsubmitCallback(postsubmit_actor):
|
||||
z = add(2, 3)
|
||||
result = z.compute(scheduler=ray_dask_get)
|
||||
|
||||
assert ray.get(postsubmit_actor.get_postsubmit_counter.remote()) == 1
|
||||
assert result == 5
|
||||
|
||||
|
||||
def test_postsubmit_all(ray_start_regular_shared):
|
||||
"""
|
||||
Test that postsubmit_all is called once.
|
||||
"""
|
||||
|
||||
class PostsubmitAllCallback(RayDaskCallback):
|
||||
def __init__(self, postsubmit_all_actor):
|
||||
self.postsubmit_all_actor = postsubmit_all_actor
|
||||
|
||||
def _ray_postsubmit_all(self, object_refs, dsk):
|
||||
self.postsubmit_all_actor.postsubmit_all.remote(object_refs, dsk)
|
||||
|
||||
@ray.remote
|
||||
class PostsubmitAllActor:
|
||||
def __init__(self):
|
||||
self.postsubmit_all_called = False
|
||||
|
||||
def postsubmit_all(self, object_refs, dsk):
|
||||
self.postsubmit_all_called = True
|
||||
|
||||
def get_postsubmit_all_called(self):
|
||||
return self.postsubmit_all_called
|
||||
|
||||
postsubmit_all_actor = PostsubmitAllActor.remote()
|
||||
with PostsubmitAllCallback(postsubmit_all_actor):
|
||||
z = add(2, 3)
|
||||
result = z.compute(scheduler=ray_dask_get)
|
||||
|
||||
assert ray.get(postsubmit_all_actor.get_postsubmit_all_called.remote())
|
||||
assert result == 5
|
||||
|
||||
|
||||
def test_finish(ray_start_regular_shared):
|
||||
"""
|
||||
Test that finish callback is called once.
|
||||
"""
|
||||
|
||||
class FinishCallback(RayDaskCallback):
|
||||
def __init__(self, finish_actor):
|
||||
self.finish_actor = finish_actor
|
||||
|
||||
def _ray_finish(self, result):
|
||||
self.finish_actor.finish.remote(result)
|
||||
|
||||
@ray.remote
|
||||
class FinishActor:
|
||||
def __init__(self):
|
||||
self.finish_called = False
|
||||
|
||||
def finish(self, result):
|
||||
self.finish_called = True
|
||||
|
||||
def get_finish_called(self):
|
||||
return self.finish_called
|
||||
|
||||
finish_actor = FinishActor.remote()
|
||||
with FinishCallback(finish_actor):
|
||||
z = add(2, 3)
|
||||
result = z.compute(scheduler=ray_dask_get)
|
||||
|
||||
assert ray.get(finish_actor.get_finish_called.remote())
|
||||
assert result == 5
|
||||
|
||||
|
||||
def test_multiple_callbacks(ray_start_regular_shared):
|
||||
"""
|
||||
Test that multiple callbacks are supported.
|
||||
"""
|
||||
|
||||
class PostsubmitCallback(RayDaskCallback):
|
||||
def __init__(self, postsubmit_actor):
|
||||
self.postsubmit_actor = postsubmit_actor
|
||||
|
||||
def _ray_postsubmit(self, task, key, deps, object_ref):
|
||||
self.postsubmit_actor.postsubmit.remote(task, key, deps,
|
||||
object_ref)
|
||||
|
||||
@ray.remote
|
||||
class PostsubmitActor:
|
||||
def __init__(self):
|
||||
self.postsubmit_counter = 0
|
||||
|
||||
def postsubmit(self, task, key, deps, object_ref):
|
||||
self.postsubmit_counter += 1
|
||||
|
||||
def get_postsubmit_counter(self):
|
||||
return self.postsubmit_counter
|
||||
|
||||
postsubmit_actor = PostsubmitActor.remote()
|
||||
cb1 = PostsubmitCallback(postsubmit_actor)
|
||||
cb2 = PostsubmitCallback(postsubmit_actor)
|
||||
cb3 = PostsubmitCallback(postsubmit_actor)
|
||||
with cb1, cb2, cb3:
|
||||
z = add(2, 3)
|
||||
result = z.compute(scheduler=ray_dask_get)
|
||||
|
||||
assert ray.get(postsubmit_actor.get_postsubmit_counter.remote()) == 3
|
||||
assert result == 5
|
||||
|
||||
|
||||
def test_pretask_posttask_shared_state_multi(ray_start_regular_shared):
|
||||
"""
|
||||
Test that pretask return values are passed to the correct corresponding
|
||||
posttask callbacks when multiple callbacks are given.
|
||||
"""
|
||||
|
||||
class PretaskPosttaskCallback(RayDaskCallback):
|
||||
def __init__(self, suffix):
|
||||
self.suffix = suffix
|
||||
|
||||
def _ray_pretask(self, key, object_refs):
|
||||
return key + self.suffix
|
||||
|
||||
def _ray_posttask(self, key, result, pre_state):
|
||||
assert pre_state == key + self.suffix
|
||||
|
||||
class PretaskOnlyCallback(RayDaskCallback):
|
||||
def _ray_pretask(self, key, object_refs):
|
||||
return "baz"
|
||||
|
||||
class PosttaskOnlyCallback(RayDaskCallback):
|
||||
def _ray_posttask(self, key, result, pre_state):
|
||||
assert pre_state is None
|
||||
|
||||
cb1 = PretaskPosttaskCallback("foo")
|
||||
cb2 = PretaskOnlyCallback()
|
||||
cb3 = PosttaskOnlyCallback()
|
||||
cb4 = PretaskPosttaskCallback("bar")
|
||||
with cb1, cb2, cb3, cb4:
|
||||
z = add(2, 3)
|
||||
result = z.compute(scheduler=ray_dask_get)
|
||||
|
||||
assert result == 5
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import sys
|
||||
sys.exit(pytest.main(["-v", __file__]))
|
||||
@@ -1,3 +1,14 @@
|
||||
from .scheduler import ray_dask_get, ray_dask_get_sync
|
||||
from .callbacks import (
|
||||
RayDaskCallback,
|
||||
local_ray_callbacks,
|
||||
unpack_ray_callbacks,
|
||||
)
|
||||
|
||||
__all__ = ["ray_dask_get", "ray_dask_get_sync"]
|
||||
__all__ = [
|
||||
"ray_dask_get",
|
||||
"ray_dask_get_sync",
|
||||
"RayDaskCallback",
|
||||
"local_ray_callbacks",
|
||||
"unpack_ray_callbacks",
|
||||
]
|
||||
|
||||
@@ -0,0 +1,211 @@
|
||||
import contextlib
|
||||
from collections import namedtuple
|
||||
|
||||
from dask.callbacks import Callback
|
||||
|
||||
# The names of the Ray-specific callbacks. These are the kwarg names that
|
||||
# RayDaskCallback will accept on construction, and is considered the
|
||||
# source-of-truth for what Ray-specific callbacks exist.
|
||||
CBS = (
|
||||
"ray_presubmit",
|
||||
"ray_postsubmit",
|
||||
"ray_pretask",
|
||||
"ray_posttask",
|
||||
"ray_postsubmit_all",
|
||||
"ray_finish",
|
||||
)
|
||||
# The Ray-specific callback method names for RayDaskCallback.
|
||||
CB_FIELDS = tuple("_" + field for field in CBS)
|
||||
# The Ray-specific callbacks that we do _not_ wish to drop from RayCallbacks
|
||||
# if not given on a RayDaskCallback instance (will be filled with None
|
||||
# instead).
|
||||
CBS_DONT_DROP = {"ray_pretask", "ray_posttask"}
|
||||
|
||||
# The Ray-specific callbacks for a single RayDaskCallback.
|
||||
RayCallback = namedtuple("RayCallback", " ".join(CBS))
|
||||
|
||||
# The Ray-specific callbacks for one or more RayDaskCallbacks.
|
||||
RayCallbacks = namedtuple("RayCallbacks",
|
||||
" ".join([field + "_cbs" for field in CBS]))
|
||||
|
||||
|
||||
class RayDaskCallback(Callback):
|
||||
"""
|
||||
Extends Dask's `Callback` class with Ray-specific hooks. When instantiating
|
||||
or subclassing this class, both the normal Dask hooks (e.g. pretask,
|
||||
posttask, etc.) and the Ray-specific hooks can be provided.
|
||||
|
||||
See `dask.callbacks.Callback` for usage.
|
||||
|
||||
Caveats: Any Dask-Ray scheduler must bring the Ray-specific callbacks into
|
||||
context using the `local_ray_callbacks` context manager, since the built-in
|
||||
`local_callbacks` context manager provided by Dask isn't aware of this
|
||||
class.
|
||||
"""
|
||||
|
||||
# Set of active Ray-specific callbacks.
|
||||
ray_active = set()
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
"""
|
||||
Ray-specific callbacks:
|
||||
- def _ray_presubmit(task, key, deps):
|
||||
Run before submitting a Ray task. If this callback returns a
|
||||
non-`None` value, a Ray task will _not_ be created and this
|
||||
value will be used as the would-be task's result value.
|
||||
|
||||
Args:
|
||||
task (tuple): A Dask task, where the first tuple item is
|
||||
the task function, and the remaining tuple items are
|
||||
the task arguments (either the actual argument values,
|
||||
or Dask keys into the deps dictionary whose
|
||||
corresponding values are the argument values).
|
||||
key (str): The Dask graph key for the given task.
|
||||
deps (dict): The dependencies of this task.
|
||||
|
||||
Returns:
|
||||
Either None, in which case a Ray task will be submitted, or
|
||||
a non-None value, in which case a Ray task will not be
|
||||
submitted and this return value will be used as the
|
||||
would-be task result value.
|
||||
|
||||
- def _ray_postsubmit(task, key, deps, object_ref):
|
||||
Run after submitting a Ray task.
|
||||
|
||||
Args:
|
||||
task (tuple): A Dask task, where the first tuple item is
|
||||
the task function, and the remaining tuple items are
|
||||
the task arguments (either the actual argument values,
|
||||
or Dask keys into the deps dictionary whose
|
||||
corresponding values are the argument values).
|
||||
key (str): The Dask graph key for the given task.
|
||||
deps (dict): The dependencies of this task.
|
||||
object_ref (ray.ObjectRef): The object reference for the
|
||||
return value of the Ray task.
|
||||
|
||||
- def _ray_pretask(key, object_refs):
|
||||
Run before executing a Dask task within a Ray task. This
|
||||
executes after the task has been submitted, within a Ray
|
||||
worker. The return value of this task will be passed to the
|
||||
_ray_posttask callback, if provided.
|
||||
|
||||
Args:
|
||||
key (str): The Dask graph key for the Dask task.
|
||||
object_refs (List[ray.ObjectRef]): The object references
|
||||
for the arguments of the Ray task.
|
||||
|
||||
Returns:
|
||||
A value that will be passed to the corresponding
|
||||
_ray_posttask callback, if said callback is defined.
|
||||
|
||||
- def _ray_posttask(key, result, pre_state):
|
||||
Run after executing a Dask task within a Ray task. This
|
||||
executes within a Ray worker. This callback receives the return
|
||||
value of the _ray_pretask callback, if provided.
|
||||
|
||||
Args:
|
||||
key (str): The Dask graph key for the Dask task.
|
||||
result (object): The task result value.
|
||||
pre_state (object): The return value of the corresponding
|
||||
_ray_pretask callback, if said callback is defined.
|
||||
|
||||
- def _ray_postsubmit_all(object_refs, dsk):
|
||||
Run after all Ray tasks have been submitted.
|
||||
|
||||
Args:
|
||||
object_refs (List[ray.ObjectRef]): The object references
|
||||
for the output (leaf) Ray tasks of the task graph.
|
||||
dsk (dict): The Dask graph.
|
||||
|
||||
- def _ray_finish(result):
|
||||
Run after all Ray tasks have finished executing and the final
|
||||
result has been returned.
|
||||
|
||||
Args:
|
||||
result (object): The final result (output) of the Dask
|
||||
computation, before any repackaging is done by
|
||||
Dask collection-specific post-compute callbacks.
|
||||
"""
|
||||
for cb in CBS:
|
||||
cb_func = kwargs.pop(cb, None)
|
||||
if cb_func is not None:
|
||||
setattr(self, "_" + cb, cb_func)
|
||||
|
||||
super().__init__(**kwargs)
|
||||
|
||||
@property
|
||||
def _ray_callback(self):
|
||||
return RayCallback(
|
||||
*[getattr(self, field, None) for field in CB_FIELDS])
|
||||
|
||||
def __enter__(self):
|
||||
self._ray_cm = add_ray_callbacks(self)
|
||||
self._ray_cm.__enter__()
|
||||
super().__enter__()
|
||||
return self
|
||||
|
||||
def __exit__(self, *args):
|
||||
super().__exit__(*args)
|
||||
self._ray_cm.__exit__(*args)
|
||||
|
||||
def register(self):
|
||||
type(self).ray_active.add(self._ray_callback)
|
||||
super().register()
|
||||
|
||||
def unregister(self):
|
||||
type(self).ray_active.remove(self._ray_callback)
|
||||
super().unregister()
|
||||
|
||||
|
||||
class add_ray_callbacks:
|
||||
def __init__(self, *callbacks):
|
||||
self.callbacks = [normalize_ray_callback(c) for c in callbacks]
|
||||
RayDaskCallback.ray_active.update(self.callbacks)
|
||||
|
||||
def __enter__(self):
|
||||
return self
|
||||
|
||||
def __exit__(self, *args):
|
||||
for c in self.callbacks:
|
||||
RayDaskCallback.ray_active.discard(c)
|
||||
|
||||
|
||||
def normalize_ray_callback(cb):
|
||||
if isinstance(cb, RayDaskCallback):
|
||||
return cb._ray_callback
|
||||
elif isinstance(cb, RayCallback):
|
||||
return cb
|
||||
else:
|
||||
raise TypeError(
|
||||
"Callbacks must be either 'RayDaskCallback' or 'RayCallback' "
|
||||
"namedtuple")
|
||||
|
||||
|
||||
def unpack_ray_callbacks(cbs):
|
||||
"""Take an iterable of callbacks, return a list of each callback."""
|
||||
if cbs:
|
||||
# Only drop callback methods that aren't in CBS_DONT_DROP.
|
||||
return RayCallbacks(*(
|
||||
[cb for cb in cbs_ if cb or CBS[idx] in CBS_DONT_DROP] or None
|
||||
for idx, cbs_ in enumerate(zip(*cbs))))
|
||||
else:
|
||||
return RayCallbacks(*([()] * len(CBS)))
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def local_ray_callbacks(callbacks=None):
|
||||
"""
|
||||
Allows Dask-Ray callbacks to work with nested schedulers.
|
||||
|
||||
Callbacks will only be used by the first started scheduler they encounter.
|
||||
This means that only the outermost scheduler will use global callbacks.
|
||||
"""
|
||||
global_callbacks = callbacks is None
|
||||
if global_callbacks:
|
||||
callbacks, RayDaskCallback.ray_active = (RayDaskCallback.ray_active,
|
||||
set())
|
||||
try:
|
||||
yield callbacks or ()
|
||||
finally:
|
||||
if global_callbacks:
|
||||
RayDaskCallback.ray_active = callbacks
|
||||
@@ -10,6 +10,7 @@ from dask.local import get_async, apply_sync
|
||||
from dask.system import CPU_COUNT
|
||||
from dask.threaded import pack_exception, _thread_get_id
|
||||
|
||||
from .callbacks import local_ray_callbacks, unpack_ray_callbacks
|
||||
from .common import unpack_object_refs
|
||||
|
||||
main_thread = threading.current_thread()
|
||||
@@ -30,12 +31,14 @@ def ray_dask_get(dsk, keys, **kwargs):
|
||||
|
||||
>>> dask.compute(obj, scheduler=ray_dask_get)
|
||||
|
||||
You can override the number of threads to use when submitting the
|
||||
Ray tasks, or the threadpool used to submit Ray tasks:
|
||||
You can override the currently active global Dask-Ray callbacks (e.g.
|
||||
supplied via a context manager), the number of threads to use when
|
||||
submitting the Ray tasks, or the threadpool used to submit Ray tasks:
|
||||
|
||||
>>> dask.compute(
|
||||
obj,
|
||||
scheduler=ray_dask_get,
|
||||
ray_callbacks=some_ray_dask_callbacks,
|
||||
num_workers=8,
|
||||
pool=some_cool_pool,
|
||||
)
|
||||
@@ -44,6 +47,7 @@ def ray_dask_get(dsk, keys, **kwargs):
|
||||
dsk (Dict): Dask graph, represented as a task DAG dictionary.
|
||||
keys (List[str]): List of Dask graph keys whose values we wish to
|
||||
compute and return.
|
||||
ray_callbacks (Optional[list[callable]]): Dask-Ray callbacks.
|
||||
num_workers (Optional[int]): The number of worker threads to use in
|
||||
the Ray task submission traversal of the Dask graph.
|
||||
pool (Optional[ThreadPool]): A multiprocessing threadpool to use to
|
||||
@@ -73,23 +77,48 @@ def ray_dask_get(dsk, keys, **kwargs):
|
||||
atexit.register(pool.close)
|
||||
pools[thread][num_workers] = pool
|
||||
|
||||
# NOTE: We hijack Dask's `get_async` function, injecting a different task
|
||||
# executor.
|
||||
object_refs = get_async(
|
||||
_apply_async_wrapper(pool.apply_async, _rayify_task_wrapper),
|
||||
len(pool._pool),
|
||||
dsk,
|
||||
keys,
|
||||
get_id=_thread_get_id,
|
||||
pack_exception=pack_exception,
|
||||
**kwargs,
|
||||
)
|
||||
# NOTE: We explicitly delete the Dask graph here so object references
|
||||
# are garbage-collected before this function returns, i.e. before all Ray
|
||||
# tasks are done. Otherwise, no intermediate objects will be cleaned up
|
||||
# until all Ray tasks are done.
|
||||
del dsk
|
||||
result = ray_get_unpack(object_refs)
|
||||
ray_callbacks = kwargs.pop("ray_callbacks", None)
|
||||
|
||||
with local_ray_callbacks(ray_callbacks) as ray_callbacks:
|
||||
# Unpack the Ray-specific callbacks.
|
||||
(
|
||||
ray_presubmit_cbs,
|
||||
ray_postsubmit_cbs,
|
||||
ray_pretask_cbs,
|
||||
ray_posttask_cbs,
|
||||
ray_postsubmit_all_cbs,
|
||||
ray_finish_cbs,
|
||||
) = unpack_ray_callbacks(ray_callbacks)
|
||||
# NOTE: We hijack Dask's `get_async` function, injecting a different
|
||||
# task executor.
|
||||
object_refs = get_async(
|
||||
_apply_async_wrapper(
|
||||
pool.apply_async,
|
||||
_rayify_task_wrapper,
|
||||
ray_presubmit_cbs,
|
||||
ray_postsubmit_cbs,
|
||||
ray_pretask_cbs,
|
||||
ray_posttask_cbs,
|
||||
),
|
||||
len(pool._pool),
|
||||
dsk,
|
||||
keys,
|
||||
get_id=_thread_get_id,
|
||||
pack_exception=pack_exception,
|
||||
**kwargs,
|
||||
)
|
||||
if ray_postsubmit_all_cbs is not None:
|
||||
for cb in ray_postsubmit_all_cbs:
|
||||
cb(object_refs, dsk)
|
||||
# NOTE: We explicitly delete the Dask graph here so object references
|
||||
# are garbage-collected before this function returns, i.e. before all
|
||||
# Ray tasks are done. Otherwise, no intermediate objects will be
|
||||
# cleaned up until all Ray tasks are done.
|
||||
del dsk
|
||||
result = ray_get_unpack(object_refs)
|
||||
if ray_finish_cbs is not None:
|
||||
for cb in ray_finish_cbs:
|
||||
cb(result)
|
||||
|
||||
# cleanup pools associated with dead threads.
|
||||
with pools_lock:
|
||||
@@ -138,6 +167,10 @@ def _rayify_task_wrapper(
|
||||
loads,
|
||||
get_id,
|
||||
pack_exception,
|
||||
ray_presubmit_cbs,
|
||||
ray_postsubmit_cbs,
|
||||
ray_pretask_cbs,
|
||||
ray_posttask_cbs,
|
||||
):
|
||||
"""
|
||||
The core Ray-Dask task execution wrapper, to be given to the thread pool's
|
||||
@@ -152,6 +185,10 @@ def _rayify_task_wrapper(
|
||||
loads (callable): A task_info deserializing function.
|
||||
get_id (callable): An ID generating function.
|
||||
pack_exception (callable): An exception serializing function.
|
||||
ray_presubmit_cbs (callable): Pre-task submission callbacks.
|
||||
ray_postsubmit_cbs (callable): Post-task submission callbacks.
|
||||
ray_pretask_cbs (callable): Pre-task execution callbacks.
|
||||
ray_posttask_cbs (callable): Post-task execution callbacks.
|
||||
|
||||
Returns:
|
||||
A 3-tuple of the task's key, a literal or a Ray object reference for a
|
||||
@@ -159,7 +196,15 @@ def _rayify_task_wrapper(
|
||||
"""
|
||||
try:
|
||||
task, deps = loads(task_info)
|
||||
result = _rayify_task(task, key, deps)
|
||||
result = _rayify_task(
|
||||
task,
|
||||
key,
|
||||
deps,
|
||||
ray_presubmit_cbs,
|
||||
ray_postsubmit_cbs,
|
||||
ray_pretask_cbs,
|
||||
ray_posttask_cbs,
|
||||
)
|
||||
id = get_id()
|
||||
result = dumps((result, id))
|
||||
failed = False
|
||||
@@ -169,15 +214,27 @@ def _rayify_task_wrapper(
|
||||
return key, result, failed
|
||||
|
||||
|
||||
def _rayify_task(task, key, deps):
|
||||
def _rayify_task(
|
||||
task,
|
||||
key,
|
||||
deps,
|
||||
ray_presubmit_cbs,
|
||||
ray_postsubmit_cbs,
|
||||
ray_pretask_cbs,
|
||||
ray_posttask_cbs,
|
||||
):
|
||||
"""
|
||||
Rayifies the given task, submitting it as a Ray task to the Ray cluster.
|
||||
|
||||
Args:
|
||||
task: A Dask graph value, being either a literal, dependency key, Dask
|
||||
task, or a list thereof.
|
||||
key: The Dask graph key for the given task.
|
||||
deps: The dependencies of this task.
|
||||
task (tuple): A Dask graph value, being either a literal, dependency
|
||||
key, Dask task, or a list thereof.
|
||||
key (str): The Dask graph key for the given task.
|
||||
deps (dict): The dependencies of this task.
|
||||
ray_presubmit_cbs (callable): Pre-task submission callbacks.
|
||||
ray_postsubmit_cbs (callable): Post-task submission callbacks.
|
||||
ray_pretask_cbs (callable): Pre-task execution callbacks.
|
||||
ray_posttask_cbs (callable): Post-task execution callbacks.
|
||||
|
||||
Returns:
|
||||
A literal, a Ray object reference representing a submitted task, or a
|
||||
@@ -186,18 +243,46 @@ def _rayify_task(task, key, deps):
|
||||
if isinstance(task, list):
|
||||
# Recursively rayify this list. This will still bottom out at the first
|
||||
# actual task encountered, inlining any tasks in that task's arguments.
|
||||
return [_rayify_task(t, deps) for t in task]
|
||||
return [
|
||||
_rayify_task(
|
||||
t,
|
||||
key,
|
||||
deps,
|
||||
ray_presubmit_cbs,
|
||||
ray_postsubmit_cbs,
|
||||
ray_pretask_cbs,
|
||||
ray_posttask_cbs,
|
||||
) for t in task
|
||||
]
|
||||
elif istask(task):
|
||||
# Unpacks and repacks Ray object references and submits the task to the
|
||||
# Ray cluster for execution.
|
||||
if ray_presubmit_cbs is not None:
|
||||
alternate_returns = [
|
||||
cb(task, key, deps) for cb in ray_presubmit_cbs
|
||||
]
|
||||
for alternate_return in alternate_returns:
|
||||
# We don't submit a Ray task if a presubmit callback returns
|
||||
# a non-`None` value, instead we return said value.
|
||||
# NOTE: This returns the first non-None presubmit callback
|
||||
# return value.
|
||||
if alternate_return is not None:
|
||||
return alternate_return
|
||||
|
||||
func, args = task[0], task[1:]
|
||||
# If the function's arguments contain nested object references, we must
|
||||
# unpack said object references into a flat set of arguments so that
|
||||
# Ray properly tracks the object dependencies between Ray tasks.
|
||||
object_refs, repack = unpack_object_refs(args, deps)
|
||||
# Submit the task using a wrapper function.
|
||||
return dask_task_wrapper.options(name=f"dask:{key!s}").remote(
|
||||
func, repack, *object_refs)
|
||||
object_ref = dask_task_wrapper.options(name=f"dask:{key!s}").remote(
|
||||
func, repack, key, ray_pretask_cbs, ray_posttask_cbs, *object_refs)
|
||||
|
||||
if ray_postsubmit_cbs is not None:
|
||||
for cb in ray_postsubmit_cbs:
|
||||
cb(task, key, deps, object_ref)
|
||||
|
||||
return object_ref
|
||||
elif not ishashable(task):
|
||||
return task
|
||||
elif task in deps:
|
||||
@@ -207,7 +292,8 @@ def _rayify_task(task, key, deps):
|
||||
|
||||
|
||||
@ray.remote
|
||||
def dask_task_wrapper(func, repack, *args):
|
||||
def dask_task_wrapper(func, repack, key, ray_pretask_cbs, ray_posttask_cbs,
|
||||
*args):
|
||||
"""
|
||||
A Ray remote function acting as a Dask task wrapper. This function will
|
||||
repackage the given flat `args` into its original data structures using
|
||||
@@ -219,6 +305,9 @@ def dask_task_wrapper(func, repack, *args):
|
||||
func (callable): The Dask task function to execute.
|
||||
repack (callable): A function that repackages the provided args into
|
||||
the original (possibly nested) Python objects.
|
||||
key (str): The Dask key for this task.
|
||||
ray_pretask_cbs (callable): Pre-task execution callbacks.
|
||||
ray_posttask_cbs (callable): Post-task execution callback.
|
||||
*args (ObjectRef): Ray object references representing the Dask task's
|
||||
arguments.
|
||||
|
||||
@@ -227,11 +316,21 @@ def dask_task_wrapper(func, repack, *args):
|
||||
dask_task_wrapper.remote() invocation will return a Ray object
|
||||
reference representing the Ray task's result.
|
||||
"""
|
||||
if ray_pretask_cbs is not None:
|
||||
pre_states = [
|
||||
cb(key, args) if cb is not None else None for cb in ray_pretask_cbs
|
||||
]
|
||||
repacked_args, repacked_deps = repack(args)
|
||||
# Recursively execute Dask-inlined tasks.
|
||||
actual_args = [_execute_task(a, repacked_deps) for a in repacked_args]
|
||||
# Execute the actual underlying Dask task.
|
||||
return func(*actual_args)
|
||||
result = func(*actual_args)
|
||||
if ray_posttask_cbs is not None:
|
||||
for cb, pre_state in zip(ray_posttask_cbs, pre_states):
|
||||
if cb is not None:
|
||||
cb(key, result, pre_state)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def ray_get_unpack(object_refs):
|
||||
@@ -273,6 +372,15 @@ def ray_dask_get_sync(dsk, keys, **kwargs):
|
||||
|
||||
>>> dask.compute(obj, scheduler=ray_dask_get_sync)
|
||||
|
||||
You can override the currently active global Dask-Ray callbacks (e.g.
|
||||
supplied via a context manager):
|
||||
|
||||
>>> dask.compute(
|
||||
obj,
|
||||
scheduler=ray_dask_get_sync,
|
||||
ray_callbacks=some_ray_dask_callbacks,
|
||||
)
|
||||
|
||||
Args:
|
||||
dsk (Dict): Dask graph, represented as a task DAG dictionary.
|
||||
keys (List[str]): List of Dask graph keys whose values we wish to
|
||||
@@ -281,18 +389,46 @@ def ray_dask_get_sync(dsk, keys, **kwargs):
|
||||
Returns:
|
||||
Computed values corresponding to the provided keys.
|
||||
"""
|
||||
# NOTE: We hijack Dask's `get_async` function, injecting a different task
|
||||
# executor.
|
||||
object_refs = get_async(
|
||||
_apply_async_wrapper(apply_sync, _rayify_task_wrapper),
|
||||
1,
|
||||
dsk,
|
||||
keys,
|
||||
**kwargs,
|
||||
)
|
||||
# NOTE: We explicitly delete the Dask graph here so object references
|
||||
# are garbage-collected before this function returns, i.e. before all Ray
|
||||
# tasks are done. Otherwise, no intermediate objects will be cleaned up
|
||||
# until all Ray tasks are done.
|
||||
del dsk
|
||||
return ray_get_unpack(object_refs)
|
||||
|
||||
ray_callbacks = kwargs.pop("ray_callbacks", None)
|
||||
|
||||
with local_ray_callbacks(ray_callbacks) as ray_callbacks:
|
||||
# Unpack the Ray-specific callbacks.
|
||||
(
|
||||
ray_presubmit_cbs,
|
||||
ray_postsubmit_cbs,
|
||||
ray_pretask_cbs,
|
||||
ray_posttask_cbs,
|
||||
ray_postsubmit_all_cbs,
|
||||
ray_finish_cbs,
|
||||
) = unpack_ray_callbacks(ray_callbacks)
|
||||
# NOTE: We hijack Dask's `get_async` function, injecting a different
|
||||
# task executor.
|
||||
object_refs = get_async(
|
||||
_apply_async_wrapper(
|
||||
apply_sync,
|
||||
_rayify_task_wrapper,
|
||||
ray_presubmit_cbs,
|
||||
ray_postsubmit_cbs,
|
||||
ray_pretask_cbs,
|
||||
ray_posttask_cbs,
|
||||
),
|
||||
1,
|
||||
dsk,
|
||||
keys,
|
||||
**kwargs,
|
||||
)
|
||||
if ray_postsubmit_all_cbs is not None:
|
||||
for cb in ray_postsubmit_all_cbs:
|
||||
cb(object_refs, dsk)
|
||||
# NOTE: We explicitly delete the Dask graph here so object references
|
||||
# are garbage-collected before this function returns, i.e. before all
|
||||
# Ray tasks are done. Otherwise, no intermediate objects will be
|
||||
# cleaned up until all Ray tasks are done.
|
||||
del dsk
|
||||
result = ray_get_unpack(object_refs)
|
||||
if ray_finish_cbs is not None:
|
||||
for cb in ray_finish_cbs:
|
||||
cb(result)
|
||||
|
||||
return result
|
||||
|
||||
Reference in New Issue
Block a user