From ebeee1d59a3e3365a455987bf517ad0d8eac35d5 Mon Sep 17 00:00:00 2001 From: Chace Ashcraft Date: Mon, 8 Feb 2021 04:06:02 -0700 Subject: [PATCH] [RLlib] Pytorch MAML fix for more than two workers with discrete actions (#13835) --- rllib/agents/maml/maml_torch_policy.py | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/rllib/agents/maml/maml_torch_policy.py b/rllib/agents/maml/maml_torch_policy.py index 2e0e1e208..695826798 100644 --- a/rllib/agents/maml/maml_torch_policy.py +++ b/rllib/agents/maml/maml_torch_policy.py @@ -8,8 +8,8 @@ from ray.rllib.policy.sample_batch import SampleBatch from ray.rllib.agents.ppo.ppo_tf_policy import setup_config from ray.rllib.agents.ppo.ppo_torch_policy import vf_preds_fetches, \ ValueNetworkMixin -from ray.rllib.utils.framework import try_import_torch from ray.rllib.utils.torch_ops import apply_grad_clipping +from ray.rllib.utils.framework import try_import_torch torch, nn = try_import_torch() @@ -178,7 +178,7 @@ class MAMLLoss(object): # Meta Update ppo_loss, s_loss, kl_loss, v_loss, ent = self.compute_losses( - fnet, self.inner_adaptation_steps, i, clip_loss=True) + fnet, self.inner_adaptation_steps - 1, i, clip_loss=True) inner_loss = torch.mean( torch.stack([ @@ -271,8 +271,14 @@ def maml_loss(policy, model, dist_class, train_batch): # `split` may not exist yet (during test-loss call), use a dummy value. # Cannot use get here due to train_batch being a TrackingDict. - split = train_batch["split"] if "split" in train_batch else \ - torch.tensor([[8, 8], [8, 8]]) + if "split" in train_batch: + split = train_batch["split"] + else: + split_shape = (policy.config["inner_adaptation_steps"], + policy.config["num_workers"]) + split_const = int(train_batch["obs"].shape[0] // + (split_shape[0] * split_shape[1])) + split = torch.ones(split_shape, dtype=int) * split_const policy.loss_obj = MAMLLoss( model=model, dist_class=dist_class,