mirror of
https://github.com/wassname/ray.git
synced 2026-07-04 03:36:23 +08:00
[rllib] Enable performance metrics reporting for RLlib pipelines, add A3C (#7299)
This commit is contained in:
@@ -1,9 +1,68 @@
|
||||
import time
|
||||
from collections import Counter
|
||||
import pytest
|
||||
|
||||
import ray
|
||||
from ray.util.iter import from_items, from_iterators, from_range, \
|
||||
from_actors, ParallelIteratorWorker
|
||||
from_actors, ParallelIteratorWorker, LocalIterator
|
||||
|
||||
|
||||
def test_metrics(ray_start_regular_shared):
|
||||
it = from_items([1, 2, 3, 4], num_shards=1)
|
||||
it2 = from_items([1, 2, 3, 4], num_shards=1)
|
||||
|
||||
def f(x):
|
||||
metrics = LocalIterator.get_metrics()
|
||||
metrics.counters["foo"] += x
|
||||
return metrics.counters["foo"]
|
||||
|
||||
it = it.gather_sync().for_each(f)
|
||||
it2 = it2.gather_sync().for_each(f)
|
||||
|
||||
# Context cannot be accessed outside the iterator.
|
||||
with pytest.raises(ValueError):
|
||||
LocalIterator.get_metrics()
|
||||
|
||||
# Tests iterators have isolated contexts.
|
||||
assert it.take(4) == [1, 3, 6, 10]
|
||||
assert it2.take(4) == [1, 3, 6, 10]
|
||||
|
||||
# Context cannot be accessed outside the iterator.
|
||||
with pytest.raises(ValueError):
|
||||
LocalIterator.get_metrics()
|
||||
|
||||
|
||||
def test_metrics_union(ray_start_regular_shared):
|
||||
it1 = from_items([1, 2, 3, 4], num_shards=1)
|
||||
it2 = from_items([1, 2, 3, 4], num_shards=1)
|
||||
|
||||
def foo_metrics(x):
|
||||
metrics = LocalIterator.get_metrics()
|
||||
metrics.counters["foo"] += x
|
||||
return metrics.counters["foo"]
|
||||
|
||||
def bar_metrics(x):
|
||||
metrics = LocalIterator.get_metrics()
|
||||
metrics.counters["bar"] += 100
|
||||
return metrics.counters["bar"]
|
||||
|
||||
def verify_metrics(x):
|
||||
metrics = LocalIterator.get_metrics()
|
||||
metrics.counters["n"] += 1
|
||||
# Check the unioned iterator gets a new metric context.
|
||||
assert "foo" not in metrics.counters
|
||||
assert "bar" not in metrics.counters
|
||||
# Check parent metrics are accessible.
|
||||
if metrics.counters["n"] > 2:
|
||||
assert "foo" in metrics.parent_metrics[0].counters
|
||||
assert "bar" in metrics.parent_metrics[1].counters
|
||||
return x
|
||||
|
||||
it1 = it1.gather_async().for_each(foo_metrics)
|
||||
it2 = it2.gather_async().for_each(bar_metrics)
|
||||
it3 = it1.union(it2, deterministic=True)
|
||||
it3 = it3.for_each(verify_metrics)
|
||||
assert it3.take(10) == [1, 100, 3, 200, 6, 300, 10, 400]
|
||||
|
||||
|
||||
def test_from_items(ray_start_regular_shared):
|
||||
@@ -271,6 +330,5 @@ def test_serialization(ray_start_regular_shared):
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import pytest
|
||||
import sys
|
||||
sys.exit(pytest.main(["-v", __file__]))
|
||||
|
||||
+103
-11
@@ -1,8 +1,10 @@
|
||||
import collections
|
||||
import random
|
||||
import threading
|
||||
from typing import TypeVar, Generic, Iterable, List, Callable, Any
|
||||
|
||||
import ray
|
||||
from ray.util.iter_metrics import MetricsContext
|
||||
|
||||
# The type of an iterator element.
|
||||
T = TypeVar("T")
|
||||
@@ -150,6 +152,8 @@ class ParallelIterator(Generic[T]):
|
||||
"""
|
||||
|
||||
def __init__(self, actor_sets: List["_ActorSet"], name: str):
|
||||
"""Create a parallel iterator (this is an internal function)."""
|
||||
|
||||
# We track multiple sets of actors to support parallel .union().
|
||||
self.actor_sets = actor_sets
|
||||
self.name = name
|
||||
@@ -415,7 +419,7 @@ class ParallelIterator(Generic[T]):
|
||||
futures = [a.par_iter_next.remote() for a in active]
|
||||
|
||||
name = "{}.batch_across_shards()".format(self)
|
||||
return LocalIterator(base_iterator, name=name)
|
||||
return LocalIterator(base_iterator, MetricsContext(), name=name)
|
||||
|
||||
def gather_async(self) -> "LocalIterator[T]":
|
||||
"""Returns a local iterable for asynchronous iteration.
|
||||
@@ -433,6 +437,8 @@ class ParallelIterator(Generic[T]):
|
||||
... 1
|
||||
"""
|
||||
|
||||
metrics = MetricsContext()
|
||||
|
||||
def base_iterator(timeout=None):
|
||||
all_actors = []
|
||||
for actor_set in self.actor_sets:
|
||||
@@ -456,6 +462,7 @@ class ParallelIterator(Generic[T]):
|
||||
for obj_id in ready:
|
||||
actor = futures.pop(obj_id)
|
||||
try:
|
||||
metrics.cur_actor = actor
|
||||
yield ray.get(obj_id)
|
||||
futures[actor.par_iter_next.remote()] = actor
|
||||
except StopIteration:
|
||||
@@ -465,7 +472,7 @@ class ParallelIterator(Generic[T]):
|
||||
yield _NextValueNotReady()
|
||||
|
||||
name = "{}.gather_async()".format(self)
|
||||
return LocalIterator(base_iterator, name=name)
|
||||
return LocalIterator(base_iterator, metrics, name=name)
|
||||
|
||||
def take(self, n: int) -> List[T]:
|
||||
"""Return up to the first n items from this iterator."""
|
||||
@@ -528,7 +535,7 @@ class ParallelIterator(Generic[T]):
|
||||
break
|
||||
|
||||
name = self.name + ".shard[{}]".format(shard_index)
|
||||
return LocalIterator(base_iterator, name=name)
|
||||
return LocalIterator(base_iterator, MetricsContext(), name=name)
|
||||
|
||||
|
||||
class LocalIterator(Generic[T]):
|
||||
@@ -541,8 +548,16 @@ class LocalIterator(Generic[T]):
|
||||
tasks and actors. However, it should be read from at most one process at
|
||||
a time."""
|
||||
|
||||
# If a function passed to LocalIterator.for_each() has this method,
|
||||
# we will call it at the beginning of each data fetch call. This can be
|
||||
# used to measure the underlying wait latency for measurement purposes.
|
||||
ON_FETCH_START_HOOK_NAME = "_on_fetch_start"
|
||||
|
||||
thread_local = threading.local()
|
||||
|
||||
def __init__(self,
|
||||
base_iterator: Callable[[], Iterable[T]],
|
||||
metrics: MetricsContext,
|
||||
local_transforms: List[Callable[[Iterable], Any]] = None,
|
||||
timeout: int = None,
|
||||
name=None):
|
||||
@@ -552,6 +567,8 @@ class LocalIterator(Generic[T]):
|
||||
base_iterator (func): A function that produces the base iterator.
|
||||
This is a function so that we can ensure LocalIterator is
|
||||
serializable.
|
||||
metrics (MetricsContext): Existing metrics context or a new
|
||||
context. Should be the same for each chained iterator.
|
||||
local_transforms (list): A list of transformation functions to be
|
||||
applied on top of the base iterator. When iteration begins, we
|
||||
create the base iterator and apply these functions. This lazy
|
||||
@@ -565,14 +582,43 @@ class LocalIterator(Generic[T]):
|
||||
self.base_iterator = base_iterator
|
||||
self.built_iterator = None
|
||||
self.local_transforms = local_transforms or []
|
||||
self.metrics = metrics
|
||||
self.timeout = timeout
|
||||
self.name = name or "unknown"
|
||||
|
||||
@staticmethod
|
||||
def get_metrics() -> MetricsContext:
|
||||
"""Return the current metrics context.
|
||||
|
||||
This can only be called within an iterator function."""
|
||||
if (not hasattr(LocalIterator.thread_local, "metrics")
|
||||
or LocalIterator.thread_local.metrics is None):
|
||||
raise ValueError("Cannot access context outside an iterator.")
|
||||
return LocalIterator.thread_local.metrics
|
||||
|
||||
def _build_once(self):
|
||||
if self.built_iterator is None:
|
||||
it = iter(self.base_iterator(self.timeout))
|
||||
for fn in self.local_transforms:
|
||||
it = fn(it)
|
||||
|
||||
# This sets the iterator context during iterator execution, and
|
||||
# clears it after so that multiple iterators can be used at a time.
|
||||
def set_restore_context(it):
|
||||
if hasattr(self.thread_local, "metrics"):
|
||||
prev_metrics = self.thread_local.metrics
|
||||
else:
|
||||
prev_metrics = None
|
||||
self.thread_local.metrics = self.metrics
|
||||
try:
|
||||
for item in it:
|
||||
self.thread_local.metrics = prev_metrics
|
||||
yield item
|
||||
self.thread_local.metrics = self.metrics
|
||||
finally:
|
||||
self.thread_local.metrics = prev_metrics
|
||||
|
||||
it = set_restore_context(it)
|
||||
self.built_iterator = it
|
||||
|
||||
def __iter__(self):
|
||||
@@ -597,8 +643,28 @@ class LocalIterator(Generic[T]):
|
||||
else:
|
||||
yield fn(item)
|
||||
|
||||
if hasattr(fn, LocalIterator.ON_FETCH_START_HOOK_NAME):
|
||||
unwrapped = apply_foreach
|
||||
|
||||
def add_wait_hooks(it):
|
||||
it = unwrapped(it)
|
||||
new_item = True
|
||||
while True:
|
||||
# Avoids calling on_fetch_start repeatedly if we are
|
||||
# yielding _NextValueNotReady.
|
||||
if new_item:
|
||||
fn._on_fetch_start()
|
||||
new_item = False
|
||||
item = next(it)
|
||||
if not isinstance(item, _NextValueNotReady):
|
||||
new_item = True
|
||||
yield item
|
||||
|
||||
apply_foreach = add_wait_hooks
|
||||
|
||||
return LocalIterator(
|
||||
self.base_iterator,
|
||||
self.metrics,
|
||||
self.local_transforms + [apply_foreach],
|
||||
name=self.name + ".for_each()")
|
||||
|
||||
@@ -610,6 +676,7 @@ class LocalIterator(Generic[T]):
|
||||
|
||||
return LocalIterator(
|
||||
self.base_iterator,
|
||||
self.metrics,
|
||||
self.local_transforms + [apply_filter],
|
||||
name=self.name + ".filter()")
|
||||
|
||||
@@ -629,6 +696,7 @@ class LocalIterator(Generic[T]):
|
||||
|
||||
return LocalIterator(
|
||||
self.base_iterator,
|
||||
self.metrics,
|
||||
self.local_transforms + [apply_batch],
|
||||
name=self.name + ".batch({})".format(n))
|
||||
|
||||
@@ -643,6 +711,7 @@ class LocalIterator(Generic[T]):
|
||||
|
||||
return LocalIterator(
|
||||
self.base_iterator,
|
||||
self.metrics,
|
||||
self.local_transforms + [apply_flatten],
|
||||
name=self.name + ".flatten()")
|
||||
|
||||
@@ -680,6 +749,7 @@ class LocalIterator(Generic[T]):
|
||||
|
||||
return LocalIterator(
|
||||
self.base_iterator,
|
||||
self.metrics,
|
||||
self.local_transforms + [apply_shuffle],
|
||||
name=self.name +
|
||||
".shuffle(shuffle_buffer_size={}, seed={})".format(
|
||||
@@ -709,12 +779,13 @@ class LocalIterator(Generic[T]):
|
||||
if i >= n:
|
||||
break
|
||||
|
||||
def union(self, other: "LocalIterator[T]") -> "LocalIterator[T]":
|
||||
def union(self, other: "LocalIterator[T]",
|
||||
deterministic: bool = False) -> "LocalIterator[T]":
|
||||
"""Return an iterator that is the union of this and the other.
|
||||
|
||||
There are no ordering guarantees between the two iterators. We make a
|
||||
best-effort attempt to return items from both as they become ready,
|
||||
preventing starvation of any particular iterator.
|
||||
If deterministic=True, we alternate between reading from one iterator
|
||||
and the other. Otherwise we return items from iterators as they
|
||||
become ready.
|
||||
"""
|
||||
|
||||
if not isinstance(other, LocalIterator):
|
||||
@@ -722,10 +793,21 @@ class LocalIterator(Generic[T]):
|
||||
"other must be of type LocalIterator, got {}".format(
|
||||
type(other)))
|
||||
|
||||
if deterministic:
|
||||
timeout = None
|
||||
else:
|
||||
timeout = 0
|
||||
|
||||
it1 = LocalIterator(
|
||||
self.base_iterator, self.local_transforms, timeout=0)
|
||||
self.base_iterator,
|
||||
self.metrics,
|
||||
self.local_transforms,
|
||||
timeout=timeout)
|
||||
it2 = LocalIterator(
|
||||
other.base_iterator, other.local_transforms, timeout=0)
|
||||
other.base_iterator,
|
||||
other.metrics,
|
||||
other.local_transforms,
|
||||
timeout=timeout)
|
||||
active = [it1, it2]
|
||||
|
||||
def build_union(timeout=None):
|
||||
@@ -740,13 +822,22 @@ class LocalIterator(Generic[T]):
|
||||
break
|
||||
else:
|
||||
yield item
|
||||
if deterministic:
|
||||
break
|
||||
except StopIteration:
|
||||
active.remove(it)
|
||||
if not active:
|
||||
break
|
||||
|
||||
# TODO(ekl) is this the best way to represent union() of metrics?
|
||||
new_ctx = MetricsContext()
|
||||
new_ctx.parent_metrics.append(self.metrics)
|
||||
new_ctx.parent_metrics.append(other.metrics)
|
||||
|
||||
return LocalIterator(
|
||||
build_union, [], name="LocalUnion[{}, {}]".format(self, other))
|
||||
build_union,
|
||||
new_ctx, [],
|
||||
name="LocalUnion[{}, {}]".format(self, other))
|
||||
|
||||
|
||||
class ParallelIteratorWorker(object):
|
||||
@@ -792,7 +883,8 @@ class ParallelIteratorWorker(object):
|
||||
|
||||
def par_iter_init(self, transforms):
|
||||
"""Implements ParallelIterator worker init."""
|
||||
it = LocalIterator(lambda timeout: self.item_generator)
|
||||
it = LocalIterator(lambda timeout: self.item_generator,
|
||||
MetricsContext())
|
||||
for fn in transforms:
|
||||
it = fn(it)
|
||||
assert it is not None, fn
|
||||
|
||||
@@ -0,0 +1,48 @@
|
||||
import collections
|
||||
|
||||
from ray.util.timer import _Timer
|
||||
|
||||
|
||||
class MetricsContext:
|
||||
"""Metrics context object for a local iterator.
|
||||
|
||||
This object is accessible by all operators of a local iterator. It can be
|
||||
used to store and retrieve global execution metrics for the iterator.
|
||||
It can be accessed by calling LocalIterator.get_metrics(), which is only
|
||||
allowable inside iterator functions.
|
||||
|
||||
Attributes:
|
||||
counters (defaultdict): dict storing increasing metrics.
|
||||
timers (defaultdict): dict storing latency timers.
|
||||
info (dict): dict storing misc metric values.
|
||||
current_actor (ActorHandle): reference to the actor handle that
|
||||
produced the current iterator output. This is automatically set
|
||||
for gather_async().
|
||||
parent_metrics (list): list of other MetricsContexts that have been
|
||||
attached to this due to LocalIterator.union().
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.counters = collections.defaultdict(int)
|
||||
self.timers = collections.defaultdict(_Timer)
|
||||
self.info = {}
|
||||
self.current_actor = None
|
||||
self.parent_metrics = []
|
||||
|
||||
def save(self):
|
||||
"""Return a serializable copy of this context."""
|
||||
return {
|
||||
"counters": dict(self.counters),
|
||||
"info": dict(self.info),
|
||||
"timers": None, # TODO(ekl) consider persisting timers too
|
||||
"parent_state": [u.save() for u in self.parent_metrics],
|
||||
}
|
||||
|
||||
def restore(self, values):
|
||||
"""Restores state given the output of save()."""
|
||||
self.counters.clear()
|
||||
self.counters.update(values["counters"])
|
||||
self.timers.clear()
|
||||
self.info = values["info"]
|
||||
for u, state in zip(self.parent_metrics, values["parent_state"]):
|
||||
u.restore(state)
|
||||
@@ -0,0 +1,62 @@
|
||||
import numpy as np
|
||||
import time
|
||||
|
||||
|
||||
class _Timer:
|
||||
"""A running stat for conveniently logging the duration of a code block.
|
||||
|
||||
Example:
|
||||
wait_timer = TimerStat()
|
||||
with wait_timer:
|
||||
ray.wait(...)
|
||||
|
||||
Note that this class is *not* thread-safe.
|
||||
"""
|
||||
|
||||
def __init__(self, window_size=10):
|
||||
self._window_size = window_size
|
||||
self._samples = []
|
||||
self._units_processed = []
|
||||
self._start_time = None
|
||||
self._total_time = 0.0
|
||||
self.count = 0
|
||||
|
||||
def __enter__(self):
|
||||
assert self._start_time is None, "concurrent updates not supported"
|
||||
self._start_time = time.time()
|
||||
|
||||
def __exit__(self, exc_type, exc_value, tb):
|
||||
assert self._start_time is not None
|
||||
time_delta = time.time() - self._start_time
|
||||
self.push(time_delta)
|
||||
self._start_time = None
|
||||
|
||||
def push(self, time_delta):
|
||||
self._samples.append(time_delta)
|
||||
if len(self._samples) > self._window_size:
|
||||
self._samples.pop(0)
|
||||
self.count += 1
|
||||
self._total_time += time_delta
|
||||
|
||||
def push_units_processed(self, n):
|
||||
self._units_processed.append(n)
|
||||
if len(self._units_processed) > self._window_size:
|
||||
self._units_processed.pop(0)
|
||||
|
||||
def has_units_processed(self):
|
||||
return len(self._units_processed) > 0
|
||||
|
||||
@property
|
||||
def mean(self):
|
||||
return np.mean(self._samples)
|
||||
|
||||
@property
|
||||
def mean_units_processed(self):
|
||||
return float(np.mean(self._units_processed))
|
||||
|
||||
@property
|
||||
def mean_throughput(self):
|
||||
time_total = sum(self._samples)
|
||||
if not time_total:
|
||||
return 0.0
|
||||
return sum(self._units_processed) / time_total
|
||||
Reference in New Issue
Block a user