mirror of
https://github.com/wassname/ray.git
synced 2026-06-28 14:12:00 +08:00
[rllib] Fix shared metrics context in parallel iterators (#7666)
* debug * build * update * wip * wpi * update * recurisve sync * comment * stream * fix * Update .travis.yml
This commit is contained in:
@@ -73,6 +73,46 @@ def test_metrics_union(ray_start_regular_shared):
|
||||
assert it3.take(10) == [1, 100, 3, 200, 6, 300, 10, 400]
|
||||
|
||||
|
||||
def test_metrics_union_recursive(ray_start_regular_shared):
|
||||
it1 = from_items([1, 2, 3, 4], num_shards=1)
|
||||
it2 = from_items([1, 2, 3, 4], num_shards=1)
|
||||
it3 = from_items([1, 2, 3, 4], num_shards=1)
|
||||
|
||||
def foo_metrics(x):
|
||||
metrics = LocalIterator.get_metrics()
|
||||
metrics.counters["foo"] += 1
|
||||
return metrics.counters["foo"]
|
||||
|
||||
def bar_metrics(x):
|
||||
metrics = LocalIterator.get_metrics()
|
||||
metrics.counters["bar"] += 1
|
||||
return metrics.counters["bar"]
|
||||
|
||||
def baz_metrics(x):
|
||||
metrics = LocalIterator.get_metrics()
|
||||
metrics.counters["baz"] += 1
|
||||
return metrics.counters["baz"]
|
||||
|
||||
def verify_metrics(x):
|
||||
metrics = LocalIterator.get_metrics()
|
||||
metrics.counters["n"] += 1
|
||||
# Check the metrics context is shared recursively.
|
||||
print(metrics.counters)
|
||||
if metrics.counters["n"] >= 3:
|
||||
assert "foo" in metrics.counters
|
||||
assert "bar" in metrics.counters
|
||||
assert "baz" in metrics.counters
|
||||
return x
|
||||
|
||||
it1 = it1.gather_async().for_each(foo_metrics)
|
||||
it2 = it2.gather_async().for_each(bar_metrics)
|
||||
it3 = it3.gather_async().for_each(baz_metrics)
|
||||
it12 = it1.union(it2, deterministic=True)
|
||||
it123 = it12.union(it3, deterministic=True)
|
||||
out = it123.for_each(verify_metrics)
|
||||
assert out.take(20) == [1, 1, 1, 2, 2, 3, 2, 4, 3, 3, 4, 4]
|
||||
|
||||
|
||||
def test_from_items(ray_start_regular_shared):
|
||||
it = from_items([1, 2, 3, 4])
|
||||
assert repr(it) == "ParallelIterator[from_items[int, 4, shards=2]]"
|
||||
|
||||
+44
-40
@@ -1,10 +1,11 @@
|
||||
from contextlib import contextmanager
|
||||
import collections
|
||||
import random
|
||||
import threading
|
||||
from typing import TypeVar, Generic, Iterable, List, Callable, Any
|
||||
|
||||
import ray
|
||||
from ray.util.iter_metrics import MetricsContext
|
||||
from ray.util.iter_metrics import MetricsContext, SharedMetrics
|
||||
|
||||
# The type of an iterator element.
|
||||
T = TypeVar("T")
|
||||
@@ -412,7 +413,7 @@ class ParallelIterator(Generic[T]):
|
||||
futures = [a.par_iter_next.remote() for a in active]
|
||||
|
||||
name = "{}.batch_across_shards()".format(self)
|
||||
return LocalIterator(base_iterator, MetricsContext(), name=name)
|
||||
return LocalIterator(base_iterator, SharedMetrics(), name=name)
|
||||
|
||||
def gather_async(self, async_queue_depth=1) -> "LocalIterator[T]":
|
||||
"""Returns a local iterable for asynchronous iteration.
|
||||
@@ -438,8 +439,10 @@ class ParallelIterator(Generic[T]):
|
||||
if async_queue_depth < 1:
|
||||
raise ValueError("queue depth must be positive")
|
||||
|
||||
# Forward reference to the returned iterator.
|
||||
local_iter = None
|
||||
|
||||
def base_iterator(timeout=None):
|
||||
metrics = LocalIterator.get_metrics()
|
||||
all_actors = []
|
||||
for actor_set in self.actor_sets:
|
||||
actor_set.init_actors()
|
||||
@@ -463,7 +466,7 @@ class ParallelIterator(Generic[T]):
|
||||
for obj_id in ready:
|
||||
actor = futures.pop(obj_id)
|
||||
try:
|
||||
metrics.current_actor = actor
|
||||
local_iter.shared_metrics.get().current_actor = actor
|
||||
yield ray.get(obj_id)
|
||||
futures[actor.par_iter_next.remote()] = actor
|
||||
except StopIteration:
|
||||
@@ -473,7 +476,8 @@ class ParallelIterator(Generic[T]):
|
||||
yield _NextValueNotReady()
|
||||
|
||||
name = "{}.gather_async()".format(self)
|
||||
return LocalIterator(base_iterator, MetricsContext(), name=name)
|
||||
local_iter = LocalIterator(base_iterator, SharedMetrics(), name=name)
|
||||
return local_iter
|
||||
|
||||
def take(self, n: int) -> List[T]:
|
||||
"""Return up to the first n items from this iterator."""
|
||||
@@ -540,7 +544,7 @@ class ParallelIterator(Generic[T]):
|
||||
break
|
||||
|
||||
name = self.name + ".shard[{}]".format(shard_index)
|
||||
return LocalIterator(base_iterator, MetricsContext(), name=name)
|
||||
return LocalIterator(base_iterator, SharedMetrics(), name=name)
|
||||
|
||||
|
||||
class LocalIterator(Generic[T]):
|
||||
@@ -562,7 +566,7 @@ class LocalIterator(Generic[T]):
|
||||
|
||||
def __init__(self,
|
||||
base_iterator: Callable[[], Iterable[T]],
|
||||
metrics: MetricsContext,
|
||||
shared_metrics: SharedMetrics,
|
||||
local_transforms: List[Callable[[Iterable], Any]] = None,
|
||||
timeout: int = None,
|
||||
name=None):
|
||||
@@ -572,7 +576,7 @@ class LocalIterator(Generic[T]):
|
||||
base_iterator (func): A function that produces the base iterator.
|
||||
This is a function so that we can ensure LocalIterator is
|
||||
serializable.
|
||||
metrics (MetricsContext): Existing metrics context or a new
|
||||
shared_metrics (SharedMetrics): Existing metrics context or a new
|
||||
context. Should be the same for each chained iterator.
|
||||
local_transforms (list): A list of transformation functions to be
|
||||
applied on top of the base iterator. When iteration begins, we
|
||||
@@ -584,10 +588,11 @@ class LocalIterator(Generic[T]):
|
||||
blocking.
|
||||
name (str): Optional name for this iterator.
|
||||
"""
|
||||
assert isinstance(shared_metrics, SharedMetrics)
|
||||
self.base_iterator = base_iterator
|
||||
self.built_iterator = None
|
||||
self.local_transforms = local_transforms or []
|
||||
self.metrics = metrics
|
||||
self.shared_metrics = shared_metrics
|
||||
self.timeout = timeout
|
||||
self.name = name or "unknown"
|
||||
|
||||
@@ -606,26 +611,20 @@ class LocalIterator(Generic[T]):
|
||||
it = iter(self.base_iterator(self.timeout))
|
||||
for fn in self.local_transforms:
|
||||
it = fn(it)
|
||||
|
||||
# This sets the iterator context during iterator execution, and
|
||||
# clears it after so that multiple iterators can be used at a time.
|
||||
def set_restore_context(it):
|
||||
if hasattr(self.thread_local, "metrics"):
|
||||
prev_metrics = self.thread_local.metrics
|
||||
else:
|
||||
prev_metrics = None
|
||||
self.thread_local.metrics = self.metrics
|
||||
try:
|
||||
for item in it:
|
||||
self.thread_local.metrics = prev_metrics
|
||||
yield item
|
||||
self.thread_local.metrics = self.metrics
|
||||
finally:
|
||||
self.thread_local.metrics = prev_metrics
|
||||
|
||||
it = set_restore_context(it)
|
||||
self.built_iterator = it
|
||||
|
||||
@contextmanager
|
||||
def _metrics_context(self):
|
||||
if hasattr(self.thread_local, "metrics"):
|
||||
prev_metrics = self.thread_local.metrics
|
||||
else:
|
||||
prev_metrics = None
|
||||
try:
|
||||
self.thread_local.metrics = self.shared_metrics.get()
|
||||
yield
|
||||
finally:
|
||||
self.thread_local.metrics = prev_metrics
|
||||
|
||||
def __iter__(self):
|
||||
self._build_once()
|
||||
return self.built_iterator
|
||||
@@ -649,7 +648,8 @@ class LocalIterator(Generic[T]):
|
||||
# Keep retrying the function until it returns a valid
|
||||
# value. This allows for non-blocking functions.
|
||||
while True:
|
||||
result = fn(item)
|
||||
with self._metrics_context():
|
||||
result = fn(item)
|
||||
yield result
|
||||
if not isinstance(result, _NextValueNotReady):
|
||||
break
|
||||
@@ -664,7 +664,8 @@ class LocalIterator(Generic[T]):
|
||||
# Avoids calling on_fetch_start repeatedly if we are
|
||||
# yielding _NextValueNotReady.
|
||||
if new_item:
|
||||
fn._on_fetch_start()
|
||||
with self._metrics_context():
|
||||
fn._on_fetch_start()
|
||||
new_item = False
|
||||
item = next(it)
|
||||
if not isinstance(item, _NextValueNotReady):
|
||||
@@ -675,19 +676,20 @@ class LocalIterator(Generic[T]):
|
||||
|
||||
return LocalIterator(
|
||||
self.base_iterator,
|
||||
self.metrics,
|
||||
self.shared_metrics,
|
||||
self.local_transforms + [apply_foreach],
|
||||
name=self.name + ".for_each()")
|
||||
|
||||
def filter(self, fn: Callable[[T], bool]) -> "LocalIterator[T]":
|
||||
def apply_filter(it):
|
||||
for item in it:
|
||||
if isinstance(item, _NextValueNotReady) or fn(item):
|
||||
yield item
|
||||
with self._metrics_context():
|
||||
if isinstance(item, _NextValueNotReady) or fn(item):
|
||||
yield item
|
||||
|
||||
return LocalIterator(
|
||||
self.base_iterator,
|
||||
self.metrics,
|
||||
self.shared_metrics,
|
||||
self.local_transforms + [apply_filter],
|
||||
name=self.name + ".filter()")
|
||||
|
||||
@@ -707,7 +709,7 @@ class LocalIterator(Generic[T]):
|
||||
|
||||
return LocalIterator(
|
||||
self.base_iterator,
|
||||
self.metrics,
|
||||
self.shared_metrics,
|
||||
self.local_transforms + [apply_batch],
|
||||
name=self.name + ".batch({})".format(n))
|
||||
|
||||
@@ -722,7 +724,7 @@ class LocalIterator(Generic[T]):
|
||||
|
||||
return LocalIterator(
|
||||
self.base_iterator,
|
||||
self.metrics,
|
||||
self.shared_metrics,
|
||||
self.local_transforms + [apply_flatten],
|
||||
name=self.name + ".flatten()")
|
||||
|
||||
@@ -760,7 +762,7 @@ class LocalIterator(Generic[T]):
|
||||
|
||||
return LocalIterator(
|
||||
self.base_iterator,
|
||||
self.metrics,
|
||||
self.shared_metrics,
|
||||
self.local_transforms + [apply_shuffle],
|
||||
name=self.name +
|
||||
".shuffle(shuffle_buffer_size={}, seed={})".format(
|
||||
@@ -836,7 +838,7 @@ class LocalIterator(Generic[T]):
|
||||
iterators.append(
|
||||
LocalIterator(
|
||||
make_next(i),
|
||||
self.metrics, [],
|
||||
self.shared_metrics, [],
|
||||
name=self.name + ".duplicate[{}]".format(i)))
|
||||
|
||||
return iterators
|
||||
@@ -862,8 +864,10 @@ class LocalIterator(Generic[T]):
|
||||
timeout = 0
|
||||
|
||||
active = []
|
||||
shared_metrics = MetricsContext()
|
||||
for it in [self] + list(others):
|
||||
parent_iters = [self] + list(others)
|
||||
shared_metrics = SharedMetrics(
|
||||
parents=[p.shared_metrics for p in parent_iters])
|
||||
for it in parent_iters:
|
||||
active.append(
|
||||
LocalIterator(
|
||||
it.base_iterator,
|
||||
@@ -945,7 +949,7 @@ class ParallelIteratorWorker(object):
|
||||
def par_iter_init(self, transforms):
|
||||
"""Implements ParallelIterator worker init."""
|
||||
it = LocalIterator(lambda timeout: self.item_generator,
|
||||
MetricsContext())
|
||||
SharedMetrics())
|
||||
for fn in transforms:
|
||||
it = fn(it)
|
||||
assert it is not None, fn
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import collections
|
||||
from typing import List
|
||||
|
||||
from ray.util.timer import _Timer
|
||||
|
||||
@@ -18,8 +19,6 @@ class MetricsContext:
|
||||
current_actor (ActorHandle): reference to the actor handle that
|
||||
produced the current iterator output. This is automatically set
|
||||
for gather_async().
|
||||
parent_metrics (list): list of other MetricsContexts that have been
|
||||
attached to this due to LocalIterator.union().
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
@@ -27,7 +26,6 @@ class MetricsContext:
|
||||
self.timers = collections.defaultdict(_Timer)
|
||||
self.info = {}
|
||||
self.current_actor = None
|
||||
self.parent_metrics = []
|
||||
|
||||
def save(self):
|
||||
"""Return a serializable copy of this context."""
|
||||
@@ -35,7 +33,6 @@ class MetricsContext:
|
||||
"counters": dict(self.counters),
|
||||
"info": dict(self.info),
|
||||
"timers": None, # TODO(ekl) consider persisting timers too
|
||||
"parent_state": [u.save() for u in self.parent_metrics],
|
||||
}
|
||||
|
||||
def restore(self, values):
|
||||
@@ -44,5 +41,26 @@ class MetricsContext:
|
||||
self.counters.update(values["counters"])
|
||||
self.timers.clear()
|
||||
self.info = values["info"]
|
||||
for u, state in zip(self.parent_metrics, values["parent_state"]):
|
||||
u.restore(state)
|
||||
|
||||
|
||||
class SharedMetrics:
|
||||
"""Holds an indirect reference to a (shared) metrics context.
|
||||
|
||||
This is used by LocalIterator.union() to point the metrics contexts of
|
||||
entirely separate iterator chains to the same underlying context."""
|
||||
|
||||
def __init__(self,
|
||||
metrics: MetricsContext = None,
|
||||
parents: List["SharedMetrics"] = None):
|
||||
self.metrics = metrics or MetricsContext()
|
||||
self.parents = parents or []
|
||||
self.set(self.metrics)
|
||||
|
||||
def set(self, metrics):
|
||||
"""Recursively set self and parents to point to the same metrics."""
|
||||
self.metrics = metrics
|
||||
for parent in self.parents:
|
||||
parent.set(metrics)
|
||||
|
||||
def get(self):
|
||||
return self.metrics
|
||||
|
||||
Reference in New Issue
Block a user