mirror of
https://github.com/wassname/ray.git
synced 2026-06-27 17:49:47 +08:00
[rllib] Fix shared metrics context in parallel iterators (#7666)
* debug * build * update * wip * wpi * update * recurisve sync * comment * stream * fix * Update .travis.yml
This commit is contained in:
+2
-3
@@ -186,7 +186,7 @@ matrix:
|
||||
- ./ci/suppress_output ./ci/travis/install-ray.sh
|
||||
script:
|
||||
- if [ $RAY_CI_RLLIB_AFFECTED != "1" ]; then exit; fi
|
||||
- travis_wait 60 bazel test --build_tests_only --test_tag_filters=learning_tests --spawn_strategy=local --flaky_test_attempts=3 --nocache_test_results --test_verbose_timeout_warnings --progress_report_interval=100 --show_progress_rate_limit=100 --show_timestamps --test_output=errors rllib/...
|
||||
- travis_wait 60 bazel test --build_tests_only --test_tag_filters=learning_tests --spawn_strategy=local --flaky_test_attempts=3 --nocache_test_results --test_verbose_timeout_warnings --progress_report_interval=100 --show_progress_rate_limit=100 --show_timestamps --test_output=streamed rllib/...
|
||||
|
||||
# RLlib: Learning tests with tf=1.x (from rllib/tuned_examples/regression_tests/*.yaml).
|
||||
# Requested by Edi (MS): Test all learning capabilities with tf1.x
|
||||
@@ -207,7 +207,7 @@ matrix:
|
||||
- ./ci/suppress_output ./ci/travis/install-ray.sh
|
||||
script:
|
||||
- if [ $RAY_CI_RLLIB_FULL_AFFECTED != "1" ]; then exit; fi
|
||||
- travis_wait 60 bazel test --build_tests_only --test_tag_filters=learning_tests --spawn_strategy=local --flaky_test_attempts=3 --nocache_test_results --test_verbose_timeout_warnings --progress_report_interval=100 --show_progress_rate_limit=100 --show_timestamps --test_output=errors rllib/...
|
||||
- travis_wait 60 bazel test --build_tests_only --test_tag_filters=learning_tests --spawn_strategy=local --flaky_test_attempts=3 --nocache_test_results --test_verbose_timeout_warnings --progress_report_interval=100 --show_progress_rate_limit=100 --show_timestamps --test_output=streamed rllib/...
|
||||
|
||||
# RLlib: Quick Agent train.py runs (compilation & running, no(!) learning).
|
||||
# Agent single tests (compilation, loss-funcs, etc..).
|
||||
@@ -376,4 +376,3 @@ deploy:
|
||||
repo: ray-project/ray
|
||||
all_branches: true
|
||||
condition: $LINUX_WHEELS = 1 || $MAC_WHEELS = 1
|
||||
|
||||
|
||||
@@ -18,7 +18,7 @@ Objects and methods annotated with ``@PublicAPI`` or ``@DeveloperAPI`` have the
|
||||
Features
|
||||
--------
|
||||
|
||||
Feature development and upcoming priorities are tracked on the `RLlib project board <https://github.com/ray-project/ray/projects/6>`__ (note that this may not include all development efforts). For discussion of issues and new features, we use the `Ray dev list <https://groups.google.com/forum/#!forum/ray-dev>`__ and `GitHub issues page <https://github.com/ray-project/ray/issues>`__.
|
||||
Feature development, discussion, and upcoming priorities are tracked on the `GitHub issues page <https://github.com/ray-project/ray/issues>`__ (note that this may not include all development efforts).
|
||||
|
||||
Benchmarks
|
||||
----------
|
||||
|
||||
@@ -73,6 +73,46 @@ def test_metrics_union(ray_start_regular_shared):
|
||||
assert it3.take(10) == [1, 100, 3, 200, 6, 300, 10, 400]
|
||||
|
||||
|
||||
def test_metrics_union_recursive(ray_start_regular_shared):
|
||||
it1 = from_items([1, 2, 3, 4], num_shards=1)
|
||||
it2 = from_items([1, 2, 3, 4], num_shards=1)
|
||||
it3 = from_items([1, 2, 3, 4], num_shards=1)
|
||||
|
||||
def foo_metrics(x):
|
||||
metrics = LocalIterator.get_metrics()
|
||||
metrics.counters["foo"] += 1
|
||||
return metrics.counters["foo"]
|
||||
|
||||
def bar_metrics(x):
|
||||
metrics = LocalIterator.get_metrics()
|
||||
metrics.counters["bar"] += 1
|
||||
return metrics.counters["bar"]
|
||||
|
||||
def baz_metrics(x):
|
||||
metrics = LocalIterator.get_metrics()
|
||||
metrics.counters["baz"] += 1
|
||||
return metrics.counters["baz"]
|
||||
|
||||
def verify_metrics(x):
|
||||
metrics = LocalIterator.get_metrics()
|
||||
metrics.counters["n"] += 1
|
||||
# Check the metrics context is shared recursively.
|
||||
print(metrics.counters)
|
||||
if metrics.counters["n"] >= 3:
|
||||
assert "foo" in metrics.counters
|
||||
assert "bar" in metrics.counters
|
||||
assert "baz" in metrics.counters
|
||||
return x
|
||||
|
||||
it1 = it1.gather_async().for_each(foo_metrics)
|
||||
it2 = it2.gather_async().for_each(bar_metrics)
|
||||
it3 = it3.gather_async().for_each(baz_metrics)
|
||||
it12 = it1.union(it2, deterministic=True)
|
||||
it123 = it12.union(it3, deterministic=True)
|
||||
out = it123.for_each(verify_metrics)
|
||||
assert out.take(20) == [1, 1, 1, 2, 2, 3, 2, 4, 3, 3, 4, 4]
|
||||
|
||||
|
||||
def test_from_items(ray_start_regular_shared):
|
||||
it = from_items([1, 2, 3, 4])
|
||||
assert repr(it) == "ParallelIterator[from_items[int, 4, shards=2]]"
|
||||
|
||||
+44
-40
@@ -1,10 +1,11 @@
|
||||
from contextlib import contextmanager
|
||||
import collections
|
||||
import random
|
||||
import threading
|
||||
from typing import TypeVar, Generic, Iterable, List, Callable, Any
|
||||
|
||||
import ray
|
||||
from ray.util.iter_metrics import MetricsContext
|
||||
from ray.util.iter_metrics import MetricsContext, SharedMetrics
|
||||
|
||||
# The type of an iterator element.
|
||||
T = TypeVar("T")
|
||||
@@ -412,7 +413,7 @@ class ParallelIterator(Generic[T]):
|
||||
futures = [a.par_iter_next.remote() for a in active]
|
||||
|
||||
name = "{}.batch_across_shards()".format(self)
|
||||
return LocalIterator(base_iterator, MetricsContext(), name=name)
|
||||
return LocalIterator(base_iterator, SharedMetrics(), name=name)
|
||||
|
||||
def gather_async(self, async_queue_depth=1) -> "LocalIterator[T]":
|
||||
"""Returns a local iterable for asynchronous iteration.
|
||||
@@ -438,8 +439,10 @@ class ParallelIterator(Generic[T]):
|
||||
if async_queue_depth < 1:
|
||||
raise ValueError("queue depth must be positive")
|
||||
|
||||
# Forward reference to the returned iterator.
|
||||
local_iter = None
|
||||
|
||||
def base_iterator(timeout=None):
|
||||
metrics = LocalIterator.get_metrics()
|
||||
all_actors = []
|
||||
for actor_set in self.actor_sets:
|
||||
actor_set.init_actors()
|
||||
@@ -463,7 +466,7 @@ class ParallelIterator(Generic[T]):
|
||||
for obj_id in ready:
|
||||
actor = futures.pop(obj_id)
|
||||
try:
|
||||
metrics.current_actor = actor
|
||||
local_iter.shared_metrics.get().current_actor = actor
|
||||
yield ray.get(obj_id)
|
||||
futures[actor.par_iter_next.remote()] = actor
|
||||
except StopIteration:
|
||||
@@ -473,7 +476,8 @@ class ParallelIterator(Generic[T]):
|
||||
yield _NextValueNotReady()
|
||||
|
||||
name = "{}.gather_async()".format(self)
|
||||
return LocalIterator(base_iterator, MetricsContext(), name=name)
|
||||
local_iter = LocalIterator(base_iterator, SharedMetrics(), name=name)
|
||||
return local_iter
|
||||
|
||||
def take(self, n: int) -> List[T]:
|
||||
"""Return up to the first n items from this iterator."""
|
||||
@@ -540,7 +544,7 @@ class ParallelIterator(Generic[T]):
|
||||
break
|
||||
|
||||
name = self.name + ".shard[{}]".format(shard_index)
|
||||
return LocalIterator(base_iterator, MetricsContext(), name=name)
|
||||
return LocalIterator(base_iterator, SharedMetrics(), name=name)
|
||||
|
||||
|
||||
class LocalIterator(Generic[T]):
|
||||
@@ -562,7 +566,7 @@ class LocalIterator(Generic[T]):
|
||||
|
||||
def __init__(self,
|
||||
base_iterator: Callable[[], Iterable[T]],
|
||||
metrics: MetricsContext,
|
||||
shared_metrics: SharedMetrics,
|
||||
local_transforms: List[Callable[[Iterable], Any]] = None,
|
||||
timeout: int = None,
|
||||
name=None):
|
||||
@@ -572,7 +576,7 @@ class LocalIterator(Generic[T]):
|
||||
base_iterator (func): A function that produces the base iterator.
|
||||
This is a function so that we can ensure LocalIterator is
|
||||
serializable.
|
||||
metrics (MetricsContext): Existing metrics context or a new
|
||||
shared_metrics (SharedMetrics): Existing metrics context or a new
|
||||
context. Should be the same for each chained iterator.
|
||||
local_transforms (list): A list of transformation functions to be
|
||||
applied on top of the base iterator. When iteration begins, we
|
||||
@@ -584,10 +588,11 @@ class LocalIterator(Generic[T]):
|
||||
blocking.
|
||||
name (str): Optional name for this iterator.
|
||||
"""
|
||||
assert isinstance(shared_metrics, SharedMetrics)
|
||||
self.base_iterator = base_iterator
|
||||
self.built_iterator = None
|
||||
self.local_transforms = local_transforms or []
|
||||
self.metrics = metrics
|
||||
self.shared_metrics = shared_metrics
|
||||
self.timeout = timeout
|
||||
self.name = name or "unknown"
|
||||
|
||||
@@ -606,26 +611,20 @@ class LocalIterator(Generic[T]):
|
||||
it = iter(self.base_iterator(self.timeout))
|
||||
for fn in self.local_transforms:
|
||||
it = fn(it)
|
||||
|
||||
# This sets the iterator context during iterator execution, and
|
||||
# clears it after so that multiple iterators can be used at a time.
|
||||
def set_restore_context(it):
|
||||
if hasattr(self.thread_local, "metrics"):
|
||||
prev_metrics = self.thread_local.metrics
|
||||
else:
|
||||
prev_metrics = None
|
||||
self.thread_local.metrics = self.metrics
|
||||
try:
|
||||
for item in it:
|
||||
self.thread_local.metrics = prev_metrics
|
||||
yield item
|
||||
self.thread_local.metrics = self.metrics
|
||||
finally:
|
||||
self.thread_local.metrics = prev_metrics
|
||||
|
||||
it = set_restore_context(it)
|
||||
self.built_iterator = it
|
||||
|
||||
@contextmanager
|
||||
def _metrics_context(self):
|
||||
if hasattr(self.thread_local, "metrics"):
|
||||
prev_metrics = self.thread_local.metrics
|
||||
else:
|
||||
prev_metrics = None
|
||||
try:
|
||||
self.thread_local.metrics = self.shared_metrics.get()
|
||||
yield
|
||||
finally:
|
||||
self.thread_local.metrics = prev_metrics
|
||||
|
||||
def __iter__(self):
|
||||
self._build_once()
|
||||
return self.built_iterator
|
||||
@@ -649,7 +648,8 @@ class LocalIterator(Generic[T]):
|
||||
# Keep retrying the function until it returns a valid
|
||||
# value. This allows for non-blocking functions.
|
||||
while True:
|
||||
result = fn(item)
|
||||
with self._metrics_context():
|
||||
result = fn(item)
|
||||
yield result
|
||||
if not isinstance(result, _NextValueNotReady):
|
||||
break
|
||||
@@ -664,7 +664,8 @@ class LocalIterator(Generic[T]):
|
||||
# Avoids calling on_fetch_start repeatedly if we are
|
||||
# yielding _NextValueNotReady.
|
||||
if new_item:
|
||||
fn._on_fetch_start()
|
||||
with self._metrics_context():
|
||||
fn._on_fetch_start()
|
||||
new_item = False
|
||||
item = next(it)
|
||||
if not isinstance(item, _NextValueNotReady):
|
||||
@@ -675,19 +676,20 @@ class LocalIterator(Generic[T]):
|
||||
|
||||
return LocalIterator(
|
||||
self.base_iterator,
|
||||
self.metrics,
|
||||
self.shared_metrics,
|
||||
self.local_transforms + [apply_foreach],
|
||||
name=self.name + ".for_each()")
|
||||
|
||||
def filter(self, fn: Callable[[T], bool]) -> "LocalIterator[T]":
|
||||
def apply_filter(it):
|
||||
for item in it:
|
||||
if isinstance(item, _NextValueNotReady) or fn(item):
|
||||
yield item
|
||||
with self._metrics_context():
|
||||
if isinstance(item, _NextValueNotReady) or fn(item):
|
||||
yield item
|
||||
|
||||
return LocalIterator(
|
||||
self.base_iterator,
|
||||
self.metrics,
|
||||
self.shared_metrics,
|
||||
self.local_transforms + [apply_filter],
|
||||
name=self.name + ".filter()")
|
||||
|
||||
@@ -707,7 +709,7 @@ class LocalIterator(Generic[T]):
|
||||
|
||||
return LocalIterator(
|
||||
self.base_iterator,
|
||||
self.metrics,
|
||||
self.shared_metrics,
|
||||
self.local_transforms + [apply_batch],
|
||||
name=self.name + ".batch({})".format(n))
|
||||
|
||||
@@ -722,7 +724,7 @@ class LocalIterator(Generic[T]):
|
||||
|
||||
return LocalIterator(
|
||||
self.base_iterator,
|
||||
self.metrics,
|
||||
self.shared_metrics,
|
||||
self.local_transforms + [apply_flatten],
|
||||
name=self.name + ".flatten()")
|
||||
|
||||
@@ -760,7 +762,7 @@ class LocalIterator(Generic[T]):
|
||||
|
||||
return LocalIterator(
|
||||
self.base_iterator,
|
||||
self.metrics,
|
||||
self.shared_metrics,
|
||||
self.local_transforms + [apply_shuffle],
|
||||
name=self.name +
|
||||
".shuffle(shuffle_buffer_size={}, seed={})".format(
|
||||
@@ -836,7 +838,7 @@ class LocalIterator(Generic[T]):
|
||||
iterators.append(
|
||||
LocalIterator(
|
||||
make_next(i),
|
||||
self.metrics, [],
|
||||
self.shared_metrics, [],
|
||||
name=self.name + ".duplicate[{}]".format(i)))
|
||||
|
||||
return iterators
|
||||
@@ -862,8 +864,10 @@ class LocalIterator(Generic[T]):
|
||||
timeout = 0
|
||||
|
||||
active = []
|
||||
shared_metrics = MetricsContext()
|
||||
for it in [self] + list(others):
|
||||
parent_iters = [self] + list(others)
|
||||
shared_metrics = SharedMetrics(
|
||||
parents=[p.shared_metrics for p in parent_iters])
|
||||
for it in parent_iters:
|
||||
active.append(
|
||||
LocalIterator(
|
||||
it.base_iterator,
|
||||
@@ -945,7 +949,7 @@ class ParallelIteratorWorker(object):
|
||||
def par_iter_init(self, transforms):
|
||||
"""Implements ParallelIterator worker init."""
|
||||
it = LocalIterator(lambda timeout: self.item_generator,
|
||||
MetricsContext())
|
||||
SharedMetrics())
|
||||
for fn in transforms:
|
||||
it = fn(it)
|
||||
assert it is not None, fn
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import collections
|
||||
from typing import List
|
||||
|
||||
from ray.util.timer import _Timer
|
||||
|
||||
@@ -18,8 +19,6 @@ class MetricsContext:
|
||||
current_actor (ActorHandle): reference to the actor handle that
|
||||
produced the current iterator output. This is automatically set
|
||||
for gather_async().
|
||||
parent_metrics (list): list of other MetricsContexts that have been
|
||||
attached to this due to LocalIterator.union().
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
@@ -27,7 +26,6 @@ class MetricsContext:
|
||||
self.timers = collections.defaultdict(_Timer)
|
||||
self.info = {}
|
||||
self.current_actor = None
|
||||
self.parent_metrics = []
|
||||
|
||||
def save(self):
|
||||
"""Return a serializable copy of this context."""
|
||||
@@ -35,7 +33,6 @@ class MetricsContext:
|
||||
"counters": dict(self.counters),
|
||||
"info": dict(self.info),
|
||||
"timers": None, # TODO(ekl) consider persisting timers too
|
||||
"parent_state": [u.save() for u in self.parent_metrics],
|
||||
}
|
||||
|
||||
def restore(self, values):
|
||||
@@ -44,5 +41,26 @@ class MetricsContext:
|
||||
self.counters.update(values["counters"])
|
||||
self.timers.clear()
|
||||
self.info = values["info"]
|
||||
for u, state in zip(self.parent_metrics, values["parent_state"]):
|
||||
u.restore(state)
|
||||
|
||||
|
||||
class SharedMetrics:
|
||||
"""Holds an indirect reference to a (shared) metrics context.
|
||||
|
||||
This is used by LocalIterator.union() to point the metrics contexts of
|
||||
entirely separate iterator chains to the same underlying context."""
|
||||
|
||||
def __init__(self,
|
||||
metrics: MetricsContext = None,
|
||||
parents: List["SharedMetrics"] = None):
|
||||
self.metrics = metrics or MetricsContext()
|
||||
self.parents = parents or []
|
||||
self.set(self.metrics)
|
||||
|
||||
def set(self, metrics):
|
||||
"""Recursively set self and parents to point to the same metrics."""
|
||||
self.metrics = metrics
|
||||
for parent in self.parents:
|
||||
parent.set(metrics)
|
||||
|
||||
def get(self):
|
||||
return self.metrics
|
||||
|
||||
@@ -189,14 +189,16 @@ def build_trainer(name,
|
||||
state = Trainer.__getstate__(self)
|
||||
state["trainer_state"] = self.state.copy()
|
||||
if self.train_exec_impl:
|
||||
state["train_exec_impl"] = self.train_exec_impl.metrics.save()
|
||||
state["train_exec_impl"] = (
|
||||
self.train_exec_impl.shared_metrics.get().save())
|
||||
return state
|
||||
|
||||
def __setstate__(self, state):
|
||||
Trainer.__setstate__(self, state)
|
||||
self.state = state["trainer_state"].copy()
|
||||
if self.train_exec_impl:
|
||||
self.train_exec_impl.metrics.restore(state["train_exec_impl"])
|
||||
self.train_exec_impl.shared_metrics.get().restore(
|
||||
state["train_exec_impl"])
|
||||
|
||||
def with_updates(**overrides):
|
||||
"""Build a copy of this trainer with the specified overrides.
|
||||
|
||||
@@ -11,7 +11,7 @@ import time
|
||||
|
||||
import ray
|
||||
from ray.util.iter import from_actors, LocalIterator, _NextValueNotReady
|
||||
from ray.util.iter_metrics import MetricsContext
|
||||
from ray.util.iter_metrics import SharedMetrics
|
||||
from ray.rllib.optimizers.replay_buffer import PrioritizedReplayBuffer, \
|
||||
ReplayBuffer
|
||||
from ray.rllib.evaluation.metrics import collect_episodes, \
|
||||
@@ -109,7 +109,7 @@ def ParallelRollouts(workers: WorkerSet, mode="bulk_sync",
|
||||
while True:
|
||||
yield workers.local_worker().sample()
|
||||
|
||||
return (LocalIterator(sampler, MetricsContext())
|
||||
return (LocalIterator(sampler, SharedMetrics())
|
||||
.for_each(report_timesteps))
|
||||
|
||||
# Create a parallel iterator over generated experiences.
|
||||
@@ -316,25 +316,21 @@ class CollectMetrics:
|
||||
|
||||
# Add in iterator metrics.
|
||||
metrics = LocalIterator.get_metrics()
|
||||
if metrics.parent_metrics:
|
||||
print("TODO: support nested metrics better")
|
||||
all_metrics = [metrics] + metrics.parent_metrics
|
||||
timers = {}
|
||||
counters = {}
|
||||
info = {}
|
||||
for metrics in all_metrics:
|
||||
info.update(metrics.info)
|
||||
for k, counter in metrics.counters.items():
|
||||
counters[k] = counter
|
||||
for k, timer in metrics.timers.items():
|
||||
timers["{}_time_ms".format(k)] = round(timer.mean * 1000, 3)
|
||||
if timer.has_units_processed():
|
||||
timers["{}_throughput".format(k)] = round(
|
||||
timer.mean_throughput, 3)
|
||||
res.update({
|
||||
"num_healthy_workers": len(self.workers.remote_workers()),
|
||||
"timesteps_total": metrics.counters[STEPS_SAMPLED_COUNTER],
|
||||
})
|
||||
info.update(metrics.info)
|
||||
for k, counter in metrics.counters.items():
|
||||
counters[k] = counter
|
||||
for k, timer in metrics.timers.items():
|
||||
timers["{}_time_ms".format(k)] = round(timer.mean * 1000, 3)
|
||||
if timer.has_units_processed():
|
||||
timers["{}_throughput".format(k)] = round(
|
||||
timer.mean_throughput, 3)
|
||||
res.update({
|
||||
"num_healthy_workers": len(self.workers.remote_workers()),
|
||||
"timesteps_total": metrics.counters[STEPS_SAMPLED_COUNTER],
|
||||
})
|
||||
res["timers"] = timers
|
||||
res["info"] = info
|
||||
res["info"].update(counters)
|
||||
@@ -617,7 +613,7 @@ def LocalReplay(replay_buffer: ReplayBuffer, train_batch_size: int):
|
||||
})
|
||||
yield MultiAgentBatch(samples, train_batch_size)
|
||||
|
||||
return LocalIterator(gen_replay, MetricsContext())
|
||||
return LocalIterator(gen_replay, SharedMetrics())
|
||||
|
||||
|
||||
def Concurrently(ops: List[LocalIterator], mode="round_robin"):
|
||||
@@ -743,4 +739,4 @@ def Dequeue(input_queue: queue.Queue, check=lambda: True):
|
||||
yield _NextValueNotReady()
|
||||
raise RuntimeError("Error raised reading from queue")
|
||||
|
||||
return LocalIterator(base_iterator, MetricsContext())
|
||||
return LocalIterator(base_iterator, SharedMetrics())
|
||||
|
||||
Reference in New Issue
Block a user