[rllib] Qmix replay ratio is wrong

This commit is contained in:
Eric Liang
2020-05-12 13:07:19 -07:00
committed by GitHub
parent bb494a3be8
commit 96f4d82cc3
2 changed files with 48 additions and 37 deletions
+12 -4
View File
@@ -1,10 +1,12 @@
from ray.rllib.agents.trainer import with_common_config
from ray.rllib.agents.dqn.dqn import GenericOffPolicyTrainer
from ray.rllib.agents.qmix.qmix_policy import QMixTorchPolicy
from ray.rllib.execution.replay_ops import MixInReplay
from ray.rllib.execution.replay_ops import SimpleReplayBuffer, Replay, \
StoreToReplayBuffer
from ray.rllib.execution.rollout_ops import ParallelRollouts, ConcatBatches
from ray.rllib.execution.train_ops import TrainOneStep, UpdateTargetNetwork
from ray.rllib.execution.metric_ops import StandardMetricsReporting
from ray.rllib.execution.concurrency_ops import Concurrently
from ray.rllib.optimizers import SyncBatchReplayOptimizer
# yapf: disable
@@ -102,16 +104,22 @@ def make_sync_batch_optimizer(workers, config):
# Experimental distributed execution impl; enable with "use_exec_api": True.
def execution_plan(workers, config):
rollouts = ParallelRollouts(workers, mode="bulk_sync")
replay_buffer = SimpleReplayBuffer(config["buffer_size"])
train_op = rollouts \
.for_each(MixInReplay(config["buffer_size"])) \
store_op = rollouts \
.for_each(StoreToReplayBuffer(local_buffer=replay_buffer))
train_op = Replay(local_buffer=replay_buffer) \
.combine(
ConcatBatches(min_batch_size=config["train_batch_size"])) \
.for_each(TrainOneStep(workers)) \
.for_each(UpdateTargetNetwork(
workers, config["target_network_update_freq"]))
return StandardMetricsReporting(train_op, workers, config)
merged_op = Concurrently(
[store_op, train_op], mode="round_robin", output_indexes=[1])
return StandardMetricsReporting(merged_op, workers, config)
QMixTrainer = GenericOffPolicyTrainer.with_updates(
+36 -33
View File
@@ -93,6 +93,32 @@ def Replay(*,
return LocalIterator(gen_replay, SharedMetrics())
class SimpleReplayBuffer:
"""Simple replay buffer that operates over batches."""
def __init__(self, num_slots, replay_proportion: float = None):
"""Initialize SimpleReplayBuffer.
Args:
num_slots (int): Number of batches to store in total.
"""
self.num_slots = num_slots
self.replay_batches = []
self.replay_index = 0
def add_batch(self, sample_batch):
if self.num_slots > 0:
if len(self.replay_batches) < self.num_slots:
self.replay_batches.append(sample_batch)
else:
self.replay_batches[self.replay_index] = sample_batch
self.replay_index += 1
self.replay_index %= self.num_slots
def replay(self):
return random.choice(self.replay_batches)
class MixInReplay:
"""This operator adds replay to a stream of experiences.
@@ -102,63 +128,40 @@ class MixInReplay:
number of replay slots.
"""
def __init__(self, num_slots, replay_proportion: float = None):
def __init__(self, num_slots, replay_proportion: float):
"""Initialize MixInReplay.
Args:
num_slots (int): Number of batches to store in total.
replay_proportion (float): If None, one batch will be replayed per
each input batch. Otherwise, the input batch will be returned
replay_proportion (float): The input batch will be returned
and an additional number of batches proportional to this value
will be added as well.
Examples:
# 1:1 mode (default)
>>> replay_op = MixInReplay(rollouts, 100)
>>> print(next(replay_op))
SampleBatch(<replay>)
# proportional mode
# replay proportion 2:1
>>> replay_op = MixInReplay(rollouts, 100, replay_proportion=2)
>>> print(next(replay_op))
[SampleBatch(<input>), SampleBatch(<replay>), SampleBatch(<rep.>)]
# proportional mode, replay disabled
# replay proportion 0:1, replay disabled
>>> replay_op = MixInReplay(rollouts, 100, replay_proportion=0)
>>> print(next(replay_op))
[SampleBatch(<input>)]
"""
if replay_proportion is not None:
if replay_proportion > 0 and num_slots == 0:
raise ValueError(
"You must set num_slots > 0 if replay_proportion > 0.")
elif num_slots == 0:
if replay_proportion > 0 and num_slots == 0:
raise ValueError(
"You must set num_slots > 0 if replay_proportion = None.")
self.num_slots = num_slots
"You must set num_slots > 0 if replay_proportion > 0.")
self.replay_buffer = SimpleReplayBuffer(num_slots)
self.replay_proportion = replay_proportion
self.replay_batches = []
self.replay_index = 0
def __call__(self, sample_batch):
# Put in replay buffer if enabled.
if self.num_slots > 0:
if len(self.replay_batches) < self.num_slots:
self.replay_batches.append(sample_batch)
else:
self.replay_batches[self.replay_index] = sample_batch
self.replay_index += 1
self.replay_index %= self.num_slots
self.buffer.add_batch(sample_batch)
# 1:1 replay mode.
if self.replay_proportion is None:
return random.choice(self.replay_batches)
# Proportional replay mode.
# Proportional replay.
output_batches = [sample_batch]
f = self.replay_proportion
while random.random() < f:
f -= 1
replay_batch = random.choice(self.replay_batches)
output_batches.append(replay_batch)
output_batches.append(self.replay_buffer.replay())
return output_batches