diff --git a/python/ray/tests/test_iter.py b/python/ray/tests/test_iter.py index 8eb0d606c..072fe0a19 100644 --- a/python/ray/tests/test_iter.py +++ b/python/ray/tests/test_iter.py @@ -17,6 +17,24 @@ def test_select_shards(ray_start_regular_shared): assert it2.take(4) == [2, 4] +def test_transform(ray_start_regular_shared): + def f(it): + for item in it: + yield item * 2 + + def g(it): + for item in it: + if item >= 2: + yield item + + it = from_range(4).transform(f) + assert repr(it) == "ParallelIterator[from_range[4, shards=2].transform()]" + assert list(it.gather_sync()) == [0, 4, 2, 6] + + it = from_range(4) + assert list(it.gather_sync().transform(g)) == [2, 3] + + def test_metrics(ray_start_regular_shared): it = from_items([1, 2, 3, 4], num_shards=1) it2 = from_items([1, 2, 3, 4], num_shards=1) diff --git a/python/ray/util/iter.py b/python/ray/util/iter.py index c277e4128..dde3cce88 100644 --- a/python/ray/util/iter.py +++ b/python/ray/util/iter.py @@ -193,10 +193,37 @@ class ParallelIterator(Generic[T]): name=self.name + name, parent_iterators=self.parent_iterators) + def transform(self, fn: Callable[[Iterable[T]], Iterable[U]] + ) -> "ParallelIterator[U]": + """Remotely transform the iterator. + + This is advanced version of for_each that allows you to apply arbitrary + generator transformations over the iterator. Prefer to use .for_each() + when possible for simplicity. + + Args: + fn (func): function to use to transform the iterator. The function + should pass through instances of _NextValueNotReady that appear + in its input iterator. Note that this function is only called + **once** over the input iterator. + + Returns: + ParallelIterator[U]: a parallel iterator. + + Examples: + >>> def f(it): + ... for x in it: + ... if x % 2 == 0: + ... yield x + >>> from_range(10, 1).transform(f).gather_sync().take(5) + ... [0, 2, 4, 6, 8] + """ + return self._with_transform(lambda local_it: local_it.transform(fn), + ".transform()") + def for_each(self, fn: Callable[[T], U], max_concurrency=1, resources=None) -> "ParallelIterator[U]": - """Remotely apply fn to each item in this iterator, at most `max_concurrency` - at a time per shard. + """Remotely apply fn to each item in this iterator. If `max_concurrency` == 1 then `fn` will be executed serially by each shards @@ -224,7 +251,6 @@ class ParallelIterator(Generic[T]): ParallelIterator[U]: a parallel iterator whose elements have `fn` applied. - Examples: >>> next(from_range(4).for_each( lambda x: x * 2, @@ -737,6 +763,20 @@ class LocalIterator(Generic[T]): def __repr__(self): return "LocalIterator[{}]".format(self.name) + def transform(self, fn: Callable[[Iterable[T]], Iterable[U]] + ) -> "LocalIterator[U]": + + # TODO(ekl) can we automatically handle NextValueNotReady here? + def apply_transform(it): + for item in fn(it): + yield item + + return LocalIterator( + self.base_iterator, + self.shared_metrics, + self.local_transforms + [apply_transform], + name=self.name + ".transform()") + def for_each(self, fn: Callable[[T], U], max_concurrency=1, resources=None) -> "LocalIterator[U]": if max_concurrency == 1: diff --git a/src/ray/object_manager/plasma/store.cc b/src/ray/object_manager/plasma/store.cc index e5e169b4a..87827245b 100644 --- a/src/ray/object_manager/plasma/store.cc +++ b/src/ray/object_manager/plasma/store.cc @@ -724,7 +724,7 @@ void PlasmaStore::DisconnectClient(int client_fd) { loop_->RemoveFileEvent(client_fd); // Close the socket. close(client_fd); - RAY_LOG(INFO) << "Disconnecting client on fd " << client_fd; + RAY_LOG(DEBUG) << "Disconnecting client on fd " << client_fd; // Release all the objects that the client was using. auto client = it->second.get(); eviction_policy_.ClientDisconnected(client);