[RLlib] Pytorch MAML fix for more than two workers with discrete actions (#13835)

This commit is contained in:
Chace Ashcraft
2021-02-08 04:06:02 -07:00
committed by GitHub
parent d001af3e59
commit ebeee1d59a
+10 -4
View File
@@ -8,8 +8,8 @@ from ray.rllib.policy.sample_batch import SampleBatch
from ray.rllib.agents.ppo.ppo_tf_policy import setup_config
from ray.rllib.agents.ppo.ppo_torch_policy import vf_preds_fetches, \
ValueNetworkMixin
from ray.rllib.utils.framework import try_import_torch
from ray.rllib.utils.torch_ops import apply_grad_clipping
from ray.rllib.utils.framework import try_import_torch
torch, nn = try_import_torch()
@@ -178,7 +178,7 @@ class MAMLLoss(object):
# Meta Update
ppo_loss, s_loss, kl_loss, v_loss, ent = self.compute_losses(
fnet, self.inner_adaptation_steps, i, clip_loss=True)
fnet, self.inner_adaptation_steps - 1, i, clip_loss=True)
inner_loss = torch.mean(
torch.stack([
@@ -271,8 +271,14 @@ def maml_loss(policy, model, dist_class, train_batch):
# `split` may not exist yet (during test-loss call), use a dummy value.
# Cannot use get here due to train_batch being a TrackingDict.
split = train_batch["split"] if "split" in train_batch else \
torch.tensor([[8, 8], [8, 8]])
if "split" in train_batch:
split = train_batch["split"]
else:
split_shape = (policy.config["inner_adaptation_steps"],
policy.config["num_workers"])
split_const = int(train_batch["obs"].shape[0] //
(split_shape[0] * split_shape[1]))
split = torch.ones(split_shape, dtype=int) * split_const
policy.loss_obj = MAMLLoss(
model=model,
dist_class=dist_class,