mirror of
https://github.com/wassname/ray.git
synced 2026-07-04 19:07:39 +08:00
[rllib] Q-Mix implementation (Q-Mix, VDN, IQN, and Ape-X variants) (#3548)
This commit is contained in:
@@ -13,7 +13,6 @@ from ray.rllib.evaluation.sample_batch import SampleBatch, DEFAULT_POLICY_ID, \
|
||||
MultiAgentBatch
|
||||
from ray.rllib.utils.annotations import override
|
||||
from ray.rllib.utils.compression import pack_if_needed
|
||||
from ray.rllib.utils.filter import RunningStat
|
||||
from ray.rllib.utils.timer import TimerStat
|
||||
from ray.rllib.utils.schedules import LinearSchedule
|
||||
|
||||
@@ -54,7 +53,6 @@ class SyncReplayOptimizer(PolicyOptimizer):
|
||||
self.sample_timer = TimerStat()
|
||||
self.replay_timer = TimerStat()
|
||||
self.grad_timer = TimerStat()
|
||||
self.throughput = RunningStat()
|
||||
self.learner_stats = {}
|
||||
|
||||
# Set up replay buffer
|
||||
@@ -159,13 +157,13 @@ class SyncReplayOptimizer(PolicyOptimizer):
|
||||
dones) = replay_buffer.sample(self.train_batch_size)
|
||||
weights = np.ones_like(rewards)
|
||||
batch_indexes = -np.ones_like(rewards)
|
||||
samples[policy_id] = SampleBatch({
|
||||
"obs": obses_t,
|
||||
"actions": actions,
|
||||
"rewards": rewards,
|
||||
"new_obs": obses_tp1,
|
||||
"dones": dones,
|
||||
"weights": weights,
|
||||
"batch_indexes": batch_indexes
|
||||
})
|
||||
samples[policy_id] = SampleBatch({
|
||||
"obs": obses_t,
|
||||
"actions": actions,
|
||||
"rewards": rewards,
|
||||
"new_obs": obses_tp1,
|
||||
"dones": dones,
|
||||
"weights": weights,
|
||||
"batch_indexes": batch_indexes
|
||||
})
|
||||
return MultiAgentBatch(samples, self.train_batch_size)
|
||||
|
||||
Reference in New Issue
Block a user