From 4db86404ad069daa634565fbebe9722747d1d097 Mon Sep 17 00:00:00 2001 From: Sven Mika Date: Thu, 11 Feb 2021 18:58:46 +0100 Subject: [PATCH] [RLlib] Issue #13507: Fix MB-MPO CartPole Env's reward function as well as MB-MPO running into a traj. view API related issue. (#14037) --- rllib/BUILD | 12 ++-- rllib/agents/mbmpo/model_ensemble.py | 3 + rllib/examples/env/mbmpo_env.py | 82 ++++++++++++++-------------- rllib/policy/dynamic_tf_policy.py | 8 ++- rllib/policy/policy.py | 9 ++- 5 files changed, 63 insertions(+), 51 deletions(-) diff --git a/rllib/BUILD b/rllib/BUILD index 431f6b75a..a09a549b1 100644 --- a/rllib/BUILD +++ b/rllib/BUILD @@ -542,12 +542,12 @@ py_test( ) # MBMPOTrainer -#py_test( -# name = "test_mbmpo", -# tags = ["agents_dir"], -# size = "medium", -# srcs = ["agents/mbmpo/tests/test_mbmpo.py"] -#) +py_test( + name = "test_mbmpo", + tags = ["agents_dir"], + size = "medium", + srcs = ["agents/mbmpo/tests/test_mbmpo.py"] +) # PGTrainer py_test( diff --git a/rllib/agents/mbmpo/model_ensemble.py b/rllib/agents/mbmpo/model_ensemble.py index f7cb35b6f..1d0f13b71 100644 --- a/rllib/agents/mbmpo/model_ensemble.py +++ b/rllib/agents/mbmpo/model_ensemble.py @@ -200,6 +200,9 @@ class DynamicsEnsembleCustomModel(TorchModelV2, nn.Module): def fit(self): # Add env samples to Replay Buffer local_worker = get_global_worker() + for pid, pol in local_worker.policy_map.items(): + pol.view_requirements[ + SampleBatch.NEXT_OBS].used_for_training = True new_samples = local_worker.sample() # Initial Exploration of 8000 timesteps if not self.global_itr: diff --git a/rllib/examples/env/mbmpo_env.py b/rllib/examples/env/mbmpo_env.py index c49ef77be..87c367611 100644 --- a/rllib/examples/env/mbmpo_env.py +++ b/rllib/examples/env/mbmpo_env.py @@ -1,12 +1,12 @@ -import gym from gym.envs.classic_control import PendulumEnv, CartPoleEnv import numpy as np # MuJoCo may not be installed. HalfCheetahEnv = HopperEnv = None + try: from gym.envs.mujoco import HalfCheetahEnv, HopperEnv -except (ImportError, gym.error.DependencyNotInstalled): +except Exception: pass @@ -22,11 +22,12 @@ class CartPoleWrapper(CartPoleEnv): x = obs_next[:, 0] theta = obs_next[:, 2] - rew = (x < -self.x_threshold) | (x > self.x_threshold) | ( - theta < -self.theta_threshold_radians) | ( - theta > self.theta_threshold_radians) + # 1.0 if we are still on, 0.0 if we are terminated due to bounds + # (angular or x-axis) being breached. + rew = 1.0 - ((x < -self.x_threshold) | (x > self.x_threshold) | + (theta < -self.theta_threshold_radians) | + (theta > self.theta_threshold_radians)).astype(np.float32) - rew = rew.astype(float) return rew @@ -54,44 +55,43 @@ class PendulumWrapper(PendulumEnv): return (((x + np.pi) % (2 * np.pi)) - np.pi) -if HalfCheetahEnv: +class HalfCheetahWrapper(HalfCheetahEnv or object): + """Wrapper for the MuJoCo HalfCheetah-v2 environment. - class HalfCheetahWrapper(HalfCheetahEnv): - """Wrapper for the MuJoCo HalfCheetah-v2 environment. + Adds an additional `reward` method for some model-based RL algos (e.g. + MB-MPO). + """ - Adds an additional `reward` method for some model-based RL algos (e.g. - MB-MPO). - """ - - def reward(self, obs, action, obs_next): - if obs.ndim == 2 and action.ndim == 2: - assert obs.shape == obs_next.shape - forward_vel = obs_next[:, 8] - ctrl_cost = 0.1 * np.sum(np.square(action), axis=1) - reward = forward_vel - ctrl_cost - return np.minimum(np.maximum(-1000.0, reward), 1000.0) - else: - forward_vel = obs_next[8] - ctrl_cost = 0.1 * np.square(action).sum() - reward = forward_vel - ctrl_cost - return np.minimum(np.maximum(-1000.0, reward), 1000.0) - - class HopperWrapper(HopperEnv): - """Wrapper for the MuJoCo Hopper-v2 environment. - - Adds an additional `reward` method for some model-based RL algos (e.g. - MB-MPO). - """ - - def reward(self, obs, action, obs_next): - alive_bonus = 1.0 - assert obs.ndim == 2 and action.ndim == 2 - assert (obs.shape == obs_next.shape - and action.shape[0] == obs.shape[0]) - vel = obs_next[:, 5] - ctrl_cost = 1e-3 * np.sum(np.square(action), axis=1) - reward = vel + alive_bonus - ctrl_cost + def reward(self, obs, action, obs_next): + if obs.ndim == 2 and action.ndim == 2: + assert obs.shape == obs_next.shape + forward_vel = obs_next[:, 8] + ctrl_cost = 0.1 * np.sum(np.square(action), axis=1) + reward = forward_vel - ctrl_cost return np.minimum(np.maximum(-1000.0, reward), 1000.0) + else: + forward_vel = obs_next[8] + ctrl_cost = 0.1 * np.square(action).sum() + reward = forward_vel - ctrl_cost + return np.minimum(np.maximum(-1000.0, reward), 1000.0) + + +class HopperWrapper(HopperEnv or object): + """Wrapper for the MuJoCo Hopper-v2 environment. + + Adds an additional `reward` method for some model-based RL algos (e.g. + MB-MPO). + """ + + def reward(self, obs, action, obs_next): + alive_bonus = 1.0 + assert obs.ndim == 2 and action.ndim == 2 + assert (obs.shape == obs_next.shape + and action.shape[0] == obs.shape[0]) + vel = obs_next[:, 5] + ctrl_cost = 1e-3 * np.sum(np.square(action), axis=1) + reward = vel + alive_bonus - ctrl_cost + return np.minimum(np.maximum(-1000.0, reward), 1000.0) if __name__ == "__main__": diff --git a/rllib/policy/dynamic_tf_policy.py b/rllib/policy/dynamic_tf_policy.py index a5b01db87..e56691370 100644 --- a/rllib/policy/dynamic_tf_policy.py +++ b/rllib/policy/dynamic_tf_policy.py @@ -580,10 +580,14 @@ class DynamicTFPolicy(TFPolicy): # Add those needed for postprocessing and training. all_accessed_keys = train_batch.accessed_keys | \ batch_for_postproc.accessed_keys - # Tag those only needed for post-processing. + # Tag those only needed for post-processing (with some exceptions). for key in batch_for_postproc.accessed_keys: if key not in train_batch.accessed_keys and \ - key not in self.model.view_requirements: + key not in self.model.view_requirements and \ + key not in [ + SampleBatch.EPS_ID, SampleBatch.AGENT_INDEX, + SampleBatch.UNROLL_ID, SampleBatch.DONES, + SampleBatch.REWARDS, SampleBatch.INFOS]: if key in self.view_requirements: self.view_requirements[key].used_for_training = False if key in self._loss_input_dict: diff --git a/rllib/policy/policy.py b/rllib/policy/policy.py index d208c7d15..277ec5c24 100644 --- a/rllib/policy/policy.py +++ b/rllib/policy/policy.py @@ -668,11 +668,16 @@ class Policy(metaclass=ABCMeta): if key not in self.view_requirements: self.view_requirements[key] = ViewRequirement() if self._loss: - # Tag those only needed for post-processing. + # Tag those only needed for post-processing (with some + # exceptions). for key in batch_for_postproc.accessed_keys: if key not in train_batch.accessed_keys and \ key in self.view_requirements and \ - key not in self.model.view_requirements: + key not in self.model.view_requirements and \ + key not in [ + SampleBatch.EPS_ID, SampleBatch.AGENT_INDEX, + SampleBatch.UNROLL_ID, SampleBatch.DONES, + SampleBatch.REWARDS, SampleBatch.INFOS]: self.view_requirements[key].used_for_training = False # Remove those not needed at all (leave those that are needed # by Sampler to properly execute sample collection).