diff --git a/.travis.yml b/.travis.yml index 8e7e13cc9..da15262fe 100644 --- a/.travis.yml +++ b/.travis.yml @@ -297,7 +297,7 @@ matrix: script: - if [ $RAY_CI_RLLIB_FULL_AFFECTED != "1" ]; then exit; fi - ./ci/keep_alive bazel test --build_tests_only --test_tag_filters=tests_dir_L,tests_dir_M --spawn_strategy=local --flaky_test_attempts=3 --nocache_test_results --test_verbose_timeout_warnings --progress_report_interval=100 --show_progress_rate_limit=100 --show_timestamps --test_output=errors rllib/... - - ./ci/keep_alive bazel test --build_tests_only --test_tag_filters=tests_dir_N,tests_dir_O --spawn_strategy=local --flaky_test_attempts=3 --nocache_test_results --test_verbose_timeout_warnings --progress_report_interval=100 --show_progress_rate_limit=100 --show_timestamps --test_output=errors rllib/... + - ./ci/keep_alive bazel test --build_tests_only --test_tag_filters=tests_dir_N,tests_dir_O,test_dir_P --spawn_strategy=local --flaky_test_attempts=3 --nocache_test_results --test_verbose_timeout_warnings --progress_report_interval=100 --show_progress_rate_limit=100 --show_timestamps --test_output=errors rllib/... - ./ci/keep_alive bazel test --build_tests_only --test_tag_filters=tests_dir_R,tests_dir_S --spawn_strategy=local --flaky_test_attempts=3 --nocache_test_results --test_verbose_timeout_warnings --progress_report_interval=100 --show_progress_rate_limit=100 --show_timestamps --test_output=errors rllib/... install: diff --git a/python/ray/tests/test_iter.py b/python/ray/tests/test_iter.py index 4aa097250..276b26a8c 100644 --- a/python/ray/tests/test_iter.py +++ b/python/ray/tests/test_iter.py @@ -1,9 +1,68 @@ import time from collections import Counter +import pytest import ray from ray.util.iter import from_items, from_iterators, from_range, \ - from_actors, ParallelIteratorWorker + from_actors, ParallelIteratorWorker, LocalIterator + + +def test_metrics(ray_start_regular_shared): + it = from_items([1, 2, 3, 4], num_shards=1) + it2 = from_items([1, 2, 3, 4], num_shards=1) + + def f(x): + metrics = LocalIterator.get_metrics() + metrics.counters["foo"] += x + return metrics.counters["foo"] + + it = it.gather_sync().for_each(f) + it2 = it2.gather_sync().for_each(f) + + # Context cannot be accessed outside the iterator. + with pytest.raises(ValueError): + LocalIterator.get_metrics() + + # Tests iterators have isolated contexts. + assert it.take(4) == [1, 3, 6, 10] + assert it2.take(4) == [1, 3, 6, 10] + + # Context cannot be accessed outside the iterator. + with pytest.raises(ValueError): + LocalIterator.get_metrics() + + +def test_metrics_union(ray_start_regular_shared): + it1 = from_items([1, 2, 3, 4], num_shards=1) + it2 = from_items([1, 2, 3, 4], num_shards=1) + + def foo_metrics(x): + metrics = LocalIterator.get_metrics() + metrics.counters["foo"] += x + return metrics.counters["foo"] + + def bar_metrics(x): + metrics = LocalIterator.get_metrics() + metrics.counters["bar"] += 100 + return metrics.counters["bar"] + + def verify_metrics(x): + metrics = LocalIterator.get_metrics() + metrics.counters["n"] += 1 + # Check the unioned iterator gets a new metric context. + assert "foo" not in metrics.counters + assert "bar" not in metrics.counters + # Check parent metrics are accessible. + if metrics.counters["n"] > 2: + assert "foo" in metrics.parent_metrics[0].counters + assert "bar" in metrics.parent_metrics[1].counters + return x + + it1 = it1.gather_async().for_each(foo_metrics) + it2 = it2.gather_async().for_each(bar_metrics) + it3 = it1.union(it2, deterministic=True) + it3 = it3.for_each(verify_metrics) + assert it3.take(10) == [1, 100, 3, 200, 6, 300, 10, 400] def test_from_items(ray_start_regular_shared): @@ -271,6 +330,5 @@ def test_serialization(ray_start_regular_shared): if __name__ == "__main__": - import pytest import sys sys.exit(pytest.main(["-v", __file__])) diff --git a/python/ray/util/iter.py b/python/ray/util/iter.py index 532b00d54..c2455f93f 100644 --- a/python/ray/util/iter.py +++ b/python/ray/util/iter.py @@ -1,8 +1,10 @@ import collections import random +import threading from typing import TypeVar, Generic, Iterable, List, Callable, Any import ray +from ray.util.iter_metrics import MetricsContext # The type of an iterator element. T = TypeVar("T") @@ -150,6 +152,8 @@ class ParallelIterator(Generic[T]): """ def __init__(self, actor_sets: List["_ActorSet"], name: str): + """Create a parallel iterator (this is an internal function).""" + # We track multiple sets of actors to support parallel .union(). self.actor_sets = actor_sets self.name = name @@ -415,7 +419,7 @@ class ParallelIterator(Generic[T]): futures = [a.par_iter_next.remote() for a in active] name = "{}.batch_across_shards()".format(self) - return LocalIterator(base_iterator, name=name) + return LocalIterator(base_iterator, MetricsContext(), name=name) def gather_async(self) -> "LocalIterator[T]": """Returns a local iterable for asynchronous iteration. @@ -433,6 +437,8 @@ class ParallelIterator(Generic[T]): ... 1 """ + metrics = MetricsContext() + def base_iterator(timeout=None): all_actors = [] for actor_set in self.actor_sets: @@ -456,6 +462,7 @@ class ParallelIterator(Generic[T]): for obj_id in ready: actor = futures.pop(obj_id) try: + metrics.cur_actor = actor yield ray.get(obj_id) futures[actor.par_iter_next.remote()] = actor except StopIteration: @@ -465,7 +472,7 @@ class ParallelIterator(Generic[T]): yield _NextValueNotReady() name = "{}.gather_async()".format(self) - return LocalIterator(base_iterator, name=name) + return LocalIterator(base_iterator, metrics, name=name) def take(self, n: int) -> List[T]: """Return up to the first n items from this iterator.""" @@ -528,7 +535,7 @@ class ParallelIterator(Generic[T]): break name = self.name + ".shard[{}]".format(shard_index) - return LocalIterator(base_iterator, name=name) + return LocalIterator(base_iterator, MetricsContext(), name=name) class LocalIterator(Generic[T]): @@ -541,8 +548,16 @@ class LocalIterator(Generic[T]): tasks and actors. However, it should be read from at most one process at a time.""" + # If a function passed to LocalIterator.for_each() has this method, + # we will call it at the beginning of each data fetch call. This can be + # used to measure the underlying wait latency for measurement purposes. + ON_FETCH_START_HOOK_NAME = "_on_fetch_start" + + thread_local = threading.local() + def __init__(self, base_iterator: Callable[[], Iterable[T]], + metrics: MetricsContext, local_transforms: List[Callable[[Iterable], Any]] = None, timeout: int = None, name=None): @@ -552,6 +567,8 @@ class LocalIterator(Generic[T]): base_iterator (func): A function that produces the base iterator. This is a function so that we can ensure LocalIterator is serializable. + metrics (MetricsContext): Existing metrics context or a new + context. Should be the same for each chained iterator. local_transforms (list): A list of transformation functions to be applied on top of the base iterator. When iteration begins, we create the base iterator and apply these functions. This lazy @@ -565,14 +582,43 @@ class LocalIterator(Generic[T]): self.base_iterator = base_iterator self.built_iterator = None self.local_transforms = local_transforms or [] + self.metrics = metrics self.timeout = timeout self.name = name or "unknown" + @staticmethod + def get_metrics() -> MetricsContext: + """Return the current metrics context. + + This can only be called within an iterator function.""" + if (not hasattr(LocalIterator.thread_local, "metrics") + or LocalIterator.thread_local.metrics is None): + raise ValueError("Cannot access context outside an iterator.") + return LocalIterator.thread_local.metrics + def _build_once(self): if self.built_iterator is None: it = iter(self.base_iterator(self.timeout)) for fn in self.local_transforms: it = fn(it) + + # This sets the iterator context during iterator execution, and + # clears it after so that multiple iterators can be used at a time. + def set_restore_context(it): + if hasattr(self.thread_local, "metrics"): + prev_metrics = self.thread_local.metrics + else: + prev_metrics = None + self.thread_local.metrics = self.metrics + try: + for item in it: + self.thread_local.metrics = prev_metrics + yield item + self.thread_local.metrics = self.metrics + finally: + self.thread_local.metrics = prev_metrics + + it = set_restore_context(it) self.built_iterator = it def __iter__(self): @@ -597,8 +643,28 @@ class LocalIterator(Generic[T]): else: yield fn(item) + if hasattr(fn, LocalIterator.ON_FETCH_START_HOOK_NAME): + unwrapped = apply_foreach + + def add_wait_hooks(it): + it = unwrapped(it) + new_item = True + while True: + # Avoids calling on_fetch_start repeatedly if we are + # yielding _NextValueNotReady. + if new_item: + fn._on_fetch_start() + new_item = False + item = next(it) + if not isinstance(item, _NextValueNotReady): + new_item = True + yield item + + apply_foreach = add_wait_hooks + return LocalIterator( self.base_iterator, + self.metrics, self.local_transforms + [apply_foreach], name=self.name + ".for_each()") @@ -610,6 +676,7 @@ class LocalIterator(Generic[T]): return LocalIterator( self.base_iterator, + self.metrics, self.local_transforms + [apply_filter], name=self.name + ".filter()") @@ -629,6 +696,7 @@ class LocalIterator(Generic[T]): return LocalIterator( self.base_iterator, + self.metrics, self.local_transforms + [apply_batch], name=self.name + ".batch({})".format(n)) @@ -643,6 +711,7 @@ class LocalIterator(Generic[T]): return LocalIterator( self.base_iterator, + self.metrics, self.local_transforms + [apply_flatten], name=self.name + ".flatten()") @@ -680,6 +749,7 @@ class LocalIterator(Generic[T]): return LocalIterator( self.base_iterator, + self.metrics, self.local_transforms + [apply_shuffle], name=self.name + ".shuffle(shuffle_buffer_size={}, seed={})".format( @@ -709,12 +779,13 @@ class LocalIterator(Generic[T]): if i >= n: break - def union(self, other: "LocalIterator[T]") -> "LocalIterator[T]": + def union(self, other: "LocalIterator[T]", + deterministic: bool = False) -> "LocalIterator[T]": """Return an iterator that is the union of this and the other. - There are no ordering guarantees between the two iterators. We make a - best-effort attempt to return items from both as they become ready, - preventing starvation of any particular iterator. + If deterministic=True, we alternate between reading from one iterator + and the other. Otherwise we return items from iterators as they + become ready. """ if not isinstance(other, LocalIterator): @@ -722,10 +793,21 @@ class LocalIterator(Generic[T]): "other must be of type LocalIterator, got {}".format( type(other))) + if deterministic: + timeout = None + else: + timeout = 0 + it1 = LocalIterator( - self.base_iterator, self.local_transforms, timeout=0) + self.base_iterator, + self.metrics, + self.local_transforms, + timeout=timeout) it2 = LocalIterator( - other.base_iterator, other.local_transforms, timeout=0) + other.base_iterator, + other.metrics, + other.local_transforms, + timeout=timeout) active = [it1, it2] def build_union(timeout=None): @@ -740,13 +822,22 @@ class LocalIterator(Generic[T]): break else: yield item + if deterministic: + break except StopIteration: active.remove(it) if not active: break + # TODO(ekl) is this the best way to represent union() of metrics? + new_ctx = MetricsContext() + new_ctx.parent_metrics.append(self.metrics) + new_ctx.parent_metrics.append(other.metrics) + return LocalIterator( - build_union, [], name="LocalUnion[{}, {}]".format(self, other)) + build_union, + new_ctx, [], + name="LocalUnion[{}, {}]".format(self, other)) class ParallelIteratorWorker(object): @@ -792,7 +883,8 @@ class ParallelIteratorWorker(object): def par_iter_init(self, transforms): """Implements ParallelIterator worker init.""" - it = LocalIterator(lambda timeout: self.item_generator) + it = LocalIterator(lambda timeout: self.item_generator, + MetricsContext()) for fn in transforms: it = fn(it) assert it is not None, fn diff --git a/python/ray/util/iter_metrics.py b/python/ray/util/iter_metrics.py new file mode 100644 index 000000000..de9f1de88 --- /dev/null +++ b/python/ray/util/iter_metrics.py @@ -0,0 +1,48 @@ +import collections + +from ray.util.timer import _Timer + + +class MetricsContext: + """Metrics context object for a local iterator. + + This object is accessible by all operators of a local iterator. It can be + used to store and retrieve global execution metrics for the iterator. + It can be accessed by calling LocalIterator.get_metrics(), which is only + allowable inside iterator functions. + + Attributes: + counters (defaultdict): dict storing increasing metrics. + timers (defaultdict): dict storing latency timers. + info (dict): dict storing misc metric values. + current_actor (ActorHandle): reference to the actor handle that + produced the current iterator output. This is automatically set + for gather_async(). + parent_metrics (list): list of other MetricsContexts that have been + attached to this due to LocalIterator.union(). + """ + + def __init__(self): + self.counters = collections.defaultdict(int) + self.timers = collections.defaultdict(_Timer) + self.info = {} + self.current_actor = None + self.parent_metrics = [] + + def save(self): + """Return a serializable copy of this context.""" + return { + "counters": dict(self.counters), + "info": dict(self.info), + "timers": None, # TODO(ekl) consider persisting timers too + "parent_state": [u.save() for u in self.parent_metrics], + } + + def restore(self, values): + """Restores state given the output of save().""" + self.counters.clear() + self.counters.update(values["counters"]) + self.timers.clear() + self.info = values["info"] + for u, state in zip(self.parent_metrics, values["parent_state"]): + u.restore(state) diff --git a/python/ray/util/timer.py b/python/ray/util/timer.py new file mode 100644 index 000000000..dc1d1fca7 --- /dev/null +++ b/python/ray/util/timer.py @@ -0,0 +1,62 @@ +import numpy as np +import time + + +class _Timer: + """A running stat for conveniently logging the duration of a code block. + + Example: + wait_timer = TimerStat() + with wait_timer: + ray.wait(...) + + Note that this class is *not* thread-safe. + """ + + def __init__(self, window_size=10): + self._window_size = window_size + self._samples = [] + self._units_processed = [] + self._start_time = None + self._total_time = 0.0 + self.count = 0 + + def __enter__(self): + assert self._start_time is None, "concurrent updates not supported" + self._start_time = time.time() + + def __exit__(self, exc_type, exc_value, tb): + assert self._start_time is not None + time_delta = time.time() - self._start_time + self.push(time_delta) + self._start_time = None + + def push(self, time_delta): + self._samples.append(time_delta) + if len(self._samples) > self._window_size: + self._samples.pop(0) + self.count += 1 + self._total_time += time_delta + + def push_units_processed(self, n): + self._units_processed.append(n) + if len(self._units_processed) > self._window_size: + self._units_processed.pop(0) + + def has_units_processed(self): + return len(self._units_processed) > 0 + + @property + def mean(self): + return np.mean(self._samples) + + @property + def mean_units_processed(self): + return float(np.mean(self._units_processed)) + + @property + def mean_throughput(self): + time_total = sum(self._samples) + if not time_total: + return 0.0 + return sum(self._units_processed) / time_total diff --git a/rllib/BUILD b/rllib/BUILD index 3cfc72f78..d1c340235 100644 --- a/rllib/BUILD +++ b/rllib/BUILD @@ -999,6 +999,13 @@ py_test( srcs = ["tests/test_optimizers.py"] ) +py_test( + name = "tests/test_pipeline", + tags = ["tests_dir", "tests_dir_P"], + size = "small", + srcs = ["tests/test_pipeline.py"] +) + py_test( name = "tests/test_reproducibility", tags = ["tests_dir", "tests_dir_R"], diff --git a/rllib/agents/a3c/__init__.py b/rllib/agents/a3c/__init__.py index 2be1ed9cc..9467b770c 100644 --- a/rllib/agents/a3c/__init__.py +++ b/rllib/agents/a3c/__init__.py @@ -1,16 +1,8 @@ from ray.rllib.agents.a3c.a3c import A3CTrainer, DEFAULT_CONFIG from ray.rllib.agents.a3c.a2c import A2CTrainer from ray.rllib.agents.a3c.a2c_pipeline import A2CPipeline -from ray.rllib.utils import renamed_agent - -A2CAgent = renamed_agent(A2CTrainer) -A3CAgent = renamed_agent(A3CTrainer) +from ray.rllib.agents.a3c.a3c_pipeline import A3CPipeline __all__ = [ - "A2CAgent", - "A3CAgent", - "A2CTrainer", - "A3CTrainer", - "DEFAULT_CONFIG", - "A2CPipeline" + "A2CTrainer", "A3CTrainer", "DEFAULT_CONFIG", "A2CPipeline", "A3CPipeline" ] diff --git a/rllib/agents/a3c/a3c_pipeline.py b/rllib/agents/a3c/a3c_pipeline.py new file mode 100644 index 000000000..359c729a8 --- /dev/null +++ b/rllib/agents/a3c/a3c_pipeline.py @@ -0,0 +1,19 @@ +"""Experimental pipeline-based impl; run this with --run='A3C_pl'""" + +from ray.rllib.agents.a3c.a3c import A3CTrainer +from ray.rllib.utils.experimental_dsl import (AsyncGradients, ApplyGradients, + StandardMetricsReporting) + + +def training_pipeline(workers, config): + # For A3C, compute policy gradients remotely on the rollout workers. + grads = AsyncGradients(workers) + + # Apply the gradients as they arrive. We set update_all to False so that + # only the worker sending the gradient is updated with new weights. + train_op = grads.for_each(ApplyGradients(workers, update_all=False)) + + return StandardMetricsReporting(train_op, workers, config) + + +A3CPipeline = A3CTrainer.with_updates(training_pipeline=training_pipeline) diff --git a/rllib/agents/registry.py b/rllib/agents/registry.py index 1d98fc9ac..b8b7de285 100644 --- a/rllib/agents/registry.py +++ b/rllib/agents/registry.py @@ -105,6 +105,11 @@ def _import_a2c_pipeline(): return a3c.A2CPipeline +def _import_a3c_pipeline(): + from ray.rllib.agents import a3c + return a3c.A3CPipeline + + def _import_pg_pipeline(): from ray.rllib.agents import pg return pg.PGPipeline @@ -133,6 +138,7 @@ ALGORITHMS = { # Experimental pipeline-based impls. "A2C_pl": _import_a2c_pipeline, + "A3C_pl": _import_a3c_pipeline, "PG_pl": _import_pg_pipeline, } diff --git a/rllib/agents/trainer.py b/rllib/agents/trainer.py index f43d33df9..a00f9ab50 100644 --- a/rllib/agents/trainer.py +++ b/rllib/agents/trainer.py @@ -260,11 +260,11 @@ COMMON_CONFIG = { # but optimal value could be obtained by measuring your environment # step / reset and model inference perf. "remote_env_batch_wait_ms": 0, - # Minimum time per train iteration + # Minimum time per train iteration (frequency of metrics reporting). "min_iter_time_s": 0, # Minimum env steps to optimize for per train call. This value does # not affect learning, only the length of train iterations. - "timesteps_per_iteration": 0, + "timesteps_per_iteration": 0, # TODO(ekl) deprecate this # This argument, in conjunction with worker_index, sets the random seed of # each worker, so that identically configured trials will have identical # results. This makes experiments reproducible. @@ -613,7 +613,7 @@ class Trainer(Trainable): def _stop(self): if hasattr(self, "workers"): self.workers.stop() - if hasattr(self, "optimizer"): + if hasattr(self, "optimizer") and self.optimizer: self.optimizer.stop() @override(Trainable) diff --git a/rllib/agents/trainer_template.py b/rllib/agents/trainer_template.py index f9f33a7f1..d5734cbc8 100644 --- a/rllib/agents/trainer_template.py +++ b/rllib/agents/trainer_template.py @@ -167,11 +167,15 @@ def build_trainer(name, def __getstate__(self): state = Trainer.__getstate__(self) state["trainer_state"] = self.state.copy() + if self.train_pipeline: + state["train_pipeline"] = self.train_pipeline.metrics.save() return state def __setstate__(self, state): Trainer.__setstate__(self, state) self.state = state["trainer_state"].copy() + if self.train_pipeline: + self.train_pipeline.metrics.restore(state["train_pipeline"]) def with_updates(**overrides): """Build a copy of this trainer with the specified overrides. diff --git a/rllib/tests/test_pipeline.py b/rllib/tests/test_pipeline.py new file mode 100644 index 000000000..03c2db218 --- /dev/null +++ b/rllib/tests/test_pipeline.py @@ -0,0 +1,49 @@ +import unittest + +import ray +from ray.rllib.agents.a3c import a2c_pipeline + + +class TestPipeline(unittest.TestCase): + """General tests for the pipeline API.""" + + def setUp(self): + ray.init() + + def tearDown(self): + ray.shutdown() + + def test_pipeline_stats(ray_start_regular): + trainer = a2c_pipeline.A2CPipeline( + env="CartPole-v0", config={"min_iter_time_s": 0}) + result = trainer.train() + assert isinstance(result, dict) + assert "info" in result + assert "learner" in result["info"] + assert "num_steps_sampled" in result["info"] + assert "num_steps_trained" in result["info"] + assert "timers" in result + assert "learn_time_ms" in result["timers"] + assert "learn_throughput" in result["timers"] + assert "sample_time_ms" in result["timers"] + assert "sample_throughput" in result["timers"] + assert "update_time_ms" in result["timers"] + + def test_pipeline_save_restore(ray_start_regular): + trainer = a2c_pipeline.A2CPipeline( + env="CartPole-v0", config={"min_iter_time_s": 0}) + res1 = trainer.train() + checkpoint = trainer.save() + res2 = trainer.train() + assert res2["timesteps_total"] > res1["timesteps_total"], (res1, res2) + trainer.restore(checkpoint) + + # Should restore the timesteps counter to the same as res2. + res3 = trainer.train() + assert res3["timesteps_total"] == res2["timesteps_total"], (res2, res3) + + +if __name__ == "__main__": + import pytest + import sys + sys.exit(pytest.main(["-v", __file__])) diff --git a/rllib/utils/experimental_dsl.py b/rllib/utils/experimental_dsl.py index 598e700bd..e324c2b26 100644 --- a/rllib/utils/experimental_dsl.py +++ b/rllib/utils/experimental_dsl.py @@ -2,14 +2,34 @@ TODO(ekl): describe the concepts.""" -from typing import List, Any +import logging +from typing import List, Any, Tuple import time import ray from ray.util.iter import from_actors, LocalIterator +from ray.util.iter_metrics import MetricsContext from ray.rllib.evaluation.metrics import collect_episodes, summarize_episodes +from ray.rllib.evaluation.rollout_worker import get_global_worker from ray.rllib.evaluation.worker_set import WorkerSet from ray.rllib.policy.sample_batch import SampleBatch +from ray.rllib.policy.policy import LEARNER_STATS_KEY + +logger = logging.getLogger(__name__) + +# Metrics context key definitions. +STEPS_SAMPLED_COUNTER = "num_steps_sampled" +STEPS_TRAINED_COUNTER = "num_steps_trained" +APPLY_GRADS_TIMER = "apply_grad" +COMPUTE_GRADS_TIMER = "compute_grads" +WORKER_UPDATE_TIMER = "update" +GRAD_WAIT_TIMER = "grad_wait" +SAMPLE_TIMER = "sample" +LEARN_ON_BATCH_TIMER = "learn" +LEARNER_INFO = "learner" + +# Type aliases. +GradientType = dict def ParallelRollouts(workers: WorkerSet, @@ -40,15 +60,23 @@ def ParallelRollouts(workers: WorkerSet, >>> batch = next(rollouts) >>> print(batch.count) 200 # config.sample_batch_size * config.num_workers + + Updates the STEPS_SAMPLED_COUNTER counter in the local iterator context. """ + def report_timesteps(batch): + metrics = LocalIterator.get_metrics() + metrics.counters[STEPS_SAMPLED_COUNTER] += batch.count + return batch + if not workers.remote_workers(): # Handle the serial sampling case. def sampler(_): while True: yield workers.local_worker().sample() - return LocalIterator(sampler) + return (LocalIterator(sampler, MetricsContext()) + .for_each(report_timesteps)) # Create a parallel iterator over generated experiences. rollouts = from_actors(workers.remote_workers()) @@ -56,16 +84,59 @@ def ParallelRollouts(workers: WorkerSet, if mode == "bulk_sync": return rollouts \ .batch_across_shards() \ - .for_each(lambda batches: SampleBatch.concat_samples(batches)) + .for_each(lambda batches: SampleBatch.concat_samples(batches)) \ + .for_each(report_timesteps) elif mode == "async": - return rollouts.gather_async() + return rollouts.gather_async().for_each(report_timesteps) else: raise ValueError( "mode must be one of 'bulk_sync', 'async', got '{}'".format(mode)) +def AsyncGradients( + workers: WorkerSet) -> LocalIterator[Tuple[GradientType, int]]: + """Operator to compute gradients in parallel from rollout workers. + + Arguments: + workers (WorkerSet): set of rollout workers to use. + + Returns: + A local iterator over policy gradients computed on rollout workers. + + Examples: + >>> grads_op = AsyncGradients(workers) + >>> print(next(grads_op)) + {"var_0": ..., ...}, 50 # grads, batch count + + Updates the STEPS_SAMPLED_COUNTER counter and LEARNER_INFO field in the + local iterator context. + """ + + # This function will be applied remotely on the workers. + def samples_to_grads(samples): + return get_global_worker().compute_gradients(samples), samples.count + + # Record learner metrics and pass through (grads, count). + class record_metrics: + def _on_fetch_start(self): + self.fetch_start_time = time.perf_counter() + + def __call__(self, item): + (grads, info), count = item + metrics = LocalIterator.get_metrics() + metrics.counters[STEPS_SAMPLED_COUNTER] += count + metrics.info[LEARNER_INFO] = info[LEARNER_STATS_KEY] + metrics.timers[GRAD_WAIT_TIMER].push(time.perf_counter() - + self.fetch_start_time) + return grads, count + + rollouts = from_actors(workers.remote_workers()) + grads = rollouts.for_each(samples_to_grads) + return grads.gather_async().for_each(record_metrics()) + + def StandardMetricsReporting(train_op: LocalIterator[Any], workers: WorkerSet, - config: dict): + config: dict) -> LocalIterator[dict]: """Operator to periodically collect and report metrics. Arguments: @@ -86,7 +157,7 @@ def StandardMetricsReporting(train_op: LocalIterator[Any], workers: WorkerSet, """ output_op = train_op \ - .filter(OncePerTimeInterval(config["min_iter_time_s"])) \ + .filter(OncePerTimeInterval(max(2, config["min_iter_time_s"]))) \ .for_each(CollectMetrics( workers, min_history=config["metrics_smoothing_episodes"], timeout_seconds=config["collect_metrics_timeout"])) @@ -109,6 +180,11 @@ class ConcatBatches: self.min_batch_size = min_batch_size self.buffer = [] self.count = 0 + self.batch_start_time = None + + def _on_fetch_start(self): + if self.batch_start_time is None: + self.batch_start_time = time.perf_counter() def __call__(self, batch: SampleBatch) -> List[SampleBatch]: if not isinstance(batch, SampleBatch): @@ -118,6 +194,10 @@ class ConcatBatches: self.count += batch.count if self.count >= self.min_batch_size: out = SampleBatch.concat_samples(self.buffer) + timer = LocalIterator.get_metrics().timers[SAMPLE_TIMER] + timer.push(time.perf_counter() - self.batch_start_time) + timer.push_units_processed(self.count) + self.batch_start_time = None self.buffer = [] self.count = 0 return [out] @@ -133,18 +213,28 @@ class TrainOneStep: >>> rollouts = ParallelRollouts(...) >>> train_op = rollouts.for_each(TrainOneStep(workers)) >>> print(next(train_op)) # This trains the policy on one batch. - {"learner_stats": {"policy_loss": ...}} + None + + Updates the STEPS_TRAINED_COUNTER counter and LEARNER_INFO field in the + local iterator context. """ def __init__(self, workers: WorkerSet): self.workers = workers def __call__(self, batch: SampleBatch) -> List[dict]: - info = self.workers.local_worker().learn_on_batch(batch) + metrics = LocalIterator.get_metrics() + learn_timer = metrics.timers[LEARN_ON_BATCH_TIMER] + with learn_timer: + info = self.workers.local_worker().learn_on_batch(batch) + learn_timer.push_units_processed(batch.count) + metrics.counters[STEPS_TRAINED_COUNTER] += batch.count + metrics.info[LEARNER_INFO] = info[LEARNER_STATS_KEY] if self.workers.remote_workers(): - weights = ray.put(self.workers.local_worker().get_weights()) - for e in self.workers.remote_workers(): - e.set_weights.remote(weights) + with metrics.timers[WORKER_UPDATE_TIMER]: + weights = ray.put(self.workers.local_worker().get_weights()) + for e in self.workers.remote_workers(): + e.set_weights.remote(weights) return info @@ -169,7 +259,10 @@ class CollectMetrics: self.min_history = min_history self.timeout_seconds = timeout_seconds - def __call__(self, info): + def __call__(self, _): + metrics = LocalIterator.get_metrics() + if metrics.parent_metrics: + raise ValueError("TODO: support nested metrics") episodes, self.to_be_collected = collect_episodes( self.workers.local_worker(), self.workers.remote_workers(), @@ -183,7 +276,22 @@ class CollectMetrics: self.episode_history.extend(orig_episodes) self.episode_history = self.episode_history[-self.min_history:] res = summarize_episodes(episodes, orig_episodes) - res.update(info=info) + res.update(info=metrics.info) + res["info"].update({ + STEPS_SAMPLED_COUNTER: metrics.counters[STEPS_SAMPLED_COUNTER], + STEPS_TRAINED_COUNTER: metrics.counters[STEPS_TRAINED_COUNTER], + }) + timers = {} + for k, timer in metrics.timers.items(): + timers["{}_time_ms".format(k)] = round(timer.mean * 1000, 3) + if timer.has_units_processed(): + timers["{}_throughput".format(k)] = round( + timer.mean_throughput, 3) + res["timers"] = timers + res.update({ + "num_healthy_workers": len(self.workers.remote_workers()), + "timesteps_total": metrics.counters[STEPS_SAMPLED_COUNTER], + }) return res @@ -222,15 +330,20 @@ class ComputeGradients: Examples: >>> grads_op = rollouts.for_each(ComputeGradients(workers)) >>> print(next(grads_op)) - {"var_0": ..., ...}, {"learner_stats": ...} # grads, learner info + {"var_0": ..., ...}, 50 # grads, batch count + + Updates the LEARNER_INFO info field in the local iterator context. """ def __init__(self, workers): self.workers = workers def __call__(self, samples): - grad, info = self.workers.local_worker().compute_gradients(samples) - return grad, info + metrics = LocalIterator.get_metrics() + with metrics.timers[COMPUTE_GRADS_TIMER]: + grad, info = self.workers.local_worker().compute_gradients(samples) + metrics.info[LEARNER_INFO] = info[LEARNER_STATS_KEY] + return grad, samples.count class ApplyGradients: @@ -241,20 +354,52 @@ class ApplyGradients: Examples: >>> apply_op = grads_op.for_each(ApplyGradients(workers)) >>> print(next(apply_op)) - {"learner_stats": ...} # learner info + None + + Updates the STEPS_TRAINED_COUNTER counter in the local iterator context. """ - def __init__(self, workers): + def __init__(self, workers, update_all=True): + """Creates an ApplyGradients instance. + + Arguments: + workers (WorkerSet): workers to apply gradients to. + update_all (bool): If true, updates all workers. Otherwise, only + update the worker that produced the sample batch we are + currently processing (i.e., A3C style). + """ self.workers = workers + self.update_all = update_all def __call__(self, item): - gradients, info = item - self.workers.local_worker().apply_gradients(gradients) - if self.workers.remote_workers(): - weights = ray.put(self.workers.local_worker().get_weights()) - for e in self.workers.remote_workers(): - e.set_weights.remote(weights) - return info + if not isinstance(item, tuple) or len(item) != 2: + raise ValueError( + "Input must be a tuple of (grad_dict, count), got {}".format( + item)) + gradients, count = item + metrics = LocalIterator.get_metrics() + metrics.counters[STEPS_TRAINED_COUNTER] += count + + apply_timer = metrics.timers[APPLY_GRADS_TIMER] + with apply_timer: + self.workers.local_worker().apply_gradients(gradients) + apply_timer.push_units_processed(count) + + if self.update_all: + if self.workers.remote_workers(): + with metrics.timers[WORKER_UPDATE_TIMER]: + weights = ray.put( + self.workers.local_worker().get_weights()) + for e in self.workers.remote_workers(): + e.set_weights.remote(weights) + else: + if metrics.cur_actor is None: + raise ValueError("Could not find actor to update. When " + "update_all=False, `cur_actor` must be set " + "in the iterator context.") + with metrics.timers[WORKER_UPDATE_TIMER]: + weights = self.workers.local_worker().get_weights() + metrics.cur_actor.set_weights.remote(weights) class AverageGradients: @@ -267,14 +412,18 @@ class AverageGradients: >>> batched_grads = grads_op.batch(32) >>> avg_grads = batched_grads.for_each(AverageGradients()) >>> print(next(avg_grads)) - {"var_0": ..., ...}, {"learner_stats": ...} # avg grads, last info + {"var_0": ..., ...}, 1600 # averaged grads, summed batch count """ def __call__(self, gradients): acc = None - for grad, info in gradients: + sum_count = 0 + for grad, count in gradients: if acc is None: acc = grad else: acc = [a + b for a, b in zip(acc, grad)] - return acc, info + sum_count += count + logger.info("Computing average of {} microbatch gradients " + "({} samples total)".format(len(gradients), sum_count)) + return acc, sum_count diff --git a/rllib/utils/timer.py b/rllib/utils/timer.py index eca95b71f..448ae2656 100644 --- a/rllib/utils/timer.py +++ b/rllib/utils/timer.py @@ -1,59 +1,3 @@ -import numpy as np -import time +from ray.util.timer import _Timer - -class TimerStat: - """A running stat for conveniently logging the duration of a code block. - - Example: - wait_timer = TimerStat() - with wait_timer: - ray.wait(...) - - Note that this class is *not* thread-safe. - """ - - def __init__(self, window_size=10): - self._window_size = window_size - self._samples = [] - self._units_processed = [] - self._start_time = None - self._total_time = 0.0 - self.count = 0 - - def __enter__(self): - assert self._start_time is None, "concurrent updates not supported" - self._start_time = time.time() - - def __exit__(self, type, value, tb): - assert self._start_time is not None - time_delta = time.time() - self._start_time - self.push(time_delta) - self._start_time = None - - def push(self, time_delta): - self._samples.append(time_delta) - if len(self._samples) > self._window_size: - self._samples.pop(0) - self.count += 1 - self._total_time += time_delta - - def push_units_processed(self, n): - self._units_processed.append(n) - if len(self._units_processed) > self._window_size: - self._units_processed.pop(0) - - @property - def mean(self): - return np.mean(self._samples) - - @property - def mean_units_processed(self): - return float(np.mean(self._units_processed)) - - @property - def mean_throughput(self): - time_total = sum(self._samples) - if not time_total: - return 0.0 - return sum(self._units_processed) / time_total +TimerStat = _Timer # backwards compatibility alias