mirror of
https://github.com/wassname/ray.git
synced 2026-06-27 21:38:18 +08:00
[dask-on-ray] Add multiple return DataFrame shuffle optimization. (#13951)
This commit is contained in:
@@ -0,0 +1,63 @@
|
||||
import dask
|
||||
import dask.dataframe as dd
|
||||
from dask.dataframe.shuffle import SimpleShuffleLayer
|
||||
import mock
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import pytest
|
||||
|
||||
from ray.util.dask import dataframe_optimize
|
||||
from ray.util.dask.optimizations import (rewrite_simple_shuffle_layer,
|
||||
MultipleReturnSimpleShuffleLayer)
|
||||
|
||||
|
||||
def test_rewrite_simple_shuffle_layer():
|
||||
npartitions = 10
|
||||
df = dd.from_pandas(
|
||||
pd.DataFrame(
|
||||
np.random.randint(0, 100, size=(100, 2)), columns=["age",
|
||||
"grade"]),
|
||||
npartitions=npartitions)
|
||||
# We set max_branch=npartitions in order to ensure that the task-based
|
||||
# shuffle happens in a single stage, which is required in order for our
|
||||
# optimization to work.
|
||||
a = df.set_index(["age"], shuffle="tasks", max_branch=npartitions)
|
||||
|
||||
dsk = a.__dask_graph__()
|
||||
keys = a.__dask_keys__()
|
||||
assert any(type(v) is SimpleShuffleLayer for k, v in dsk.layers.items())
|
||||
dsk = rewrite_simple_shuffle_layer(dsk, keys)
|
||||
assert all(
|
||||
type(v) is not SimpleShuffleLayer for k, v in dsk.layers.items())
|
||||
assert any(
|
||||
type(v) is MultipleReturnSimpleShuffleLayer
|
||||
for k, v in dsk.layers.items())
|
||||
|
||||
|
||||
@mock.patch("ray.util.dask.optimizations.rewrite_simple_shuffle_layer")
|
||||
def test_dataframe_optimize(mock_rewrite):
|
||||
def side_effect(dsk, keys):
|
||||
return rewrite_simple_shuffle_layer(dsk, keys)
|
||||
|
||||
mock_rewrite.side_effect = side_effect
|
||||
with dask.config.set(dataframe_optimize=dataframe_optimize):
|
||||
npartitions = 10
|
||||
df = dd.from_pandas(
|
||||
pd.DataFrame(
|
||||
np.random.randint(0, 100, size=(100, 2)),
|
||||
columns=["age", "grade"]),
|
||||
npartitions=npartitions)
|
||||
# We set max_branch=npartitions in order to ensure that the task-based
|
||||
# shuffle happens in a single stage, which is required in order for our
|
||||
# optimization to work.
|
||||
a = df.set_index(
|
||||
["age"], shuffle="tasks", max_branch=npartitions).compute()
|
||||
|
||||
assert mock_rewrite.call_count == 2
|
||||
assert a.index.is_monotonic_increasing
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import sys
|
||||
|
||||
sys.exit(pytest.main(["-v", __file__]))
|
||||
@@ -4,11 +4,16 @@ from .callbacks import (
|
||||
local_ray_callbacks,
|
||||
unpack_ray_callbacks,
|
||||
)
|
||||
from .optimizations import dataframe_optimize
|
||||
|
||||
__all__ = [
|
||||
# Schedulers
|
||||
"ray_dask_get",
|
||||
"ray_dask_get_sync",
|
||||
# Callbacks
|
||||
"RayDaskCallback",
|
||||
"local_ray_callbacks",
|
||||
"unpack_ray_callbacks",
|
||||
# Optimizations
|
||||
"dataframe_optimize",
|
||||
]
|
||||
|
||||
@@ -0,0 +1,160 @@
|
||||
import operator
|
||||
import warnings
|
||||
|
||||
import dask
|
||||
from dask import core
|
||||
from dask.core import istask
|
||||
from dask.dataframe.core import _concat
|
||||
from dask.dataframe.optimize import optimize
|
||||
from dask.dataframe.shuffle import shuffle_group
|
||||
from dask.highlevelgraph import HighLevelGraph
|
||||
|
||||
from .scheduler import MultipleReturnFunc, multiple_return_get
|
||||
|
||||
try:
|
||||
from dask.dataframe.shuffle import SimpleShuffleLayer
|
||||
except ImportError:
|
||||
# SimpleShuffleLayer doesn't exist in this version of Dask.
|
||||
SimpleShuffleLayer = None
|
||||
|
||||
if SimpleShuffleLayer is not None:
|
||||
|
||||
class MultipleReturnSimpleShuffleLayer(SimpleShuffleLayer):
|
||||
@classmethod
|
||||
def clone(cls, layer: SimpleShuffleLayer):
|
||||
# TODO(Clark): Probably don't need this since SimpleShuffleLayer
|
||||
# implements __copy__() and the shallow clone should be enough?
|
||||
return cls(
|
||||
name=layer.name,
|
||||
column=layer.column,
|
||||
npartitions=layer.npartitions,
|
||||
npartitions_input=layer.npartitions_input,
|
||||
ignore_index=layer.ignore_index,
|
||||
name_input=layer.name_input,
|
||||
meta_input=layer.meta_input,
|
||||
parts_out=layer.parts_out,
|
||||
annotations=layer.annotations,
|
||||
)
|
||||
|
||||
def __repr__(self):
|
||||
return (f"MultipleReturnSimpleShuffleLayer<name='{self.name}', "
|
||||
f"npartitions={self.npartitions}>")
|
||||
|
||||
def __reduce__(self):
|
||||
attrs = [
|
||||
"name",
|
||||
"column",
|
||||
"npartitions",
|
||||
"npartitions_input",
|
||||
"ignore_index",
|
||||
"name_input",
|
||||
"meta_input",
|
||||
"parts_out",
|
||||
"annotations",
|
||||
]
|
||||
return (MultipleReturnSimpleShuffleLayer,
|
||||
tuple(getattr(self, attr) for attr in attrs))
|
||||
|
||||
def _cull(self, parts_out):
|
||||
return MultipleReturnSimpleShuffleLayer(
|
||||
self.name,
|
||||
self.column,
|
||||
self.npartitions,
|
||||
self.npartitions_input,
|
||||
self.ignore_index,
|
||||
self.name_input,
|
||||
self.meta_input,
|
||||
parts_out=parts_out,
|
||||
)
|
||||
|
||||
def _construct_graph(self):
|
||||
"""Construct graph for a simple shuffle operation."""
|
||||
|
||||
shuffle_group_name = "group-" + self.name
|
||||
shuffle_split_name = "split-" + self.name
|
||||
|
||||
dsk = {}
|
||||
n_parts_out = len(self.parts_out)
|
||||
for part_out in self.parts_out:
|
||||
# TODO(Clark): Find better pattern than in-scheduler concat.
|
||||
_concat_list = [(shuffle_split_name, part_out, part_in)
|
||||
for part_in in range(self.npartitions_input)]
|
||||
dsk[(self.name, part_out)] = (_concat, _concat_list,
|
||||
self.ignore_index)
|
||||
for _, _part_out, _part_in in _concat_list:
|
||||
dsk[(shuffle_split_name, _part_out, _part_in)] = (
|
||||
multiple_return_get,
|
||||
(shuffle_group_name, _part_in),
|
||||
_part_out,
|
||||
)
|
||||
if (shuffle_group_name, _part_in) not in dsk:
|
||||
dsk[(shuffle_group_name, _part_in)] = (
|
||||
MultipleReturnFunc(
|
||||
shuffle_group,
|
||||
n_parts_out,
|
||||
),
|
||||
(self.name_input, _part_in),
|
||||
self.column,
|
||||
0,
|
||||
self.npartitions,
|
||||
self.npartitions,
|
||||
self.ignore_index,
|
||||
self.npartitions,
|
||||
)
|
||||
|
||||
return dsk
|
||||
|
||||
def rewrite_simple_shuffle_layer(dsk, keys):
|
||||
if not isinstance(dsk, HighLevelGraph):
|
||||
dsk = HighLevelGraph.from_collections(
|
||||
id(dsk), dsk, dependencies=())
|
||||
else:
|
||||
dsk = dsk.copy()
|
||||
|
||||
layers = dsk.layers.copy()
|
||||
for key, layer in layers.items():
|
||||
if type(layer) is SimpleShuffleLayer:
|
||||
dsk.layers[key] = MultipleReturnSimpleShuffleLayer.clone(layer)
|
||||
return dsk
|
||||
|
||||
def dataframe_optimize(dsk, keys, **kwargs):
|
||||
if not isinstance(keys, (list, set)):
|
||||
keys = [keys]
|
||||
keys = list(core.flatten(keys))
|
||||
|
||||
if not isinstance(dsk, HighLevelGraph):
|
||||
dsk = HighLevelGraph.from_collections(
|
||||
id(dsk), dsk, dependencies=())
|
||||
|
||||
dsk = rewrite_simple_shuffle_layer(dsk, keys=keys)
|
||||
return optimize(dsk, keys, **kwargs)
|
||||
else:
|
||||
|
||||
def dataframe_optimize(dsk, keys, **kwargs):
|
||||
warnings.warn("Custom dataframe shuffle optimization only works on "
|
||||
"dask>=2020.12.0, you are on version "
|
||||
f"{dask.__version__}, please upgrade Dask."
|
||||
"Falling back to default dataframe optimizer.")
|
||||
return optimize(dsk, keys, **kwargs)
|
||||
|
||||
|
||||
# Stale approaches below.
|
||||
|
||||
|
||||
def fuse_splits_into_multiple_return(dsk, keys):
|
||||
if not isinstance(dsk, HighLevelGraph):
|
||||
dsk = HighLevelGraph.from_collections(id(dsk), dsk, dependencies=())
|
||||
else:
|
||||
dsk = dsk.copy()
|
||||
dependencies = dsk.dependencies.copy()
|
||||
for k, v in dsk.items():
|
||||
if istask(v) and v[0] == shuffle_group:
|
||||
task_deps = dependencies[k]
|
||||
# Only rewrite shuffle group split if all downstream dependencies
|
||||
# are splits.
|
||||
if all(
|
||||
istask(dsk[dep]) and dsk[dep][0] == operator.getitem
|
||||
for dep in task_deps):
|
||||
for dep in task_deps:
|
||||
# Rewrite split
|
||||
pass
|
||||
@@ -1,6 +1,7 @@
|
||||
import atexit
|
||||
from collections import defaultdict
|
||||
from multiprocessing.pool import ThreadPool
|
||||
from dataclasses import dataclass
|
||||
import threading
|
||||
|
||||
import ray
|
||||
@@ -270,19 +271,31 @@ def _rayify_task(
|
||||
return alternate_return
|
||||
|
||||
func, args = task[0], task[1:]
|
||||
if func is multiple_return_get:
|
||||
return _execute_task(task, deps)
|
||||
# If the function's arguments contain nested object references, we must
|
||||
# unpack said object references into a flat set of arguments so that
|
||||
# Ray properly tracks the object dependencies between Ray tasks.
|
||||
object_refs, repack = unpack_object_refs(args, deps)
|
||||
arg_object_refs, repack = unpack_object_refs(args, deps)
|
||||
# Submit the task using a wrapper function.
|
||||
object_ref = dask_task_wrapper.options(name=f"dask:{key!s}").remote(
|
||||
func, repack, key, ray_pretask_cbs, ray_posttask_cbs, *object_refs)
|
||||
object_refs = dask_task_wrapper.options(
|
||||
name=f"dask:{key!s}",
|
||||
num_returns=(1 if not isinstance(func, MultipleReturnFunc) else
|
||||
func.num_returns),
|
||||
).remote(
|
||||
func,
|
||||
repack,
|
||||
key,
|
||||
ray_pretask_cbs,
|
||||
ray_posttask_cbs,
|
||||
*arg_object_refs,
|
||||
)
|
||||
|
||||
if ray_postsubmit_cbs is not None:
|
||||
for cb in ray_postsubmit_cbs:
|
||||
cb(task, key, deps, object_ref)
|
||||
cb(task, key, deps, object_refs)
|
||||
|
||||
return object_ref
|
||||
return object_refs
|
||||
elif not ishashable(task):
|
||||
return task
|
||||
elif task in deps:
|
||||
@@ -434,3 +447,16 @@ def ray_dask_get_sync(dsk, keys, **kwargs):
|
||||
cb(result)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
@dataclass
|
||||
class MultipleReturnFunc:
|
||||
func: callable
|
||||
num_returns: int
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
return self.func(*args, **kwargs)
|
||||
|
||||
|
||||
def multiple_return_get(multiple_returns, idx):
|
||||
return multiple_returns[idx]
|
||||
|
||||
Reference in New Issue
Block a user