From e540e425e4aee15706f4f3da4e6a0f7d73dbb52c Mon Sep 17 00:00:00 2001 From: Sven Mika Date: Thu, 30 Jul 2020 16:17:03 +0200 Subject: [PATCH] [RLlib] `rllib rollout` test and bug fixes. (#9779) --- rllib/BUILD | 18 +++++++- rllib/rollout.py | 13 ++++-- rllib/tests/test_rollout.py | 85 +++++++++++++++++++++++++++++++++++-- 3 files changed, 107 insertions(+), 9 deletions(-) diff --git a/rllib/BUILD b/rllib/BUILD index de99dd58b..140abcb42 100644 --- a/rllib/BUILD +++ b/rllib/BUILD @@ -1296,13 +1296,27 @@ py_test( srcs = ["tests/test_reproducibility.py"] ) +# Test train/rollout scripts (w/o confirming rollout performance). py_test( - name = "test_rollout", + name = "test_rollout_no_learning", main = "tests/test_rollout.py", tags = ["tests_dir", "tests_dir_R"], size = "large", data = ["train.py", "rollout.py"], - srcs = ["tests/test_rollout.py"] + srcs = ["tests/test_rollout.py"], + args = ["TestRolloutSimple"] +) + +# Test train/rollout scripts (and confirm `rllib rollout` performance is same +# as the final one from the `rllib train` run). +py_test( + name = "test_rollout_w_learning", + main = "tests/test_rollout.py", + tags = ["tests_dir", "tests_dir_R"], + size = "medium", + data = ["train.py", "rollout.py"], + srcs = ["tests/test_rollout.py"], + args = ["TestRolloutLearntPolicy"] ) py_test( diff --git a/rllib/rollout.py b/rllib/rollout.py index 87ff510d6..c036ec35c 100755 --- a/rllib/rollout.py +++ b/rllib/rollout.py @@ -241,7 +241,6 @@ def create_parser(parser_creator=None): def run(args, parser): - config = {} # Load configuration from checkpoint file. config_dir = os.path.dirname(args.checkpoint) config_path = os.path.join(config_dir, "params.pkl") @@ -255,6 +254,8 @@ def run(args, parser): raise ValueError( "Could not find params.pkl in either the checkpoint dir or " "its parent directory AND no config given on command line!") + else: + config = args.config # Load the config from pickled. else: @@ -265,10 +266,14 @@ def run(args, parser): if "num_workers" in config: config["num_workers"] = min(2, config["num_workers"]) - # Merge with `evaluation_config`. - evaluation_config = copy.deepcopy(config.get("evaluation_config", {})) + # Merge with `evaluation_config` (first try from command line, then from + # pkl file). + evaluation_config = copy.deepcopy( + args.config.get("evaluation_config", config.get( + "evaluation_config", {}))) config = merge_dicts(config, evaluation_config) - # Merge with command line `--config` settings. + # Merge with command line `--config` settings (if not already the same + # anyways). config = merge_dicts(config, args.config) if not args.env: if not config.get("env"): diff --git a/rllib/tests/test_rollout.py b/rllib/tests/test_rollout.py index 18e38cd4d..ad4c2aec5 100644 --- a/rllib/tests/test_rollout.py +++ b/rllib/tests/test_rollout.py @@ -1,8 +1,9 @@ from pathlib import Path import os -import sys +import re import unittest +import ray from ray.rllib.utils.test_utils import framework_iterator @@ -62,7 +63,73 @@ def rollout_test(algo, env="CartPole-v0", test_episode_rollout=False): os.popen("rm -rf \"{}\"".format(tmp_dir)).read() -class TestRollout(unittest.TestCase): +def learn_test_plus_rollout(algo, env="CartPole-v0"): + for fw in framework_iterator(frameworks="tf"): + fw_ = ", \\\"framework\\\": \\\"{}\\\"".format(fw) + + tmp_dir = os.popen("mktemp -d").read()[:-1] + if not os.path.exists(tmp_dir): + # Last resort: Resolve via underlying tempdir (and cut tmp_. + tmp_dir = ray.utils.tempfile.gettempdir() + tmp_dir[4:] + if not os.path.exists(tmp_dir): + sys.exit(1) + + print("Saving results to {}".format(tmp_dir)) + + rllib_dir = str(Path(__file__).parent.parent.absolute()) + print("RLlib dir = {}\nexists={}".format(rllib_dir, + os.path.exists(rllib_dir))) + os.system("python {}/train.py --local-dir={} --run={} " + "--checkpoint-freq=1 --checkpoint-at-end ".format( + rllib_dir, tmp_dir, algo) + + "--config=\"{\\\"num_gpus\\\": 0, \\\"num_workers\\\": 1, " + "\\\"evaluation_config\\\": {\\\"explore\\\": false}" + fw_ + + "}\" " + "--stop=\"{\\\"episode_reward_mean\\\": 190.0}\"" + + " --env={}".format(env)) + + # Find last checkpoint and use that for the rollout. + checkpoint_path = os.popen("ls {}/default/*/checkpoint_*/" + "checkpoint-*".format(tmp_dir)).read()[:-1] + checkpoints = [ + cp for cp in checkpoint_path.split("\n") + if re.match(r"^.+checkpoint-\d+$", cp) + ] + # Sort by number and pick last (which should be the best checkpoint). + last_checkpoint = sorted( + checkpoints, + key=lambda x: int(re.match(r".+checkpoint-(\d+)", x).group(1)))[-1] + assert re.match(r"^.+checkpoint_\d+/checkpoint-\d+$", last_checkpoint) + if not os.path.exists(last_checkpoint): + sys.exit(1) + print("Best checkpoint={} (exists)".format(last_checkpoint)) + + # Test rolling out n steps. + result = os.popen( + "python {}/rollout.py --run={} " + "--steps=400 " + "--out=\"{}/rollouts_n_steps.pkl\" --no-render \"{}\"".format( + rllib_dir, algo, tmp_dir, last_checkpoint)).read()[:-1] + if not os.path.exists(tmp_dir + "/rollouts_n_steps.pkl"): + sys.exit(1) + print("Rollout output exists -> Checking reward ...".format( + checkpoint_path)) + episodes = result.split("\n") + mean_reward = 0.0 + num_episodes = 0 + for ep in episodes: + mo = re.match(r"Episode .+reward: ([\d\.\-]+)", ep) + if mo: + mean_reward += float(mo.group(1)) + num_episodes += 1 + mean_reward /= num_episodes + print("Rollout's mean episode reward={}".format(mean_reward)) + assert mean_reward >= 190.0 + + # Cleanup. + os.popen("rm -rf \"{}\"".format(tmp_dir)).read() + + +class TestRolloutSimple(unittest.TestCase): def test_a3c(self): rollout_test("A3C") @@ -85,6 +152,18 @@ class TestRollout(unittest.TestCase): rollout_test("SAC", env="Pendulum-v0") +class TestRolloutLearntPolicy(unittest.TestCase): + def test_ppo_train_then_rollout(self): + learn_test_plus_rollout("PPO") + + if __name__ == "__main__": + import sys import pytest - sys.exit(pytest.main(["-v", __file__])) + + # One can specify the specific TestCase class to run. + # None for all unittest.TestCase classes in this file. + class_ = sys.argv[1] if len(sys.argv) > 1 else None + sys.exit( + pytest.main( + ["-v", __file__ + ("" if class_ is None else "::" + class_)]))