mirror of
https://github.com/wassname/ray.git
synced 2026-06-28 02:01:24 +08:00
da83bbd764
* move dask * fix dask
78 lines
2.4 KiB
Python
78 lines
2.4 KiB
Python
from collections import OrderedDict
|
|
from collections.abc import Iterator
|
|
from operator import getitem
|
|
import uuid
|
|
|
|
import ray
|
|
|
|
from dask.base import quote
|
|
from dask.core import get as get_sync
|
|
from dask.compatibility import apply, is_dataclass, dataclass_fields
|
|
|
|
|
|
def unpack_object_refs(*args):
|
|
"""
|
|
Extract `ray.ObjectRef`s from a set of potentially arbitrarily nested
|
|
Python objects.
|
|
|
|
Intended use is to find all Ray object references in a set of (possibly
|
|
nested) Python objects, do something to them (get(), wait(), etc.), then
|
|
repackage them into equivalent Python objects.
|
|
|
|
Args:
|
|
*args: One or more (potentially nested) Python objects that contain
|
|
Ray object references.
|
|
|
|
Returns:
|
|
A 2-tuple of a flat list of all contained Ray object references, and a
|
|
function that, when given the corresponding flat list of concrete
|
|
values, will return a set of Python objects equivalent to that which
|
|
was given in *args, but with all Ray object references replaced with
|
|
their corresponding concrete values.
|
|
"""
|
|
object_refs = []
|
|
repack_dsk = {}
|
|
|
|
object_refs_token = uuid.uuid4().hex
|
|
|
|
def _unpack(expr):
|
|
if isinstance(expr, ray.ObjectRef):
|
|
token = expr.hex()
|
|
repack_dsk[token] = (getitem, object_refs_token, len(object_refs))
|
|
object_refs.append(expr)
|
|
return token
|
|
|
|
token = uuid.uuid4().hex
|
|
# Treat iterators like lists
|
|
typ = list if isinstance(expr, Iterator) else type(expr)
|
|
if typ in (list, tuple, set):
|
|
repack_task = (typ, [_unpack(i) for i in expr])
|
|
elif typ in (dict, OrderedDict):
|
|
repack_task = (typ,
|
|
[[_unpack(k), _unpack(v)] for k, v in expr.items()])
|
|
elif is_dataclass(expr):
|
|
repack_task = (
|
|
apply,
|
|
typ,
|
|
(),
|
|
(
|
|
dict,
|
|
[[f.name, _unpack(getattr(expr, f.name))]
|
|
for f in dataclass_fields(expr)],
|
|
),
|
|
)
|
|
else:
|
|
return expr
|
|
repack_dsk[token] = repack_task
|
|
return token
|
|
|
|
out = uuid.uuid4().hex
|
|
repack_dsk[out] = (tuple, [_unpack(i) for i in args])
|
|
|
|
def repack(results):
|
|
dsk = repack_dsk.copy()
|
|
dsk[object_refs_token] = quote(results)
|
|
return get_sync(dsk, out)
|
|
|
|
return object_refs, repack
|