[dask-on-ray] Add multiple return DataFrame shuffle optimization. (#13951)

This commit is contained in:
Clark Zinzow
2021-02-09 16:39:48 -07:00
committed by GitHub
parent e0b81796c5
commit 79c7c181f3
4 changed files with 259 additions and 5 deletions
@@ -0,0 +1,63 @@
import dask
import dask.dataframe as dd
from dask.dataframe.shuffle import SimpleShuffleLayer
import mock
import numpy as np
import pandas as pd
import pytest
from ray.util.dask import dataframe_optimize
from ray.util.dask.optimizations import (rewrite_simple_shuffle_layer,
MultipleReturnSimpleShuffleLayer)
def test_rewrite_simple_shuffle_layer():
npartitions = 10
df = dd.from_pandas(
pd.DataFrame(
np.random.randint(0, 100, size=(100, 2)), columns=["age",
"grade"]),
npartitions=npartitions)
# We set max_branch=npartitions in order to ensure that the task-based
# shuffle happens in a single stage, which is required in order for our
# optimization to work.
a = df.set_index(["age"], shuffle="tasks", max_branch=npartitions)
dsk = a.__dask_graph__()
keys = a.__dask_keys__()
assert any(type(v) is SimpleShuffleLayer for k, v in dsk.layers.items())
dsk = rewrite_simple_shuffle_layer(dsk, keys)
assert all(
type(v) is not SimpleShuffleLayer for k, v in dsk.layers.items())
assert any(
type(v) is MultipleReturnSimpleShuffleLayer
for k, v in dsk.layers.items())
@mock.patch("ray.util.dask.optimizations.rewrite_simple_shuffle_layer")
def test_dataframe_optimize(mock_rewrite):
def side_effect(dsk, keys):
return rewrite_simple_shuffle_layer(dsk, keys)
mock_rewrite.side_effect = side_effect
with dask.config.set(dataframe_optimize=dataframe_optimize):
npartitions = 10
df = dd.from_pandas(
pd.DataFrame(
np.random.randint(0, 100, size=(100, 2)),
columns=["age", "grade"]),
npartitions=npartitions)
# We set max_branch=npartitions in order to ensure that the task-based
# shuffle happens in a single stage, which is required in order for our
# optimization to work.
a = df.set_index(
["age"], shuffle="tasks", max_branch=npartitions).compute()
assert mock_rewrite.call_count == 2
assert a.index.is_monotonic_increasing
if __name__ == "__main__":
import sys
sys.exit(pytest.main(["-v", __file__]))
+5
View File
@@ -4,11 +4,16 @@ from .callbacks import (
local_ray_callbacks,
unpack_ray_callbacks,
)
from .optimizations import dataframe_optimize
__all__ = [
# Schedulers
"ray_dask_get",
"ray_dask_get_sync",
# Callbacks
"RayDaskCallback",
"local_ray_callbacks",
"unpack_ray_callbacks",
# Optimizations
"dataframe_optimize",
]
+160
View File
@@ -0,0 +1,160 @@
import operator
import warnings
import dask
from dask import core
from dask.core import istask
from dask.dataframe.core import _concat
from dask.dataframe.optimize import optimize
from dask.dataframe.shuffle import shuffle_group
from dask.highlevelgraph import HighLevelGraph
from .scheduler import MultipleReturnFunc, multiple_return_get
try:
from dask.dataframe.shuffle import SimpleShuffleLayer
except ImportError:
# SimpleShuffleLayer doesn't exist in this version of Dask.
SimpleShuffleLayer = None
if SimpleShuffleLayer is not None:
class MultipleReturnSimpleShuffleLayer(SimpleShuffleLayer):
@classmethod
def clone(cls, layer: SimpleShuffleLayer):
# TODO(Clark): Probably don't need this since SimpleShuffleLayer
# implements __copy__() and the shallow clone should be enough?
return cls(
name=layer.name,
column=layer.column,
npartitions=layer.npartitions,
npartitions_input=layer.npartitions_input,
ignore_index=layer.ignore_index,
name_input=layer.name_input,
meta_input=layer.meta_input,
parts_out=layer.parts_out,
annotations=layer.annotations,
)
def __repr__(self):
return (f"MultipleReturnSimpleShuffleLayer<name='{self.name}', "
f"npartitions={self.npartitions}>")
def __reduce__(self):
attrs = [
"name",
"column",
"npartitions",
"npartitions_input",
"ignore_index",
"name_input",
"meta_input",
"parts_out",
"annotations",
]
return (MultipleReturnSimpleShuffleLayer,
tuple(getattr(self, attr) for attr in attrs))
def _cull(self, parts_out):
return MultipleReturnSimpleShuffleLayer(
self.name,
self.column,
self.npartitions,
self.npartitions_input,
self.ignore_index,
self.name_input,
self.meta_input,
parts_out=parts_out,
)
def _construct_graph(self):
"""Construct graph for a simple shuffle operation."""
shuffle_group_name = "group-" + self.name
shuffle_split_name = "split-" + self.name
dsk = {}
n_parts_out = len(self.parts_out)
for part_out in self.parts_out:
# TODO(Clark): Find better pattern than in-scheduler concat.
_concat_list = [(shuffle_split_name, part_out, part_in)
for part_in in range(self.npartitions_input)]
dsk[(self.name, part_out)] = (_concat, _concat_list,
self.ignore_index)
for _, _part_out, _part_in in _concat_list:
dsk[(shuffle_split_name, _part_out, _part_in)] = (
multiple_return_get,
(shuffle_group_name, _part_in),
_part_out,
)
if (shuffle_group_name, _part_in) not in dsk:
dsk[(shuffle_group_name, _part_in)] = (
MultipleReturnFunc(
shuffle_group,
n_parts_out,
),
(self.name_input, _part_in),
self.column,
0,
self.npartitions,
self.npartitions,
self.ignore_index,
self.npartitions,
)
return dsk
def rewrite_simple_shuffle_layer(dsk, keys):
if not isinstance(dsk, HighLevelGraph):
dsk = HighLevelGraph.from_collections(
id(dsk), dsk, dependencies=())
else:
dsk = dsk.copy()
layers = dsk.layers.copy()
for key, layer in layers.items():
if type(layer) is SimpleShuffleLayer:
dsk.layers[key] = MultipleReturnSimpleShuffleLayer.clone(layer)
return dsk
def dataframe_optimize(dsk, keys, **kwargs):
if not isinstance(keys, (list, set)):
keys = [keys]
keys = list(core.flatten(keys))
if not isinstance(dsk, HighLevelGraph):
dsk = HighLevelGraph.from_collections(
id(dsk), dsk, dependencies=())
dsk = rewrite_simple_shuffle_layer(dsk, keys=keys)
return optimize(dsk, keys, **kwargs)
else:
def dataframe_optimize(dsk, keys, **kwargs):
warnings.warn("Custom dataframe shuffle optimization only works on "
"dask>=2020.12.0, you are on version "
f"{dask.__version__}, please upgrade Dask."
"Falling back to default dataframe optimizer.")
return optimize(dsk, keys, **kwargs)
# Stale approaches below.
def fuse_splits_into_multiple_return(dsk, keys):
if not isinstance(dsk, HighLevelGraph):
dsk = HighLevelGraph.from_collections(id(dsk), dsk, dependencies=())
else:
dsk = dsk.copy()
dependencies = dsk.dependencies.copy()
for k, v in dsk.items():
if istask(v) and v[0] == shuffle_group:
task_deps = dependencies[k]
# Only rewrite shuffle group split if all downstream dependencies
# are splits.
if all(
istask(dsk[dep]) and dsk[dep][0] == operator.getitem
for dep in task_deps):
for dep in task_deps:
# Rewrite split
pass
+31 -5
View File
@@ -1,6 +1,7 @@
import atexit
from collections import defaultdict
from multiprocessing.pool import ThreadPool
from dataclasses import dataclass
import threading
import ray
@@ -270,19 +271,31 @@ def _rayify_task(
return alternate_return
func, args = task[0], task[1:]
if func is multiple_return_get:
return _execute_task(task, deps)
# If the function's arguments contain nested object references, we must
# unpack said object references into a flat set of arguments so that
# Ray properly tracks the object dependencies between Ray tasks.
object_refs, repack = unpack_object_refs(args, deps)
arg_object_refs, repack = unpack_object_refs(args, deps)
# Submit the task using a wrapper function.
object_ref = dask_task_wrapper.options(name=f"dask:{key!s}").remote(
func, repack, key, ray_pretask_cbs, ray_posttask_cbs, *object_refs)
object_refs = dask_task_wrapper.options(
name=f"dask:{key!s}",
num_returns=(1 if not isinstance(func, MultipleReturnFunc) else
func.num_returns),
).remote(
func,
repack,
key,
ray_pretask_cbs,
ray_posttask_cbs,
*arg_object_refs,
)
if ray_postsubmit_cbs is not None:
for cb in ray_postsubmit_cbs:
cb(task, key, deps, object_ref)
cb(task, key, deps, object_refs)
return object_ref
return object_refs
elif not ishashable(task):
return task
elif task in deps:
@@ -434,3 +447,16 @@ def ray_dask_get_sync(dsk, keys, **kwargs):
cb(result)
return result
@dataclass
class MultipleReturnFunc:
func: callable
num_returns: int
def __call__(self, *args, **kwargs):
return self.func(*args, **kwargs)
def multiple_return_get(multiple_returns, idx):
return multiple_returns[idx]