diff --git a/rllib/agents/qmix/qmix.py b/rllib/agents/qmix/qmix.py index 264267f5b..8a6222605 100644 --- a/rllib/agents/qmix/qmix.py +++ b/rllib/agents/qmix/qmix.py @@ -1,10 +1,12 @@ from ray.rllib.agents.trainer import with_common_config from ray.rllib.agents.dqn.dqn import GenericOffPolicyTrainer from ray.rllib.agents.qmix.qmix_policy import QMixTorchPolicy -from ray.rllib.execution.replay_ops import MixInReplay +from ray.rllib.execution.replay_ops import SimpleReplayBuffer, Replay, \ + StoreToReplayBuffer from ray.rllib.execution.rollout_ops import ParallelRollouts, ConcatBatches from ray.rllib.execution.train_ops import TrainOneStep, UpdateTargetNetwork from ray.rllib.execution.metric_ops import StandardMetricsReporting +from ray.rllib.execution.concurrency_ops import Concurrently from ray.rllib.optimizers import SyncBatchReplayOptimizer # yapf: disable @@ -102,16 +104,22 @@ def make_sync_batch_optimizer(workers, config): # Experimental distributed execution impl; enable with "use_exec_api": True. def execution_plan(workers, config): rollouts = ParallelRollouts(workers, mode="bulk_sync") + replay_buffer = SimpleReplayBuffer(config["buffer_size"]) - train_op = rollouts \ - .for_each(MixInReplay(config["buffer_size"])) \ + store_op = rollouts \ + .for_each(StoreToReplayBuffer(local_buffer=replay_buffer)) + + train_op = Replay(local_buffer=replay_buffer) \ .combine( ConcatBatches(min_batch_size=config["train_batch_size"])) \ .for_each(TrainOneStep(workers)) \ .for_each(UpdateTargetNetwork( workers, config["target_network_update_freq"])) - return StandardMetricsReporting(train_op, workers, config) + merged_op = Concurrently( + [store_op, train_op], mode="round_robin", output_indexes=[1]) + + return StandardMetricsReporting(merged_op, workers, config) QMixTrainer = GenericOffPolicyTrainer.with_updates( diff --git a/rllib/execution/replay_ops.py b/rllib/execution/replay_ops.py index fc2ec3993..516505dd1 100644 --- a/rllib/execution/replay_ops.py +++ b/rllib/execution/replay_ops.py @@ -93,6 +93,32 @@ def Replay(*, return LocalIterator(gen_replay, SharedMetrics()) +class SimpleReplayBuffer: + """Simple replay buffer that operates over batches.""" + + def __init__(self, num_slots, replay_proportion: float = None): + """Initialize SimpleReplayBuffer. + + Args: + num_slots (int): Number of batches to store in total. + """ + self.num_slots = num_slots + self.replay_batches = [] + self.replay_index = 0 + + def add_batch(self, sample_batch): + if self.num_slots > 0: + if len(self.replay_batches) < self.num_slots: + self.replay_batches.append(sample_batch) + else: + self.replay_batches[self.replay_index] = sample_batch + self.replay_index += 1 + self.replay_index %= self.num_slots + + def replay(self): + return random.choice(self.replay_batches) + + class MixInReplay: """This operator adds replay to a stream of experiences. @@ -102,63 +128,40 @@ class MixInReplay: number of replay slots. """ - def __init__(self, num_slots, replay_proportion: float = None): + def __init__(self, num_slots, replay_proportion: float): """Initialize MixInReplay. Args: num_slots (int): Number of batches to store in total. - replay_proportion (float): If None, one batch will be replayed per - each input batch. Otherwise, the input batch will be returned + replay_proportion (float): The input batch will be returned and an additional number of batches proportional to this value will be added as well. Examples: - # 1:1 mode (default) - >>> replay_op = MixInReplay(rollouts, 100) - >>> print(next(replay_op)) - SampleBatch() - - # proportional mode + # replay proportion 2:1 >>> replay_op = MixInReplay(rollouts, 100, replay_proportion=2) >>> print(next(replay_op)) [SampleBatch(), SampleBatch(), SampleBatch()] - # proportional mode, replay disabled + # replay proportion 0:1, replay disabled >>> replay_op = MixInReplay(rollouts, 100, replay_proportion=0) >>> print(next(replay_op)) [SampleBatch()] """ - if replay_proportion is not None: - if replay_proportion > 0 and num_slots == 0: - raise ValueError( - "You must set num_slots > 0 if replay_proportion > 0.") - elif num_slots == 0: + if replay_proportion > 0 and num_slots == 0: raise ValueError( - "You must set num_slots > 0 if replay_proportion = None.") - self.num_slots = num_slots + "You must set num_slots > 0 if replay_proportion > 0.") + self.replay_buffer = SimpleReplayBuffer(num_slots) self.replay_proportion = replay_proportion - self.replay_batches = [] - self.replay_index = 0 def __call__(self, sample_batch): # Put in replay buffer if enabled. - if self.num_slots > 0: - if len(self.replay_batches) < self.num_slots: - self.replay_batches.append(sample_batch) - else: - self.replay_batches[self.replay_index] = sample_batch - self.replay_index += 1 - self.replay_index %= self.num_slots + self.buffer.add_batch(sample_batch) - # 1:1 replay mode. - if self.replay_proportion is None: - return random.choice(self.replay_batches) - - # Proportional replay mode. + # Proportional replay. output_batches = [sample_batch] f = self.replay_proportion while random.random() < f: f -= 1 - replay_batch = random.choice(self.replay_batches) - output_batches.append(replay_batch) + output_batches.append(self.replay_buffer.replay()) return output_batches