From 2c599dbf05e41e338920ee2fbe692658bcbec4dd Mon Sep 17 00:00:00 2001 From: Eric Liang Date: Thu, 7 May 2020 23:41:10 -0700 Subject: [PATCH] [rllib] Port QMIX, MADDPG to new execution API (#8344) --- rllib/agents/dqn/apex.py | 2 + rllib/agents/dqn/dqn.py | 8 +++ rllib/agents/qmix/qmix.py | 25 ++++++-- rllib/contrib/maddpg/maddpg.py | 24 +++++++- rllib/execution/replay_ops.py | 71 ++++++++++++++++++++++ rllib/optimizers/async_replay_optimizer.py | 15 ++++- 6 files changed, 135 insertions(+), 10 deletions(-) diff --git a/rllib/agents/dqn/apex.py b/rllib/agents/dqn/apex.py index 263319fac..f5b70c874 100644 --- a/rllib/agents/dqn/apex.py +++ b/rllib/agents/dqn/apex.py @@ -157,7 +157,9 @@ def execution_plan(workers: WorkerSet, config: dict): # (2) Read experiences from the replay buffer actors and send to the # learner thread via its in-queue. + post_fn = config.get("before_learn_on_batch") or (lambda b, *a: b) replay_op = Replay(actors=replay_actors, num_async=4) \ + .for_each(lambda x: post_fn(x, workers, config)) \ .zip_with_source_actor() \ .for_each(Enqueue(learner_thread.inqueue)) diff --git a/rllib/agents/dqn/dqn.py b/rllib/agents/dqn/dqn.py index 694aeec8f..cd2889032 100644 --- a/rllib/agents/dqn/dqn.py +++ b/rllib/agents/dqn/dqn.py @@ -84,6 +84,11 @@ DEFAULT_CONFIG = with_common_config({ "prioritized_replay_eps": 1e-6, # Whether to LZ4 compress observations "compress_observations": False, + # In multi-agent mode, whether to replay experiences from the same time + # step for all policies. This is required for MADDPG. + "multiagent_sync_replay": False, + # Callback to run before learning on a multi-agent batch of experiences. + "before_learn_on_batch": None, # === Optimization === # Learning rate for adam optimizer @@ -312,6 +317,7 @@ def execution_plan(workers, config): learning_starts=config["learning_starts"], buffer_size=config["buffer_size"], replay_batch_size=config["train_batch_size"], + multiagent_sync_replay=config.get("multiagent_sync_replay"), **prio_args) rollouts = ParallelRollouts(workers, mode="bulk_sync") @@ -341,7 +347,9 @@ def execution_plan(workers, config): # (2) Read and train on experiences from the replay buffer. Every batch # returned from the LocalReplay() iterator is passed to TrainOneStep to # take a SGD step, and then we decide whether to update the target network. + post_fn = config.get("before_learn_on_batch") or (lambda b, *a: b) replay_op = Replay(local_buffer=local_replay_buffer) \ + .for_each(lambda x: post_fn(x, workers, config)) \ .for_each(TrainOneStep(workers)) \ .for_each(update_prio) \ .for_each(UpdateTargetNetwork( diff --git a/rllib/agents/qmix/qmix.py b/rllib/agents/qmix/qmix.py index 07996f50e..264267f5b 100644 --- a/rllib/agents/qmix/qmix.py +++ b/rllib/agents/qmix/qmix.py @@ -1,6 +1,10 @@ from ray.rllib.agents.trainer import with_common_config from ray.rllib.agents.dqn.dqn import GenericOffPolicyTrainer from ray.rllib.agents.qmix.qmix_policy import QMixTorchPolicy +from ray.rllib.execution.replay_ops import MixInReplay +from ray.rllib.execution.rollout_ops import ParallelRollouts, ConcatBatches +from ray.rllib.execution.train_ops import TrainOneStep, UpdateTargetNetwork +from ray.rllib.execution.metric_ops import StandardMetricsReporting from ray.rllib.optimizers import SyncBatchReplayOptimizer # yapf: disable @@ -82,9 +86,6 @@ DEFAULT_CONFIG = with_common_config({ "lstm_cell_size": 64, "max_seq_len": 999999, }, - - # TODO(ekl) support sync batch replay. - "use_exec_api": False, }) # __sphinx_doc_end__ # yapf: enable @@ -98,9 +99,25 @@ def make_sync_batch_optimizer(workers, config): train_batch_size=config["train_batch_size"]) +# Experimental distributed execution impl; enable with "use_exec_api": True. +def execution_plan(workers, config): + rollouts = ParallelRollouts(workers, mode="bulk_sync") + + train_op = rollouts \ + .for_each(MixInReplay(config["buffer_size"])) \ + .combine( + ConcatBatches(min_batch_size=config["train_batch_size"])) \ + .for_each(TrainOneStep(workers)) \ + .for_each(UpdateTargetNetwork( + workers, config["target_network_update_freq"])) + + return StandardMetricsReporting(train_op, workers, config) + + QMixTrainer = GenericOffPolicyTrainer.with_updates( name="QMIX", default_config=DEFAULT_CONFIG, default_policy=QMixTorchPolicy, get_policy_class=None, - make_policy_optimizer=make_sync_batch_optimizer) + make_policy_optimizer=make_sync_batch_optimizer, + execution_plan=execution_plan) diff --git a/rllib/contrib/maddpg/maddpg.py b/rllib/contrib/maddpg/maddpg.py index a6ca8231c..a2775b8a3 100644 --- a/rllib/contrib/maddpg/maddpg.py +++ b/rllib/contrib/maddpg/maddpg.py @@ -67,6 +67,9 @@ DEFAULT_CONFIG = with_common_config({ # Observation compression. Note that compression makes simulation slow in # MPE. "compress_observations": False, + # In multi-agent mode, whether to replay experiences from the same time + # step for all policies. This is required for MADDPG. + "multiagent_sync_replay": True, # === Optimization === # Learning rate for the critic (Q-function) optimizer. @@ -100,9 +103,6 @@ DEFAULT_CONFIG = with_common_config({ "num_workers": 1, # Prevent iterations from going lower than this time span "min_iter_time_s": 0, - - # TODO(ekl) support synchronized sampling. - "use_exec_api": False, }) # __sphinx_doc_end__ # yapf: enable @@ -171,10 +171,28 @@ def collect_metrics(trainer): return result +def add_maddpg_postprocessing(config): + """Add the before learn on batch hook. + + This hook is called explicitly prior to TrainOneStep() in the execution + setups for DQN and APEX. + """ + + def f(batch, workers, config): + policies = dict(workers.local_worker() + .foreach_trainable_policy(lambda p, i: (i, p))) + return before_learn_on_batch(batch, policies, + config["train_batch_size"]) + + config["before_learn_on_batch"] = f + return config + + MADDPGTrainer = GenericOffPolicyTrainer.with_updates( name="MADDPG", default_config=DEFAULT_CONFIG, default_policy=MADDPGTFPolicy, + validate_config=add_maddpg_postprocessing, get_policy_class=None, before_init=None, before_train_step=set_global_timestep, diff --git a/rllib/execution/replay_ops.py b/rllib/execution/replay_ops.py index 0d2d20e18..fc2ec3993 100644 --- a/rllib/execution/replay_ops.py +++ b/rllib/execution/replay_ops.py @@ -91,3 +91,74 @@ def Replay(*, yield item return LocalIterator(gen_replay, SharedMetrics()) + + +class MixInReplay: + """This operator adds replay to a stream of experiences. + + It takes input batches, and returns a list of batches that include replayed + data as well. The number of replayed batches is determined by the + configured replay proportion. The max age of a batch is determined by the + number of replay slots. + """ + + def __init__(self, num_slots, replay_proportion: float = None): + """Initialize MixInReplay. + + Args: + num_slots (int): Number of batches to store in total. + replay_proportion (float): If None, one batch will be replayed per + each input batch. Otherwise, the input batch will be returned + and an additional number of batches proportional to this value + will be added as well. + + Examples: + # 1:1 mode (default) + >>> replay_op = MixInReplay(rollouts, 100) + >>> print(next(replay_op)) + SampleBatch() + + # proportional mode + >>> replay_op = MixInReplay(rollouts, 100, replay_proportion=2) + >>> print(next(replay_op)) + [SampleBatch(), SampleBatch(), SampleBatch()] + + # proportional mode, replay disabled + >>> replay_op = MixInReplay(rollouts, 100, replay_proportion=0) + >>> print(next(replay_op)) + [SampleBatch()] + """ + if replay_proportion is not None: + if replay_proportion > 0 and num_slots == 0: + raise ValueError( + "You must set num_slots > 0 if replay_proportion > 0.") + elif num_slots == 0: + raise ValueError( + "You must set num_slots > 0 if replay_proportion = None.") + self.num_slots = num_slots + self.replay_proportion = replay_proportion + self.replay_batches = [] + self.replay_index = 0 + + def __call__(self, sample_batch): + # Put in replay buffer if enabled. + if self.num_slots > 0: + if len(self.replay_batches) < self.num_slots: + self.replay_batches.append(sample_batch) + else: + self.replay_batches[self.replay_index] = sample_batch + self.replay_index += 1 + self.replay_index %= self.num_slots + + # 1:1 replay mode. + if self.replay_proportion is None: + return random.choice(self.replay_batches) + + # Proportional replay mode. + output_batches = [sample_batch] + f = self.replay_proportion + while random.random() < f: + f -= 1 + replay_batch = random.choice(self.replay_batches) + output_batches.append(replay_batch) + return output_batches diff --git a/rllib/optimizers/async_replay_optimizer.py b/rllib/optimizers/async_replay_optimizer.py index caac76c77..2c407effd 100644 --- a/rllib/optimizers/async_replay_optimizer.py +++ b/rllib/optimizers/async_replay_optimizer.py @@ -303,12 +303,14 @@ class LocalReplayBuffer(ParallelIteratorWorker): replay_batch_size, prioritized_replay_alpha=0.6, prioritized_replay_beta=0.4, - prioritized_replay_eps=1e-6): + prioritized_replay_eps=1e-6, + multiagent_sync_replay=False): self.replay_starts = learning_starts // num_shards self.buffer_size = buffer_size // num_shards self.replay_batch_size = replay_batch_size self.prioritized_replay_beta = prioritized_replay_beta self.prioritized_replay_eps = prioritized_replay_eps + self.multiagent_sync_replay = multiagent_sync_replay def gen_replay(): while True: @@ -369,10 +371,17 @@ class LocalReplayBuffer(ParallelIteratorWorker): with self.replay_timer: samples = {} + idxes = None for policy_id, replay_buffer in self.replay_buffers.items(): + if self.multiagent_sync_replay: + if idxes is None: + idxes = replay_buffer.sample_idxes( + self.replay_batch_size) + else: + idxes = replay_buffer.sample_idxes(self.replay_batch_size) (obses_t, actions, rewards, obses_tp1, dones, weights, - batch_indexes) = replay_buffer.sample( - self.replay_batch_size, beta=self.prioritized_replay_beta) + batch_indexes) = replay_buffer.sample_with_idxes( + idxes, beta=self.prioritized_replay_beta) samples[policy_id] = SampleBatch({ "obs": obses_t, "actions": actions,