diff --git a/rllib/agents/maml/maml_torch_policy.py b/rllib/agents/maml/maml_torch_policy.py index 2e0e1e208..695826798 100644 --- a/rllib/agents/maml/maml_torch_policy.py +++ b/rllib/agents/maml/maml_torch_policy.py @@ -8,8 +8,8 @@ from ray.rllib.policy.sample_batch import SampleBatch from ray.rllib.agents.ppo.ppo_tf_policy import setup_config from ray.rllib.agents.ppo.ppo_torch_policy import vf_preds_fetches, \ ValueNetworkMixin -from ray.rllib.utils.framework import try_import_torch from ray.rllib.utils.torch_ops import apply_grad_clipping +from ray.rllib.utils.framework import try_import_torch torch, nn = try_import_torch() @@ -178,7 +178,7 @@ class MAMLLoss(object): # Meta Update ppo_loss, s_loss, kl_loss, v_loss, ent = self.compute_losses( - fnet, self.inner_adaptation_steps, i, clip_loss=True) + fnet, self.inner_adaptation_steps - 1, i, clip_loss=True) inner_loss = torch.mean( torch.stack([ @@ -271,8 +271,14 @@ def maml_loss(policy, model, dist_class, train_batch): # `split` may not exist yet (during test-loss call), use a dummy value. # Cannot use get here due to train_batch being a TrackingDict. - split = train_batch["split"] if "split" in train_batch else \ - torch.tensor([[8, 8], [8, 8]]) + if "split" in train_batch: + split = train_batch["split"] + else: + split_shape = (policy.config["inner_adaptation_steps"], + policy.config["num_workers"]) + split_const = int(train_batch["obs"].shape[0] // + (split_shape[0] * split_shape[1])) + split = torch.ones(split_shape, dtype=int) * split_const policy.loss_obj = MAMLLoss( model=model, dist_class=dist_class,