mirror of
https://github.com/wassname/ray.git
synced 2026-06-28 14:48:54 +08:00
[1.0] Move dask scheduler from experimental to util (#10553)
* move dask * fix dask
This commit is contained in:
@@ -0,0 +1,3 @@
|
||||
from .scheduler import ray_dask_get, ray_dask_get_sync
|
||||
|
||||
__all__ = ["ray_dask_get", "ray_dask_get_sync"]
|
||||
@@ -0,0 +1,77 @@
|
||||
from collections import OrderedDict
|
||||
from collections.abc import Iterator
|
||||
from operator import getitem
|
||||
import uuid
|
||||
|
||||
import ray
|
||||
|
||||
from dask.base import quote
|
||||
from dask.core import get as get_sync
|
||||
from dask.compatibility import apply, is_dataclass, dataclass_fields
|
||||
|
||||
|
||||
def unpack_object_refs(*args):
|
||||
"""
|
||||
Extract `ray.ObjectRef`s from a set of potentially arbitrarily nested
|
||||
Python objects.
|
||||
|
||||
Intended use is to find all Ray object references in a set of (possibly
|
||||
nested) Python objects, do something to them (get(), wait(), etc.), then
|
||||
repackage them into equivalent Python objects.
|
||||
|
||||
Args:
|
||||
*args: One or more (potentially nested) Python objects that contain
|
||||
Ray object references.
|
||||
|
||||
Returns:
|
||||
A 2-tuple of a flat list of all contained Ray object references, and a
|
||||
function that, when given the corresponding flat list of concrete
|
||||
values, will return a set of Python objects equivalent to that which
|
||||
was given in *args, but with all Ray object references replaced with
|
||||
their corresponding concrete values.
|
||||
"""
|
||||
object_refs = []
|
||||
repack_dsk = {}
|
||||
|
||||
object_refs_token = uuid.uuid4().hex
|
||||
|
||||
def _unpack(expr):
|
||||
if isinstance(expr, ray.ObjectRef):
|
||||
token = expr.hex()
|
||||
repack_dsk[token] = (getitem, object_refs_token, len(object_refs))
|
||||
object_refs.append(expr)
|
||||
return token
|
||||
|
||||
token = uuid.uuid4().hex
|
||||
# Treat iterators like lists
|
||||
typ = list if isinstance(expr, Iterator) else type(expr)
|
||||
if typ in (list, tuple, set):
|
||||
repack_task = (typ, [_unpack(i) for i in expr])
|
||||
elif typ in (dict, OrderedDict):
|
||||
repack_task = (typ,
|
||||
[[_unpack(k), _unpack(v)] for k, v in expr.items()])
|
||||
elif is_dataclass(expr):
|
||||
repack_task = (
|
||||
apply,
|
||||
typ,
|
||||
(),
|
||||
(
|
||||
dict,
|
||||
[[f.name, _unpack(getattr(expr, f.name))]
|
||||
for f in dataclass_fields(expr)],
|
||||
),
|
||||
)
|
||||
else:
|
||||
return expr
|
||||
repack_dsk[token] = repack_task
|
||||
return token
|
||||
|
||||
out = uuid.uuid4().hex
|
||||
repack_dsk[out] = (tuple, [_unpack(i) for i in args])
|
||||
|
||||
def repack(results):
|
||||
dsk = repack_dsk.copy()
|
||||
dsk[object_refs_token] = quote(results)
|
||||
return get_sync(dsk, out)
|
||||
|
||||
return object_refs, repack
|
||||
@@ -0,0 +1,298 @@
|
||||
import atexit
|
||||
from collections import defaultdict
|
||||
from multiprocessing.pool import ThreadPool
|
||||
import threading
|
||||
|
||||
import ray
|
||||
|
||||
from dask.core import istask, ishashable, _execute_task
|
||||
from dask.local import get_async, apply_sync
|
||||
from dask.system import CPU_COUNT
|
||||
from dask.threaded import pack_exception, _thread_get_id
|
||||
|
||||
from .common import unpack_object_refs
|
||||
|
||||
main_thread = threading.current_thread()
|
||||
default_pool = None
|
||||
pools = defaultdict(dict)
|
||||
pools_lock = threading.Lock()
|
||||
|
||||
|
||||
def ray_dask_get(dsk, keys, **kwargs):
|
||||
"""
|
||||
A Dask-Ray scheduler. This scheduler will send top-level (non-inlined) Dask
|
||||
tasks to a Ray cluster for execution. The scheduler will wait for the
|
||||
tasks to finish executing, fetch the results, and repackage them into the
|
||||
appropriate Dask collections. This particular scheduler uses a threadpool
|
||||
to submit Ray tasks.
|
||||
|
||||
This can be passed directly to `dask.compute()`, as the scheduler:
|
||||
|
||||
>>> dask.compute(obj, scheduler=ray_dask_get)
|
||||
|
||||
You can override the number of threads to use when submitting the
|
||||
Ray tasks, or the threadpool used to submit Ray tasks:
|
||||
|
||||
>>> dask.compute(
|
||||
obj,
|
||||
scheduler=ray_dask_get,
|
||||
num_workers=8,
|
||||
pool=some_cool_pool,
|
||||
)
|
||||
|
||||
Args:
|
||||
dsk (Dict): Dask graph, represented as a task DAG dictionary.
|
||||
keys (List[str]): List of Dask graph keys whose values we wish to
|
||||
compute and return.
|
||||
num_workers (Optional[int]): The number of worker threads to use in
|
||||
the Ray task submission traversal of the Dask graph.
|
||||
pool (Optional[ThreadPool]): A multiprocessing threadpool to use to
|
||||
submit Ray tasks.
|
||||
|
||||
Returns:
|
||||
Computed values corresponding to the provided keys.
|
||||
"""
|
||||
num_workers = kwargs.pop("num_workers", None)
|
||||
pool = kwargs.pop("pool", None)
|
||||
# We attempt to reuse any other thread pools that have been created within
|
||||
# this thread and with the given number of workers. We reuse a global
|
||||
# thread pool if num_workers is not given and we're in the main thread.
|
||||
global default_pool
|
||||
thread = threading.current_thread()
|
||||
if pool is None:
|
||||
with pools_lock:
|
||||
if num_workers is None and thread is main_thread:
|
||||
if default_pool is None:
|
||||
default_pool = ThreadPool(CPU_COUNT)
|
||||
atexit.register(default_pool.close)
|
||||
pool = default_pool
|
||||
elif thread in pools and num_workers in pools[thread]:
|
||||
pool = pools[thread][num_workers]
|
||||
else:
|
||||
pool = ThreadPool(num_workers)
|
||||
atexit.register(pool.close)
|
||||
pools[thread][num_workers] = pool
|
||||
|
||||
# NOTE: We hijack Dask's `get_async` function, injecting a different task
|
||||
# executor.
|
||||
object_refs = get_async(
|
||||
_apply_async_wrapper(pool.apply_async, _rayify_task_wrapper),
|
||||
len(pool._pool),
|
||||
dsk,
|
||||
keys,
|
||||
get_id=_thread_get_id,
|
||||
pack_exception=pack_exception,
|
||||
**kwargs,
|
||||
)
|
||||
# NOTE: We explicitly delete the Dask graph here so object references
|
||||
# are garbage-collected before this function returns, i.e. before all Ray
|
||||
# tasks are done. Otherwise, no intermediate objects will be cleaned up
|
||||
# until all Ray tasks are done.
|
||||
del dsk
|
||||
result = ray_get_unpack(object_refs)
|
||||
|
||||
# cleanup pools associated with dead threads.
|
||||
with pools_lock:
|
||||
active_threads = set(threading.enumerate())
|
||||
if thread is not main_thread:
|
||||
for t in list(pools):
|
||||
if t not in active_threads:
|
||||
for p in pools.pop(t).values():
|
||||
p.close()
|
||||
return result
|
||||
|
||||
|
||||
def _apply_async_wrapper(apply_async, real_func, *extra_args, **extra_kwargs):
|
||||
"""
|
||||
Wraps the given pool `apply_async` function, hotswapping `real_func` in as
|
||||
the function to be applied and adding `extra_args` and `extra_kwargs` to
|
||||
`real_func`'s call.
|
||||
|
||||
Args:
|
||||
apply_async (callable): The pool function to be wrapped.
|
||||
real_func (callable): The real function that we wish the pool apply
|
||||
function to execute.
|
||||
*extra_args: Extra positional arguments to pass to the `real_func`.
|
||||
**extra_kwargs: Extra keyword arguments to pass to the `real_func`.
|
||||
|
||||
Returns:
|
||||
A wrapper function that will ignore it's first `func` argument and
|
||||
pass `real_func` in its place. To be passed to `dask.local.get_async`.
|
||||
"""
|
||||
|
||||
def wrapper(func, args=(), kwds={}, callback=None): # noqa: M511
|
||||
return apply_async(
|
||||
real_func,
|
||||
args=args + extra_args,
|
||||
kwds=dict(kwds, **extra_kwargs),
|
||||
callback=callback,
|
||||
)
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
def _rayify_task_wrapper(
|
||||
key,
|
||||
task_info,
|
||||
dumps,
|
||||
loads,
|
||||
get_id,
|
||||
pack_exception,
|
||||
):
|
||||
"""
|
||||
The core Ray-Dask task execution wrapper, to be given to the thread pool's
|
||||
`apply_async` function. Exactly the same as `execute_task`, except that it
|
||||
calls `_rayify_task` on the task instead of `_execute_task`.
|
||||
|
||||
Args:
|
||||
key (str): The Dask graph key whose corresponding task we wish to
|
||||
execute.
|
||||
task_info: The task to execute and its dependencies.
|
||||
dumps (callable): A result serializing function.
|
||||
loads (callable): A task_info deserializing function.
|
||||
get_id (callable): An ID generating function.
|
||||
pack_exception (callable): An exception serializing function.
|
||||
|
||||
Returns:
|
||||
A 3-tuple of the task's key, a literal or a Ray object reference for a
|
||||
Ray task's result, and whether the Ray task submission failed.
|
||||
"""
|
||||
try:
|
||||
task, deps = loads(task_info)
|
||||
result = _rayify_task(task, key, deps)
|
||||
id = get_id()
|
||||
result = dumps((result, id))
|
||||
failed = False
|
||||
except BaseException as e:
|
||||
result = pack_exception(e, dumps)
|
||||
failed = True
|
||||
return key, result, failed
|
||||
|
||||
|
||||
def _rayify_task(task, key, deps):
|
||||
"""
|
||||
Rayifies the given task, submitting it as a Ray task to the Ray cluster.
|
||||
|
||||
Args:
|
||||
task: A Dask graph value, being either a literal, dependency key, Dask
|
||||
task, or a list thereof.
|
||||
key: The Dask graph key for the given task.
|
||||
deps: The dependencies of this task.
|
||||
|
||||
Returns:
|
||||
A literal, a Ray object reference representing a submitted task, or a
|
||||
list thereof.
|
||||
"""
|
||||
if isinstance(task, list):
|
||||
# Recursively rayify this list. This will still bottom out at the first
|
||||
# actual task encountered, inlining any tasks in that task's arguments.
|
||||
return [_rayify_task(t, deps) for t in task]
|
||||
elif istask(task):
|
||||
# Unpacks and repacks Ray object references and submits the task to the
|
||||
# Ray cluster for execution.
|
||||
func, args = task[0], task[1:]
|
||||
# If the function's arguments contain nested object references, we must
|
||||
# unpack said object references into a flat set of arguments so that
|
||||
# Ray properly tracks the object dependencies between Ray tasks.
|
||||
object_refs, repack = unpack_object_refs(args, deps)
|
||||
# Submit the task using a wrapper function.
|
||||
return dask_task_wrapper.options(name=f"dask:{key!s}").remote(
|
||||
func, repack, *object_refs)
|
||||
elif not ishashable(task):
|
||||
return task
|
||||
elif task in deps:
|
||||
return deps[task]
|
||||
else:
|
||||
return task
|
||||
|
||||
|
||||
@ray.remote
|
||||
def dask_task_wrapper(func, repack, *args):
|
||||
"""
|
||||
A Ray remote function acting as a Dask task wrapper. This function will
|
||||
repackage the given flat `args` into its original data structures using
|
||||
`repack`, execute any Dask subtasks within the repackaged arguments
|
||||
(inlined by Dask's optimization pass), and then pass the concrete task
|
||||
arguments to the provide Dask task function, `func`.
|
||||
|
||||
Args:
|
||||
func (callable): The Dask task function to execute.
|
||||
repack (callable): A function that repackages the provided args into
|
||||
the original (possibly nested) Python objects.
|
||||
*args (ObjectRef): Ray object references representing the Dask task's
|
||||
arguments.
|
||||
|
||||
Returns:
|
||||
The output of the Dask task. In the context of Ray, a
|
||||
dask_task_wrapper.remote() invocation will return a Ray object
|
||||
reference representing the Ray task's result.
|
||||
"""
|
||||
repacked_args, repacked_deps = repack(args)
|
||||
# Recursively execute Dask-inlined tasks.
|
||||
actual_args = [_execute_task(a, repacked_deps) for a in repacked_args]
|
||||
# Execute the actual underlying Dask task.
|
||||
return func(*actual_args)
|
||||
|
||||
|
||||
def ray_get_unpack(object_refs):
|
||||
"""
|
||||
Unpacks object references, gets the object references, and repacks.
|
||||
Traverses arbitrary data structures.
|
||||
|
||||
Args:
|
||||
object_refs: A (potentially nested) Python object containing Ray object
|
||||
references.
|
||||
|
||||
Returns:
|
||||
The input Python object with all contained Ray object references
|
||||
resolved with their concrete values.
|
||||
"""
|
||||
if isinstance(object_refs,
|
||||
(tuple, list)) and any(not isinstance(x, ray.ObjectRef)
|
||||
for x in object_refs):
|
||||
# We flatten the object references before calling ray.get(), since Dask
|
||||
# loves to nest collections in nested tuples and Ray expects a flat
|
||||
# list of object references. We repack the results after ray.get()
|
||||
# completes.
|
||||
object_refs, repack = unpack_object_refs(*object_refs)
|
||||
computed_result = ray.get(object_refs)
|
||||
return repack(computed_result)
|
||||
else:
|
||||
return ray.get(object_refs)
|
||||
|
||||
|
||||
def ray_dask_get_sync(dsk, keys, **kwargs):
|
||||
"""
|
||||
A synchronous Dask-Ray scheduler. This scheduler will send top-level
|
||||
(non-inlined) Dask tasks to a Ray cluster for execution. The scheduler will
|
||||
wait for the tasks to finish executing, fetch the results, and repackage
|
||||
them into the appropriate Dask collections. This particular scheduler
|
||||
submits Ray tasks synchronously, which can be useful for debugging.
|
||||
|
||||
This can be passed directly to `dask.compute()`, as the scheduler:
|
||||
|
||||
>>> dask.compute(obj, scheduler=ray_dask_get_sync)
|
||||
|
||||
Args:
|
||||
dsk (Dict): Dask graph, represented as a task DAG dictionary.
|
||||
keys (List[str]): List of Dask graph keys whose values we wish to
|
||||
compute and return.
|
||||
|
||||
Returns:
|
||||
Computed values corresponding to the provided keys.
|
||||
"""
|
||||
# NOTE: We hijack Dask's `get_async` function, injecting a different task
|
||||
# executor.
|
||||
object_refs = get_async(
|
||||
_apply_async_wrapper(apply_sync, _rayify_task_wrapper),
|
||||
1,
|
||||
dsk,
|
||||
keys,
|
||||
**kwargs,
|
||||
)
|
||||
# NOTE: We explicitly delete the Dask graph here so object references
|
||||
# are garbage-collected before this function returns, i.e. before all Ray
|
||||
# tasks are done. Otherwise, no intermediate objects will be cleaned up
|
||||
# until all Ray tasks are done.
|
||||
del dsk
|
||||
return ray_get_unpack(object_refs)
|
||||
Reference in New Issue
Block a user