mirror of
https://github.com/wassname/ray.git
synced 2026-07-05 22:54:17 +08:00
[RLlib] Pytorch MAML fix for more than two workers with discrete actions (#13835)
This commit is contained in:
@@ -8,8 +8,8 @@ from ray.rllib.policy.sample_batch import SampleBatch
|
||||
from ray.rllib.agents.ppo.ppo_tf_policy import setup_config
|
||||
from ray.rllib.agents.ppo.ppo_torch_policy import vf_preds_fetches, \
|
||||
ValueNetworkMixin
|
||||
from ray.rllib.utils.framework import try_import_torch
|
||||
from ray.rllib.utils.torch_ops import apply_grad_clipping
|
||||
from ray.rllib.utils.framework import try_import_torch
|
||||
|
||||
torch, nn = try_import_torch()
|
||||
|
||||
@@ -178,7 +178,7 @@ class MAMLLoss(object):
|
||||
|
||||
# Meta Update
|
||||
ppo_loss, s_loss, kl_loss, v_loss, ent = self.compute_losses(
|
||||
fnet, self.inner_adaptation_steps, i, clip_loss=True)
|
||||
fnet, self.inner_adaptation_steps - 1, i, clip_loss=True)
|
||||
|
||||
inner_loss = torch.mean(
|
||||
torch.stack([
|
||||
@@ -271,8 +271,14 @@ def maml_loss(policy, model, dist_class, train_batch):
|
||||
|
||||
# `split` may not exist yet (during test-loss call), use a dummy value.
|
||||
# Cannot use get here due to train_batch being a TrackingDict.
|
||||
split = train_batch["split"] if "split" in train_batch else \
|
||||
torch.tensor([[8, 8], [8, 8]])
|
||||
if "split" in train_batch:
|
||||
split = train_batch["split"]
|
||||
else:
|
||||
split_shape = (policy.config["inner_adaptation_steps"],
|
||||
policy.config["num_workers"])
|
||||
split_const = int(train_batch["obs"].shape[0] //
|
||||
(split_shape[0] * split_shape[1]))
|
||||
split = torch.ones(split_shape, dtype=int) * split_const
|
||||
policy.loss_obj = MAMLLoss(
|
||||
model=model,
|
||||
dist_class=dist_class,
|
||||
|
||||
Reference in New Issue
Block a user