Add Dask-Ray scheduler callbacks. (#10519)

Improve Dask-on-Ray documentation.

Move to RayCallback(s) namedtuples, and use top-level CBS tuple as source-of-truth for callback methods.
This commit is contained in:
Clark Zinzow
2020-09-08 14:00:58 -06:00
committed by GitHub
parent fdd3acd492
commit 36e1f20e9c
6 changed files with 837 additions and 73 deletions
+1
View File
@@ -79,6 +79,7 @@ py_test_module_list(
"test_command_runner.py",
"test_coordinator_server.py",
"test_dask_scheduler.py",
"test_dask_callback.py",
"test_debug_tools.py",
"test_global_state.py",
"test_job.py",
+234
View File
@@ -0,0 +1,234 @@
import dask
import pytest
import ray
from ray.util.dask import ray_dask_get, RayDaskCallback
@dask.delayed
def add(x, y):
return x + y
def test_callback_active():
"""Test that callbacks are active within context"""
assert not RayDaskCallback.ray_active
with RayDaskCallback():
assert RayDaskCallback.ray_active
assert not RayDaskCallback.ray_active
def test_presubmit_shortcircuit(ray_start_regular_shared):
"""
Test that presubmit return short-circuits task submission, and that task's
result is set to the presubmit return value.
"""
class PresubmitShortcircuitCallback(RayDaskCallback):
def _ray_presubmit(self, task, key, deps):
return 0
def _ray_postsubmit(self, task, key, deps, object_ref):
pytest.fail("_ray_postsubmit shouldn't be called when "
"_ray_presubmit returns a value")
with PresubmitShortcircuitCallback():
z = add(2, 3)
result = z.compute(scheduler=ray_dask_get)
assert result == 0
def test_pretask_posttask_shared_state(ray_start_regular_shared):
"""
Test that pretask return value is passed to corresponding posttask
callback.
"""
class PretaskPosttaskCallback(RayDaskCallback):
def _ray_pretask(self, key, object_refs):
return key
def _ray_posttask(self, key, result, pre_state):
assert pre_state == key
with PretaskPosttaskCallback():
z = add(2, 3)
result = z.compute(scheduler=ray_dask_get)
assert result == 5
def test_postsubmit(ray_start_regular_shared):
"""
Test that postsubmit is called after each task.
"""
class PostsubmitCallback(RayDaskCallback):
def __init__(self, postsubmit_actor):
self.postsubmit_actor = postsubmit_actor
def _ray_postsubmit(self, task, key, deps, object_ref):
self.postsubmit_actor.postsubmit.remote(task, key, deps,
object_ref)
@ray.remote
class PostsubmitActor:
def __init__(self):
self.postsubmit_counter = 0
def postsubmit(self, task, key, deps, object_ref):
self.postsubmit_counter += 1
def get_postsubmit_counter(self):
return self.postsubmit_counter
postsubmit_actor = PostsubmitActor.remote()
with PostsubmitCallback(postsubmit_actor):
z = add(2, 3)
result = z.compute(scheduler=ray_dask_get)
assert ray.get(postsubmit_actor.get_postsubmit_counter.remote()) == 1
assert result == 5
def test_postsubmit_all(ray_start_regular_shared):
"""
Test that postsubmit_all is called once.
"""
class PostsubmitAllCallback(RayDaskCallback):
def __init__(self, postsubmit_all_actor):
self.postsubmit_all_actor = postsubmit_all_actor
def _ray_postsubmit_all(self, object_refs, dsk):
self.postsubmit_all_actor.postsubmit_all.remote(object_refs, dsk)
@ray.remote
class PostsubmitAllActor:
def __init__(self):
self.postsubmit_all_called = False
def postsubmit_all(self, object_refs, dsk):
self.postsubmit_all_called = True
def get_postsubmit_all_called(self):
return self.postsubmit_all_called
postsubmit_all_actor = PostsubmitAllActor.remote()
with PostsubmitAllCallback(postsubmit_all_actor):
z = add(2, 3)
result = z.compute(scheduler=ray_dask_get)
assert ray.get(postsubmit_all_actor.get_postsubmit_all_called.remote())
assert result == 5
def test_finish(ray_start_regular_shared):
"""
Test that finish callback is called once.
"""
class FinishCallback(RayDaskCallback):
def __init__(self, finish_actor):
self.finish_actor = finish_actor
def _ray_finish(self, result):
self.finish_actor.finish.remote(result)
@ray.remote
class FinishActor:
def __init__(self):
self.finish_called = False
def finish(self, result):
self.finish_called = True
def get_finish_called(self):
return self.finish_called
finish_actor = FinishActor.remote()
with FinishCallback(finish_actor):
z = add(2, 3)
result = z.compute(scheduler=ray_dask_get)
assert ray.get(finish_actor.get_finish_called.remote())
assert result == 5
def test_multiple_callbacks(ray_start_regular_shared):
"""
Test that multiple callbacks are supported.
"""
class PostsubmitCallback(RayDaskCallback):
def __init__(self, postsubmit_actor):
self.postsubmit_actor = postsubmit_actor
def _ray_postsubmit(self, task, key, deps, object_ref):
self.postsubmit_actor.postsubmit.remote(task, key, deps,
object_ref)
@ray.remote
class PostsubmitActor:
def __init__(self):
self.postsubmit_counter = 0
def postsubmit(self, task, key, deps, object_ref):
self.postsubmit_counter += 1
def get_postsubmit_counter(self):
return self.postsubmit_counter
postsubmit_actor = PostsubmitActor.remote()
cb1 = PostsubmitCallback(postsubmit_actor)
cb2 = PostsubmitCallback(postsubmit_actor)
cb3 = PostsubmitCallback(postsubmit_actor)
with cb1, cb2, cb3:
z = add(2, 3)
result = z.compute(scheduler=ray_dask_get)
assert ray.get(postsubmit_actor.get_postsubmit_counter.remote()) == 3
assert result == 5
def test_pretask_posttask_shared_state_multi(ray_start_regular_shared):
"""
Test that pretask return values are passed to the correct corresponding
posttask callbacks when multiple callbacks are given.
"""
class PretaskPosttaskCallback(RayDaskCallback):
def __init__(self, suffix):
self.suffix = suffix
def _ray_pretask(self, key, object_refs):
return key + self.suffix
def _ray_posttask(self, key, result, pre_state):
assert pre_state == key + self.suffix
class PretaskOnlyCallback(RayDaskCallback):
def _ray_pretask(self, key, object_refs):
return "baz"
class PosttaskOnlyCallback(RayDaskCallback):
def _ray_posttask(self, key, result, pre_state):
assert pre_state is None
cb1 = PretaskPosttaskCallback("foo")
cb2 = PretaskOnlyCallback()
cb3 = PosttaskOnlyCallback()
cb4 = PretaskPosttaskCallback("bar")
with cb1, cb2, cb3, cb4:
z = add(2, 3)
result = z.compute(scheduler=ray_dask_get)
assert result == 5
if __name__ == "__main__":
import sys
sys.exit(pytest.main(["-v", __file__]))
+12 -1
View File
@@ -1,3 +1,14 @@
from .scheduler import ray_dask_get, ray_dask_get_sync
from .callbacks import (
RayDaskCallback,
local_ray_callbacks,
unpack_ray_callbacks,
)
__all__ = ["ray_dask_get", "ray_dask_get_sync"]
__all__ = [
"ray_dask_get",
"ray_dask_get_sync",
"RayDaskCallback",
"local_ray_callbacks",
"unpack_ray_callbacks",
]
+211
View File
@@ -0,0 +1,211 @@
import contextlib
from collections import namedtuple
from dask.callbacks import Callback
# The names of the Ray-specific callbacks. These are the kwarg names that
# RayDaskCallback will accept on construction, and is considered the
# source-of-truth for what Ray-specific callbacks exist.
CBS = (
"ray_presubmit",
"ray_postsubmit",
"ray_pretask",
"ray_posttask",
"ray_postsubmit_all",
"ray_finish",
)
# The Ray-specific callback method names for RayDaskCallback.
CB_FIELDS = tuple("_" + field for field in CBS)
# The Ray-specific callbacks that we do _not_ wish to drop from RayCallbacks
# if not given on a RayDaskCallback instance (will be filled with None
# instead).
CBS_DONT_DROP = {"ray_pretask", "ray_posttask"}
# The Ray-specific callbacks for a single RayDaskCallback.
RayCallback = namedtuple("RayCallback", " ".join(CBS))
# The Ray-specific callbacks for one or more RayDaskCallbacks.
RayCallbacks = namedtuple("RayCallbacks",
" ".join([field + "_cbs" for field in CBS]))
class RayDaskCallback(Callback):
"""
Extends Dask's `Callback` class with Ray-specific hooks. When instantiating
or subclassing this class, both the normal Dask hooks (e.g. pretask,
posttask, etc.) and the Ray-specific hooks can be provided.
See `dask.callbacks.Callback` for usage.
Caveats: Any Dask-Ray scheduler must bring the Ray-specific callbacks into
context using the `local_ray_callbacks` context manager, since the built-in
`local_callbacks` context manager provided by Dask isn't aware of this
class.
"""
# Set of active Ray-specific callbacks.
ray_active = set()
def __init__(self, **kwargs):
"""
Ray-specific callbacks:
- def _ray_presubmit(task, key, deps):
Run before submitting a Ray task. If this callback returns a
non-`None` value, a Ray task will _not_ be created and this
value will be used as the would-be task's result value.
Args:
task (tuple): A Dask task, where the first tuple item is
the task function, and the remaining tuple items are
the task arguments (either the actual argument values,
or Dask keys into the deps dictionary whose
corresponding values are the argument values).
key (str): The Dask graph key for the given task.
deps (dict): The dependencies of this task.
Returns:
Either None, in which case a Ray task will be submitted, or
a non-None value, in which case a Ray task will not be
submitted and this return value will be used as the
would-be task result value.
- def _ray_postsubmit(task, key, deps, object_ref):
Run after submitting a Ray task.
Args:
task (tuple): A Dask task, where the first tuple item is
the task function, and the remaining tuple items are
the task arguments (either the actual argument values,
or Dask keys into the deps dictionary whose
corresponding values are the argument values).
key (str): The Dask graph key for the given task.
deps (dict): The dependencies of this task.
object_ref (ray.ObjectRef): The object reference for the
return value of the Ray task.
- def _ray_pretask(key, object_refs):
Run before executing a Dask task within a Ray task. This
executes after the task has been submitted, within a Ray
worker. The return value of this task will be passed to the
_ray_posttask callback, if provided.
Args:
key (str): The Dask graph key for the Dask task.
object_refs (List[ray.ObjectRef]): The object references
for the arguments of the Ray task.
Returns:
A value that will be passed to the corresponding
_ray_posttask callback, if said callback is defined.
- def _ray_posttask(key, result, pre_state):
Run after executing a Dask task within a Ray task. This
executes within a Ray worker. This callback receives the return
value of the _ray_pretask callback, if provided.
Args:
key (str): The Dask graph key for the Dask task.
result (object): The task result value.
pre_state (object): The return value of the corresponding
_ray_pretask callback, if said callback is defined.
- def _ray_postsubmit_all(object_refs, dsk):
Run after all Ray tasks have been submitted.
Args:
object_refs (List[ray.ObjectRef]): The object references
for the output (leaf) Ray tasks of the task graph.
dsk (dict): The Dask graph.
- def _ray_finish(result):
Run after all Ray tasks have finished executing and the final
result has been returned.
Args:
result (object): The final result (output) of the Dask
computation, before any repackaging is done by
Dask collection-specific post-compute callbacks.
"""
for cb in CBS:
cb_func = kwargs.pop(cb, None)
if cb_func is not None:
setattr(self, "_" + cb, cb_func)
super().__init__(**kwargs)
@property
def _ray_callback(self):
return RayCallback(
*[getattr(self, field, None) for field in CB_FIELDS])
def __enter__(self):
self._ray_cm = add_ray_callbacks(self)
self._ray_cm.__enter__()
super().__enter__()
return self
def __exit__(self, *args):
super().__exit__(*args)
self._ray_cm.__exit__(*args)
def register(self):
type(self).ray_active.add(self._ray_callback)
super().register()
def unregister(self):
type(self).ray_active.remove(self._ray_callback)
super().unregister()
class add_ray_callbacks:
def __init__(self, *callbacks):
self.callbacks = [normalize_ray_callback(c) for c in callbacks]
RayDaskCallback.ray_active.update(self.callbacks)
def __enter__(self):
return self
def __exit__(self, *args):
for c in self.callbacks:
RayDaskCallback.ray_active.discard(c)
def normalize_ray_callback(cb):
if isinstance(cb, RayDaskCallback):
return cb._ray_callback
elif isinstance(cb, RayCallback):
return cb
else:
raise TypeError(
"Callbacks must be either 'RayDaskCallback' or 'RayCallback' "
"namedtuple")
def unpack_ray_callbacks(cbs):
"""Take an iterable of callbacks, return a list of each callback."""
if cbs:
# Only drop callback methods that aren't in CBS_DONT_DROP.
return RayCallbacks(*(
[cb for cb in cbs_ if cb or CBS[idx] in CBS_DONT_DROP] or None
for idx, cbs_ in enumerate(zip(*cbs))))
else:
return RayCallbacks(*([()] * len(CBS)))
@contextlib.contextmanager
def local_ray_callbacks(callbacks=None):
"""
Allows Dask-Ray callbacks to work with nested schedulers.
Callbacks will only be used by the first started scheduler they encounter.
This means that only the outermost scheduler will use global callbacks.
"""
global_callbacks = callbacks is None
if global_callbacks:
callbacks, RayDaskCallback.ray_active = (RayDaskCallback.ray_active,
set())
try:
yield callbacks or ()
finally:
if global_callbacks:
RayDaskCallback.ray_active = callbacks
+181 -45
View File
@@ -10,6 +10,7 @@ from dask.local import get_async, apply_sync
from dask.system import CPU_COUNT
from dask.threaded import pack_exception, _thread_get_id
from .callbacks import local_ray_callbacks, unpack_ray_callbacks
from .common import unpack_object_refs
main_thread = threading.current_thread()
@@ -30,12 +31,14 @@ def ray_dask_get(dsk, keys, **kwargs):
>>> dask.compute(obj, scheduler=ray_dask_get)
You can override the number of threads to use when submitting the
Ray tasks, or the threadpool used to submit Ray tasks:
You can override the currently active global Dask-Ray callbacks (e.g.
supplied via a context manager), the number of threads to use when
submitting the Ray tasks, or the threadpool used to submit Ray tasks:
>>> dask.compute(
obj,
scheduler=ray_dask_get,
ray_callbacks=some_ray_dask_callbacks,
num_workers=8,
pool=some_cool_pool,
)
@@ -44,6 +47,7 @@ def ray_dask_get(dsk, keys, **kwargs):
dsk (Dict): Dask graph, represented as a task DAG dictionary.
keys (List[str]): List of Dask graph keys whose values we wish to
compute and return.
ray_callbacks (Optional[list[callable]]): Dask-Ray callbacks.
num_workers (Optional[int]): The number of worker threads to use in
the Ray task submission traversal of the Dask graph.
pool (Optional[ThreadPool]): A multiprocessing threadpool to use to
@@ -73,23 +77,48 @@ def ray_dask_get(dsk, keys, **kwargs):
atexit.register(pool.close)
pools[thread][num_workers] = pool
# NOTE: We hijack Dask's `get_async` function, injecting a different task
# executor.
object_refs = get_async(
_apply_async_wrapper(pool.apply_async, _rayify_task_wrapper),
len(pool._pool),
dsk,
keys,
get_id=_thread_get_id,
pack_exception=pack_exception,
**kwargs,
)
# NOTE: We explicitly delete the Dask graph here so object references
# are garbage-collected before this function returns, i.e. before all Ray
# tasks are done. Otherwise, no intermediate objects will be cleaned up
# until all Ray tasks are done.
del dsk
result = ray_get_unpack(object_refs)
ray_callbacks = kwargs.pop("ray_callbacks", None)
with local_ray_callbacks(ray_callbacks) as ray_callbacks:
# Unpack the Ray-specific callbacks.
(
ray_presubmit_cbs,
ray_postsubmit_cbs,
ray_pretask_cbs,
ray_posttask_cbs,
ray_postsubmit_all_cbs,
ray_finish_cbs,
) = unpack_ray_callbacks(ray_callbacks)
# NOTE: We hijack Dask's `get_async` function, injecting a different
# task executor.
object_refs = get_async(
_apply_async_wrapper(
pool.apply_async,
_rayify_task_wrapper,
ray_presubmit_cbs,
ray_postsubmit_cbs,
ray_pretask_cbs,
ray_posttask_cbs,
),
len(pool._pool),
dsk,
keys,
get_id=_thread_get_id,
pack_exception=pack_exception,
**kwargs,
)
if ray_postsubmit_all_cbs is not None:
for cb in ray_postsubmit_all_cbs:
cb(object_refs, dsk)
# NOTE: We explicitly delete the Dask graph here so object references
# are garbage-collected before this function returns, i.e. before all
# Ray tasks are done. Otherwise, no intermediate objects will be
# cleaned up until all Ray tasks are done.
del dsk
result = ray_get_unpack(object_refs)
if ray_finish_cbs is not None:
for cb in ray_finish_cbs:
cb(result)
# cleanup pools associated with dead threads.
with pools_lock:
@@ -138,6 +167,10 @@ def _rayify_task_wrapper(
loads,
get_id,
pack_exception,
ray_presubmit_cbs,
ray_postsubmit_cbs,
ray_pretask_cbs,
ray_posttask_cbs,
):
"""
The core Ray-Dask task execution wrapper, to be given to the thread pool's
@@ -152,6 +185,10 @@ def _rayify_task_wrapper(
loads (callable): A task_info deserializing function.
get_id (callable): An ID generating function.
pack_exception (callable): An exception serializing function.
ray_presubmit_cbs (callable): Pre-task submission callbacks.
ray_postsubmit_cbs (callable): Post-task submission callbacks.
ray_pretask_cbs (callable): Pre-task execution callbacks.
ray_posttask_cbs (callable): Post-task execution callbacks.
Returns:
A 3-tuple of the task's key, a literal or a Ray object reference for a
@@ -159,7 +196,15 @@ def _rayify_task_wrapper(
"""
try:
task, deps = loads(task_info)
result = _rayify_task(task, key, deps)
result = _rayify_task(
task,
key,
deps,
ray_presubmit_cbs,
ray_postsubmit_cbs,
ray_pretask_cbs,
ray_posttask_cbs,
)
id = get_id()
result = dumps((result, id))
failed = False
@@ -169,15 +214,27 @@ def _rayify_task_wrapper(
return key, result, failed
def _rayify_task(task, key, deps):
def _rayify_task(
task,
key,
deps,
ray_presubmit_cbs,
ray_postsubmit_cbs,
ray_pretask_cbs,
ray_posttask_cbs,
):
"""
Rayifies the given task, submitting it as a Ray task to the Ray cluster.
Args:
task: A Dask graph value, being either a literal, dependency key, Dask
task, or a list thereof.
key: The Dask graph key for the given task.
deps: The dependencies of this task.
task (tuple): A Dask graph value, being either a literal, dependency
key, Dask task, or a list thereof.
key (str): The Dask graph key for the given task.
deps (dict): The dependencies of this task.
ray_presubmit_cbs (callable): Pre-task submission callbacks.
ray_postsubmit_cbs (callable): Post-task submission callbacks.
ray_pretask_cbs (callable): Pre-task execution callbacks.
ray_posttask_cbs (callable): Post-task execution callbacks.
Returns:
A literal, a Ray object reference representing a submitted task, or a
@@ -186,18 +243,46 @@ def _rayify_task(task, key, deps):
if isinstance(task, list):
# Recursively rayify this list. This will still bottom out at the first
# actual task encountered, inlining any tasks in that task's arguments.
return [_rayify_task(t, deps) for t in task]
return [
_rayify_task(
t,
key,
deps,
ray_presubmit_cbs,
ray_postsubmit_cbs,
ray_pretask_cbs,
ray_posttask_cbs,
) for t in task
]
elif istask(task):
# Unpacks and repacks Ray object references and submits the task to the
# Ray cluster for execution.
if ray_presubmit_cbs is not None:
alternate_returns = [
cb(task, key, deps) for cb in ray_presubmit_cbs
]
for alternate_return in alternate_returns:
# We don't submit a Ray task if a presubmit callback returns
# a non-`None` value, instead we return said value.
# NOTE: This returns the first non-None presubmit callback
# return value.
if alternate_return is not None:
return alternate_return
func, args = task[0], task[1:]
# If the function's arguments contain nested object references, we must
# unpack said object references into a flat set of arguments so that
# Ray properly tracks the object dependencies between Ray tasks.
object_refs, repack = unpack_object_refs(args, deps)
# Submit the task using a wrapper function.
return dask_task_wrapper.options(name=f"dask:{key!s}").remote(
func, repack, *object_refs)
object_ref = dask_task_wrapper.options(name=f"dask:{key!s}").remote(
func, repack, key, ray_pretask_cbs, ray_posttask_cbs, *object_refs)
if ray_postsubmit_cbs is not None:
for cb in ray_postsubmit_cbs:
cb(task, key, deps, object_ref)
return object_ref
elif not ishashable(task):
return task
elif task in deps:
@@ -207,7 +292,8 @@ def _rayify_task(task, key, deps):
@ray.remote
def dask_task_wrapper(func, repack, *args):
def dask_task_wrapper(func, repack, key, ray_pretask_cbs, ray_posttask_cbs,
*args):
"""
A Ray remote function acting as a Dask task wrapper. This function will
repackage the given flat `args` into its original data structures using
@@ -219,6 +305,9 @@ def dask_task_wrapper(func, repack, *args):
func (callable): The Dask task function to execute.
repack (callable): A function that repackages the provided args into
the original (possibly nested) Python objects.
key (str): The Dask key for this task.
ray_pretask_cbs (callable): Pre-task execution callbacks.
ray_posttask_cbs (callable): Post-task execution callback.
*args (ObjectRef): Ray object references representing the Dask task's
arguments.
@@ -227,11 +316,21 @@ def dask_task_wrapper(func, repack, *args):
dask_task_wrapper.remote() invocation will return a Ray object
reference representing the Ray task's result.
"""
if ray_pretask_cbs is not None:
pre_states = [
cb(key, args) if cb is not None else None for cb in ray_pretask_cbs
]
repacked_args, repacked_deps = repack(args)
# Recursively execute Dask-inlined tasks.
actual_args = [_execute_task(a, repacked_deps) for a in repacked_args]
# Execute the actual underlying Dask task.
return func(*actual_args)
result = func(*actual_args)
if ray_posttask_cbs is not None:
for cb, pre_state in zip(ray_posttask_cbs, pre_states):
if cb is not None:
cb(key, result, pre_state)
return result
def ray_get_unpack(object_refs):
@@ -273,6 +372,15 @@ def ray_dask_get_sync(dsk, keys, **kwargs):
>>> dask.compute(obj, scheduler=ray_dask_get_sync)
You can override the currently active global Dask-Ray callbacks (e.g.
supplied via a context manager):
>>> dask.compute(
obj,
scheduler=ray_dask_get_sync,
ray_callbacks=some_ray_dask_callbacks,
)
Args:
dsk (Dict): Dask graph, represented as a task DAG dictionary.
keys (List[str]): List of Dask graph keys whose values we wish to
@@ -281,18 +389,46 @@ def ray_dask_get_sync(dsk, keys, **kwargs):
Returns:
Computed values corresponding to the provided keys.
"""
# NOTE: We hijack Dask's `get_async` function, injecting a different task
# executor.
object_refs = get_async(
_apply_async_wrapper(apply_sync, _rayify_task_wrapper),
1,
dsk,
keys,
**kwargs,
)
# NOTE: We explicitly delete the Dask graph here so object references
# are garbage-collected before this function returns, i.e. before all Ray
# tasks are done. Otherwise, no intermediate objects will be cleaned up
# until all Ray tasks are done.
del dsk
return ray_get_unpack(object_refs)
ray_callbacks = kwargs.pop("ray_callbacks", None)
with local_ray_callbacks(ray_callbacks) as ray_callbacks:
# Unpack the Ray-specific callbacks.
(
ray_presubmit_cbs,
ray_postsubmit_cbs,
ray_pretask_cbs,
ray_posttask_cbs,
ray_postsubmit_all_cbs,
ray_finish_cbs,
) = unpack_ray_callbacks(ray_callbacks)
# NOTE: We hijack Dask's `get_async` function, injecting a different
# task executor.
object_refs = get_async(
_apply_async_wrapper(
apply_sync,
_rayify_task_wrapper,
ray_presubmit_cbs,
ray_postsubmit_cbs,
ray_pretask_cbs,
ray_posttask_cbs,
),
1,
dsk,
keys,
**kwargs,
)
if ray_postsubmit_all_cbs is not None:
for cb in ray_postsubmit_all_cbs:
cb(object_refs, dsk)
# NOTE: We explicitly delete the Dask graph here so object references
# are garbage-collected before this function returns, i.e. before all
# Ray tasks are done. Otherwise, no intermediate objects will be
# cleaned up until all Ray tasks are done.
del dsk
result = ray_get_unpack(object_refs)
if ray_finish_cbs is not None:
for cb in ray_finish_cbs:
cb(result)
return result