[rllib] Enable performance metrics reporting for RLlib pipelines, add A3C (#7299)

This commit is contained in:
Eric Liang
2020-02-28 16:44:17 -08:00
committed by GitHub
parent 50145e668d
commit 3c6b94f3f5
14 changed files with 543 additions and 113 deletions
+60 -2
View File
@@ -1,9 +1,68 @@
import time
from collections import Counter
import pytest
import ray
from ray.util.iter import from_items, from_iterators, from_range, \
from_actors, ParallelIteratorWorker
from_actors, ParallelIteratorWorker, LocalIterator
def test_metrics(ray_start_regular_shared):
it = from_items([1, 2, 3, 4], num_shards=1)
it2 = from_items([1, 2, 3, 4], num_shards=1)
def f(x):
metrics = LocalIterator.get_metrics()
metrics.counters["foo"] += x
return metrics.counters["foo"]
it = it.gather_sync().for_each(f)
it2 = it2.gather_sync().for_each(f)
# Context cannot be accessed outside the iterator.
with pytest.raises(ValueError):
LocalIterator.get_metrics()
# Tests iterators have isolated contexts.
assert it.take(4) == [1, 3, 6, 10]
assert it2.take(4) == [1, 3, 6, 10]
# Context cannot be accessed outside the iterator.
with pytest.raises(ValueError):
LocalIterator.get_metrics()
def test_metrics_union(ray_start_regular_shared):
it1 = from_items([1, 2, 3, 4], num_shards=1)
it2 = from_items([1, 2, 3, 4], num_shards=1)
def foo_metrics(x):
metrics = LocalIterator.get_metrics()
metrics.counters["foo"] += x
return metrics.counters["foo"]
def bar_metrics(x):
metrics = LocalIterator.get_metrics()
metrics.counters["bar"] += 100
return metrics.counters["bar"]
def verify_metrics(x):
metrics = LocalIterator.get_metrics()
metrics.counters["n"] += 1
# Check the unioned iterator gets a new metric context.
assert "foo" not in metrics.counters
assert "bar" not in metrics.counters
# Check parent metrics are accessible.
if metrics.counters["n"] > 2:
assert "foo" in metrics.parent_metrics[0].counters
assert "bar" in metrics.parent_metrics[1].counters
return x
it1 = it1.gather_async().for_each(foo_metrics)
it2 = it2.gather_async().for_each(bar_metrics)
it3 = it1.union(it2, deterministic=True)
it3 = it3.for_each(verify_metrics)
assert it3.take(10) == [1, 100, 3, 200, 6, 300, 10, 400]
def test_from_items(ray_start_regular_shared):
@@ -271,6 +330,5 @@ def test_serialization(ray_start_regular_shared):
if __name__ == "__main__":
import pytest
import sys
sys.exit(pytest.main(["-v", __file__]))
+103 -11
View File
@@ -1,8 +1,10 @@
import collections
import random
import threading
from typing import TypeVar, Generic, Iterable, List, Callable, Any
import ray
from ray.util.iter_metrics import MetricsContext
# The type of an iterator element.
T = TypeVar("T")
@@ -150,6 +152,8 @@ class ParallelIterator(Generic[T]):
"""
def __init__(self, actor_sets: List["_ActorSet"], name: str):
"""Create a parallel iterator (this is an internal function)."""
# We track multiple sets of actors to support parallel .union().
self.actor_sets = actor_sets
self.name = name
@@ -415,7 +419,7 @@ class ParallelIterator(Generic[T]):
futures = [a.par_iter_next.remote() for a in active]
name = "{}.batch_across_shards()".format(self)
return LocalIterator(base_iterator, name=name)
return LocalIterator(base_iterator, MetricsContext(), name=name)
def gather_async(self) -> "LocalIterator[T]":
"""Returns a local iterable for asynchronous iteration.
@@ -433,6 +437,8 @@ class ParallelIterator(Generic[T]):
... 1
"""
metrics = MetricsContext()
def base_iterator(timeout=None):
all_actors = []
for actor_set in self.actor_sets:
@@ -456,6 +462,7 @@ class ParallelIterator(Generic[T]):
for obj_id in ready:
actor = futures.pop(obj_id)
try:
metrics.cur_actor = actor
yield ray.get(obj_id)
futures[actor.par_iter_next.remote()] = actor
except StopIteration:
@@ -465,7 +472,7 @@ class ParallelIterator(Generic[T]):
yield _NextValueNotReady()
name = "{}.gather_async()".format(self)
return LocalIterator(base_iterator, name=name)
return LocalIterator(base_iterator, metrics, name=name)
def take(self, n: int) -> List[T]:
"""Return up to the first n items from this iterator."""
@@ -528,7 +535,7 @@ class ParallelIterator(Generic[T]):
break
name = self.name + ".shard[{}]".format(shard_index)
return LocalIterator(base_iterator, name=name)
return LocalIterator(base_iterator, MetricsContext(), name=name)
class LocalIterator(Generic[T]):
@@ -541,8 +548,16 @@ class LocalIterator(Generic[T]):
tasks and actors. However, it should be read from at most one process at
a time."""
# If a function passed to LocalIterator.for_each() has this method,
# we will call it at the beginning of each data fetch call. This can be
# used to measure the underlying wait latency for measurement purposes.
ON_FETCH_START_HOOK_NAME = "_on_fetch_start"
thread_local = threading.local()
def __init__(self,
base_iterator: Callable[[], Iterable[T]],
metrics: MetricsContext,
local_transforms: List[Callable[[Iterable], Any]] = None,
timeout: int = None,
name=None):
@@ -552,6 +567,8 @@ class LocalIterator(Generic[T]):
base_iterator (func): A function that produces the base iterator.
This is a function so that we can ensure LocalIterator is
serializable.
metrics (MetricsContext): Existing metrics context or a new
context. Should be the same for each chained iterator.
local_transforms (list): A list of transformation functions to be
applied on top of the base iterator. When iteration begins, we
create the base iterator and apply these functions. This lazy
@@ -565,14 +582,43 @@ class LocalIterator(Generic[T]):
self.base_iterator = base_iterator
self.built_iterator = None
self.local_transforms = local_transforms or []
self.metrics = metrics
self.timeout = timeout
self.name = name or "unknown"
@staticmethod
def get_metrics() -> MetricsContext:
"""Return the current metrics context.
This can only be called within an iterator function."""
if (not hasattr(LocalIterator.thread_local, "metrics")
or LocalIterator.thread_local.metrics is None):
raise ValueError("Cannot access context outside an iterator.")
return LocalIterator.thread_local.metrics
def _build_once(self):
if self.built_iterator is None:
it = iter(self.base_iterator(self.timeout))
for fn in self.local_transforms:
it = fn(it)
# This sets the iterator context during iterator execution, and
# clears it after so that multiple iterators can be used at a time.
def set_restore_context(it):
if hasattr(self.thread_local, "metrics"):
prev_metrics = self.thread_local.metrics
else:
prev_metrics = None
self.thread_local.metrics = self.metrics
try:
for item in it:
self.thread_local.metrics = prev_metrics
yield item
self.thread_local.metrics = self.metrics
finally:
self.thread_local.metrics = prev_metrics
it = set_restore_context(it)
self.built_iterator = it
def __iter__(self):
@@ -597,8 +643,28 @@ class LocalIterator(Generic[T]):
else:
yield fn(item)
if hasattr(fn, LocalIterator.ON_FETCH_START_HOOK_NAME):
unwrapped = apply_foreach
def add_wait_hooks(it):
it = unwrapped(it)
new_item = True
while True:
# Avoids calling on_fetch_start repeatedly if we are
# yielding _NextValueNotReady.
if new_item:
fn._on_fetch_start()
new_item = False
item = next(it)
if not isinstance(item, _NextValueNotReady):
new_item = True
yield item
apply_foreach = add_wait_hooks
return LocalIterator(
self.base_iterator,
self.metrics,
self.local_transforms + [apply_foreach],
name=self.name + ".for_each()")
@@ -610,6 +676,7 @@ class LocalIterator(Generic[T]):
return LocalIterator(
self.base_iterator,
self.metrics,
self.local_transforms + [apply_filter],
name=self.name + ".filter()")
@@ -629,6 +696,7 @@ class LocalIterator(Generic[T]):
return LocalIterator(
self.base_iterator,
self.metrics,
self.local_transforms + [apply_batch],
name=self.name + ".batch({})".format(n))
@@ -643,6 +711,7 @@ class LocalIterator(Generic[T]):
return LocalIterator(
self.base_iterator,
self.metrics,
self.local_transforms + [apply_flatten],
name=self.name + ".flatten()")
@@ -680,6 +749,7 @@ class LocalIterator(Generic[T]):
return LocalIterator(
self.base_iterator,
self.metrics,
self.local_transforms + [apply_shuffle],
name=self.name +
".shuffle(shuffle_buffer_size={}, seed={})".format(
@@ -709,12 +779,13 @@ class LocalIterator(Generic[T]):
if i >= n:
break
def union(self, other: "LocalIterator[T]") -> "LocalIterator[T]":
def union(self, other: "LocalIterator[T]",
deterministic: bool = False) -> "LocalIterator[T]":
"""Return an iterator that is the union of this and the other.
There are no ordering guarantees between the two iterators. We make a
best-effort attempt to return items from both as they become ready,
preventing starvation of any particular iterator.
If deterministic=True, we alternate between reading from one iterator
and the other. Otherwise we return items from iterators as they
become ready.
"""
if not isinstance(other, LocalIterator):
@@ -722,10 +793,21 @@ class LocalIterator(Generic[T]):
"other must be of type LocalIterator, got {}".format(
type(other)))
if deterministic:
timeout = None
else:
timeout = 0
it1 = LocalIterator(
self.base_iterator, self.local_transforms, timeout=0)
self.base_iterator,
self.metrics,
self.local_transforms,
timeout=timeout)
it2 = LocalIterator(
other.base_iterator, other.local_transforms, timeout=0)
other.base_iterator,
other.metrics,
other.local_transforms,
timeout=timeout)
active = [it1, it2]
def build_union(timeout=None):
@@ -740,13 +822,22 @@ class LocalIterator(Generic[T]):
break
else:
yield item
if deterministic:
break
except StopIteration:
active.remove(it)
if not active:
break
# TODO(ekl) is this the best way to represent union() of metrics?
new_ctx = MetricsContext()
new_ctx.parent_metrics.append(self.metrics)
new_ctx.parent_metrics.append(other.metrics)
return LocalIterator(
build_union, [], name="LocalUnion[{}, {}]".format(self, other))
build_union,
new_ctx, [],
name="LocalUnion[{}, {}]".format(self, other))
class ParallelIteratorWorker(object):
@@ -792,7 +883,8 @@ class ParallelIteratorWorker(object):
def par_iter_init(self, transforms):
"""Implements ParallelIterator worker init."""
it = LocalIterator(lambda timeout: self.item_generator)
it = LocalIterator(lambda timeout: self.item_generator,
MetricsContext())
for fn in transforms:
it = fn(it)
assert it is not None, fn
+48
View File
@@ -0,0 +1,48 @@
import collections
from ray.util.timer import _Timer
class MetricsContext:
"""Metrics context object for a local iterator.
This object is accessible by all operators of a local iterator. It can be
used to store and retrieve global execution metrics for the iterator.
It can be accessed by calling LocalIterator.get_metrics(), which is only
allowable inside iterator functions.
Attributes:
counters (defaultdict): dict storing increasing metrics.
timers (defaultdict): dict storing latency timers.
info (dict): dict storing misc metric values.
current_actor (ActorHandle): reference to the actor handle that
produced the current iterator output. This is automatically set
for gather_async().
parent_metrics (list): list of other MetricsContexts that have been
attached to this due to LocalIterator.union().
"""
def __init__(self):
self.counters = collections.defaultdict(int)
self.timers = collections.defaultdict(_Timer)
self.info = {}
self.current_actor = None
self.parent_metrics = []
def save(self):
"""Return a serializable copy of this context."""
return {
"counters": dict(self.counters),
"info": dict(self.info),
"timers": None, # TODO(ekl) consider persisting timers too
"parent_state": [u.save() for u in self.parent_metrics],
}
def restore(self, values):
"""Restores state given the output of save()."""
self.counters.clear()
self.counters.update(values["counters"])
self.timers.clear()
self.info = values["info"]
for u, state in zip(self.parent_metrics, values["parent_state"]):
u.restore(state)
+62
View File
@@ -0,0 +1,62 @@
import numpy as np
import time
class _Timer:
"""A running stat for conveniently logging the duration of a code block.
Example:
wait_timer = TimerStat()
with wait_timer:
ray.wait(...)
Note that this class is *not* thread-safe.
"""
def __init__(self, window_size=10):
self._window_size = window_size
self._samples = []
self._units_processed = []
self._start_time = None
self._total_time = 0.0
self.count = 0
def __enter__(self):
assert self._start_time is None, "concurrent updates not supported"
self._start_time = time.time()
def __exit__(self, exc_type, exc_value, tb):
assert self._start_time is not None
time_delta = time.time() - self._start_time
self.push(time_delta)
self._start_time = None
def push(self, time_delta):
self._samples.append(time_delta)
if len(self._samples) > self._window_size:
self._samples.pop(0)
self.count += 1
self._total_time += time_delta
def push_units_processed(self, n):
self._units_processed.append(n)
if len(self._units_processed) > self._window_size:
self._units_processed.pop(0)
def has_units_processed(self):
return len(self._units_processed) > 0
@property
def mean(self):
return np.mean(self._samples)
@property
def mean_units_processed(self):
return float(np.mean(self._units_processed))
@property
def mean_throughput(self):
time_total = sum(self._samples)
if not time_total:
return 0.0
return sum(self._units_processed) / time_total