From 79c7c181f36e63035e82d883c1af2f7f04873fc9 Mon Sep 17 00:00:00 2001 From: Clark Zinzow Date: Tue, 9 Feb 2021 16:39:48 -0700 Subject: [PATCH] [dask-on-ray] Add multiple return DataFrame shuffle optimization. (#13951) --- python/ray/tests/test_dask_optimization.py | 63 ++++++++ python/ray/util/dask/__init__.py | 5 + python/ray/util/dask/optimizations.py | 160 +++++++++++++++++++++ python/ray/util/dask/scheduler.py | 36 ++++- 4 files changed, 259 insertions(+), 5 deletions(-) create mode 100644 python/ray/tests/test_dask_optimization.py create mode 100644 python/ray/util/dask/optimizations.py diff --git a/python/ray/tests/test_dask_optimization.py b/python/ray/tests/test_dask_optimization.py new file mode 100644 index 000000000..e8a045aee --- /dev/null +++ b/python/ray/tests/test_dask_optimization.py @@ -0,0 +1,63 @@ +import dask +import dask.dataframe as dd +from dask.dataframe.shuffle import SimpleShuffleLayer +import mock +import numpy as np +import pandas as pd +import pytest + +from ray.util.dask import dataframe_optimize +from ray.util.dask.optimizations import (rewrite_simple_shuffle_layer, + MultipleReturnSimpleShuffleLayer) + + +def test_rewrite_simple_shuffle_layer(): + npartitions = 10 + df = dd.from_pandas( + pd.DataFrame( + np.random.randint(0, 100, size=(100, 2)), columns=["age", + "grade"]), + npartitions=npartitions) + # We set max_branch=npartitions in order to ensure that the task-based + # shuffle happens in a single stage, which is required in order for our + # optimization to work. + a = df.set_index(["age"], shuffle="tasks", max_branch=npartitions) + + dsk = a.__dask_graph__() + keys = a.__dask_keys__() + assert any(type(v) is SimpleShuffleLayer for k, v in dsk.layers.items()) + dsk = rewrite_simple_shuffle_layer(dsk, keys) + assert all( + type(v) is not SimpleShuffleLayer for k, v in dsk.layers.items()) + assert any( + type(v) is MultipleReturnSimpleShuffleLayer + for k, v in dsk.layers.items()) + + +@mock.patch("ray.util.dask.optimizations.rewrite_simple_shuffle_layer") +def test_dataframe_optimize(mock_rewrite): + def side_effect(dsk, keys): + return rewrite_simple_shuffle_layer(dsk, keys) + + mock_rewrite.side_effect = side_effect + with dask.config.set(dataframe_optimize=dataframe_optimize): + npartitions = 10 + df = dd.from_pandas( + pd.DataFrame( + np.random.randint(0, 100, size=(100, 2)), + columns=["age", "grade"]), + npartitions=npartitions) + # We set max_branch=npartitions in order to ensure that the task-based + # shuffle happens in a single stage, which is required in order for our + # optimization to work. + a = df.set_index( + ["age"], shuffle="tasks", max_branch=npartitions).compute() + + assert mock_rewrite.call_count == 2 + assert a.index.is_monotonic_increasing + + +if __name__ == "__main__": + import sys + + sys.exit(pytest.main(["-v", __file__])) diff --git a/python/ray/util/dask/__init__.py b/python/ray/util/dask/__init__.py index bfe28571a..10a08379c 100644 --- a/python/ray/util/dask/__init__.py +++ b/python/ray/util/dask/__init__.py @@ -4,11 +4,16 @@ from .callbacks import ( local_ray_callbacks, unpack_ray_callbacks, ) +from .optimizations import dataframe_optimize __all__ = [ + # Schedulers "ray_dask_get", "ray_dask_get_sync", + # Callbacks "RayDaskCallback", "local_ray_callbacks", "unpack_ray_callbacks", + # Optimizations + "dataframe_optimize", ] diff --git a/python/ray/util/dask/optimizations.py b/python/ray/util/dask/optimizations.py new file mode 100644 index 000000000..c36757af6 --- /dev/null +++ b/python/ray/util/dask/optimizations.py @@ -0,0 +1,160 @@ +import operator +import warnings + +import dask +from dask import core +from dask.core import istask +from dask.dataframe.core import _concat +from dask.dataframe.optimize import optimize +from dask.dataframe.shuffle import shuffle_group +from dask.highlevelgraph import HighLevelGraph + +from .scheduler import MultipleReturnFunc, multiple_return_get + +try: + from dask.dataframe.shuffle import SimpleShuffleLayer +except ImportError: + # SimpleShuffleLayer doesn't exist in this version of Dask. + SimpleShuffleLayer = None + +if SimpleShuffleLayer is not None: + + class MultipleReturnSimpleShuffleLayer(SimpleShuffleLayer): + @classmethod + def clone(cls, layer: SimpleShuffleLayer): + # TODO(Clark): Probably don't need this since SimpleShuffleLayer + # implements __copy__() and the shallow clone should be enough? + return cls( + name=layer.name, + column=layer.column, + npartitions=layer.npartitions, + npartitions_input=layer.npartitions_input, + ignore_index=layer.ignore_index, + name_input=layer.name_input, + meta_input=layer.meta_input, + parts_out=layer.parts_out, + annotations=layer.annotations, + ) + + def __repr__(self): + return (f"MultipleReturnSimpleShuffleLayer") + + def __reduce__(self): + attrs = [ + "name", + "column", + "npartitions", + "npartitions_input", + "ignore_index", + "name_input", + "meta_input", + "parts_out", + "annotations", + ] + return (MultipleReturnSimpleShuffleLayer, + tuple(getattr(self, attr) for attr in attrs)) + + def _cull(self, parts_out): + return MultipleReturnSimpleShuffleLayer( + self.name, + self.column, + self.npartitions, + self.npartitions_input, + self.ignore_index, + self.name_input, + self.meta_input, + parts_out=parts_out, + ) + + def _construct_graph(self): + """Construct graph for a simple shuffle operation.""" + + shuffle_group_name = "group-" + self.name + shuffle_split_name = "split-" + self.name + + dsk = {} + n_parts_out = len(self.parts_out) + for part_out in self.parts_out: + # TODO(Clark): Find better pattern than in-scheduler concat. + _concat_list = [(shuffle_split_name, part_out, part_in) + for part_in in range(self.npartitions_input)] + dsk[(self.name, part_out)] = (_concat, _concat_list, + self.ignore_index) + for _, _part_out, _part_in in _concat_list: + dsk[(shuffle_split_name, _part_out, _part_in)] = ( + multiple_return_get, + (shuffle_group_name, _part_in), + _part_out, + ) + if (shuffle_group_name, _part_in) not in dsk: + dsk[(shuffle_group_name, _part_in)] = ( + MultipleReturnFunc( + shuffle_group, + n_parts_out, + ), + (self.name_input, _part_in), + self.column, + 0, + self.npartitions, + self.npartitions, + self.ignore_index, + self.npartitions, + ) + + return dsk + + def rewrite_simple_shuffle_layer(dsk, keys): + if not isinstance(dsk, HighLevelGraph): + dsk = HighLevelGraph.from_collections( + id(dsk), dsk, dependencies=()) + else: + dsk = dsk.copy() + + layers = dsk.layers.copy() + for key, layer in layers.items(): + if type(layer) is SimpleShuffleLayer: + dsk.layers[key] = MultipleReturnSimpleShuffleLayer.clone(layer) + return dsk + + def dataframe_optimize(dsk, keys, **kwargs): + if not isinstance(keys, (list, set)): + keys = [keys] + keys = list(core.flatten(keys)) + + if not isinstance(dsk, HighLevelGraph): + dsk = HighLevelGraph.from_collections( + id(dsk), dsk, dependencies=()) + + dsk = rewrite_simple_shuffle_layer(dsk, keys=keys) + return optimize(dsk, keys, **kwargs) +else: + + def dataframe_optimize(dsk, keys, **kwargs): + warnings.warn("Custom dataframe shuffle optimization only works on " + "dask>=2020.12.0, you are on version " + f"{dask.__version__}, please upgrade Dask." + "Falling back to default dataframe optimizer.") + return optimize(dsk, keys, **kwargs) + + +# Stale approaches below. + + +def fuse_splits_into_multiple_return(dsk, keys): + if not isinstance(dsk, HighLevelGraph): + dsk = HighLevelGraph.from_collections(id(dsk), dsk, dependencies=()) + else: + dsk = dsk.copy() + dependencies = dsk.dependencies.copy() + for k, v in dsk.items(): + if istask(v) and v[0] == shuffle_group: + task_deps = dependencies[k] + # Only rewrite shuffle group split if all downstream dependencies + # are splits. + if all( + istask(dsk[dep]) and dsk[dep][0] == operator.getitem + for dep in task_deps): + for dep in task_deps: + # Rewrite split + pass diff --git a/python/ray/util/dask/scheduler.py b/python/ray/util/dask/scheduler.py index 0614d3564..d6a8a6edc 100644 --- a/python/ray/util/dask/scheduler.py +++ b/python/ray/util/dask/scheduler.py @@ -1,6 +1,7 @@ import atexit from collections import defaultdict from multiprocessing.pool import ThreadPool +from dataclasses import dataclass import threading import ray @@ -270,19 +271,31 @@ def _rayify_task( return alternate_return func, args = task[0], task[1:] + if func is multiple_return_get: + return _execute_task(task, deps) # If the function's arguments contain nested object references, we must # unpack said object references into a flat set of arguments so that # Ray properly tracks the object dependencies between Ray tasks. - object_refs, repack = unpack_object_refs(args, deps) + arg_object_refs, repack = unpack_object_refs(args, deps) # Submit the task using a wrapper function. - object_ref = dask_task_wrapper.options(name=f"dask:{key!s}").remote( - func, repack, key, ray_pretask_cbs, ray_posttask_cbs, *object_refs) + object_refs = dask_task_wrapper.options( + name=f"dask:{key!s}", + num_returns=(1 if not isinstance(func, MultipleReturnFunc) else + func.num_returns), + ).remote( + func, + repack, + key, + ray_pretask_cbs, + ray_posttask_cbs, + *arg_object_refs, + ) if ray_postsubmit_cbs is not None: for cb in ray_postsubmit_cbs: - cb(task, key, deps, object_ref) + cb(task, key, deps, object_refs) - return object_ref + return object_refs elif not ishashable(task): return task elif task in deps: @@ -434,3 +447,16 @@ def ray_dask_get_sync(dsk, keys, **kwargs): cb(result) return result + + +@dataclass +class MultipleReturnFunc: + func: callable + num_returns: int + + def __call__(self, *args, **kwargs): + return self.func(*args, **kwargs) + + +def multiple_return_get(multiple_returns, idx): + return multiple_returns[idx]