From bb03e2499b2bc3a99ded1a6896b2ea32e6b477b6 Mon Sep 17 00:00:00 2001 From: Sven Mika Date: Mon, 30 Nov 2020 12:41:24 +0100 Subject: [PATCH] [RLlib] PyBullet Env native support via env str-specifier (if installed). (#12209) --- ci/travis/install-dependencies.sh | 2 +- python/requirements_rllib.txt | 2 ++ rllib/BUILD | 20 +++++++++++++++++++ rllib/agents/trainer.py | 19 ++++++++++++++---- .../sac/cartpole-continuous-pybullet-sac.yaml | 9 +++++++++ 5 files changed, 47 insertions(+), 5 deletions(-) create mode 100644 rllib/tuned_examples/sac/cartpole-continuous-pybullet-sac.yaml diff --git a/ci/travis/install-dependencies.sh b/ci/travis/install-dependencies.sh index 1f62dfd08..a6d2961f4 100755 --- a/ci/travis/install-dependencies.sh +++ b/ci/travis/install-dependencies.sh @@ -283,7 +283,7 @@ install_dependencies() { fi fi - # Additional RLlib dependencies. + # Additional RLlib test dependencies. if [ "${RLLIB_TESTING-}" = 1 ]; then pip install -r "${WORKSPACE_DIR}"/python/requirements_rllib.txt # install the following packages for testing on travis only diff --git a/python/requirements_rllib.txt b/python/requirements_rllib.txt index 442e3db2f..ac43812a5 100644 --- a/python/requirements_rllib.txt +++ b/python/requirements_rllib.txt @@ -5,5 +5,7 @@ torch>=1.6.0 # Version requirement to match Tune torchvision>=0.6.0 smart_open +# For testing in MuJoCo-like envs (in PyBullet). +pybullet # For tests on PettingZoo's multi-agent envs. pettingzoo>=1.4.0 diff --git a/rllib/BUILD b/rllib/BUILD index 89c784e25..67747559c 100644 --- a/rllib/BUILD +++ b/rllib/BUILD @@ -339,6 +339,16 @@ py_test( args = ["--yaml-dir=tuned_examples/sac"] ) +py_test( + name = "run_regression_tests_cartpole_continuous_pybullet_sac_tf", + main = "tests/run_regression_tests.py", + tags = ["learning_tests_tf", "learning_tests_cartpole"], + size = "large", + srcs = ["tests/run_regression_tests.py"], + data = ["tuned_examples/sac/cartpole-continuous-pybullet-sac.yaml"], + args = ["--yaml-dir=tuned_examples/sac"] +) + py_test( name = "run_regression_tests_cartpole_sac_torch", main = "tests/run_regression_tests.py", @@ -349,6 +359,16 @@ py_test( args = ["--yaml-dir=tuned_examples/sac", "--torch"] ) +py_test( + name = "run_regression_tests_cartpole_continuous_pybullet_sac_torch", + main = "tests/run_regression_tests.py", + tags = ["learning_tests_torch", "learning_tests_cartpole"], + size = "large", + srcs = ["tests/run_regression_tests.py"], + data = ["tuned_examples/sac/cartpole-continuous-pybullet-sac.yaml"], + args = ["--yaml-dir=tuned_examples/sac", "--torch"] +) + py_test( name = "run_regression_tests_pendulum_sac_tf", main = "tests/run_regression_tests.py", diff --git a/rllib/agents/trainer.py b/rllib/agents/trainer.py index 02462f3c4..0c5c9cf94 100644 --- a/rllib/agents/trainer.py +++ b/rllib/agents/trainer.py @@ -553,11 +553,22 @@ class Trainer(Trainable): elif "." in env: self.env_creator = \ lambda env_context: from_config(env, env_context) - # Try gym. + # Try gym/PyBullet. else: - import gym # soft dependency - self.env_creator = \ - lambda env_context: gym.make(env, **env_context) + + def _creator(env_context): + import gym + # Allow for PyBullet envs to be used as well (via string). + # This allows for doing things like + # `env=CartPoleContinuousBulletEnv-v0`. + try: + import pybullet_envs + pybullet_envs.getList() + except (ModuleNotFoundError, ImportError): + pass + return gym.make(env, **env_context) + + self.env_creator = _creator else: self.env_creator = lambda env_config: None diff --git a/rllib/tuned_examples/sac/cartpole-continuous-pybullet-sac.yaml b/rllib/tuned_examples/sac/cartpole-continuous-pybullet-sac.yaml new file mode 100644 index 000000000..7713582ef --- /dev/null +++ b/rllib/tuned_examples/sac/cartpole-continuous-pybullet-sac.yaml @@ -0,0 +1,9 @@ +cartpole-sac: + env: CartPoleContinuousBulletEnv-v0 + run: SAC + stop: + episode_reward_mean: 100 + timesteps_total: 100000 + config: + # Works for both torch and tf. + framework: tf