From f5d12a958be73349fb36c92536f3576d40d0d0fc Mon Sep 17 00:00:00 2001 From: Eric Liang Date: Thu, 12 Mar 2020 00:54:08 -0700 Subject: [PATCH] [rllib] Port Ape-X to distributed execution API (#7497) --- .travis.yml | 8 +- python/ray/tests/test_iter.py | 20 +- python/ray/util/iter.py | 39 +++- rllib/BUILD | 6 +- rllib/agents/a3c/a2c.py | 6 +- rllib/agents/a3c/a3c.py | 6 +- rllib/agents/a3c/a3c_pipeline.py | 19 -- rllib/agents/a3c/tests/test_a2c.py | 10 +- rllib/agents/dqn/apex.py | 103 +++++++++- rllib/agents/dqn/dqn.py | 6 +- rllib/agents/pg/pg.py | 6 +- rllib/agents/pg/tests/test_pg.py | 4 +- rllib/agents/trainer.py | 6 +- rllib/agents/trainer_template.py | 36 ++-- rllib/optimizers/async_replay_optimizer.py | 41 ++-- .../{test_pipeline.py => test_exec_api.py} | 17 +- rllib/utils/experimental_dsl.py | 184 ++++++++++++++++-- 17 files changed, 401 insertions(+), 116 deletions(-) delete mode 100644 rllib/agents/a3c/a3c_pipeline.py rename rllib/tests/{test_pipeline.py => test_exec_api.py} (76%) diff --git a/.travis.yml b/.travis.yml index 8ee95bc0f..7fda0410d 100644 --- a/.travis.yml +++ b/.travis.yml @@ -274,9 +274,7 @@ matrix: - ./ci/suppress_output ./ci/travis/install-ray.sh script: - if [ $RAY_CI_RLLIB_FULL_AFFECTED != "1" ]; then exit; fi - - ./ci/keep_alive bazel test --build_tests_only --test_tag_filters=tests_dir_A,tests_dir_C,tests_dir_D --spawn_strategy=local --flaky_test_attempts=3 --nocache_test_results --test_verbose_timeout_warnings --progress_report_interval=100 --show_progress_rate_limit=100 --show_timestamps --test_output=errors rllib/... - - ./ci/keep_alive bazel test --build_tests_only --test_tag_filters=tests_dir_E --spawn_strategy=local --flaky_test_attempts=3 --nocache_test_results --test_verbose_timeout_warnings --progress_report_interval=100 --show_progress_rate_limit=100 --show_timestamps --test_output=errors rllib/... - - ./ci/keep_alive bazel test --build_tests_only --test_tag_filters=tests_dir_F,tests_dir_I --spawn_strategy=local --flaky_test_attempts=3 --nocache_test_results --test_verbose_timeout_warnings --progress_report_interval=100 --show_progress_rate_limit=100 --show_timestamps --test_output=errors rllib/... + - ./ci/keep_alive bazel test --build_tests_only --test_tag_filters=tests_dir_A,tests_dir_B,tests_dir_C,tests_dir_D,tests_dir_E,tests_dir_F,tests_dir_G,tests_dir_H,tests_dir_I --spawn_strategy=local --flaky_test_attempts=3 --nocache_test_results --test_verbose_timeout_warnings --progress_report_interval=100 --show_progress_rate_limit=100 --show_timestamps --test_output=errors rllib/... # RLlib: tests_dir: Everything in rllib/tests/ directory (J-Z). - os: linux @@ -296,9 +294,7 @@ matrix: - ./ci/suppress_output ./ci/travis/install-ray.sh script: - if [ $RAY_CI_RLLIB_FULL_AFFECTED != "1" ]; then exit; fi - - ./ci/keep_alive bazel test --build_tests_only --test_tag_filters=tests_dir_L,tests_dir_M --spawn_strategy=local --flaky_test_attempts=3 --nocache_test_results --test_verbose_timeout_warnings --progress_report_interval=100 --show_progress_rate_limit=100 --show_timestamps --test_output=errors rllib/... - - ./ci/keep_alive bazel test --build_tests_only --test_tag_filters=tests_dir_N,tests_dir_O,test_dir_P --spawn_strategy=local --flaky_test_attempts=3 --nocache_test_results --test_verbose_timeout_warnings --progress_report_interval=100 --show_progress_rate_limit=100 --show_timestamps --test_output=errors rllib/... - - ./ci/keep_alive bazel test --build_tests_only --test_tag_filters=tests_dir_R,tests_dir_S --spawn_strategy=local --flaky_test_attempts=3 --nocache_test_results --test_verbose_timeout_warnings --progress_report_interval=100 --show_progress_rate_limit=100 --show_timestamps --test_output=errors rllib/... + - ./ci/keep_alive bazel test --build_tests_only --test_tag_filters=tests_dir_J,tests_dir_K,tests_dir_L,tests_dir_M,tests_dir_N,tests_dir_O,tests_dir_P,tests_dir_Q,tests_dir_R,tests_dir_S,tests_dir_T,tests_dir_U,tests_dir_V,tests_dir_W,tests_dir_X,tests_dir_Y,tests_dir_Z --spawn_strategy=local --flaky_test_attempts=3 --nocache_test_results --test_verbose_timeout_warnings --progress_report_interval=100 --show_progress_rate_limit=100 --show_timestamps --test_output=errors rllib/... install: - eval `python $TRAVIS_BUILD_DIR/ci/travis/determine_tests_to_run.py` diff --git a/python/ray/tests/test_iter.py b/python/ray/tests/test_iter.py index 86f17d7e9..c31064e3b 100644 --- a/python/ray/tests/test_iter.py +++ b/python/ray/tests/test_iter.py @@ -1,4 +1,5 @@ import time +import collections from collections import Counter import pytest @@ -32,6 +33,16 @@ def test_metrics(ray_start_regular_shared): LocalIterator.get_metrics() +def test_zip_with_source_actor(ray_start_regular_shared): + it = from_items([1, 2, 3, 4], num_shards=2) + counts = collections.defaultdict(int) + for actor, value in it.gather_async().zip_with_source_actor(): + counts[actor] += 1 + assert len(counts) == 2 + for a, count in counts.items(): + assert count == 2 + + def test_metrics_union(ray_start_regular_shared): it1 = from_items([1, 2, 3, 4], num_shards=1) it2 = from_items([1, 2, 3, 4], num_shards=1) @@ -49,7 +60,8 @@ def test_metrics_union(ray_start_regular_shared): def verify_metrics(x): metrics = LocalIterator.get_metrics() metrics.counters["n"] += 1 - if metrics.counters["n"] > 2: + # Check the metrics context is shared. + if metrics.counters["n"] >= 2: assert "foo" in metrics.counters assert "bar" in metrics.counters return x @@ -238,6 +250,12 @@ def test_gather_async(ray_start_regular_shared): assert sorted(it) == [0, 1, 2, 3] +def test_gather_async_queue(ray_start_regular_shared): + it = from_range(100) + it = it.gather_async(async_queue_depth=4) + assert sorted(it) == list(range(100)) + + def test_batch_across_shards(ray_start_regular_shared): it = from_iterators([[0, 1], [2, 3]]) it = it.batch_across_shards() diff --git a/python/ray/util/iter.py b/python/ray/util/iter.py index 514dfbc04..31a37c838 100644 --- a/python/ray/util/iter.py +++ b/python/ray/util/iter.py @@ -414,12 +414,17 @@ class ParallelIterator(Generic[T]): name = "{}.batch_across_shards()".format(self) return LocalIterator(base_iterator, MetricsContext(), name=name) - def gather_async(self) -> "LocalIterator[T]": + def gather_async(self, async_queue_depth=1) -> "LocalIterator[T]": """Returns a local iterable for asynchronous iteration. New items will be fetched from the shards asynchronously as soon as the previous one is computed. Items arrive in non-deterministic order. + Arguments: + async_queue_depth (int): The max number of async requests in flight + per actor. Increasing this improves the amount of pipeline + parallelism in the iterator. + Examples: >>> it = from_range(100, 1).gather_async() >>> next(it) @@ -430,16 +435,19 @@ class ParallelIterator(Generic[T]): ... 1 """ - metrics = MetricsContext() + if async_queue_depth < 1: + raise ValueError("queue depth must be positive") def base_iterator(timeout=None): + metrics = LocalIterator.get_metrics() all_actors = [] for actor_set in self.actor_sets: actor_set.init_actors() all_actors.extend(actor_set.actors) futures = {} - for a in all_actors: - futures[a.par_iter_next.remote()] = a + for _ in range(async_queue_depth): + for a in all_actors: + futures[a.par_iter_next.remote()] = a while futures: pending = list(futures) if timeout is None: @@ -455,7 +463,7 @@ class ParallelIterator(Generic[T]): for obj_id in ready: actor = futures.pop(obj_id) try: - metrics.cur_actor = actor + metrics.current_actor = actor yield ray.get(obj_id) futures[actor.par_iter_next.remote()] = actor except StopIteration: @@ -465,7 +473,7 @@ class ParallelIterator(Generic[T]): yield _NextValueNotReady() name = "{}.gather_async()".format(self) - return LocalIterator(base_iterator, metrics, name=name) + return LocalIterator(base_iterator, MetricsContext(), name=name) def take(self, n: int) -> List[T]: """Return up to the first n items from this iterator.""" @@ -638,7 +646,13 @@ class LocalIterator(Generic[T]): if isinstance(item, _NextValueNotReady): yield item else: - yield fn(item) + # Keep retrying the function until it returns a valid + # value. This allows for non-blocking functions. + while True: + result = fn(item) + yield result + if not isinstance(result, _NextValueNotReady): + break if hasattr(fn, LocalIterator.ON_FETCH_START_HOOK_NAME): unwrapped = apply_foreach @@ -758,6 +772,17 @@ class LocalIterator(Generic[T]): it.name = self.name + ".combine()" return it + def zip_with_source_actor(self): + def zip_with_source(item): + metrics = LocalIterator.get_metrics() + if metrics.current_actor is None: + raise ValueError("Could not identify source actor of item") + return metrics.current_actor, item + + it = self.for_each(zip_with_source) + it.name = self.name + ".zip_with_source_actor()" + return it + def take(self, n: int) -> List[T]: """Return up to the first n items from this iterator.""" out = [] diff --git a/rllib/BUILD b/rllib/BUILD index a2d07ba0d..9b89e5672 100644 --- a/rllib/BUILD +++ b/rllib/BUILD @@ -1060,10 +1060,10 @@ py_test( ) py_test( - name = "tests/test_pipeline", - tags = ["tests_dir", "tests_dir_P"], + name = "tests/test_exec_api", + tags = ["tests_dir", "tests_dir_E"], size = "small", - srcs = ["tests/test_pipeline.py"] + srcs = ["tests/test_exec_api.py"] ) py_test( diff --git a/rllib/agents/a3c/a2c.py b/rllib/agents/a3c/a2c.py index eb002d92c..89e61252c 100644 --- a/rllib/agents/a3c/a2c.py +++ b/rllib/agents/a3c/a2c.py @@ -37,8 +37,8 @@ def choose_policy_optimizer(workers, config): workers, train_batch_size=config["train_batch_size"]) -# Experimental pipeline-based impl; enable with "use_pipeline_impl": True. -def training_pipeline(workers, config): +# Experimental distributed execution impl; enable with "use_exec_api": True. +def execution_plan(workers, config): rollouts = ParallelRollouts(workers, mode="bulk_sync") if config["microbatch_size"]: @@ -72,4 +72,4 @@ A2CTrainer = build_trainer( get_policy_class=get_policy_class, make_policy_optimizer=choose_policy_optimizer, validate_config=validate_config, - training_pipeline=training_pipeline) + execution_plan=execution_plan) diff --git a/rllib/agents/a3c/a3c.py b/rllib/agents/a3c/a3c.py index 36339f11e..ec94b4701 100644 --- a/rllib/agents/a3c/a3c.py +++ b/rllib/agents/a3c/a3c.py @@ -65,8 +65,8 @@ def make_async_optimizer(workers, config): return AsyncGradientsOptimizer(workers, **config["optimizer"]) -# Experimental pipeline-based impl; enable with "use_pipeline_impl": True. -def training_pipeline(workers, config): +# Experimental distributed execution impl; enable with "use_exec_api": True. +def execution_plan(workers, config): # For A3C, compute policy gradients remotely on the rollout workers. grads = AsyncGradients(workers) @@ -84,4 +84,4 @@ A3CTrainer = build_trainer( get_policy_class=get_policy_class, validate_config=validate_config, make_policy_optimizer=make_async_optimizer, - training_pipeline=training_pipeline) + execution_plan=execution_plan) diff --git a/rllib/agents/a3c/a3c_pipeline.py b/rllib/agents/a3c/a3c_pipeline.py deleted file mode 100644 index 359c729a8..000000000 --- a/rllib/agents/a3c/a3c_pipeline.py +++ /dev/null @@ -1,19 +0,0 @@ -"""Experimental pipeline-based impl; run this with --run='A3C_pl'""" - -from ray.rllib.agents.a3c.a3c import A3CTrainer -from ray.rllib.utils.experimental_dsl import (AsyncGradients, ApplyGradients, - StandardMetricsReporting) - - -def training_pipeline(workers, config): - # For A3C, compute policy gradients remotely on the rollout workers. - grads = AsyncGradients(workers) - - # Apply the gradients as they arrive. We set update_all to False so that - # only the worker sending the gradient is updated with new weights. - train_op = grads.for_each(ApplyGradients(workers, update_all=False)) - - return StandardMetricsReporting(train_op, workers, config) - - -A3CPipeline = A3CTrainer.with_updates(training_pipeline=training_pipeline) diff --git a/rllib/agents/a3c/tests/test_a2c.py b/rllib/agents/a3c/tests/test_a2c.py index 19ff7bf89..db055cc2d 100644 --- a/rllib/agents/a3c/tests/test_a2c.py +++ b/rllib/agents/a3c/tests/test_a2c.py @@ -5,7 +5,7 @@ from ray.rllib.agents.a3c import A2CTrainer class TestA2C(unittest.TestCase): - """Sanity tests for A2C pipeline.""" + """Sanity tests for A2C exec impl.""" def setUp(self): ray.init() @@ -13,22 +13,22 @@ class TestA2C(unittest.TestCase): def tearDown(self): ray.shutdown() - def test_a2c_pipeline(ray_start_regular): + def test_a2c_exec_impl(ray_start_regular): trainer = A2CTrainer( env="CartPole-v0", config={ "min_iter_time_s": 0, - "use_pipeline_impl": True + "use_exec_api": True }) assert isinstance(trainer.train(), dict) - def test_a2c_pipeline_microbatch(ray_start_regular): + def test_a2c_exec_impl_microbatch(ray_start_regular): trainer = A2CTrainer( env="CartPole-v0", config={ "min_iter_time_s": 0, "microbatch_size": 10, - "use_pipeline_impl": True, + "use_exec_api": True, }) assert isinstance(trainer.train(), dict) diff --git a/rllib/agents/dqn/apex.py b/rllib/agents/dqn/apex.py index fc5c3e677..dffae0614 100644 --- a/rllib/agents/dqn/apex.py +++ b/rllib/agents/dqn/apex.py @@ -1,6 +1,17 @@ +import collections + +import ray from ray.rllib.agents.dqn.dqn import DQNTrainer, DEFAULT_CONFIG as DQN_CONFIG from ray.rllib.optimizers import AsyncReplayOptimizer +from ray.rllib.optimizers.async_replay_optimizer import ReplayActor from ray.rllib.utils import merge_dicts +from ray.rllib.utils.actors import create_colocated +from ray.rllib.utils.experimental_dsl import ( + ParallelRollouts, Concurrently, ParallelReplay, StandardMetricsReporting, + StoreToReplayActors, UpdateTargetNetwork, Enqueue, Dequeue, + STEPS_TRAINED_COUNTER) +from ray.rllib.optimizers.async_replay_optimizer import LearnerThread +from ray.util.iter import LocalIterator # yapf: disable # __sphinx_doc_begin__ @@ -70,6 +81,93 @@ def update_target_based_on_num_steps_trained(trainer, fetches): trainer.state["num_target_updates"] += 1 +# Experimental distributed execution impl; enable with "use_exec_api": True. +def execution_plan(workers, config): + # Create a number of replay buffer actors. + # TODO(ekl) support batch replay options + num_replay_buffer_shards = config["optimizer"]["num_replay_buffer_shards"] + replay_actors = create_colocated(ReplayActor, [ + num_replay_buffer_shards, + config["learning_starts"], + config["buffer_size"], + config["train_batch_size"], + config["prioritized_replay_alpha"], + config["prioritized_replay_beta"], + config["prioritized_replay_eps"], + ], num_replay_buffer_shards) + + # Update experience priorities post learning. + def update_prio_and_stats(item): + actor, prio_dict, count = item + actor.update_priorities.remote(prio_dict) + metrics = LocalIterator.get_metrics() + metrics.counters[STEPS_TRAINED_COUNTER] += count + metrics.timers["learner_dequeue"] = learner_thread.queue_timer + metrics.timers["learner_grad"] = learner_thread.grad_timer + metrics.timers["learner_overall"] = learner_thread.overall_timer + + # Update worker weights as they finish generating experiences. + class UpdateWorkerWeights: + def __init__(self, learner_thread, workers, max_weight_sync_delay): + self.learner_thread = learner_thread + self.workers = workers + self.steps_since_update = collections.defaultdict(int) + self.max_weight_sync_delay = max_weight_sync_delay + self.weights = None + + def __call__(self, item): + actor, batch = item + self.steps_since_update[actor] += batch.count + if self.steps_since_update[actor] >= self.max_weight_sync_delay: + # Note that it's important to pull new weights once + # updated to avoid excessive correlation between actors. + if self.weights is None or self.learner_thread.weights_updated: + self.learner_thread.weights_updated = False + self.weights = ray.put( + self.workers.local_worker().get_weights()) + actor.set_weights.remote(self.weights) + self.steps_since_update[actor] = 0 + # Update metrics. + metrics = LocalIterator.get_metrics() + metrics.counters["num_weight_syncs"] += 1 + + # Start the learner thread. + learner_thread = LearnerThread(workers.local_worker()) + learner_thread.start() + + # We execute the following steps concurrently: + # (1) Generate rollouts and store them in our replay buffer actors. Update + # the weights of the worker that generated the batch. + rollouts = ParallelRollouts(workers, mode="async", async_queue_depth=2) + store_op = rollouts \ + .for_each(StoreToReplayActors(replay_actors)) \ + .zip_with_source_actor() \ + .for_each(UpdateWorkerWeights( + learner_thread, workers, + max_weight_sync_delay=config["optimizer"]["max_weight_sync_delay"]) + ) + + # (2) Read experiences from the replay buffer actors and send to the + # learner thread via its in-queue. + replay_op = ParallelReplay(replay_actors, async_queue_depth=4) \ + .zip_with_source_actor() \ + .for_each(Enqueue(learner_thread.inqueue)) + + # (3) Get priorities back from learner thread and apply them to the + # replay buffer actors. + update_op = Dequeue( + learner_thread.outqueue, check=learner_thread.is_alive) \ + .for_each(update_prio_and_stats) \ + .for_each(UpdateTargetNetwork( + workers, config["target_network_update_freq"], + by_steps_trained=True)) + + # Execute (1), (2), (3) asynchronously as fast as possible. + merged_op = Concurrently([store_op, replay_op, update_op], mode="async") + + return StandardMetricsReporting(merged_op, workers, config) + + APEX_TRAINER_PROPERTIES = { "make_workers": defer_make_workers, "make_policy_optimizer": make_async_optimizer, @@ -77,4 +175,7 @@ APEX_TRAINER_PROPERTIES = { } ApexTrainer = DQNTrainer.with_updates( - name="APEX", default_config=APEX_DEFAULT_CONFIG, **APEX_TRAINER_PROPERTIES) + name="APEX", + default_config=APEX_DEFAULT_CONFIG, + execution_plan=execution_plan, + **APEX_TRAINER_PROPERTIES) diff --git a/rllib/agents/dqn/dqn.py b/rllib/agents/dqn/dqn.py index 8a98a6cb1..ca8bc43b9 100644 --- a/rllib/agents/dqn/dqn.py +++ b/rllib/agents/dqn/dqn.py @@ -312,8 +312,8 @@ def update_target_if_needed(trainer, fetches): trainer.state["num_target_updates"] += 1 -# Experimental pipeline-based impl; enable with "use_pipeline_impl": True. -def training_pipeline(workers, config): +# Experimental distributed execution impl; enable with "use_exec_api": True. +def execution_plan(workers, config): local_replay_buffer = ReplayBuffer(config["buffer_size"]) rollouts = ParallelRollouts(workers, mode="bulk_sync") @@ -346,7 +346,7 @@ GenericOffPolicyTrainer = build_trainer( before_train_step=update_worker_exploration, after_optimizer_step=update_target_if_needed, after_train_result=after_train_result, - training_pipeline=training_pipeline) + execution_plan=execution_plan) DQNTrainer = GenericOffPolicyTrainer.with_updates( name="DQN", default_policy=DQNTFPolicy, default_config=DEFAULT_CONFIG) diff --git a/rllib/agents/pg/pg.py b/rllib/agents/pg/pg.py index 6a495f024..d9790ea01 100644 --- a/rllib/agents/pg/pg.py +++ b/rllib/agents/pg/pg.py @@ -24,8 +24,8 @@ def get_policy_class(config): return PGTFPolicy -# Experimental pipeline-based impl; enable with "use_pipeline_impl": True. -def training_pipeline(workers, config): +# Experimental distributed execution impl; enable with "use_exec_api": True. +def execution_plan(workers, config): # Collects experiences in parallel from multiple RolloutWorker actors. rollouts = ParallelRollouts(workers, mode="bulk_sync") @@ -46,4 +46,4 @@ PGTrainer = build_trainer( default_config=DEFAULT_CONFIG, default_policy=PGTFPolicy, get_policy_class=get_policy_class, - training_pipeline=training_pipeline) + execution_plan=execution_plan) diff --git a/rllib/agents/pg/tests/test_pg.py b/rllib/agents/pg/tests/test_pg.py index 5ada2e144..e0558132c 100644 --- a/rllib/agents/pg/tests/test_pg.py +++ b/rllib/agents/pg/tests/test_pg.py @@ -18,12 +18,12 @@ class TestPG(unittest.TestCase): def tearDown(self): ray.shutdown() - def test_pg_pipeline(ray_start_regular): + def test_pg_exec_impl(ray_start_regular): trainer = PGTrainer( env="CartPole-v0", config={ "min_iter_time_s": 0, - "use_pipeline_impl": True + "use_exec_api": True }) assert isinstance(trainer.train(), dict) diff --git a/rllib/agents/trainer.py b/rllib/agents/trainer.py index f3baf2d2e..a49814c89 100644 --- a/rllib/agents/trainer.py +++ b/rllib/agents/trainer.py @@ -213,9 +213,9 @@ COMMON_CONFIG = { # trainer guarantees all eval workers have the latest policy state before # this function is called. "custom_eval_function": None, - # EXPERIMENTAL: use the pipeline based implementation of the algo. Can also - # be enabled by setting RLLIB_USE_PIPELINE_IMPL=1. - "use_pipeline_impl": False, + # EXPERIMENTAL: use the execution plan based API impl of the algo. Can also + # be enabled by setting RLLIB_EXEC_API=1. + "use_exec_api": False, # === Advanced Rollout Settings === # Use a background thread for sampling (slightly off-policy, usually not diff --git a/rllib/agents/trainer_template.py b/rllib/agents/trainer_template.py index 87e2a14a6..ddf4d9afc 100644 --- a/rllib/agents/trainer_template.py +++ b/rllib/agents/trainer_template.py @@ -27,7 +27,7 @@ def build_trainer(name, collect_metrics_fn=None, before_evaluate_fn=None, mixins=None, - training_pipeline=None): + execution_plan=None): """Helper function for defining a custom trainer. Functions will be run in this order to initialize the trainer: @@ -74,8 +74,8 @@ def build_trainer(name, mixins (Optional[List[class]]): Optional list of mixin class(es) for the returned trainer class. These mixins will be applied in order and will have higher precedence than the Trainer class. - training_pipeline (Optional[callable]): Experimental support for custom - training pipelines. This overrides `make_policy_optimizer`. + execution_plan (Optional[callable]): Experimental distributed execution + API. This overrides `make_policy_optimizer`. Returns: a Trainer instance that uses the specified args. @@ -107,22 +107,24 @@ def build_trainer(name, if before_init: before_init(self) + use_exec_api = (execution_plan + and (self.config["use_exec_api"] + or "RLLIB_EXEC_API" in os.environ)) # Creating all workers (excluding evaluation workers). - if make_workers: + if make_workers and not use_exec_api: self.workers = make_workers(self, env_creator, self._policy, config) else: self.workers = self._make_workers(env_creator, self._policy, config, self.config["num_workers"]) - self.train_pipeline = None + self.train_exec_impl = None self.optimizer = None - if training_pipeline and (self.config["use_pipeline_impl"] or - "RLLIB_USE_PIPELINE_IMPL" in os.environ): - logger.warning("Using experimental pipeline based impl.") - self.train_pipeline = training_pipeline(self.workers, config) + if use_exec_api: + logger.warning("Using experimental execution plan impl.") + self.train_exec_impl = execution_plan(self.workers, config) elif make_policy_optimizer: self.optimizer = make_policy_optimizer(self.workers, config) else: @@ -136,8 +138,8 @@ def build_trainer(name, @override(Trainer) def _train(self): - if self.train_pipeline: - return self._train_pipeline() + if self.train_exec_impl: + return self._train_exec_impl() if before_train_step: before_train_step(self) @@ -166,10 +168,10 @@ def build_trainer(name, after_train_result(self, res) return res - def _train_pipeline(self): + def _train_exec_impl(self): if before_train_step: logger.warning("Ignoring before_train_step callback") - res = next(self.train_pipeline) + res = next(self.train_exec_impl) if after_train_result: logger.warning("Ignoring after_train_result callback") return res @@ -182,15 +184,15 @@ def build_trainer(name, def __getstate__(self): state = Trainer.__getstate__(self) state["trainer_state"] = self.state.copy() - if self.train_pipeline: - state["train_pipeline"] = self.train_pipeline.metrics.save() + if self.train_exec_impl: + state["train_exec_impl"] = self.train_exec_impl.metrics.save() return state def __setstate__(self, state): Trainer.__setstate__(self, state) self.state = state["trainer_state"].copy() - if self.train_pipeline: - self.train_pipeline.metrics.restore(state["train_pipeline"]) + if self.train_exec_impl: + self.train_exec_impl.metrics.restore(state["train_exec_impl"]) def with_updates(**overrides): """Build a copy of this trainer with the specified overrides. diff --git a/rllib/optimizers/async_replay_optimizer.py b/rllib/optimizers/async_replay_optimizer.py index 59a5ade1e..c4cac0e3f 100644 --- a/rllib/optimizers/async_replay_optimizer.py +++ b/rllib/optimizers/async_replay_optimizer.py @@ -13,6 +13,7 @@ import time import ray from ray.exceptions import RayError +from ray.util.iter import ParallelIteratorWorker from ray.rllib.evaluation.metrics import get_learner_stats from ray.rllib.policy.sample_batch import SampleBatch, DEFAULT_POLICY_ID, \ MultiAgentBatch @@ -283,7 +284,7 @@ class AsyncReplayOptimizer(PolicyOptimizer): @ray.remote(num_cpus=0) -class ReplayActor: +class ReplayActor(ParallelIteratorWorker): """A replay buffer shard. Ray actors are single-threaded, so for scalability multiple replay actors @@ -298,6 +299,12 @@ class ReplayActor: self.prioritized_replay_beta = prioritized_replay_beta self.prioritized_replay_eps = prioritized_replay_eps + def gen_replay(): + while True: + yield self.replay() + + ParallelIteratorWorker.__init__(self, gen_replay, False) + def new_buffer(): return PrioritizedReplayBuffer( self.buffer_size, alpha=prioritized_replay_alpha) @@ -435,6 +442,7 @@ class LearnerThread(threading.Thread): self.outqueue = queue.Queue() self.queue_timer = TimerStat() self.grad_timer = TimerStat() + self.overall_timer = TimerStat() self.daemon = True self.weights_updated = False self.stopped = False @@ -445,17 +453,20 @@ class LearnerThread(threading.Thread): self.step() def step(self): - with self.queue_timer: - ra, replay = self.inqueue.get() - if replay is not None: - prio_dict = {} - with self.grad_timer: - grad_out = self.local_worker.learn_on_batch(replay) - for pid, info in grad_out.items(): - prio_dict[pid] = ( - replay.policy_batches[pid].data.get("batch_indexes"), - info.get("td_error")) - self.stats[pid] = get_learner_stats(info) - self.outqueue.put((ra, prio_dict, replay.count)) - self.learner_queue_size.push(self.inqueue.qsize()) - self.weights_updated = True + with self.overall_timer: + with self.queue_timer: + ra, replay = self.inqueue.get() + if replay is not None: + prio_dict = {} + with self.grad_timer: + grad_out = self.local_worker.learn_on_batch(replay) + for pid, info in grad_out.items(): + prio_dict[pid] = (replay.policy_batches[pid].data.get( + "batch_indexes"), info.get("td_error")) + self.stats[pid] = get_learner_stats(info) + self.grad_timer.push_units_processed(replay.count) + self.outqueue.put((ra, prio_dict, replay.count)) + self.learner_queue_size.push(self.inqueue.qsize()) + self.weights_updated = True + self.overall_timer.push_units_processed(replay and replay.count + or 0) diff --git a/rllib/tests/test_pipeline.py b/rllib/tests/test_exec_api.py similarity index 76% rename from rllib/tests/test_pipeline.py rename to rllib/tests/test_exec_api.py index 68c445199..971ca17af 100644 --- a/rllib/tests/test_pipeline.py +++ b/rllib/tests/test_exec_api.py @@ -4,8 +4,8 @@ import ray from ray.rllib.agents.a3c import A2CTrainer -class TestPipeline(unittest.TestCase): - """General tests for the pipeline API.""" +class TestDistributedExecution(unittest.TestCase): + """General tests for the distributed execution API.""" @classmethod def setUpClass(cls): @@ -15,12 +15,12 @@ class TestPipeline(unittest.TestCase): def tearDownClass(cls): ray.shutdown() - def test_pipeline_stats(ray_start_regular): + def test_exec_plan_stats(ray_start_regular): trainer = A2CTrainer( env="CartPole-v0", config={ "min_iter_time_s": 0, - "use_pipeline_impl": True + "use_exec_api": True }) result = trainer.train() assert isinstance(result, dict) @@ -35,22 +35,23 @@ class TestPipeline(unittest.TestCase): assert "sample_throughput" in result["timers"] assert "update_time_ms" in result["timers"] - def test_pipeline_save_restore(ray_start_regular): + def test_exec_plan_save_restore(ray_start_regular): trainer = A2CTrainer( env="CartPole-v0", config={ "min_iter_time_s": 0, - "use_pipeline_impl": True + "use_exec_api": True }) res1 = trainer.train() checkpoint = trainer.save() - res2 = trainer.train() + for _ in range(2): + res2 = trainer.train() assert res2["timesteps_total"] > res1["timesteps_total"], (res1, res2) trainer.restore(checkpoint) # Should restore the timesteps counter to the same as res2. res3 = trainer.train() - assert res3["timesteps_total"] == res2["timesteps_total"], (res2, res3) + assert res3["timesteps_total"] < res2["timesteps_total"], (res2, res3) if __name__ == "__main__": diff --git a/rllib/utils/experimental_dsl.py b/rllib/utils/experimental_dsl.py index f2e581798..fa70d182c 100644 --- a/rllib/utils/experimental_dsl.py +++ b/rllib/utils/experimental_dsl.py @@ -1,16 +1,19 @@ -"""Experimental operators for defining distributed training pipelines. +"""Experimental distributed execution API. TODO(ekl): describe the concepts.""" import logging from typing import List, Any, Tuple, Union import numpy as np +import queue +import random import time import ray -from ray.util.iter import from_actors, LocalIterator +from ray.util.iter import from_actors, LocalIterator, _NextValueNotReady from ray.util.iter_metrics import MetricsContext -from ray.rllib.optimizers.replay_buffer import PrioritizedReplayBuffer +from ray.rllib.optimizers.replay_buffer import PrioritizedReplayBuffer, \ + ReplayBuffer from ray.rllib.evaluation.metrics import collect_episodes, \ summarize_episodes, get_learner_stats from ray.rllib.evaluation.rollout_worker import get_global_worker @@ -58,8 +61,8 @@ def _get_global_vars(): return {"timestep": metrics.counters[STEPS_SAMPLED_COUNTER]} -def ParallelRollouts(workers: WorkerSet, - mode="bulk_sync") -> LocalIterator[SampleBatch]: +def ParallelRollouts(workers: WorkerSet, mode="bulk_sync", + async_queue_depth=1) -> LocalIterator[SampleBatch]: """Operator to collect experiences in parallel from rollout workers. If there are no remote workers, experiences will be collected serially from @@ -72,6 +75,8 @@ def ParallelRollouts(workers: WorkerSet, computed by rollout workers with no order guarantees. - In 'bulk_sync' mode, we collect one batch from each worker and concatenate them together into a large batch to return. + async_queue_depth (int): In async mode, the max number of async + requests in flight per actor. Returns: A local iterator over experiences collected in parallel. @@ -116,7 +121,8 @@ def ParallelRollouts(workers: WorkerSet, .for_each(lambda batches: SampleBatch.concat_samples(batches)) \ .for_each(report_timesteps) elif mode == "async": - return rollouts.gather_async().for_each(report_timesteps) + return rollouts.gather_async( + async_queue_depth=async_queue_depth).for_each(report_timesteps) else: raise ValueError( "mode must be one of 'bulk_sync', 'async', got '{}'".format(mode)) @@ -437,14 +443,15 @@ class ApplyGradients: for e in self.workers.remote_workers(): e.set_weights.remote(weights, _get_global_vars()) else: - if metrics.cur_actor is None: - raise ValueError("Could not find actor to update. When " - "update_all=False, `cur_actor` must be set " - "in the iterator context.") + if metrics.current_actor is None: + raise ValueError( + "Could not find actor to update. When " + "update_all=False, `current_actor` must be set " + "in the iterator context.") with metrics.timers[WORKER_UPDATE_TIMER]: weights = self.workers.local_worker().get_weights() - metrics.cur_actor.set_weights.remote(weights, - _get_global_vars()) + metrics.current_actor.set_weights.remote( + weights, _get_global_vars()) class AverageGradients: @@ -475,7 +482,21 @@ class AverageGradients: class StoreToReplayBuffer: - def __init__(self, replay_buffer): + """Callable that stores data into a local replay buffer. + + This should be used with the .for_each() operator on a rollouts iterator. + The batch that was stored is returned. + + Examples: + >>> buf = ReplayBuffer(1000) + >>> rollouts = ParallelRollouts(...) + >>> store_op = rollouts.for_each(StoreToReplayBuffer(buf)) + >>> next(store_op) + SampleBatch(...) + """ + + def __init__(self, replay_buffer: ReplayBuffer): + assert isinstance(replay_buffer, ReplayBuffer) self.replay_buffers = {DEFAULT_POLICY_ID: replay_buffer} def __call__(self, batch: SampleBatchType): @@ -492,11 +513,73 @@ class StoreToReplayBuffer: pack_if_needed(row["new_obs"]), row["dones"], weight=None) + return batch -def LocalReplay(replay_buffer, train_batch_size): +class StoreToReplayActors: + """Callable that stores data into a replay buffer actors. + + This should be used with the .for_each() operator on a rollouts iterator. + The batch that was stored is returned. + + Examples: + >>> actors = [ReplayActor.remote() for _ in range(4)] + >>> rollouts = ParallelRollouts(...) + >>> store_op = rollouts.for_each(StoreToReplayActors(actors)) + >>> next(store_op) + SampleBatch(...) + """ + + def __init__(self, replay_actors: List["ActorHandle"]): + self.replay_actors = replay_actors + + def __call__(self, batch: SampleBatchType): + actor = random.choice(self.replay_actors) + actor.add_batch.remote(batch) + return batch + + +def ParallelReplay(replay_actors: List["ActorHandle"], async_queue_depth=4): + """Replay experiences in parallel from the given actors. + + This should be combined with the StoreToReplayActors operation using the + Concurrently() operator. + + Arguments: + replay_actors (list): List of replay actors. + async_queue_depth (int): In async mode, the max number of async + requests in flight per actor. + + Examples: + >>> actors = [ReplayActor.remote() for _ in range(4)] + >>> replay_op = ParallelReplay(actors) + >>> next(replay_op) + SampleBatch(...) + """ + replay = from_actors(replay_actors) + return replay.gather_async( + async_queue_depth=async_queue_depth).filter(lambda x: x is not None) + + +def LocalReplay(replay_buffer: ReplayBuffer, train_batch_size: int): + """Replay experiences from a local buffer instance. + + This should be combined with the StoreToReplayBuffer operation using the + Concurrently() operator. + + Arguments: + replay_buffer (ReplayBuffer): Buffer to replay experiences from. + train_batch_size (int): Batch size of fetches from the buffer. + + Examples: + >>> actors = [ReplayActor.remote() for _ in range(4)] + >>> replay_op = ParallelReplay(actors) + >>> next(replay_op) + SampleBatch(...) + """ + assert isinstance(replay_buffer, ReplayBuffer) replay_buffers = {DEFAULT_POLICY_ID: replay_buffer} - # TODO(ekl) support more options + # TODO(ekl) support more options, or combine with ParallelReplay (?) synchronize_sampling = False prioritized_replay_beta = None @@ -581,16 +664,83 @@ class UpdateTargetNetwork: track when we should update the target next. """ - def __init__(self, workers, target_update_freq): + def __init__(self, workers, target_update_freq, by_steps_trained=False): self.workers = workers self.target_update_freq = target_update_freq + if by_steps_trained: + self.metric = STEPS_TRAINED_COUNTER + else: + self.metric = STEPS_SAMPLED_COUNTER def __call__(self, _): metrics = LocalIterator.get_metrics() - cur_ts = metrics.counters[STEPS_SAMPLED_COUNTER] + cur_ts = metrics.counters[self.metric] last_update = metrics.counters[LAST_TARGET_UPDATE_TS] if cur_ts - last_update > self.target_update_freq: self.workers.local_worker().foreach_trainable_policy( lambda p, _: p.update_target()) metrics.counters[NUM_TARGET_UPDATES] += 1 metrics.counters[LAST_TARGET_UPDATE_TS] = cur_ts + + +class Enqueue: + """Enqueue data items into a queue.Queue instance. + + The enqueue is non-blocking, so Enqueue operations can executed with + Dequeue via the Concurrently() operator. + + Examples: + >>> queue = queue.Queue(100) + >>> write_op = ParallelRollouts(...).for_each(Enqueue(queue)) + >>> read_op = Dequeue(queue) + >>> combined_op = Concurrently([write_op, read_op], mode="async") + >>> next(combined_op) + SampleBatch(...) + """ + + def __init__(self, output_queue: queue.Queue): + if not isinstance(output_queue, queue.Queue): + raise ValueError("Expected queue.Queue, got {}".format( + type(output_queue))) + self.queue = output_queue + + def __call__(self, x): + try: + self.queue.put_nowait(x) + except queue.Full: + return _NextValueNotReady() + + +def Dequeue(input_queue: queue.Queue, check=lambda: True): + """Dequeue data items from a queue.Queue instance. + + The dequeue is non-blocking, so Dequeue operations can executed with + Enqueue via the Concurrently() operator. + + Arguments: + input_queue (Queue): queue to pull items from. + check (fn): liveness check. When this function returns false, + Dequeue() will raise an error to halt execution. + + Examples: + >>> queue = queue.Queue(100) + >>> write_op = ParallelRollouts(...).for_each(Enqueue(queue)) + >>> read_op = Dequeue(queue) + >>> combined_op = Concurrently([write_op, read_op], mode="async") + >>> next(combined_op) + SampleBatch(...) + """ + if not isinstance(input_queue, queue.Queue): + raise ValueError("Expected queue.Queue, got {}".format( + type(input_queue))) + + def base_iterator(timeout=None): + while check(): + try: + item = input_queue.get_nowait() + yield item + except queue.Empty: + yield _NextValueNotReady() + raise RuntimeError("Error raised reading from queue") + + return LocalIterator(base_iterator, MetricsContext())