diff --git a/doc/source/rllib-algorithms.rst b/doc/source/rllib-algorithms.rst index b4880cfc2..b30d52259 100644 --- a/doc/source/rllib-algorithms.rst +++ b/doc/source/rllib-algorithms.rst @@ -19,6 +19,7 @@ Algorithm Frameworks Discrete Actions Continuous Actions Multi- `DQN`_, `Rainbow`_ tf + torch **Yes** `+parametric`_ No **Yes** `APEX-DQN`_ tf + torch **Yes** `+parametric`_ No **Yes** `IMPALA`_ tf + torch **Yes** `+parametric`_ **Yes** **Yes** `+RNN`_, `+autoreg`_ +`MAML`_ tf + torch No **Yes** No `MARWIL`_ tf + torch **Yes** `+parametric`_ **Yes** **Yes** `+RNN`_ `PG`_ tf + torch **Yes** `+parametric`_ **Yes** **Yes** `+RNN`_, `+autoreg`_ `PPO`_, `APPO`_ tf + torch **Yes** `+parametric`_ **Yes** **Yes** `+RNN`_, `+autoreg`_ @@ -406,6 +407,26 @@ HalfCheetah 13000 ~15000 :start-after: __sphinx_doc_begin__ :end-before: __sphinx_doc_end__ +.. _maml: + +Model-Agnostic Meta-Learning (MAML) +----------------------------------- +|pytorch| |tensorflow| +`[paper] `__ `[implementation] `__ + +RLlib's MAML implementation is a meta-learning method for learning and quick adaptation across different tasks for continuous control. Code here is adapted from https://github.com/jonasrothfuss, which outperforms vanilla MAML and avoids computation of the higher order gradients during the meta-update step. MAML is evaluated on custom environments that are described in greater detail `here `__. + +MAML uses additional metrics to measure performance; ``episode_reward_mean`` measures the agent's returns before adaptation, ``episode_reward_mean_adapt_N`` measures the agent's returns after N gradient steps of inner adaptation, and ``adaptation_delta`` measures the difference in performance before and after adaptation. Examples can be seen `here `__. + +Tuned examples: HalfCheetahRandDirecEnv (`Env `__, `Config `__), AntRandGoalEnv (`Env `__, `Config `__), PendulumMassEnv (`Env `__, `Config `__) + +**MAML-specific configs** (see also `common configs `__): + +.. literalinclude:: ../../rllib/agents/maml/maml.py + :language: python + :start-after: __sphinx_doc_begin__ + :end-before: __sphinx_doc_end__ + Derivative-free ~~~~~~~~~~~~~~~ diff --git a/doc/source/rllib-toc.rst b/doc/source/rllib-toc.rst index 0f69f62a3..4210d9a50 100644 --- a/doc/source/rllib-toc.rst +++ b/doc/source/rllib-toc.rst @@ -106,6 +106,8 @@ Algorithms - |pytorch| |tensorflow| :ref:`Deep Q Networks (DQN, Rainbow, Parametric DQN) ` + - |pytorch| |tensorflow| :ref:`Model-Agnostic Meta-Learning (MAML) ` + - |pytorch| |tensorflow| :ref:`Policy Gradients ` - |pytorch| |tensorflow| :ref:`Proximal Policy Optimization (PPO) ` diff --git a/rllib/agents/maml/maml.py b/rllib/agents/maml/maml.py index 77cd4a851..92d2d4329 100644 --- a/rllib/agents/maml/maml.py +++ b/rllib/agents/maml/maml.py @@ -4,6 +4,7 @@ import numpy as np from ray.rllib.utils.sgd import standardized from ray.rllib.agents import with_common_config from ray.rllib.agents.maml.maml_tf_policy import MAMLTFPolicy +from ray.rllib.agents.maml.maml_torch_policy import MAMLTorchPolicy from ray.rllib.agents.trainer_template import build_trainer from typing import List from ray.rllib.evaluation.metrics import get_learner_stats @@ -198,9 +199,8 @@ def execution_plan(workers, config): def get_policy_class(config): - # @mluo: TODO if config["framework"] == "torch": - raise ValueError("MAML not implemented in Pytorch yet") + return MAMLTorchPolicy return MAMLTFPolicy diff --git a/rllib/agents/maml/maml_torch_policy.py b/rllib/agents/maml/maml_torch_policy.py new file mode 100644 index 000000000..e46876a5c --- /dev/null +++ b/rllib/agents/maml/maml_torch_policy.py @@ -0,0 +1,432 @@ +import logging + +import ray +from ray.rllib.evaluation.postprocessing import Postprocessing +from ray.rllib.policy.sample_batch import SampleBatch +from ray.rllib.policy.torch_policy_template import build_torch_policy +from ray.rllib.agents.ppo.ppo_tf_policy import postprocess_ppo_gae, \ + setup_config +from ray.rllib.agents.ppo.ppo_torch_policy import vf_preds_fetches, \ + ValueNetworkMixin +from ray.rllib.agents.a3c.a3c_torch_policy import apply_grad_clipping +from ray.rllib.utils.framework import get_activation_fn +from ray.rllib.utils.framework import try_import_torch + +torch, nn = try_import_torch() + +logger = logging.getLogger(__name__) + + +def PPOLoss(dist_class, + actions, + curr_logits, + behaviour_logits, + advantages, + value_fn, + value_targets, + vf_preds, + cur_kl_coeff, + entropy_coeff, + clip_param, + vf_clip_param, + vf_loss_coeff, + clip_loss=False): + def surrogate_loss(actions, curr_dist, prev_dist, advantages, clip_param, + clip_loss): + pi_new_logp = curr_dist.logp(actions) + pi_old_logp = prev_dist.logp(actions) + + logp_ratio = torch.exp(pi_new_logp - pi_old_logp) + if clip_loss: + return torch.min( + advantages * logp_ratio, + advantages * torch.clamp(logp_ratio, 1 - clip_param, + 1 + clip_param)) + return advantages * logp_ratio + + def kl_loss(curr_dist, prev_dist): + return prev_dist.kl(curr_dist) + + def entropy_loss(dist): + return dist.entropy() + + def vf_loss(value_fn, value_targets, vf_preds, vf_clip_param=0.1): + # GAE Value Function Loss + vf_loss1 = torch.pow(value_fn - value_targets, 2.0) + vf_clipped = vf_preds + torch.clamp(value_fn - vf_preds, + -vf_clip_param, vf_clip_param) + vf_loss2 = torch.pow(vf_clipped - value_targets, 2.0) + vf_loss = torch.max(vf_loss1, vf_loss2) + return vf_loss + + pi_new_dist = dist_class(curr_logits, None) + pi_old_dist = dist_class(behaviour_logits, None) + + surr_loss = torch.mean( + surrogate_loss(actions, pi_new_dist, pi_old_dist, advantages, + clip_param, clip_loss)) + kl_loss = torch.mean(kl_loss(pi_new_dist, pi_old_dist)) + vf_loss = torch.mean( + vf_loss(value_fn, value_targets, vf_preds, vf_clip_param)) + entropy_loss = torch.mean(entropy_loss(pi_new_dist)) + + total_loss = -surr_loss + cur_kl_coeff * kl_loss + total_loss += vf_loss_coeff * vf_loss - entropy_coeff * entropy_loss + return total_loss, surr_loss, kl_loss, vf_loss, entropy_loss + + +# This is the computation graph for workers (inner adaptation steps) +class WorkerLoss(object): + def __init__(self, + model, + dist_class, + actions, + curr_logits, + behaviour_logits, + advantages, + value_fn, + value_targets, + vf_preds, + cur_kl_coeff, + entropy_coeff, + clip_param, + vf_clip_param, + vf_loss_coeff, + clip_loss=False): + self.loss, surr_loss, kl_loss, vf_loss, ent_loss = PPOLoss( + dist_class=dist_class, + actions=actions, + curr_logits=curr_logits, + behaviour_logits=behaviour_logits, + advantages=advantages, + value_fn=value_fn, + value_targets=value_targets, + vf_preds=vf_preds, + cur_kl_coeff=cur_kl_coeff, + entropy_coeff=entropy_coeff, + clip_param=clip_param, + vf_clip_param=vf_clip_param, + vf_loss_coeff=vf_loss_coeff, + clip_loss=clip_loss) + print("Worker Loss: ", self.loss) + + +# This is the Meta-Update computation graph for main (meta-update step) +class MAMLLoss(object): + def __init__(self, + model, + config, + dist_class, + value_targets, + advantages, + actions, + behaviour_logits, + vf_preds, + cur_kl_coeff, + policy_vars, + obs, + num_tasks, + split, + inner_adaptation_steps=1, + entropy_coeff=0, + clip_param=0.3, + vf_clip_param=0.1, + vf_loss_coeff=1.0, + use_gae=True): + + self.config = config + self.num_tasks = num_tasks + self.inner_adaptation_steps = inner_adaptation_steps + self.clip_param = clip_param + self.dist_class = dist_class + self.cur_kl_coeff = cur_kl_coeff + + # Split episode tensors into [inner_adaptation_steps+1, num_tasks, -1] + self.obs = self.split_placeholders(obs, split) + self.actions = self.split_placeholders(actions, split) + self.behaviour_logits = self.split_placeholders( + behaviour_logits, split) + self.advantages = self.split_placeholders(advantages, split) + self.value_targets = self.split_placeholders(value_targets, split) + self.vf_preds = self.split_placeholders(vf_preds, split) + + # Construct name to tensor dictionary for easier indexing + self.policy_vars = {} + for name, w in policy_vars: + self.policy_vars[name] = w + + # Calculate pi_new for PPO + pi_new_logits, current_policy_vars, value_fns = [], [], [] + for i in range(self.num_tasks): + pi_new, value_fn = self.feed_forward( + self.obs[0][i], + self.policy_vars, + policy_config=config["model"]) + pi_new_logits.append(pi_new) + value_fns.append(value_fn) + current_policy_vars.append(self.policy_vars) + + inner_kls = [] + inner_ppo_loss = [] + + # Recompute weights for inner-adaptation (same weights as workers) + for step in range(self.inner_adaptation_steps): + kls = [] + for i in range(self.num_tasks): + # PPO Loss Function (only Surrogate) + ppo_loss, _, kl_loss, _, _ = PPOLoss( + dist_class=dist_class, + actions=self.actions[step][i], + curr_logits=pi_new_logits[i], + behaviour_logits=self.behaviour_logits[step][i], + advantages=self.advantages[step][i], + value_fn=value_fns[i], + value_targets=self.value_targets[step][i], + vf_preds=self.vf_preds[step][i], + cur_kl_coeff=0.0, + entropy_coeff=entropy_coeff, + clip_param=clip_param, + vf_clip_param=vf_clip_param, + vf_loss_coeff=vf_loss_coeff, + clip_loss=False) + + adapted_policy_vars = self.compute_updated_variables( + ppo_loss, current_policy_vars[i], model) + pi_new_logits[i], value_fns[i] = self.feed_forward( + self.obs[step + 1][i], + adapted_policy_vars, + policy_config=config["model"]) + current_policy_vars[i] = adapted_policy_vars + kls.append(kl_loss) + inner_ppo_loss.append(ppo_loss) + inner_kls.append(kls) + + mean_inner_kl = [torch.mean(torch.stack(kls)) for kls in inner_kls] + self.mean_inner_kl = mean_inner_kl + + ppo_obj = [] + for i in range(self.num_tasks): + ppo_loss, surr_loss, kl_loss, val_loss, entropy_loss = PPOLoss( + dist_class=dist_class, + actions=self.actions[self.inner_adaptation_steps][i], + curr_logits=pi_new_logits[i], + behaviour_logits=self.behaviour_logits[ + self.inner_adaptation_steps][i], + advantages=self.advantages[self.inner_adaptation_steps][i], + value_fn=value_fns[i], + value_targets=self.value_targets[self.inner_adaptation_steps][ + i], + vf_preds=self.vf_preds[self.inner_adaptation_steps][i], + cur_kl_coeff=0.0, + entropy_coeff=entropy_coeff, + clip_param=clip_param, + vf_clip_param=vf_clip_param, + vf_loss_coeff=vf_loss_coeff, + clip_loss=True) + ppo_obj.append(ppo_loss) + self.mean_policy_loss = surr_loss + self.mean_kl = kl_loss + self.mean_vf_loss = val_loss + self.mean_entropy = entropy_loss + + self.inner_kl_loss = torch.mean( + torch.stack( + [a * b for a, b in zip(self.cur_kl_coeff, mean_inner_kl)])) + self.loss = torch.mean(torch.stack(ppo_obj)) + self.inner_kl_loss + print("Meta-Loss: ", self.loss, ", Inner KL:", self.inner_kl_loss) + + def feed_forward(self, obs, policy_vars, policy_config): + # Hacky for now, reconstruct FC network with adapted weights + # @mluo: TODO for any network + def fc_network(inp, network_vars, hidden_nonlinearity, + output_nonlinearity, policy_config, hiddens_name, + logits_name): + x = inp + + hidden_w = [] + logits_w = [] + + for name, w in network_vars.items(): + if hiddens_name in name: + hidden_w.append(w) + elif logits_name in name: + logits_w.append(w) + else: + raise NameError + + assert len(hidden_w) % 2 == 0 and len(logits_w) == 2 + + while len(hidden_w) != 0: + x = nn.functional.linear(x, hidden_w.pop(0), hidden_w.pop(0)) + x = hidden_nonlinearity()(x) + + x = nn.functional.linear(x, logits_w.pop(0), logits_w.pop(0)) + x = output_nonlinearity()(x) + + return x + + policyn_vars = {} + valuen_vars = {} + log_std = None + for name, param in policy_vars.items(): + if "value" in name: + valuen_vars[name] = param + elif "log_std" in name: + log_std = param + else: + policyn_vars[name] = param + + output_nonlinearity = nn.Identity + hidden_nonlinearity = get_activation_fn( + policy_config["fcnet_activation"], framework="torch") + + pi_new_logits = fc_network(obs, policyn_vars, hidden_nonlinearity, + output_nonlinearity, policy_config, + "hidden_layers", "logits") + if log_std is not None: + pi_new_logits = torch.cat( + [ + pi_new_logits, + log_std.unsqueeze(0).repeat([len(pi_new_logits), 1]) + ], + axis=1) + + value_fn = fc_network(obs, valuen_vars, hidden_nonlinearity, + output_nonlinearity, policy_config, + "value_branch_separate", "value_branch") + + return pi_new_logits, torch.squeeze(value_fn) + + def compute_updated_variables(self, loss, network_vars, model): + + grad = torch.autograd.grad( + loss, + inputs=model.parameters(), + create_graph=True, + retain_graph=True, + only_inputs=True) + adapted_vars = {} + for i, tup in enumerate(network_vars.items()): + name, var = tup + if grad[i] is None: + adapted_vars[name] = var + else: + adapted_vars[name] = var - self.config["inner_lr"] * grad[i] + return adapted_vars + + def split_placeholders(self, placeholder, split): + inner_placeholder_list = torch.split( + placeholder, torch.sum(split, dim=1).tolist(), dim=0) + placeholder_list = [] + for index, split_placeholder in enumerate(inner_placeholder_list): + placeholder_list.append( + torch.split(split_placeholder, split[index].tolist(), dim=0)) + return placeholder_list + + +def maml_loss(policy, model, dist_class, train_batch): + logits, state = model.from_batch(train_batch) + policy.cur_lr = policy.config["lr"] + + if policy.config["worker_index"]: + policy.loss_obj = WorkerLoss( + model=model, + dist_class=dist_class, + actions=train_batch[SampleBatch.ACTIONS], + curr_logits=logits, + behaviour_logits=train_batch[SampleBatch.ACTION_DIST_INPUTS], + advantages=train_batch[Postprocessing.ADVANTAGES], + value_fn=model.value_function(), + value_targets=train_batch[Postprocessing.VALUE_TARGETS], + vf_preds=train_batch[SampleBatch.VF_PREDS], + cur_kl_coeff=0.0, + entropy_coeff=policy.config["entropy_coeff"], + clip_param=policy.config["clip_param"], + vf_clip_param=policy.config["vf_clip_param"], + vf_loss_coeff=policy.config["vf_loss_coeff"], + clip_loss=False) + else: + policy.var_list = model.named_parameters() + + policy.loss_obj = MAMLLoss( + model=model, + dist_class=dist_class, + value_targets=train_batch[Postprocessing.VALUE_TARGETS], + advantages=train_batch[Postprocessing.ADVANTAGES], + actions=train_batch[SampleBatch.ACTIONS], + behaviour_logits=train_batch[SampleBatch.ACTION_DIST_INPUTS], + vf_preds=train_batch[SampleBatch.VF_PREDS], + cur_kl_coeff=policy.kl_coeff_val, + policy_vars=policy.var_list, + obs=train_batch[SampleBatch.CUR_OBS], + num_tasks=policy.config["num_workers"], + split=train_batch["split"], + config=policy.config, + inner_adaptation_steps=policy.config["inner_adaptation_steps"], + entropy_coeff=policy.config["entropy_coeff"], + clip_param=policy.config["clip_param"], + vf_clip_param=policy.config["vf_clip_param"], + vf_loss_coeff=policy.config["vf_loss_coeff"], + use_gae=policy.config["use_gae"]) + + return policy.loss_obj.loss + + +def maml_stats(policy, train_batch): + if policy.config["worker_index"]: + return {"worker_loss": policy.loss_obj.loss} + else: + return { + "cur_kl_coeff": policy.kl_coeff_val, + "cur_lr": policy.cur_lr, + "total_loss": policy.loss_obj.loss, + "policy_loss": policy.loss_obj.mean_policy_loss, + "vf_loss": policy.loss_obj.mean_vf_loss, + "kl": policy.loss_obj.mean_kl, + "inner_kl": policy.loss_obj.mean_inner_kl, + "entropy": policy.loss_obj.mean_entropy, + } + + +class KLCoeffMixin: + def __init__(self, config): + self.kl_coeff_val = [config["kl_coeff"] + ] * config["inner_adaptation_steps"] + self.kl_target = self.config["kl_target"] + + def update_kls(self, sampled_kls): + for i, kl in enumerate(sampled_kls): + if kl < self.kl_target / 1.5: + self.kl_coeff_val[i] *= 0.5 + elif kl > 1.5 * self.kl_target: + self.kl_coeff_val[i] *= 2.0 + return self.kl_coeff_val + + +def maml_optimizer_fn(policy, config): + """ + Workers use simple SGD for inner adaptation + Meta-Policy uses Adam optimizer for meta-update + """ + if not config["worker_index"]: + return torch.optim.Adam(policy.model.parameters(), lr=config["lr"]) + return torch.optim.SGD(policy.model.parameters(), lr=config["inner_lr"]) + + +def setup_mixins(policy, obs_space, action_space, config): + ValueNetworkMixin.__init__(policy, obs_space, action_space, config) + KLCoeffMixin.__init__(policy, config) + + +MAMLTorchPolicy = build_torch_policy( + name="MAMLTorchPolicy", + get_default_config=lambda: ray.rllib.agents.maml.maml.DEFAULT_CONFIG, + loss_fn=maml_loss, + stats_fn=maml_stats, + optimizer_fn=maml_optimizer_fn, + extra_action_out_fn=vf_preds_fetches, + postprocess_fn=postprocess_ppo_gae, + extra_grad_process_fn=apply_grad_clipping, + before_init=setup_config, + after_init=setup_mixins, + mixins=[KLCoeffMixin]) diff --git a/rllib/agents/maml/tests/test_maml.py b/rllib/agents/maml/tests/test_maml.py index a8daf260d..8ee696555 100644 --- a/rllib/agents/maml/tests/test_maml.py +++ b/rllib/agents/maml/tests/test_maml.py @@ -24,7 +24,7 @@ class TestMAML(unittest.TestCase): num_iterations = 1 # Test for tf framework (torch not implemented yet). - for _ in framework_iterator(config, frameworks=("tf")): + for _ in framework_iterator(config, frameworks=("tf", "torch")): trainer = maml.MAMLTrainer( config=config, env="ray.rllib.examples.env.pendulum_mass.PendulumMassEnv") diff --git a/rllib/env/meta_env.py b/rllib/env/meta_env.py new file mode 100644 index 000000000..6e8331009 --- /dev/null +++ b/rllib/env/meta_env.py @@ -0,0 +1,38 @@ +import gym +from typing import List, Any + +TaskType = Any # Can be different types depending on env, e.g., int or dict + + +class MetaEnv(gym.Env): + """ + Extension of gym.Env to define a distribution of tasks to meta-learn over. + Your env must implement this interface in order to be used with MAML. + """ + + def sample_tasks(self, n_tasks: int) -> List[TaskType]: + """Samples task of the meta-environment + + Args: + n_tasks (int) : number of different meta-tasks needed + + Returns: + tasks (list) : an (n_tasks) length list of tasks + """ + raise NotImplementedError + + def set_task(self, task: TaskType) -> None: + """Sets the specified task to the current environment + + Args: + task: task of the meta-learning environment + """ + raise NotImplementedError + + def get_task(self) -> TaskType: + """Gets the task that the agent is performing in the current environment + + Returns: + task: task of the meta-learning environment + """ + raise NotImplementedError diff --git a/rllib/examples/env/ant_rand_goal.py b/rllib/examples/env/ant_rand_goal.py index 76cd8a136..ffb9745c3 100644 --- a/rllib/examples/env/ant_rand_goal.py +++ b/rllib/examples/env/ant_rand_goal.py @@ -1,9 +1,10 @@ import numpy as np import gym from gym.envs.mujoco.mujoco_env import MujocoEnv +from ray.rllib.env.meta_env import MetaEnv -class AntRandGoalEnv(gym.utils.EzPickle, MujocoEnv): +class AntRandGoalEnv(gym.utils.EzPickle, MujocoEnv, MetaEnv): """Ant Environment that randomizes goals as tasks Goals are randomly sampled 2D positions diff --git a/rllib/examples/env/halfcheetah_rand_direc.py b/rllib/examples/env/halfcheetah_rand_direc.py index 07c02a30e..823271eb3 100644 --- a/rllib/examples/env/halfcheetah_rand_direc.py +++ b/rllib/examples/env/halfcheetah_rand_direc.py @@ -1,9 +1,10 @@ import numpy as np import gym from gym.envs.mujoco.mujoco_env import MujocoEnv +from ray.rllib.env.meta_env import MetaEnv -class HalfCheetahRandDirecEnv(MujocoEnv, gym.utils.EzPickle): +class HalfCheetahRandDirecEnv(MujocoEnv, gym.utils.EzPickle, MetaEnv): """HalfCheetah Environment with two diff tasks, moving forwards or backwards Direction is defined as a scalar: +1.0 (forwards) or -1.0 (backwards) diff --git a/rllib/examples/env/pendulum_mass.py b/rllib/examples/env/pendulum_mass.py index e29359baa..c4dc93ed7 100644 --- a/rllib/examples/env/pendulum_mass.py +++ b/rllib/examples/env/pendulum_mass.py @@ -1,9 +1,10 @@ import numpy as np import gym from gym.envs.classic_control.pendulum import PendulumEnv +from ray.rllib.env.meta_env import MetaEnv -class PendulumMassEnv(PendulumEnv, gym.utils.EzPickle): +class PendulumMassEnv(PendulumEnv, gym.utils.EzPickle, MetaEnv): """PendulumMassEnv varies the weight of the pendulum Tasks are defined to be weight uniformly sampled between [0.5,2]