[1.0] Move dask scheduler from experimental to util (#10553)

* move dask

* fix dask
This commit is contained in:
Eric Liang
2020-09-04 12:16:32 -07:00
committed by GitHub
parent b1bd58581b
commit da83bbd764
5 changed files with 3 additions and 4 deletions
+3
View File
@@ -0,0 +1,3 @@
from .scheduler import ray_dask_get, ray_dask_get_sync
__all__ = ["ray_dask_get", "ray_dask_get_sync"]
+77
View File
@@ -0,0 +1,77 @@
from collections import OrderedDict
from collections.abc import Iterator
from operator import getitem
import uuid
import ray
from dask.base import quote
from dask.core import get as get_sync
from dask.compatibility import apply, is_dataclass, dataclass_fields
def unpack_object_refs(*args):
"""
Extract `ray.ObjectRef`s from a set of potentially arbitrarily nested
Python objects.
Intended use is to find all Ray object references in a set of (possibly
nested) Python objects, do something to them (get(), wait(), etc.), then
repackage them into equivalent Python objects.
Args:
*args: One or more (potentially nested) Python objects that contain
Ray object references.
Returns:
A 2-tuple of a flat list of all contained Ray object references, and a
function that, when given the corresponding flat list of concrete
values, will return a set of Python objects equivalent to that which
was given in *args, but with all Ray object references replaced with
their corresponding concrete values.
"""
object_refs = []
repack_dsk = {}
object_refs_token = uuid.uuid4().hex
def _unpack(expr):
if isinstance(expr, ray.ObjectRef):
token = expr.hex()
repack_dsk[token] = (getitem, object_refs_token, len(object_refs))
object_refs.append(expr)
return token
token = uuid.uuid4().hex
# Treat iterators like lists
typ = list if isinstance(expr, Iterator) else type(expr)
if typ in (list, tuple, set):
repack_task = (typ, [_unpack(i) for i in expr])
elif typ in (dict, OrderedDict):
repack_task = (typ,
[[_unpack(k), _unpack(v)] for k, v in expr.items()])
elif is_dataclass(expr):
repack_task = (
apply,
typ,
(),
(
dict,
[[f.name, _unpack(getattr(expr, f.name))]
for f in dataclass_fields(expr)],
),
)
else:
return expr
repack_dsk[token] = repack_task
return token
out = uuid.uuid4().hex
repack_dsk[out] = (tuple, [_unpack(i) for i in args])
def repack(results):
dsk = repack_dsk.copy()
dsk[object_refs_token] = quote(results)
return get_sync(dsk, out)
return object_refs, repack
+298
View File
@@ -0,0 +1,298 @@
import atexit
from collections import defaultdict
from multiprocessing.pool import ThreadPool
import threading
import ray
from dask.core import istask, ishashable, _execute_task
from dask.local import get_async, apply_sync
from dask.system import CPU_COUNT
from dask.threaded import pack_exception, _thread_get_id
from .common import unpack_object_refs
main_thread = threading.current_thread()
default_pool = None
pools = defaultdict(dict)
pools_lock = threading.Lock()
def ray_dask_get(dsk, keys, **kwargs):
"""
A Dask-Ray scheduler. This scheduler will send top-level (non-inlined) Dask
tasks to a Ray cluster for execution. The scheduler will wait for the
tasks to finish executing, fetch the results, and repackage them into the
appropriate Dask collections. This particular scheduler uses a threadpool
to submit Ray tasks.
This can be passed directly to `dask.compute()`, as the scheduler:
>>> dask.compute(obj, scheduler=ray_dask_get)
You can override the number of threads to use when submitting the
Ray tasks, or the threadpool used to submit Ray tasks:
>>> dask.compute(
obj,
scheduler=ray_dask_get,
num_workers=8,
pool=some_cool_pool,
)
Args:
dsk (Dict): Dask graph, represented as a task DAG dictionary.
keys (List[str]): List of Dask graph keys whose values we wish to
compute and return.
num_workers (Optional[int]): The number of worker threads to use in
the Ray task submission traversal of the Dask graph.
pool (Optional[ThreadPool]): A multiprocessing threadpool to use to
submit Ray tasks.
Returns:
Computed values corresponding to the provided keys.
"""
num_workers = kwargs.pop("num_workers", None)
pool = kwargs.pop("pool", None)
# We attempt to reuse any other thread pools that have been created within
# this thread and with the given number of workers. We reuse a global
# thread pool if num_workers is not given and we're in the main thread.
global default_pool
thread = threading.current_thread()
if pool is None:
with pools_lock:
if num_workers is None and thread is main_thread:
if default_pool is None:
default_pool = ThreadPool(CPU_COUNT)
atexit.register(default_pool.close)
pool = default_pool
elif thread in pools and num_workers in pools[thread]:
pool = pools[thread][num_workers]
else:
pool = ThreadPool(num_workers)
atexit.register(pool.close)
pools[thread][num_workers] = pool
# NOTE: We hijack Dask's `get_async` function, injecting a different task
# executor.
object_refs = get_async(
_apply_async_wrapper(pool.apply_async, _rayify_task_wrapper),
len(pool._pool),
dsk,
keys,
get_id=_thread_get_id,
pack_exception=pack_exception,
**kwargs,
)
# NOTE: We explicitly delete the Dask graph here so object references
# are garbage-collected before this function returns, i.e. before all Ray
# tasks are done. Otherwise, no intermediate objects will be cleaned up
# until all Ray tasks are done.
del dsk
result = ray_get_unpack(object_refs)
# cleanup pools associated with dead threads.
with pools_lock:
active_threads = set(threading.enumerate())
if thread is not main_thread:
for t in list(pools):
if t not in active_threads:
for p in pools.pop(t).values():
p.close()
return result
def _apply_async_wrapper(apply_async, real_func, *extra_args, **extra_kwargs):
"""
Wraps the given pool `apply_async` function, hotswapping `real_func` in as
the function to be applied and adding `extra_args` and `extra_kwargs` to
`real_func`'s call.
Args:
apply_async (callable): The pool function to be wrapped.
real_func (callable): The real function that we wish the pool apply
function to execute.
*extra_args: Extra positional arguments to pass to the `real_func`.
**extra_kwargs: Extra keyword arguments to pass to the `real_func`.
Returns:
A wrapper function that will ignore it's first `func` argument and
pass `real_func` in its place. To be passed to `dask.local.get_async`.
"""
def wrapper(func, args=(), kwds={}, callback=None): # noqa: M511
return apply_async(
real_func,
args=args + extra_args,
kwds=dict(kwds, **extra_kwargs),
callback=callback,
)
return wrapper
def _rayify_task_wrapper(
key,
task_info,
dumps,
loads,
get_id,
pack_exception,
):
"""
The core Ray-Dask task execution wrapper, to be given to the thread pool's
`apply_async` function. Exactly the same as `execute_task`, except that it
calls `_rayify_task` on the task instead of `_execute_task`.
Args:
key (str): The Dask graph key whose corresponding task we wish to
execute.
task_info: The task to execute and its dependencies.
dumps (callable): A result serializing function.
loads (callable): A task_info deserializing function.
get_id (callable): An ID generating function.
pack_exception (callable): An exception serializing function.
Returns:
A 3-tuple of the task's key, a literal or a Ray object reference for a
Ray task's result, and whether the Ray task submission failed.
"""
try:
task, deps = loads(task_info)
result = _rayify_task(task, key, deps)
id = get_id()
result = dumps((result, id))
failed = False
except BaseException as e:
result = pack_exception(e, dumps)
failed = True
return key, result, failed
def _rayify_task(task, key, deps):
"""
Rayifies the given task, submitting it as a Ray task to the Ray cluster.
Args:
task: A Dask graph value, being either a literal, dependency key, Dask
task, or a list thereof.
key: The Dask graph key for the given task.
deps: The dependencies of this task.
Returns:
A literal, a Ray object reference representing a submitted task, or a
list thereof.
"""
if isinstance(task, list):
# Recursively rayify this list. This will still bottom out at the first
# actual task encountered, inlining any tasks in that task's arguments.
return [_rayify_task(t, deps) for t in task]
elif istask(task):
# Unpacks and repacks Ray object references and submits the task to the
# Ray cluster for execution.
func, args = task[0], task[1:]
# If the function's arguments contain nested object references, we must
# unpack said object references into a flat set of arguments so that
# Ray properly tracks the object dependencies between Ray tasks.
object_refs, repack = unpack_object_refs(args, deps)
# Submit the task using a wrapper function.
return dask_task_wrapper.options(name=f"dask:{key!s}").remote(
func, repack, *object_refs)
elif not ishashable(task):
return task
elif task in deps:
return deps[task]
else:
return task
@ray.remote
def dask_task_wrapper(func, repack, *args):
"""
A Ray remote function acting as a Dask task wrapper. This function will
repackage the given flat `args` into its original data structures using
`repack`, execute any Dask subtasks within the repackaged arguments
(inlined by Dask's optimization pass), and then pass the concrete task
arguments to the provide Dask task function, `func`.
Args:
func (callable): The Dask task function to execute.
repack (callable): A function that repackages the provided args into
the original (possibly nested) Python objects.
*args (ObjectRef): Ray object references representing the Dask task's
arguments.
Returns:
The output of the Dask task. In the context of Ray, a
dask_task_wrapper.remote() invocation will return a Ray object
reference representing the Ray task's result.
"""
repacked_args, repacked_deps = repack(args)
# Recursively execute Dask-inlined tasks.
actual_args = [_execute_task(a, repacked_deps) for a in repacked_args]
# Execute the actual underlying Dask task.
return func(*actual_args)
def ray_get_unpack(object_refs):
"""
Unpacks object references, gets the object references, and repacks.
Traverses arbitrary data structures.
Args:
object_refs: A (potentially nested) Python object containing Ray object
references.
Returns:
The input Python object with all contained Ray object references
resolved with their concrete values.
"""
if isinstance(object_refs,
(tuple, list)) and any(not isinstance(x, ray.ObjectRef)
for x in object_refs):
# We flatten the object references before calling ray.get(), since Dask
# loves to nest collections in nested tuples and Ray expects a flat
# list of object references. We repack the results after ray.get()
# completes.
object_refs, repack = unpack_object_refs(*object_refs)
computed_result = ray.get(object_refs)
return repack(computed_result)
else:
return ray.get(object_refs)
def ray_dask_get_sync(dsk, keys, **kwargs):
"""
A synchronous Dask-Ray scheduler. This scheduler will send top-level
(non-inlined) Dask tasks to a Ray cluster for execution. The scheduler will
wait for the tasks to finish executing, fetch the results, and repackage
them into the appropriate Dask collections. This particular scheduler
submits Ray tasks synchronously, which can be useful for debugging.
This can be passed directly to `dask.compute()`, as the scheduler:
>>> dask.compute(obj, scheduler=ray_dask_get_sync)
Args:
dsk (Dict): Dask graph, represented as a task DAG dictionary.
keys (List[str]): List of Dask graph keys whose values we wish to
compute and return.
Returns:
Computed values corresponding to the provided keys.
"""
# NOTE: We hijack Dask's `get_async` function, injecting a different task
# executor.
object_refs = get_async(
_apply_async_wrapper(apply_sync, _rayify_task_wrapper),
1,
dsk,
keys,
**kwargs,
)
# NOTE: We explicitly delete the Dask graph here so object references
# are garbage-collected before this function returns, i.e. before all Ray
# tasks are done. Otherwise, no intermediate objects will be cleaned up
# until all Ray tasks are done.
del dsk
return ray_get_unpack(object_refs)