[rllib] Fix shared metrics context in parallel iterators (#7666)

* debug

* build

* update

* wip

* wpi

* update

* recurisve sync

* comment

* stream

* fix

* Update .travis.yml
This commit is contained in:
Eric Liang
2020-03-22 14:15:01 -07:00
committed by GitHub
parent 2fb219a658
commit 288933ec6b
7 changed files with 131 additions and 72 deletions
+2 -3
View File
@@ -186,7 +186,7 @@ matrix:
- ./ci/suppress_output ./ci/travis/install-ray.sh
script:
- if [ $RAY_CI_RLLIB_AFFECTED != "1" ]; then exit; fi
- travis_wait 60 bazel test --build_tests_only --test_tag_filters=learning_tests --spawn_strategy=local --flaky_test_attempts=3 --nocache_test_results --test_verbose_timeout_warnings --progress_report_interval=100 --show_progress_rate_limit=100 --show_timestamps --test_output=errors rllib/...
- travis_wait 60 bazel test --build_tests_only --test_tag_filters=learning_tests --spawn_strategy=local --flaky_test_attempts=3 --nocache_test_results --test_verbose_timeout_warnings --progress_report_interval=100 --show_progress_rate_limit=100 --show_timestamps --test_output=streamed rllib/...
# RLlib: Learning tests with tf=1.x (from rllib/tuned_examples/regression_tests/*.yaml).
# Requested by Edi (MS): Test all learning capabilities with tf1.x
@@ -207,7 +207,7 @@ matrix:
- ./ci/suppress_output ./ci/travis/install-ray.sh
script:
- if [ $RAY_CI_RLLIB_FULL_AFFECTED != "1" ]; then exit; fi
- travis_wait 60 bazel test --build_tests_only --test_tag_filters=learning_tests --spawn_strategy=local --flaky_test_attempts=3 --nocache_test_results --test_verbose_timeout_warnings --progress_report_interval=100 --show_progress_rate_limit=100 --show_timestamps --test_output=errors rllib/...
- travis_wait 60 bazel test --build_tests_only --test_tag_filters=learning_tests --spawn_strategy=local --flaky_test_attempts=3 --nocache_test_results --test_verbose_timeout_warnings --progress_report_interval=100 --show_progress_rate_limit=100 --show_timestamps --test_output=streamed rllib/...
# RLlib: Quick Agent train.py runs (compilation & running, no(!) learning).
# Agent single tests (compilation, loss-funcs, etc..).
@@ -376,4 +376,3 @@ deploy:
repo: ray-project/ray
all_branches: true
condition: $LINUX_WHEELS = 1 || $MAC_WHEELS = 1
+1 -1
View File
@@ -18,7 +18,7 @@ Objects and methods annotated with ``@PublicAPI`` or ``@DeveloperAPI`` have the
Features
--------
Feature development and upcoming priorities are tracked on the `RLlib project board <https://github.com/ray-project/ray/projects/6>`__ (note that this may not include all development efforts). For discussion of issues and new features, we use the `Ray dev list <https://groups.google.com/forum/#!forum/ray-dev>`__ and `GitHub issues page <https://github.com/ray-project/ray/issues>`__.
Feature development, discussion, and upcoming priorities are tracked on the `GitHub issues page <https://github.com/ray-project/ray/issues>`__ (note that this may not include all development efforts).
Benchmarks
----------
+40
View File
@@ -73,6 +73,46 @@ def test_metrics_union(ray_start_regular_shared):
assert it3.take(10) == [1, 100, 3, 200, 6, 300, 10, 400]
def test_metrics_union_recursive(ray_start_regular_shared):
it1 = from_items([1, 2, 3, 4], num_shards=1)
it2 = from_items([1, 2, 3, 4], num_shards=1)
it3 = from_items([1, 2, 3, 4], num_shards=1)
def foo_metrics(x):
metrics = LocalIterator.get_metrics()
metrics.counters["foo"] += 1
return metrics.counters["foo"]
def bar_metrics(x):
metrics = LocalIterator.get_metrics()
metrics.counters["bar"] += 1
return metrics.counters["bar"]
def baz_metrics(x):
metrics = LocalIterator.get_metrics()
metrics.counters["baz"] += 1
return metrics.counters["baz"]
def verify_metrics(x):
metrics = LocalIterator.get_metrics()
metrics.counters["n"] += 1
# Check the metrics context is shared recursively.
print(metrics.counters)
if metrics.counters["n"] >= 3:
assert "foo" in metrics.counters
assert "bar" in metrics.counters
assert "baz" in metrics.counters
return x
it1 = it1.gather_async().for_each(foo_metrics)
it2 = it2.gather_async().for_each(bar_metrics)
it3 = it3.gather_async().for_each(baz_metrics)
it12 = it1.union(it2, deterministic=True)
it123 = it12.union(it3, deterministic=True)
out = it123.for_each(verify_metrics)
assert out.take(20) == [1, 1, 1, 2, 2, 3, 2, 4, 3, 3, 4, 4]
def test_from_items(ray_start_regular_shared):
it = from_items([1, 2, 3, 4])
assert repr(it) == "ParallelIterator[from_items[int, 4, shards=2]]"
+44 -40
View File
@@ -1,10 +1,11 @@
from contextlib import contextmanager
import collections
import random
import threading
from typing import TypeVar, Generic, Iterable, List, Callable, Any
import ray
from ray.util.iter_metrics import MetricsContext
from ray.util.iter_metrics import MetricsContext, SharedMetrics
# The type of an iterator element.
T = TypeVar("T")
@@ -412,7 +413,7 @@ class ParallelIterator(Generic[T]):
futures = [a.par_iter_next.remote() for a in active]
name = "{}.batch_across_shards()".format(self)
return LocalIterator(base_iterator, MetricsContext(), name=name)
return LocalIterator(base_iterator, SharedMetrics(), name=name)
def gather_async(self, async_queue_depth=1) -> "LocalIterator[T]":
"""Returns a local iterable for asynchronous iteration.
@@ -438,8 +439,10 @@ class ParallelIterator(Generic[T]):
if async_queue_depth < 1:
raise ValueError("queue depth must be positive")
# Forward reference to the returned iterator.
local_iter = None
def base_iterator(timeout=None):
metrics = LocalIterator.get_metrics()
all_actors = []
for actor_set in self.actor_sets:
actor_set.init_actors()
@@ -463,7 +466,7 @@ class ParallelIterator(Generic[T]):
for obj_id in ready:
actor = futures.pop(obj_id)
try:
metrics.current_actor = actor
local_iter.shared_metrics.get().current_actor = actor
yield ray.get(obj_id)
futures[actor.par_iter_next.remote()] = actor
except StopIteration:
@@ -473,7 +476,8 @@ class ParallelIterator(Generic[T]):
yield _NextValueNotReady()
name = "{}.gather_async()".format(self)
return LocalIterator(base_iterator, MetricsContext(), name=name)
local_iter = LocalIterator(base_iterator, SharedMetrics(), name=name)
return local_iter
def take(self, n: int) -> List[T]:
"""Return up to the first n items from this iterator."""
@@ -540,7 +544,7 @@ class ParallelIterator(Generic[T]):
break
name = self.name + ".shard[{}]".format(shard_index)
return LocalIterator(base_iterator, MetricsContext(), name=name)
return LocalIterator(base_iterator, SharedMetrics(), name=name)
class LocalIterator(Generic[T]):
@@ -562,7 +566,7 @@ class LocalIterator(Generic[T]):
def __init__(self,
base_iterator: Callable[[], Iterable[T]],
metrics: MetricsContext,
shared_metrics: SharedMetrics,
local_transforms: List[Callable[[Iterable], Any]] = None,
timeout: int = None,
name=None):
@@ -572,7 +576,7 @@ class LocalIterator(Generic[T]):
base_iterator (func): A function that produces the base iterator.
This is a function so that we can ensure LocalIterator is
serializable.
metrics (MetricsContext): Existing metrics context or a new
shared_metrics (SharedMetrics): Existing metrics context or a new
context. Should be the same for each chained iterator.
local_transforms (list): A list of transformation functions to be
applied on top of the base iterator. When iteration begins, we
@@ -584,10 +588,11 @@ class LocalIterator(Generic[T]):
blocking.
name (str): Optional name for this iterator.
"""
assert isinstance(shared_metrics, SharedMetrics)
self.base_iterator = base_iterator
self.built_iterator = None
self.local_transforms = local_transforms or []
self.metrics = metrics
self.shared_metrics = shared_metrics
self.timeout = timeout
self.name = name or "unknown"
@@ -606,26 +611,20 @@ class LocalIterator(Generic[T]):
it = iter(self.base_iterator(self.timeout))
for fn in self.local_transforms:
it = fn(it)
# This sets the iterator context during iterator execution, and
# clears it after so that multiple iterators can be used at a time.
def set_restore_context(it):
if hasattr(self.thread_local, "metrics"):
prev_metrics = self.thread_local.metrics
else:
prev_metrics = None
self.thread_local.metrics = self.metrics
try:
for item in it:
self.thread_local.metrics = prev_metrics
yield item
self.thread_local.metrics = self.metrics
finally:
self.thread_local.metrics = prev_metrics
it = set_restore_context(it)
self.built_iterator = it
@contextmanager
def _metrics_context(self):
if hasattr(self.thread_local, "metrics"):
prev_metrics = self.thread_local.metrics
else:
prev_metrics = None
try:
self.thread_local.metrics = self.shared_metrics.get()
yield
finally:
self.thread_local.metrics = prev_metrics
def __iter__(self):
self._build_once()
return self.built_iterator
@@ -649,7 +648,8 @@ class LocalIterator(Generic[T]):
# Keep retrying the function until it returns a valid
# value. This allows for non-blocking functions.
while True:
result = fn(item)
with self._metrics_context():
result = fn(item)
yield result
if not isinstance(result, _NextValueNotReady):
break
@@ -664,7 +664,8 @@ class LocalIterator(Generic[T]):
# Avoids calling on_fetch_start repeatedly if we are
# yielding _NextValueNotReady.
if new_item:
fn._on_fetch_start()
with self._metrics_context():
fn._on_fetch_start()
new_item = False
item = next(it)
if not isinstance(item, _NextValueNotReady):
@@ -675,19 +676,20 @@ class LocalIterator(Generic[T]):
return LocalIterator(
self.base_iterator,
self.metrics,
self.shared_metrics,
self.local_transforms + [apply_foreach],
name=self.name + ".for_each()")
def filter(self, fn: Callable[[T], bool]) -> "LocalIterator[T]":
def apply_filter(it):
for item in it:
if isinstance(item, _NextValueNotReady) or fn(item):
yield item
with self._metrics_context():
if isinstance(item, _NextValueNotReady) or fn(item):
yield item
return LocalIterator(
self.base_iterator,
self.metrics,
self.shared_metrics,
self.local_transforms + [apply_filter],
name=self.name + ".filter()")
@@ -707,7 +709,7 @@ class LocalIterator(Generic[T]):
return LocalIterator(
self.base_iterator,
self.metrics,
self.shared_metrics,
self.local_transforms + [apply_batch],
name=self.name + ".batch({})".format(n))
@@ -722,7 +724,7 @@ class LocalIterator(Generic[T]):
return LocalIterator(
self.base_iterator,
self.metrics,
self.shared_metrics,
self.local_transforms + [apply_flatten],
name=self.name + ".flatten()")
@@ -760,7 +762,7 @@ class LocalIterator(Generic[T]):
return LocalIterator(
self.base_iterator,
self.metrics,
self.shared_metrics,
self.local_transforms + [apply_shuffle],
name=self.name +
".shuffle(shuffle_buffer_size={}, seed={})".format(
@@ -836,7 +838,7 @@ class LocalIterator(Generic[T]):
iterators.append(
LocalIterator(
make_next(i),
self.metrics, [],
self.shared_metrics, [],
name=self.name + ".duplicate[{}]".format(i)))
return iterators
@@ -862,8 +864,10 @@ class LocalIterator(Generic[T]):
timeout = 0
active = []
shared_metrics = MetricsContext()
for it in [self] + list(others):
parent_iters = [self] + list(others)
shared_metrics = SharedMetrics(
parents=[p.shared_metrics for p in parent_iters])
for it in parent_iters:
active.append(
LocalIterator(
it.base_iterator,
@@ -945,7 +949,7 @@ class ParallelIteratorWorker(object):
def par_iter_init(self, transforms):
"""Implements ParallelIterator worker init."""
it = LocalIterator(lambda timeout: self.item_generator,
MetricsContext())
SharedMetrics())
for fn in transforms:
it = fn(it)
assert it is not None, fn
+24 -6
View File
@@ -1,4 +1,5 @@
import collections
from typing import List
from ray.util.timer import _Timer
@@ -18,8 +19,6 @@ class MetricsContext:
current_actor (ActorHandle): reference to the actor handle that
produced the current iterator output. This is automatically set
for gather_async().
parent_metrics (list): list of other MetricsContexts that have been
attached to this due to LocalIterator.union().
"""
def __init__(self):
@@ -27,7 +26,6 @@ class MetricsContext:
self.timers = collections.defaultdict(_Timer)
self.info = {}
self.current_actor = None
self.parent_metrics = []
def save(self):
"""Return a serializable copy of this context."""
@@ -35,7 +33,6 @@ class MetricsContext:
"counters": dict(self.counters),
"info": dict(self.info),
"timers": None, # TODO(ekl) consider persisting timers too
"parent_state": [u.save() for u in self.parent_metrics],
}
def restore(self, values):
@@ -44,5 +41,26 @@ class MetricsContext:
self.counters.update(values["counters"])
self.timers.clear()
self.info = values["info"]
for u, state in zip(self.parent_metrics, values["parent_state"]):
u.restore(state)
class SharedMetrics:
"""Holds an indirect reference to a (shared) metrics context.
This is used by LocalIterator.union() to point the metrics contexts of
entirely separate iterator chains to the same underlying context."""
def __init__(self,
metrics: MetricsContext = None,
parents: List["SharedMetrics"] = None):
self.metrics = metrics or MetricsContext()
self.parents = parents or []
self.set(self.metrics)
def set(self, metrics):
"""Recursively set self and parents to point to the same metrics."""
self.metrics = metrics
for parent in self.parents:
parent.set(metrics)
def get(self):
return self.metrics
+4 -2
View File
@@ -189,14 +189,16 @@ def build_trainer(name,
state = Trainer.__getstate__(self)
state["trainer_state"] = self.state.copy()
if self.train_exec_impl:
state["train_exec_impl"] = self.train_exec_impl.metrics.save()
state["train_exec_impl"] = (
self.train_exec_impl.shared_metrics.get().save())
return state
def __setstate__(self, state):
Trainer.__setstate__(self, state)
self.state = state["trainer_state"].copy()
if self.train_exec_impl:
self.train_exec_impl.metrics.restore(state["train_exec_impl"])
self.train_exec_impl.shared_metrics.get().restore(
state["train_exec_impl"])
def with_updates(**overrides):
"""Build a copy of this trainer with the specified overrides.
+16 -20
View File
@@ -11,7 +11,7 @@ import time
import ray
from ray.util.iter import from_actors, LocalIterator, _NextValueNotReady
from ray.util.iter_metrics import MetricsContext
from ray.util.iter_metrics import SharedMetrics
from ray.rllib.optimizers.replay_buffer import PrioritizedReplayBuffer, \
ReplayBuffer
from ray.rllib.evaluation.metrics import collect_episodes, \
@@ -109,7 +109,7 @@ def ParallelRollouts(workers: WorkerSet, mode="bulk_sync",
while True:
yield workers.local_worker().sample()
return (LocalIterator(sampler, MetricsContext())
return (LocalIterator(sampler, SharedMetrics())
.for_each(report_timesteps))
# Create a parallel iterator over generated experiences.
@@ -316,25 +316,21 @@ class CollectMetrics:
# Add in iterator metrics.
metrics = LocalIterator.get_metrics()
if metrics.parent_metrics:
print("TODO: support nested metrics better")
all_metrics = [metrics] + metrics.parent_metrics
timers = {}
counters = {}
info = {}
for metrics in all_metrics:
info.update(metrics.info)
for k, counter in metrics.counters.items():
counters[k] = counter
for k, timer in metrics.timers.items():
timers["{}_time_ms".format(k)] = round(timer.mean * 1000, 3)
if timer.has_units_processed():
timers["{}_throughput".format(k)] = round(
timer.mean_throughput, 3)
res.update({
"num_healthy_workers": len(self.workers.remote_workers()),
"timesteps_total": metrics.counters[STEPS_SAMPLED_COUNTER],
})
info.update(metrics.info)
for k, counter in metrics.counters.items():
counters[k] = counter
for k, timer in metrics.timers.items():
timers["{}_time_ms".format(k)] = round(timer.mean * 1000, 3)
if timer.has_units_processed():
timers["{}_throughput".format(k)] = round(
timer.mean_throughput, 3)
res.update({
"num_healthy_workers": len(self.workers.remote_workers()),
"timesteps_total": metrics.counters[STEPS_SAMPLED_COUNTER],
})
res["timers"] = timers
res["info"] = info
res["info"].update(counters)
@@ -617,7 +613,7 @@ def LocalReplay(replay_buffer: ReplayBuffer, train_batch_size: int):
})
yield MultiAgentBatch(samples, train_batch_size)
return LocalIterator(gen_replay, MetricsContext())
return LocalIterator(gen_replay, SharedMetrics())
def Concurrently(ops: List[LocalIterator], mode="round_robin"):
@@ -743,4 +739,4 @@ def Dequeue(input_queue: queue.Queue, check=lambda: True):
yield _NextValueNotReady()
raise RuntimeError("Error raised reading from queue")
return LocalIterator(base_iterator, MetricsContext())
return LocalIterator(base_iterator, SharedMetrics())