diff --git a/python/requirements_rllib.txt b/python/requirements_rllib.txt index 94ae9cdbb..0cefb0296 100644 --- a/python/requirements_rllib.txt +++ b/python/requirements_rllib.txt @@ -13,3 +13,6 @@ pettingzoo>=1.4.0 # For tests on RecSim and Kaggle envs. recsim kaggle_environments + +# For MAML on PyTorch. +higher diff --git a/rllib/agents/maml/tests/test_maml.py b/rllib/agents/maml/tests/test_maml.py index e5ef3cf69..b84e02857 100644 --- a/rllib/agents/maml/tests/test_maml.py +++ b/rllib/agents/maml/tests/test_maml.py @@ -23,15 +23,21 @@ class TestMAML(unittest.TestCase): num_iterations = 1 # Test for tf framework (torch not implemented yet). - for _ in framework_iterator(config, frameworks=("tf")): - trainer = maml.MAMLTrainer( - config=config, - env="ray.rllib.examples.env.pendulum_mass.PendulumMassEnv") - for i in range(num_iterations): - trainer.train() - check_compute_single_action( - trainer, include_prev_action_reward=True) - trainer.stop() + for fw in framework_iterator(config, frameworks=("tf", "torch")): + for env in [ + "pendulum_mass.PendulumMassEnv", + "cartpole_mass.CartPoleMassEnv" + ]: + if fw == "tf" and env.startswith("cartpole"): + continue + print("env={}".format(env)) + env_ = "ray.rllib.examples.env.{}".format(env) + trainer = maml.MAMLTrainer(config=config, env=env_) + for i in range(num_iterations): + trainer.train() + check_compute_single_action( + trainer, include_prev_action_reward=True) + trainer.stop() if __name__ == "__main__": diff --git a/rllib/examples/env/cartpole_mass.py b/rllib/examples/env/cartpole_mass.py new file mode 100644 index 000000000..a0519cb17 --- /dev/null +++ b/rllib/examples/env/cartpole_mass.py @@ -0,0 +1,31 @@ +import numpy as np +import gym +from gym.envs.classic_control.cartpole import CartPoleEnv +from ray.rllib.env.meta_env import MetaEnv + + +class CartPoleMassEnv(CartPoleEnv, gym.utils.EzPickle, MetaEnv): + """CartPoleMassEnv varies the weights of the cart and the pole. + """ + + def sample_tasks(self, n_tasks): + # Sample new cart- and pole masses (random floats between 0.5 and 2.0 + # (cart) and between 0.05 and 0.2 (pole)). + cart_masses = np.random.uniform(low=0.5, high=2.0, size=(n_tasks, 1)) + pole_masses = np.random.uniform(low=0.05, high=0.2, size=(n_tasks, 1)) + return np.concatenate([cart_masses, pole_masses], axis=-1) + + def set_task(self, task): + """ + Args: + task (Tuple[float]): Masses of the cart and the pole. + """ + self.masscart = task[0] + self.masspole = task[1] + + def get_task(self): + """ + Returns: + Tuple[float]: The current mass of the cart- and pole. + """ + return np.array([self.masscart, self.masspole]) diff --git a/rllib/examples/env/pendulum_mass.py b/rllib/examples/env/pendulum_mass.py index c4dc93ed7..b68b283e7 100644 --- a/rllib/examples/env/pendulum_mass.py +++ b/rllib/examples/env/pendulum_mass.py @@ -11,19 +11,22 @@ class PendulumMassEnv(PendulumEnv, gym.utils.EzPickle, MetaEnv): """ def sample_tasks(self, n_tasks): - # Mass is a random float between 0.5 and 2 + # Sample new pendulum masses (random floats between 0.5 and 2). return np.random.uniform(low=0.5, high=2.0, size=(n_tasks, )) def set_task(self, task): """ Args: - task: task of the meta-learning environment + task (float): Task of the meta-learning environment (here: mass of + the pendulum). """ + # self.m is the mass property of the pendulum. self.m = task def get_task(self): """ Returns: - task: task of the meta-learning environment + float: The current mass of the pendulum (self.m in the PendulumEnv + object). """ return self.m