mirror of
https://github.com/wassname/ray.git
synced 2026-06-28 05:43:03 +08:00
[rllib] Fix truncate episodes mode in central critic example (#8073)
This commit is contained in:
@@ -96,8 +96,6 @@ def centralized_critic_postprocessing(policy,
|
||||
other_agent_batches=None,
|
||||
episode=None):
|
||||
if policy.loss_initialized():
|
||||
assert sample_batch["dones"][-1], \
|
||||
"Not implemented for train_batch_mode=truncate_episodes"
|
||||
assert other_agent_batches is not None
|
||||
[(_, opponent_batch)] = list(other_agent_batches.values())
|
||||
|
||||
@@ -116,11 +114,17 @@ def centralized_critic_postprocessing(policy,
|
||||
sample_batch[OPPONENT_ACTION] = np.zeros_like(
|
||||
sample_batch[SampleBatch.ACTIONS])
|
||||
sample_batch[SampleBatch.VF_PREDS] = np.zeros_like(
|
||||
sample_batch[SampleBatch.ACTIONS], dtype=np.float32)
|
||||
sample_batch[SampleBatch.REWARDS], dtype=np.float32)
|
||||
|
||||
completed = sample_batch["dones"][-1]
|
||||
if completed:
|
||||
last_r = 0.0
|
||||
else:
|
||||
last_r = sample_batch[SampleBatch.VF_PREDS][-1]
|
||||
|
||||
train_batch = compute_advantages(
|
||||
sample_batch,
|
||||
0.0,
|
||||
last_r,
|
||||
policy.config["gamma"],
|
||||
policy.config["lambda"],
|
||||
use_gae=policy.config["use_gae"])
|
||||
|
||||
Reference in New Issue
Block a user