mirror of
https://github.com/wassname/ray.git
synced 2026-06-30 08:46:01 +08:00
[rllib] Qmix replay ratio is wrong
This commit is contained in:
@@ -1,10 +1,12 @@
|
||||
from ray.rllib.agents.trainer import with_common_config
|
||||
from ray.rllib.agents.dqn.dqn import GenericOffPolicyTrainer
|
||||
from ray.rllib.agents.qmix.qmix_policy import QMixTorchPolicy
|
||||
from ray.rllib.execution.replay_ops import MixInReplay
|
||||
from ray.rllib.execution.replay_ops import SimpleReplayBuffer, Replay, \
|
||||
StoreToReplayBuffer
|
||||
from ray.rllib.execution.rollout_ops import ParallelRollouts, ConcatBatches
|
||||
from ray.rllib.execution.train_ops import TrainOneStep, UpdateTargetNetwork
|
||||
from ray.rllib.execution.metric_ops import StandardMetricsReporting
|
||||
from ray.rllib.execution.concurrency_ops import Concurrently
|
||||
from ray.rllib.optimizers import SyncBatchReplayOptimizer
|
||||
|
||||
# yapf: disable
|
||||
@@ -102,16 +104,22 @@ def make_sync_batch_optimizer(workers, config):
|
||||
# Experimental distributed execution impl; enable with "use_exec_api": True.
|
||||
def execution_plan(workers, config):
|
||||
rollouts = ParallelRollouts(workers, mode="bulk_sync")
|
||||
replay_buffer = SimpleReplayBuffer(config["buffer_size"])
|
||||
|
||||
train_op = rollouts \
|
||||
.for_each(MixInReplay(config["buffer_size"])) \
|
||||
store_op = rollouts \
|
||||
.for_each(StoreToReplayBuffer(local_buffer=replay_buffer))
|
||||
|
||||
train_op = Replay(local_buffer=replay_buffer) \
|
||||
.combine(
|
||||
ConcatBatches(min_batch_size=config["train_batch_size"])) \
|
||||
.for_each(TrainOneStep(workers)) \
|
||||
.for_each(UpdateTargetNetwork(
|
||||
workers, config["target_network_update_freq"]))
|
||||
|
||||
return StandardMetricsReporting(train_op, workers, config)
|
||||
merged_op = Concurrently(
|
||||
[store_op, train_op], mode="round_robin", output_indexes=[1])
|
||||
|
||||
return StandardMetricsReporting(merged_op, workers, config)
|
||||
|
||||
|
||||
QMixTrainer = GenericOffPolicyTrainer.with_updates(
|
||||
|
||||
@@ -93,6 +93,32 @@ def Replay(*,
|
||||
return LocalIterator(gen_replay, SharedMetrics())
|
||||
|
||||
|
||||
class SimpleReplayBuffer:
|
||||
"""Simple replay buffer that operates over batches."""
|
||||
|
||||
def __init__(self, num_slots, replay_proportion: float = None):
|
||||
"""Initialize SimpleReplayBuffer.
|
||||
|
||||
Args:
|
||||
num_slots (int): Number of batches to store in total.
|
||||
"""
|
||||
self.num_slots = num_slots
|
||||
self.replay_batches = []
|
||||
self.replay_index = 0
|
||||
|
||||
def add_batch(self, sample_batch):
|
||||
if self.num_slots > 0:
|
||||
if len(self.replay_batches) < self.num_slots:
|
||||
self.replay_batches.append(sample_batch)
|
||||
else:
|
||||
self.replay_batches[self.replay_index] = sample_batch
|
||||
self.replay_index += 1
|
||||
self.replay_index %= self.num_slots
|
||||
|
||||
def replay(self):
|
||||
return random.choice(self.replay_batches)
|
||||
|
||||
|
||||
class MixInReplay:
|
||||
"""This operator adds replay to a stream of experiences.
|
||||
|
||||
@@ -102,63 +128,40 @@ class MixInReplay:
|
||||
number of replay slots.
|
||||
"""
|
||||
|
||||
def __init__(self, num_slots, replay_proportion: float = None):
|
||||
def __init__(self, num_slots, replay_proportion: float):
|
||||
"""Initialize MixInReplay.
|
||||
|
||||
Args:
|
||||
num_slots (int): Number of batches to store in total.
|
||||
replay_proportion (float): If None, one batch will be replayed per
|
||||
each input batch. Otherwise, the input batch will be returned
|
||||
replay_proportion (float): The input batch will be returned
|
||||
and an additional number of batches proportional to this value
|
||||
will be added as well.
|
||||
|
||||
Examples:
|
||||
# 1:1 mode (default)
|
||||
>>> replay_op = MixInReplay(rollouts, 100)
|
||||
>>> print(next(replay_op))
|
||||
SampleBatch(<replay>)
|
||||
|
||||
# proportional mode
|
||||
# replay proportion 2:1
|
||||
>>> replay_op = MixInReplay(rollouts, 100, replay_proportion=2)
|
||||
>>> print(next(replay_op))
|
||||
[SampleBatch(<input>), SampleBatch(<replay>), SampleBatch(<rep.>)]
|
||||
|
||||
# proportional mode, replay disabled
|
||||
# replay proportion 0:1, replay disabled
|
||||
>>> replay_op = MixInReplay(rollouts, 100, replay_proportion=0)
|
||||
>>> print(next(replay_op))
|
||||
[SampleBatch(<input>)]
|
||||
"""
|
||||
if replay_proportion is not None:
|
||||
if replay_proportion > 0 and num_slots == 0:
|
||||
raise ValueError(
|
||||
"You must set num_slots > 0 if replay_proportion > 0.")
|
||||
elif num_slots == 0:
|
||||
if replay_proportion > 0 and num_slots == 0:
|
||||
raise ValueError(
|
||||
"You must set num_slots > 0 if replay_proportion = None.")
|
||||
self.num_slots = num_slots
|
||||
"You must set num_slots > 0 if replay_proportion > 0.")
|
||||
self.replay_buffer = SimpleReplayBuffer(num_slots)
|
||||
self.replay_proportion = replay_proportion
|
||||
self.replay_batches = []
|
||||
self.replay_index = 0
|
||||
|
||||
def __call__(self, sample_batch):
|
||||
# Put in replay buffer if enabled.
|
||||
if self.num_slots > 0:
|
||||
if len(self.replay_batches) < self.num_slots:
|
||||
self.replay_batches.append(sample_batch)
|
||||
else:
|
||||
self.replay_batches[self.replay_index] = sample_batch
|
||||
self.replay_index += 1
|
||||
self.replay_index %= self.num_slots
|
||||
self.buffer.add_batch(sample_batch)
|
||||
|
||||
# 1:1 replay mode.
|
||||
if self.replay_proportion is None:
|
||||
return random.choice(self.replay_batches)
|
||||
|
||||
# Proportional replay mode.
|
||||
# Proportional replay.
|
||||
output_batches = [sample_batch]
|
||||
f = self.replay_proportion
|
||||
while random.random() < f:
|
||||
f -= 1
|
||||
replay_batch = random.choice(self.replay_batches)
|
||||
output_batches.append(replay_batch)
|
||||
output_batches.append(self.replay_buffer.replay())
|
||||
return output_batches
|
||||
|
||||
Reference in New Issue
Block a user