[rllib] Fix shared metrics context in parallel iterators (#7666)

* debug

* build

* update

* wip

* wpi

* update

* recurisve sync

* comment

* stream

* fix

* Update .travis.yml
This commit is contained in:
Eric Liang
2020-03-22 14:15:01 -07:00
committed by GitHub
parent 2fb219a658
commit 288933ec6b
7 changed files with 131 additions and 72 deletions
+40
View File
@@ -73,6 +73,46 @@ def test_metrics_union(ray_start_regular_shared):
assert it3.take(10) == [1, 100, 3, 200, 6, 300, 10, 400]
def test_metrics_union_recursive(ray_start_regular_shared):
it1 = from_items([1, 2, 3, 4], num_shards=1)
it2 = from_items([1, 2, 3, 4], num_shards=1)
it3 = from_items([1, 2, 3, 4], num_shards=1)
def foo_metrics(x):
metrics = LocalIterator.get_metrics()
metrics.counters["foo"] += 1
return metrics.counters["foo"]
def bar_metrics(x):
metrics = LocalIterator.get_metrics()
metrics.counters["bar"] += 1
return metrics.counters["bar"]
def baz_metrics(x):
metrics = LocalIterator.get_metrics()
metrics.counters["baz"] += 1
return metrics.counters["baz"]
def verify_metrics(x):
metrics = LocalIterator.get_metrics()
metrics.counters["n"] += 1
# Check the metrics context is shared recursively.
print(metrics.counters)
if metrics.counters["n"] >= 3:
assert "foo" in metrics.counters
assert "bar" in metrics.counters
assert "baz" in metrics.counters
return x
it1 = it1.gather_async().for_each(foo_metrics)
it2 = it2.gather_async().for_each(bar_metrics)
it3 = it3.gather_async().for_each(baz_metrics)
it12 = it1.union(it2, deterministic=True)
it123 = it12.union(it3, deterministic=True)
out = it123.for_each(verify_metrics)
assert out.take(20) == [1, 1, 1, 2, 2, 3, 2, 4, 3, 3, 4, 4]
def test_from_items(ray_start_regular_shared):
it = from_items([1, 2, 3, 4])
assert repr(it) == "ParallelIterator[from_items[int, 4, shards=2]]"
+44 -40
View File
@@ -1,10 +1,11 @@
from contextlib import contextmanager
import collections
import random
import threading
from typing import TypeVar, Generic, Iterable, List, Callable, Any
import ray
from ray.util.iter_metrics import MetricsContext
from ray.util.iter_metrics import MetricsContext, SharedMetrics
# The type of an iterator element.
T = TypeVar("T")
@@ -412,7 +413,7 @@ class ParallelIterator(Generic[T]):
futures = [a.par_iter_next.remote() for a in active]
name = "{}.batch_across_shards()".format(self)
return LocalIterator(base_iterator, MetricsContext(), name=name)
return LocalIterator(base_iterator, SharedMetrics(), name=name)
def gather_async(self, async_queue_depth=1) -> "LocalIterator[T]":
"""Returns a local iterable for asynchronous iteration.
@@ -438,8 +439,10 @@ class ParallelIterator(Generic[T]):
if async_queue_depth < 1:
raise ValueError("queue depth must be positive")
# Forward reference to the returned iterator.
local_iter = None
def base_iterator(timeout=None):
metrics = LocalIterator.get_metrics()
all_actors = []
for actor_set in self.actor_sets:
actor_set.init_actors()
@@ -463,7 +466,7 @@ class ParallelIterator(Generic[T]):
for obj_id in ready:
actor = futures.pop(obj_id)
try:
metrics.current_actor = actor
local_iter.shared_metrics.get().current_actor = actor
yield ray.get(obj_id)
futures[actor.par_iter_next.remote()] = actor
except StopIteration:
@@ -473,7 +476,8 @@ class ParallelIterator(Generic[T]):
yield _NextValueNotReady()
name = "{}.gather_async()".format(self)
return LocalIterator(base_iterator, MetricsContext(), name=name)
local_iter = LocalIterator(base_iterator, SharedMetrics(), name=name)
return local_iter
def take(self, n: int) -> List[T]:
"""Return up to the first n items from this iterator."""
@@ -540,7 +544,7 @@ class ParallelIterator(Generic[T]):
break
name = self.name + ".shard[{}]".format(shard_index)
return LocalIterator(base_iterator, MetricsContext(), name=name)
return LocalIterator(base_iterator, SharedMetrics(), name=name)
class LocalIterator(Generic[T]):
@@ -562,7 +566,7 @@ class LocalIterator(Generic[T]):
def __init__(self,
base_iterator: Callable[[], Iterable[T]],
metrics: MetricsContext,
shared_metrics: SharedMetrics,
local_transforms: List[Callable[[Iterable], Any]] = None,
timeout: int = None,
name=None):
@@ -572,7 +576,7 @@ class LocalIterator(Generic[T]):
base_iterator (func): A function that produces the base iterator.
This is a function so that we can ensure LocalIterator is
serializable.
metrics (MetricsContext): Existing metrics context or a new
shared_metrics (SharedMetrics): Existing metrics context or a new
context. Should be the same for each chained iterator.
local_transforms (list): A list of transformation functions to be
applied on top of the base iterator. When iteration begins, we
@@ -584,10 +588,11 @@ class LocalIterator(Generic[T]):
blocking.
name (str): Optional name for this iterator.
"""
assert isinstance(shared_metrics, SharedMetrics)
self.base_iterator = base_iterator
self.built_iterator = None
self.local_transforms = local_transforms or []
self.metrics = metrics
self.shared_metrics = shared_metrics
self.timeout = timeout
self.name = name or "unknown"
@@ -606,26 +611,20 @@ class LocalIterator(Generic[T]):
it = iter(self.base_iterator(self.timeout))
for fn in self.local_transforms:
it = fn(it)
# This sets the iterator context during iterator execution, and
# clears it after so that multiple iterators can be used at a time.
def set_restore_context(it):
if hasattr(self.thread_local, "metrics"):
prev_metrics = self.thread_local.metrics
else:
prev_metrics = None
self.thread_local.metrics = self.metrics
try:
for item in it:
self.thread_local.metrics = prev_metrics
yield item
self.thread_local.metrics = self.metrics
finally:
self.thread_local.metrics = prev_metrics
it = set_restore_context(it)
self.built_iterator = it
@contextmanager
def _metrics_context(self):
if hasattr(self.thread_local, "metrics"):
prev_metrics = self.thread_local.metrics
else:
prev_metrics = None
try:
self.thread_local.metrics = self.shared_metrics.get()
yield
finally:
self.thread_local.metrics = prev_metrics
def __iter__(self):
self._build_once()
return self.built_iterator
@@ -649,7 +648,8 @@ class LocalIterator(Generic[T]):
# Keep retrying the function until it returns a valid
# value. This allows for non-blocking functions.
while True:
result = fn(item)
with self._metrics_context():
result = fn(item)
yield result
if not isinstance(result, _NextValueNotReady):
break
@@ -664,7 +664,8 @@ class LocalIterator(Generic[T]):
# Avoids calling on_fetch_start repeatedly if we are
# yielding _NextValueNotReady.
if new_item:
fn._on_fetch_start()
with self._metrics_context():
fn._on_fetch_start()
new_item = False
item = next(it)
if not isinstance(item, _NextValueNotReady):
@@ -675,19 +676,20 @@ class LocalIterator(Generic[T]):
return LocalIterator(
self.base_iterator,
self.metrics,
self.shared_metrics,
self.local_transforms + [apply_foreach],
name=self.name + ".for_each()")
def filter(self, fn: Callable[[T], bool]) -> "LocalIterator[T]":
def apply_filter(it):
for item in it:
if isinstance(item, _NextValueNotReady) or fn(item):
yield item
with self._metrics_context():
if isinstance(item, _NextValueNotReady) or fn(item):
yield item
return LocalIterator(
self.base_iterator,
self.metrics,
self.shared_metrics,
self.local_transforms + [apply_filter],
name=self.name + ".filter()")
@@ -707,7 +709,7 @@ class LocalIterator(Generic[T]):
return LocalIterator(
self.base_iterator,
self.metrics,
self.shared_metrics,
self.local_transforms + [apply_batch],
name=self.name + ".batch({})".format(n))
@@ -722,7 +724,7 @@ class LocalIterator(Generic[T]):
return LocalIterator(
self.base_iterator,
self.metrics,
self.shared_metrics,
self.local_transforms + [apply_flatten],
name=self.name + ".flatten()")
@@ -760,7 +762,7 @@ class LocalIterator(Generic[T]):
return LocalIterator(
self.base_iterator,
self.metrics,
self.shared_metrics,
self.local_transforms + [apply_shuffle],
name=self.name +
".shuffle(shuffle_buffer_size={}, seed={})".format(
@@ -836,7 +838,7 @@ class LocalIterator(Generic[T]):
iterators.append(
LocalIterator(
make_next(i),
self.metrics, [],
self.shared_metrics, [],
name=self.name + ".duplicate[{}]".format(i)))
return iterators
@@ -862,8 +864,10 @@ class LocalIterator(Generic[T]):
timeout = 0
active = []
shared_metrics = MetricsContext()
for it in [self] + list(others):
parent_iters = [self] + list(others)
shared_metrics = SharedMetrics(
parents=[p.shared_metrics for p in parent_iters])
for it in parent_iters:
active.append(
LocalIterator(
it.base_iterator,
@@ -945,7 +949,7 @@ class ParallelIteratorWorker(object):
def par_iter_init(self, transforms):
"""Implements ParallelIterator worker init."""
it = LocalIterator(lambda timeout: self.item_generator,
MetricsContext())
SharedMetrics())
for fn in transforms:
it = fn(it)
assert it is not None, fn
+24 -6
View File
@@ -1,4 +1,5 @@
import collections
from typing import List
from ray.util.timer import _Timer
@@ -18,8 +19,6 @@ class MetricsContext:
current_actor (ActorHandle): reference to the actor handle that
produced the current iterator output. This is automatically set
for gather_async().
parent_metrics (list): list of other MetricsContexts that have been
attached to this due to LocalIterator.union().
"""
def __init__(self):
@@ -27,7 +26,6 @@ class MetricsContext:
self.timers = collections.defaultdict(_Timer)
self.info = {}
self.current_actor = None
self.parent_metrics = []
def save(self):
"""Return a serializable copy of this context."""
@@ -35,7 +33,6 @@ class MetricsContext:
"counters": dict(self.counters),
"info": dict(self.info),
"timers": None, # TODO(ekl) consider persisting timers too
"parent_state": [u.save() for u in self.parent_metrics],
}
def restore(self, values):
@@ -44,5 +41,26 @@ class MetricsContext:
self.counters.update(values["counters"])
self.timers.clear()
self.info = values["info"]
for u, state in zip(self.parent_metrics, values["parent_state"]):
u.restore(state)
class SharedMetrics:
"""Holds an indirect reference to a (shared) metrics context.
This is used by LocalIterator.union() to point the metrics contexts of
entirely separate iterator chains to the same underlying context."""
def __init__(self,
metrics: MetricsContext = None,
parents: List["SharedMetrics"] = None):
self.metrics = metrics or MetricsContext()
self.parents = parents or []
self.set(self.metrics)
def set(self, metrics):
"""Recursively set self and parents to point to the same metrics."""
self.metrics = metrics
for parent in self.parents:
parent.set(metrics)
def get(self):
return self.metrics