From 9f04a65922e601e7bf444616e6262e67a8a06647 Mon Sep 17 00:00:00 2001 From: Eric Liang Date: Thu, 7 May 2020 23:40:29 -0700 Subject: [PATCH] [rllib] Add PPO+DQN two trainer multiagent workflow example (#8334) --- python/ray/util/iter.py | 37 +++---- rllib/BUILD | 8 ++ rllib/agents/dqn/apex.py | 69 ++++++------- rllib/examples/two_trainer_workflow.py | 132 +++++++++++++++++++++++++ rllib/execution/common.py | 7 ++ rllib/execution/rollout_ops.py | 4 +- rllib/execution/train_ops.py | 52 ++++++---- rllib/tests/test_execution.py | 2 +- 8 files changed, 241 insertions(+), 70 deletions(-) create mode 100644 rllib/examples/two_trainer_workflow.py diff --git a/python/ray/util/iter.py b/python/ray/util/iter.py index 8d97ff31b..fe00034b1 100644 --- a/python/ray/util/iter.py +++ b/python/ray/util/iter.py @@ -870,8 +870,12 @@ class LocalIterator(Generic[T]): def duplicate(self, n) -> List["LocalIterator[T]"]: """Copy this iterator `n` times, duplicating the data. + The child iterators will be prioritized by how much of the parent + stream they have consumed. That is, we will not allow children to fall + behind, since that can cause infinite memory buildup in this operator. + Returns: - List[LocalIterator[T]]: multiple iterators that each have a copy + List[LocalIterator[T]]: child iterators that each have a copy of the data of this iterator. """ @@ -891,9 +895,16 @@ class LocalIterator(Generic[T]): def make_next(i): def gen(timeout): while True: - if len(queues[i]) == 0: - fill_next(timeout) - yield queues[i].popleft() + my_len = len(queues[i]) + max_len = max(len(q) for q in queues) + # Yield to let other iterators that have fallen behind + # process more items. + if my_len < max_len: + yield _NextValueNotReady() + else: + if len(queues[i]) == 0: + fill_next(timeout) + yield queues[i].popleft() return gen @@ -939,21 +950,13 @@ class LocalIterator(Generic[T]): def build_union(timeout=None): while True: for it in list(active): - # Yield items from the iterator until _NextValueNotReady is - # found, then switch to the next iterator. - # To avoid starvation, we yield at most max_yield items per - # iterator before switching. - if deterministic: - max_yield = 1 # Forces round robin. - else: - max_yield = 20 try: - for _ in range(max_yield): - item = next(it) - if isinstance(item, _NextValueNotReady): - break - else: + item = next(it) + if isinstance(item, _NextValueNotReady): + if timeout is not None: yield item + else: + yield item except StopIteration: active.remove(it) if not active: diff --git a/rllib/BUILD b/rllib/BUILD index 3e0d895f7..027ae841a 100644 --- a/rllib/BUILD +++ b/rllib/BUILD @@ -1454,6 +1454,14 @@ py_test( args = ["--num-iters=2"] ) +py_test( + name = "examples/two_trainer_workflow", + tags = ["examples", "examples_T"], + size = "medium", + srcs = ["examples/two_trainer_workflow.py"], + args = ["--num-iters=2"] +) + py_test( name = "examples/nested_action_spaces_ppo", main = "examples/nested_action_spaces.py", diff --git a/rllib/agents/dqn/apex.py b/rllib/agents/dqn/apex.py index 13a038afa..263319fac 100644 --- a/rllib/agents/dqn/apex.py +++ b/rllib/agents/dqn/apex.py @@ -3,7 +3,9 @@ import copy import ray from ray.rllib.agents.dqn.dqn import DQNTrainer, DEFAULT_CONFIG as DQN_CONFIG -from ray.rllib.execution.common import STEPS_TRAINED_COUNTER +from ray.rllib.execution.common import STEPS_TRAINED_COUNTER, \ + SampleBatchType, _get_shared_metrics +from ray.rllib.evaluation.worker_set import WorkerSet from ray.rllib.execution.rollout_ops import ParallelRollouts from ray.rllib.execution.concurrency_ops import Concurrently, Enqueue, Dequeue from ray.rllib.execution.replay_ops import StoreToReplayBuffer, Replay @@ -84,8 +86,34 @@ def update_target_based_on_num_steps_trained(trainer, fetches): trainer.state["num_target_updates"] += 1 +# Update worker weights as they finish generating experiences. +class UpdateWorkerWeights: + def __init__(self, learner_thread, workers, max_weight_sync_delay): + self.learner_thread = learner_thread + self.workers = workers + self.steps_since_update = collections.defaultdict(int) + self.max_weight_sync_delay = max_weight_sync_delay + self.weights = None + + def __call__(self, item: ("ActorHandle", SampleBatchType)): + actor, batch = item + self.steps_since_update[actor] += batch.count + if self.steps_since_update[actor] >= self.max_weight_sync_delay: + # Note that it's important to pull new weights once + # updated to avoid excessive correlation between actors. + if self.weights is None or self.learner_thread.weights_updated: + self.learner_thread.weights_updated = False + self.weights = ray.put( + self.workers.local_worker().get_weights()) + actor.set_weights.remote(self.weights) + self.steps_since_update[actor] = 0 + # Update metrics. + metrics = LocalIterator.get_metrics() + metrics.counters["num_weight_syncs"] += 1 + + # Experimental distributed execution impl; enable with "use_exec_api": True. -def execution_plan(workers, config): +def execution_plan(workers: WorkerSet, config: dict): # Create a number of replay buffer actors. # TODO(ekl) support batch replay options num_replay_buffer_shards = config["optimizer"]["num_replay_buffer_shards"] @@ -99,11 +127,15 @@ def execution_plan(workers, config): config["prioritized_replay_eps"], ], num_replay_buffer_shards) + # Start the learner thread. + learner_thread = LearnerThread(workers.local_worker()) + learner_thread.start() + # Update experience priorities post learning. - def update_prio_and_stats(item): + def update_prio_and_stats(item: ("ActorHandle", dict, int)): actor, prio_dict, count = item actor.update_priorities.remote(prio_dict) - metrics = LocalIterator.get_metrics() + metrics = _get_shared_metrics() # Manually update the steps trained counter since the learner thread # is executing outside the pipeline. metrics.counters[STEPS_TRAINED_COUNTER] += count @@ -111,35 +143,6 @@ def execution_plan(workers, config): metrics.timers["learner_grad"] = learner_thread.grad_timer metrics.timers["learner_overall"] = learner_thread.overall_timer - # Update worker weights as they finish generating experiences. - class UpdateWorkerWeights: - def __init__(self, learner_thread, workers, max_weight_sync_delay): - self.learner_thread = learner_thread - self.workers = workers - self.steps_since_update = collections.defaultdict(int) - self.max_weight_sync_delay = max_weight_sync_delay - self.weights = None - - def __call__(self, item): - actor, batch = item - self.steps_since_update[actor] += batch.count - if self.steps_since_update[actor] >= self.max_weight_sync_delay: - # Note that it's important to pull new weights once - # updated to avoid excessive correlation between actors. - if self.weights is None or self.learner_thread.weights_updated: - self.learner_thread.weights_updated = False - self.weights = ray.put( - self.workers.local_worker().get_weights()) - actor.set_weights.remote(self.weights) - self.steps_since_update[actor] = 0 - # Update metrics. - metrics = LocalIterator.get_metrics() - metrics.counters["num_weight_syncs"] += 1 - - # Start the learner thread. - learner_thread = LearnerThread(workers.local_worker()) - learner_thread.start() - # We execute the following steps concurrently: # (1) Generate rollouts and store them in our replay buffer actors. Update # the weights of the worker that generated the batch. diff --git a/rllib/examples/two_trainer_workflow.py b/rllib/examples/two_trainer_workflow.py new file mode 100644 index 000000000..c0ade66f6 --- /dev/null +++ b/rllib/examples/two_trainer_workflow.py @@ -0,0 +1,132 @@ +"""Example of using a custom training workflow. + +Here we create a number of CartPole agents, some of which are trained with +DQN, and some of which are trained with PPO. Both are executed concurrently +via a custom training workflow. +""" + +import argparse +import gym + +import ray +from ray import tune +from ray.rllib.agents.trainer_template import build_trainer +from ray.rllib.agents.dqn.dqn import DEFAULT_CONFIG as DQN_CONFIG +from ray.rllib.agents.dqn.dqn_tf_policy import DQNTFPolicy +from ray.rllib.agents.ppo.ppo import DEFAULT_CONFIG as PPO_CONFIG +from ray.rllib.agents.ppo.ppo_tf_policy import PPOTFPolicy +from ray.rllib.evaluation.worker_set import WorkerSet +from ray.rllib.execution.common import _get_shared_metrics +from ray.rllib.execution.concurrency_ops import Concurrently +from ray.rllib.execution.metric_ops import StandardMetricsReporting +from ray.rllib.execution.rollout_ops import ParallelRollouts, ConcatBatches, \ + StandardizeFields, SelectExperiences +from ray.rllib.execution.replay_ops import StoreToReplayBuffer, Replay +from ray.rllib.execution.train_ops import TrainOneStep, UpdateTargetNetwork +from ray.rllib.examples.env.multi_agent import MultiAgentCartPole +from ray.rllib.optimizers.async_replay_optimizer import LocalReplayBuffer +from ray.tune.registry import register_env + +parser = argparse.ArgumentParser() +parser.add_argument("--num-iters", type=int, default=20) + + +def custom_training_workflow(workers: WorkerSet, config: dict): + local_replay_buffer = LocalReplayBuffer( + num_shards=1, + learning_starts=1000, + buffer_size=50000, + replay_batch_size=64) + + def add_ppo_metrics(batch): + print("PPO policy learning on samples from", + batch.policy_batches.keys(), "env steps", batch.count, + "agent steps", batch.total()) + metrics = _get_shared_metrics() + metrics.counters["agent_steps_trained_PPO"] += batch.total() + return batch + + def add_dqn_metrics(batch): + print("DQN policy learning on samples from", + batch.policy_batches.keys(), "env steps", batch.count, + "agent steps", batch.total()) + metrics = _get_shared_metrics() + metrics.counters["agent_steps_trained_DQN"] += batch.total() + return batch + + # Generate common experiences. + rollouts = ParallelRollouts(workers, mode="bulk_sync") + r1, r2 = rollouts.duplicate(n=2) + + # DQN sub-flow. + dqn_store_op = r1.for_each(SelectExperiences(["dqn_policy"])) \ + .for_each( + StoreToReplayBuffer(local_buffer=local_replay_buffer)) + dqn_replay_op = Replay(local_buffer=local_replay_buffer) \ + .for_each(add_dqn_metrics) \ + .for_each(TrainOneStep(workers, policies=["dqn_policy"])) \ + .for_each(UpdateTargetNetwork( + workers, target_update_freq=500, policies=["dqn_policy"])) + dqn_train_op = Concurrently( + [dqn_store_op, dqn_replay_op], mode="round_robin", output_indexes=[1]) + + # PPO sub-flow. + ppo_train_op = r2.for_each(SelectExperiences(["ppo_policy"])) \ + .combine(ConcatBatches(min_batch_size=200)) \ + .for_each(add_ppo_metrics) \ + .for_each(StandardizeFields(["advantages"])) \ + .for_each(TrainOneStep( + workers, + policies=["ppo_policy"], + num_sgd_iter=10, + sgd_minibatch_size=128)) + + # Combined training flow + train_op = Concurrently( + [ppo_train_op, dqn_train_op], mode="async", output_indexes=[1]) + + return StandardMetricsReporting(train_op, workers, config) + + +if __name__ == "__main__": + args = parser.parse_args() + ray.init() + + # Simple environment with 4 independent cartpole entities + register_env("multi_agent_cartpole", + lambda _: MultiAgentCartPole({"num_agents": 4})) + single_env = gym.make("CartPole-v0") + obs_space = single_env.observation_space + act_space = single_env.action_space + + # Note that since the trainer below does not include a default policy or + # policy configs, we have to explicitly set it in the multiagent config: + policies = { + "ppo_policy": (PPOTFPolicy, obs_space, act_space, PPO_CONFIG), + "dqn_policy": (DQNTFPolicy, obs_space, act_space, DQN_CONFIG), + } + + def policy_mapping_fn(agent_id): + if agent_id % 2 == 0: + return "ppo_policy" + else: + return "dqn_policy" + + MyTrainer = build_trainer( + name="PPO_DQN_MultiAgent", + default_policy=None, + execution_plan=custom_training_workflow) + + tune.run( + MyTrainer, + stop={"training_iteration": args.num_iters}, + config={ + "rollout_fragment_length": 50, + "num_workers": 0, + "env": "multi_agent_cartpole", + "multiagent": { + "policies": policies, + "policy_mapping_fn": policy_mapping_fn, + "policies_to_train": ["dqn_policy", "ppo_policy"], + }, + }) diff --git a/rllib/execution/common.py b/rllib/execution/common.py index c0e6ed2cd..a7741e832 100644 --- a/rllib/execution/common.py +++ b/rllib/execution/common.py @@ -39,3 +39,10 @@ def _check_sample_batch_type(batch): def _get_global_vars(): metrics = LocalIterator.get_metrics() return {"timestep": metrics.counters[STEPS_SAMPLED_COUNTER]} + + +def _get_shared_metrics(): + """Return shared metrics for the training workflow. + + This only applies if this trainer has an execution plan.""" + return LocalIterator.get_metrics() diff --git a/rllib/execution/rollout_ops.py b/rllib/execution/rollout_ops.py index 6c97f6538..1b78d124a 100644 --- a/rllib/execution/rollout_ops.py +++ b/rllib/execution/rollout_ops.py @@ -10,6 +10,7 @@ from ray.rllib.evaluation.worker_set import WorkerSet from ray.rllib.execution.common import GradientType, SampleBatchType, \ STEPS_SAMPLED_COUNTER, LEARNER_INFO, SAMPLE_TIMER, \ GRAD_WAIT_TIMER, _check_sample_batch_type +from ray.rllib.policy.policy import PolicyID from ray.rllib.policy.sample_batch import SampleBatch, DEFAULT_POLICY_ID, \ MultiAgentBatch from ray.rllib.utils.sgd import standardized @@ -190,7 +191,8 @@ class SelectExperiences: {"pol1", "pol2"} """ - def __init__(self, policy_ids: List[str]): + def __init__(self, policy_ids: List[PolicyID]): + assert isinstance(policy_ids, list), policy_ids self.policy_ids = policy_ids def __call__(self, samples: SampleBatchType) -> SampleBatchType: diff --git a/rllib/execution/train_ops.py b/rllib/execution/train_ops.py index ebd76b3ad..10168cfde 100644 --- a/rllib/execution/train_ops.py +++ b/rllib/execution/train_ops.py @@ -14,6 +14,7 @@ from ray.rllib.execution.common import SampleBatchType, \ LEARN_ON_BATCH_TIMER, LOAD_BATCH_TIMER, LAST_TARGET_UPDATE_TS, \ NUM_TARGET_UPDATES, _get_global_vars, _check_sample_batch_type from ray.rllib.optimizers.multi_gpu_impl import LocalSyncParallelOptimizer +from ray.rllib.policy.policy import PolicyID from ray.rllib.policy.sample_batch import SampleBatch, DEFAULT_POLICY_ID, \ MultiAgentBatch from ray.rllib.utils import try_import_tf @@ -42,11 +43,11 @@ class TrainOneStep: def __init__(self, workers: WorkerSet, + policies: List[PolicyID] = frozenset([]), num_sgd_iter: int = 1, sgd_minibatch_size: int = 0): self.workers = workers - self.policies = dict(self.workers.local_worker() - .foreach_trainable_policy(lambda p, i: (i, p))) + self.policies = policies or workers.local_worker().policies_to_train self.num_sgd_iter = num_sgd_iter self.sgd_minibatch_size = sgd_minibatch_size @@ -57,10 +58,11 @@ class TrainOneStep: learn_timer = metrics.timers[LEARN_ON_BATCH_TIMER] with learn_timer: if self.num_sgd_iter > 1 or self.sgd_minibatch_size > 0: - info = do_minibatch_sgd(batch, self.policies, - self.workers.local_worker(), - self.num_sgd_iter, - self.sgd_minibatch_size, []) + w = self.workers.local_worker() + info = do_minibatch_sgd( + batch, {p: w.get_policy(p) + for p in self.policies}, w, self.num_sgd_iter, + self.sgd_minibatch_size, []) # TODO(ekl) shouldn't be returning learner stats directly here metrics.info[LEARNER_INFO] = info else: @@ -70,7 +72,8 @@ class TrainOneStep: metrics.counters[STEPS_TRAINED_COUNTER] += batch.count if self.workers.remote_workers(): with metrics.timers[WORKER_UPDATE_TIMER]: - weights = ray.put(self.workers.local_worker().get_weights()) + weights = ray.put(self.workers.local_worker().get_weights( + self.policies)) for e in self.workers.remote_workers(): e.set_weights.remote(weights, _get_global_vars()) # Also update global vars of the local worker. @@ -103,10 +106,10 @@ class TrainTFMultiGPU: num_envs_per_worker: int, train_batch_size: int, shuffle_sequences: bool, + policies: List[PolicyID] = frozenset([]), _fake_gpus: bool = False): self.workers = workers - self.policies = dict(self.workers.local_worker() - .foreach_trainable_policy(lambda p, i: (i, p))) + self.policies = policies or workers.local_worker().policies_to_train self.num_sgd_iter = num_sgd_iter self.sgd_minibatch_size = sgd_minibatch_size self.shuffle_sequences = shuffle_sequences @@ -132,7 +135,8 @@ class TrainTFMultiGPU: self.optimizers = {} with self.workers.local_worker().tf_sess.graph.as_default(): with self.workers.local_worker().tf_sess.as_default(): - for policy_id, policy in self.policies.items(): + for policy_id in self.policies: + policy = self.workers.local_worker().get_policy(policy_id) with tf.variable_scope(policy_id, reuse=tf.AUTO_REUSE): if policy._state_inputs: rnn_inputs = policy._state_inputs + [ @@ -170,7 +174,7 @@ class TrainTFMultiGPU: if policy_id not in self.policies: continue - policy = self.policies[policy_id] + policy = self.workers.local_worker().get_policy(policy_id) policy._debug_vars() tuples = policy._get_loss_inputs_dict( batch, shuffle=self.shuffle_sequences) @@ -213,7 +217,8 @@ class TrainTFMultiGPU: metrics.info[LEARNER_INFO] = fetches if self.workers.remote_workers(): with metrics.timers[WORKER_UPDATE_TIMER]: - weights = ray.put(self.workers.local_worker().get_weights()) + weights = ray.put(self.workers.local_worker().get_weights( + self.policies)) for e in self.workers.remote_workers(): e.set_weights.remote(weights, _get_global_vars()) # Also update global vars of the local worker. @@ -259,7 +264,10 @@ class ApplyGradients: Updates the STEPS_TRAINED_COUNTER counter in the local iterator context. """ - def __init__(self, workers, update_all=True): + def __init__(self, + workers, + policies: List[PolicyID] = frozenset([]), + update_all=True): """Creates an ApplyGradients instance. Arguments: @@ -269,6 +277,7 @@ class ApplyGradients: currently processing (i.e., A3C style). """ self.workers = workers + self.policies = policies or workers.local_worker().policies_to_train self.update_all = update_all def __call__(self, item): @@ -291,8 +300,8 @@ class ApplyGradients: if self.update_all: if self.workers.remote_workers(): with metrics.timers[WORKER_UPDATE_TIMER]: - weights = ray.put( - self.workers.local_worker().get_weights()) + weights = ray.put(self.workers.local_worker().get_weights( + self.policies)) for e in self.workers.remote_workers(): e.set_weights.remote(weights, _get_global_vars()) else: @@ -302,7 +311,8 @@ class ApplyGradients: "update_all=False, `current_actor` must be set " "in the iterator context.") with metrics.timers[WORKER_UPDATE_TIMER]: - weights = self.workers.local_worker().get_weights() + weights = self.workers.local_worker().get_weights( + self.policies) metrics.current_actor.set_weights.remote( weights, _get_global_vars()) @@ -352,9 +362,14 @@ class UpdateTargetNetwork: track when we should update the target next. """ - def __init__(self, workers, target_update_freq, by_steps_trained=False): + def __init__(self, + workers, + target_update_freq, + by_steps_trained=False, + policies=frozenset([])): self.workers = workers self.target_update_freq = target_update_freq + self.policies = (policies or workers.local_worker().policies_to_train) if by_steps_trained: self.metric = STEPS_TRAINED_COUNTER else: @@ -365,7 +380,8 @@ class UpdateTargetNetwork: cur_ts = metrics.counters[self.metric] last_update = metrics.counters[LAST_TARGET_UPDATE_TS] if cur_ts - last_update > self.target_update_freq: + to_update = self.policies self.workers.local_worker().foreach_trainable_policy( - lambda p, _: p.update_target()) + lambda p, p_id: p_id in to_update and p.update_target()) metrics.counters[NUM_TARGET_UPDATES] += 1 metrics.counters[LAST_TARGET_UPDATE_TS] = cur_ts diff --git a/rllib/tests/test_execution.py b/rllib/tests/test_execution.py index efca91954..152ea3070 100644 --- a/rllib/tests/test_execution.py +++ b/rllib/tests/test_execution.py @@ -50,7 +50,7 @@ def test_concurrently(ray_start_regular_shared): a = iter_list([1, 2, 3]) b = iter_list([4, 5, 6]) c = Concurrently([a, b], mode="async") - assert c.take(6) == [1, 2, 3, 4, 5, 6] + assert c.take(6) == [1, 4, 2, 5, 3, 6] def test_concurrently_output(ray_start_regular_shared):