From 288933ec6bf2d1dea1e3a1ae853ca059f67b4a18 Mon Sep 17 00:00:00 2001 From: Eric Liang Date: Sun, 22 Mar 2020 14:15:01 -0700 Subject: [PATCH] [rllib] Fix shared metrics context in parallel iterators (#7666) * debug * build * update * wip * wpi * update * recurisve sync * comment * stream * fix * Update .travis.yml --- .travis.yml | 5 +- doc/source/rllib-dev.rst | 2 +- python/ray/tests/test_iter.py | 40 +++++++++++++++ python/ray/util/iter.py | 84 +++++++++++++++++--------------- python/ray/util/iter_metrics.py | 30 +++++++++--- rllib/agents/trainer_template.py | 6 ++- rllib/utils/experimental_dsl.py | 36 ++++++-------- 7 files changed, 131 insertions(+), 72 deletions(-) diff --git a/.travis.yml b/.travis.yml index 5f992a3b9..f7ffd2be3 100644 --- a/.travis.yml +++ b/.travis.yml @@ -186,7 +186,7 @@ matrix: - ./ci/suppress_output ./ci/travis/install-ray.sh script: - if [ $RAY_CI_RLLIB_AFFECTED != "1" ]; then exit; fi - - travis_wait 60 bazel test --build_tests_only --test_tag_filters=learning_tests --spawn_strategy=local --flaky_test_attempts=3 --nocache_test_results --test_verbose_timeout_warnings --progress_report_interval=100 --show_progress_rate_limit=100 --show_timestamps --test_output=errors rllib/... + - travis_wait 60 bazel test --build_tests_only --test_tag_filters=learning_tests --spawn_strategy=local --flaky_test_attempts=3 --nocache_test_results --test_verbose_timeout_warnings --progress_report_interval=100 --show_progress_rate_limit=100 --show_timestamps --test_output=streamed rllib/... # RLlib: Learning tests with tf=1.x (from rllib/tuned_examples/regression_tests/*.yaml). # Requested by Edi (MS): Test all learning capabilities with tf1.x @@ -207,7 +207,7 @@ matrix: - ./ci/suppress_output ./ci/travis/install-ray.sh script: - if [ $RAY_CI_RLLIB_FULL_AFFECTED != "1" ]; then exit; fi - - travis_wait 60 bazel test --build_tests_only --test_tag_filters=learning_tests --spawn_strategy=local --flaky_test_attempts=3 --nocache_test_results --test_verbose_timeout_warnings --progress_report_interval=100 --show_progress_rate_limit=100 --show_timestamps --test_output=errors rllib/... + - travis_wait 60 bazel test --build_tests_only --test_tag_filters=learning_tests --spawn_strategy=local --flaky_test_attempts=3 --nocache_test_results --test_verbose_timeout_warnings --progress_report_interval=100 --show_progress_rate_limit=100 --show_timestamps --test_output=streamed rllib/... # RLlib: Quick Agent train.py runs (compilation & running, no(!) learning). # Agent single tests (compilation, loss-funcs, etc..). @@ -376,4 +376,3 @@ deploy: repo: ray-project/ray all_branches: true condition: $LINUX_WHEELS = 1 || $MAC_WHEELS = 1 - diff --git a/doc/source/rllib-dev.rst b/doc/source/rllib-dev.rst index 543beb78c..1b103978f 100644 --- a/doc/source/rllib-dev.rst +++ b/doc/source/rllib-dev.rst @@ -18,7 +18,7 @@ Objects and methods annotated with ``@PublicAPI`` or ``@DeveloperAPI`` have the Features -------- -Feature development and upcoming priorities are tracked on the `RLlib project board `__ (note that this may not include all development efforts). For discussion of issues and new features, we use the `Ray dev list `__ and `GitHub issues page `__. +Feature development, discussion, and upcoming priorities are tracked on the `GitHub issues page `__ (note that this may not include all development efforts). Benchmarks ---------- diff --git a/python/ray/tests/test_iter.py b/python/ray/tests/test_iter.py index c31064e3b..f4435e023 100644 --- a/python/ray/tests/test_iter.py +++ b/python/ray/tests/test_iter.py @@ -73,6 +73,46 @@ def test_metrics_union(ray_start_regular_shared): assert it3.take(10) == [1, 100, 3, 200, 6, 300, 10, 400] +def test_metrics_union_recursive(ray_start_regular_shared): + it1 = from_items([1, 2, 3, 4], num_shards=1) + it2 = from_items([1, 2, 3, 4], num_shards=1) + it3 = from_items([1, 2, 3, 4], num_shards=1) + + def foo_metrics(x): + metrics = LocalIterator.get_metrics() + metrics.counters["foo"] += 1 + return metrics.counters["foo"] + + def bar_metrics(x): + metrics = LocalIterator.get_metrics() + metrics.counters["bar"] += 1 + return metrics.counters["bar"] + + def baz_metrics(x): + metrics = LocalIterator.get_metrics() + metrics.counters["baz"] += 1 + return metrics.counters["baz"] + + def verify_metrics(x): + metrics = LocalIterator.get_metrics() + metrics.counters["n"] += 1 + # Check the metrics context is shared recursively. + print(metrics.counters) + if metrics.counters["n"] >= 3: + assert "foo" in metrics.counters + assert "bar" in metrics.counters + assert "baz" in metrics.counters + return x + + it1 = it1.gather_async().for_each(foo_metrics) + it2 = it2.gather_async().for_each(bar_metrics) + it3 = it3.gather_async().for_each(baz_metrics) + it12 = it1.union(it2, deterministic=True) + it123 = it12.union(it3, deterministic=True) + out = it123.for_each(verify_metrics) + assert out.take(20) == [1, 1, 1, 2, 2, 3, 2, 4, 3, 3, 4, 4] + + def test_from_items(ray_start_regular_shared): it = from_items([1, 2, 3, 4]) assert repr(it) == "ParallelIterator[from_items[int, 4, shards=2]]" diff --git a/python/ray/util/iter.py b/python/ray/util/iter.py index 31a37c838..ab1533e8c 100644 --- a/python/ray/util/iter.py +++ b/python/ray/util/iter.py @@ -1,10 +1,11 @@ +from contextlib import contextmanager import collections import random import threading from typing import TypeVar, Generic, Iterable, List, Callable, Any import ray -from ray.util.iter_metrics import MetricsContext +from ray.util.iter_metrics import MetricsContext, SharedMetrics # The type of an iterator element. T = TypeVar("T") @@ -412,7 +413,7 @@ class ParallelIterator(Generic[T]): futures = [a.par_iter_next.remote() for a in active] name = "{}.batch_across_shards()".format(self) - return LocalIterator(base_iterator, MetricsContext(), name=name) + return LocalIterator(base_iterator, SharedMetrics(), name=name) def gather_async(self, async_queue_depth=1) -> "LocalIterator[T]": """Returns a local iterable for asynchronous iteration. @@ -438,8 +439,10 @@ class ParallelIterator(Generic[T]): if async_queue_depth < 1: raise ValueError("queue depth must be positive") + # Forward reference to the returned iterator. + local_iter = None + def base_iterator(timeout=None): - metrics = LocalIterator.get_metrics() all_actors = [] for actor_set in self.actor_sets: actor_set.init_actors() @@ -463,7 +466,7 @@ class ParallelIterator(Generic[T]): for obj_id in ready: actor = futures.pop(obj_id) try: - metrics.current_actor = actor + local_iter.shared_metrics.get().current_actor = actor yield ray.get(obj_id) futures[actor.par_iter_next.remote()] = actor except StopIteration: @@ -473,7 +476,8 @@ class ParallelIterator(Generic[T]): yield _NextValueNotReady() name = "{}.gather_async()".format(self) - return LocalIterator(base_iterator, MetricsContext(), name=name) + local_iter = LocalIterator(base_iterator, SharedMetrics(), name=name) + return local_iter def take(self, n: int) -> List[T]: """Return up to the first n items from this iterator.""" @@ -540,7 +544,7 @@ class ParallelIterator(Generic[T]): break name = self.name + ".shard[{}]".format(shard_index) - return LocalIterator(base_iterator, MetricsContext(), name=name) + return LocalIterator(base_iterator, SharedMetrics(), name=name) class LocalIterator(Generic[T]): @@ -562,7 +566,7 @@ class LocalIterator(Generic[T]): def __init__(self, base_iterator: Callable[[], Iterable[T]], - metrics: MetricsContext, + shared_metrics: SharedMetrics, local_transforms: List[Callable[[Iterable], Any]] = None, timeout: int = None, name=None): @@ -572,7 +576,7 @@ class LocalIterator(Generic[T]): base_iterator (func): A function that produces the base iterator. This is a function so that we can ensure LocalIterator is serializable. - metrics (MetricsContext): Existing metrics context or a new + shared_metrics (SharedMetrics): Existing metrics context or a new context. Should be the same for each chained iterator. local_transforms (list): A list of transformation functions to be applied on top of the base iterator. When iteration begins, we @@ -584,10 +588,11 @@ class LocalIterator(Generic[T]): blocking. name (str): Optional name for this iterator. """ + assert isinstance(shared_metrics, SharedMetrics) self.base_iterator = base_iterator self.built_iterator = None self.local_transforms = local_transforms or [] - self.metrics = metrics + self.shared_metrics = shared_metrics self.timeout = timeout self.name = name or "unknown" @@ -606,26 +611,20 @@ class LocalIterator(Generic[T]): it = iter(self.base_iterator(self.timeout)) for fn in self.local_transforms: it = fn(it) - - # This sets the iterator context during iterator execution, and - # clears it after so that multiple iterators can be used at a time. - def set_restore_context(it): - if hasattr(self.thread_local, "metrics"): - prev_metrics = self.thread_local.metrics - else: - prev_metrics = None - self.thread_local.metrics = self.metrics - try: - for item in it: - self.thread_local.metrics = prev_metrics - yield item - self.thread_local.metrics = self.metrics - finally: - self.thread_local.metrics = prev_metrics - - it = set_restore_context(it) self.built_iterator = it + @contextmanager + def _metrics_context(self): + if hasattr(self.thread_local, "metrics"): + prev_metrics = self.thread_local.metrics + else: + prev_metrics = None + try: + self.thread_local.metrics = self.shared_metrics.get() + yield + finally: + self.thread_local.metrics = prev_metrics + def __iter__(self): self._build_once() return self.built_iterator @@ -649,7 +648,8 @@ class LocalIterator(Generic[T]): # Keep retrying the function until it returns a valid # value. This allows for non-blocking functions. while True: - result = fn(item) + with self._metrics_context(): + result = fn(item) yield result if not isinstance(result, _NextValueNotReady): break @@ -664,7 +664,8 @@ class LocalIterator(Generic[T]): # Avoids calling on_fetch_start repeatedly if we are # yielding _NextValueNotReady. if new_item: - fn._on_fetch_start() + with self._metrics_context(): + fn._on_fetch_start() new_item = False item = next(it) if not isinstance(item, _NextValueNotReady): @@ -675,19 +676,20 @@ class LocalIterator(Generic[T]): return LocalIterator( self.base_iterator, - self.metrics, + self.shared_metrics, self.local_transforms + [apply_foreach], name=self.name + ".for_each()") def filter(self, fn: Callable[[T], bool]) -> "LocalIterator[T]": def apply_filter(it): for item in it: - if isinstance(item, _NextValueNotReady) or fn(item): - yield item + with self._metrics_context(): + if isinstance(item, _NextValueNotReady) or fn(item): + yield item return LocalIterator( self.base_iterator, - self.metrics, + self.shared_metrics, self.local_transforms + [apply_filter], name=self.name + ".filter()") @@ -707,7 +709,7 @@ class LocalIterator(Generic[T]): return LocalIterator( self.base_iterator, - self.metrics, + self.shared_metrics, self.local_transforms + [apply_batch], name=self.name + ".batch({})".format(n)) @@ -722,7 +724,7 @@ class LocalIterator(Generic[T]): return LocalIterator( self.base_iterator, - self.metrics, + self.shared_metrics, self.local_transforms + [apply_flatten], name=self.name + ".flatten()") @@ -760,7 +762,7 @@ class LocalIterator(Generic[T]): return LocalIterator( self.base_iterator, - self.metrics, + self.shared_metrics, self.local_transforms + [apply_shuffle], name=self.name + ".shuffle(shuffle_buffer_size={}, seed={})".format( @@ -836,7 +838,7 @@ class LocalIterator(Generic[T]): iterators.append( LocalIterator( make_next(i), - self.metrics, [], + self.shared_metrics, [], name=self.name + ".duplicate[{}]".format(i))) return iterators @@ -862,8 +864,10 @@ class LocalIterator(Generic[T]): timeout = 0 active = [] - shared_metrics = MetricsContext() - for it in [self] + list(others): + parent_iters = [self] + list(others) + shared_metrics = SharedMetrics( + parents=[p.shared_metrics for p in parent_iters]) + for it in parent_iters: active.append( LocalIterator( it.base_iterator, @@ -945,7 +949,7 @@ class ParallelIteratorWorker(object): def par_iter_init(self, transforms): """Implements ParallelIterator worker init.""" it = LocalIterator(lambda timeout: self.item_generator, - MetricsContext()) + SharedMetrics()) for fn in transforms: it = fn(it) assert it is not None, fn diff --git a/python/ray/util/iter_metrics.py b/python/ray/util/iter_metrics.py index de9f1de88..db748429b 100644 --- a/python/ray/util/iter_metrics.py +++ b/python/ray/util/iter_metrics.py @@ -1,4 +1,5 @@ import collections +from typing import List from ray.util.timer import _Timer @@ -18,8 +19,6 @@ class MetricsContext: current_actor (ActorHandle): reference to the actor handle that produced the current iterator output. This is automatically set for gather_async(). - parent_metrics (list): list of other MetricsContexts that have been - attached to this due to LocalIterator.union(). """ def __init__(self): @@ -27,7 +26,6 @@ class MetricsContext: self.timers = collections.defaultdict(_Timer) self.info = {} self.current_actor = None - self.parent_metrics = [] def save(self): """Return a serializable copy of this context.""" @@ -35,7 +33,6 @@ class MetricsContext: "counters": dict(self.counters), "info": dict(self.info), "timers": None, # TODO(ekl) consider persisting timers too - "parent_state": [u.save() for u in self.parent_metrics], } def restore(self, values): @@ -44,5 +41,26 @@ class MetricsContext: self.counters.update(values["counters"]) self.timers.clear() self.info = values["info"] - for u, state in zip(self.parent_metrics, values["parent_state"]): - u.restore(state) + + +class SharedMetrics: + """Holds an indirect reference to a (shared) metrics context. + + This is used by LocalIterator.union() to point the metrics contexts of + entirely separate iterator chains to the same underlying context.""" + + def __init__(self, + metrics: MetricsContext = None, + parents: List["SharedMetrics"] = None): + self.metrics = metrics or MetricsContext() + self.parents = parents or [] + self.set(self.metrics) + + def set(self, metrics): + """Recursively set self and parents to point to the same metrics.""" + self.metrics = metrics + for parent in self.parents: + parent.set(metrics) + + def get(self): + return self.metrics diff --git a/rllib/agents/trainer_template.py b/rllib/agents/trainer_template.py index 8b10580a2..f8899862e 100644 --- a/rllib/agents/trainer_template.py +++ b/rllib/agents/trainer_template.py @@ -189,14 +189,16 @@ def build_trainer(name, state = Trainer.__getstate__(self) state["trainer_state"] = self.state.copy() if self.train_exec_impl: - state["train_exec_impl"] = self.train_exec_impl.metrics.save() + state["train_exec_impl"] = ( + self.train_exec_impl.shared_metrics.get().save()) return state def __setstate__(self, state): Trainer.__setstate__(self, state) self.state = state["trainer_state"].copy() if self.train_exec_impl: - self.train_exec_impl.metrics.restore(state["train_exec_impl"]) + self.train_exec_impl.shared_metrics.get().restore( + state["train_exec_impl"]) def with_updates(**overrides): """Build a copy of this trainer with the specified overrides. diff --git a/rllib/utils/experimental_dsl.py b/rllib/utils/experimental_dsl.py index 9fbdbdb2e..28658b44f 100644 --- a/rllib/utils/experimental_dsl.py +++ b/rllib/utils/experimental_dsl.py @@ -11,7 +11,7 @@ import time import ray from ray.util.iter import from_actors, LocalIterator, _NextValueNotReady -from ray.util.iter_metrics import MetricsContext +from ray.util.iter_metrics import SharedMetrics from ray.rllib.optimizers.replay_buffer import PrioritizedReplayBuffer, \ ReplayBuffer from ray.rllib.evaluation.metrics import collect_episodes, \ @@ -109,7 +109,7 @@ def ParallelRollouts(workers: WorkerSet, mode="bulk_sync", while True: yield workers.local_worker().sample() - return (LocalIterator(sampler, MetricsContext()) + return (LocalIterator(sampler, SharedMetrics()) .for_each(report_timesteps)) # Create a parallel iterator over generated experiences. @@ -316,25 +316,21 @@ class CollectMetrics: # Add in iterator metrics. metrics = LocalIterator.get_metrics() - if metrics.parent_metrics: - print("TODO: support nested metrics better") - all_metrics = [metrics] + metrics.parent_metrics timers = {} counters = {} info = {} - for metrics in all_metrics: - info.update(metrics.info) - for k, counter in metrics.counters.items(): - counters[k] = counter - for k, timer in metrics.timers.items(): - timers["{}_time_ms".format(k)] = round(timer.mean * 1000, 3) - if timer.has_units_processed(): - timers["{}_throughput".format(k)] = round( - timer.mean_throughput, 3) - res.update({ - "num_healthy_workers": len(self.workers.remote_workers()), - "timesteps_total": metrics.counters[STEPS_SAMPLED_COUNTER], - }) + info.update(metrics.info) + for k, counter in metrics.counters.items(): + counters[k] = counter + for k, timer in metrics.timers.items(): + timers["{}_time_ms".format(k)] = round(timer.mean * 1000, 3) + if timer.has_units_processed(): + timers["{}_throughput".format(k)] = round( + timer.mean_throughput, 3) + res.update({ + "num_healthy_workers": len(self.workers.remote_workers()), + "timesteps_total": metrics.counters[STEPS_SAMPLED_COUNTER], + }) res["timers"] = timers res["info"] = info res["info"].update(counters) @@ -617,7 +613,7 @@ def LocalReplay(replay_buffer: ReplayBuffer, train_batch_size: int): }) yield MultiAgentBatch(samples, train_batch_size) - return LocalIterator(gen_replay, MetricsContext()) + return LocalIterator(gen_replay, SharedMetrics()) def Concurrently(ops: List[LocalIterator], mode="round_robin"): @@ -743,4 +739,4 @@ def Dequeue(input_queue: queue.Queue, check=lambda: True): yield _NextValueNotReady() raise RuntimeError("Error raised reading from queue") - return LocalIterator(base_iterator, MetricsContext()) + return LocalIterator(base_iterator, SharedMetrics())