[RLlib] rllib rollout test and bug fixes. (#9779)

This commit is contained in:
Sven Mika
2020-07-30 16:17:03 +02:00
committed by GitHub
parent f6bd12eb18
commit e540e425e4
3 changed files with 107 additions and 9 deletions
+16 -2
View File
@@ -1296,13 +1296,27 @@ py_test(
srcs = ["tests/test_reproducibility.py"]
)
# Test train/rollout scripts (w/o confirming rollout performance).
py_test(
name = "test_rollout",
name = "test_rollout_no_learning",
main = "tests/test_rollout.py",
tags = ["tests_dir", "tests_dir_R"],
size = "large",
data = ["train.py", "rollout.py"],
srcs = ["tests/test_rollout.py"]
srcs = ["tests/test_rollout.py"],
args = ["TestRolloutSimple"]
)
# Test train/rollout scripts (and confirm `rllib rollout` performance is same
# as the final one from the `rllib train` run).
py_test(
name = "test_rollout_w_learning",
main = "tests/test_rollout.py",
tags = ["tests_dir", "tests_dir_R"],
size = "medium",
data = ["train.py", "rollout.py"],
srcs = ["tests/test_rollout.py"],
args = ["TestRolloutLearntPolicy"]
)
py_test(
+9 -4
View File
@@ -241,7 +241,6 @@ def create_parser(parser_creator=None):
def run(args, parser):
config = {}
# Load configuration from checkpoint file.
config_dir = os.path.dirname(args.checkpoint)
config_path = os.path.join(config_dir, "params.pkl")
@@ -255,6 +254,8 @@ def run(args, parser):
raise ValueError(
"Could not find params.pkl in either the checkpoint dir or "
"its parent directory AND no config given on command line!")
else:
config = args.config
# Load the config from pickled.
else:
@@ -265,10 +266,14 @@ def run(args, parser):
if "num_workers" in config:
config["num_workers"] = min(2, config["num_workers"])
# Merge with `evaluation_config`.
evaluation_config = copy.deepcopy(config.get("evaluation_config", {}))
# Merge with `evaluation_config` (first try from command line, then from
# pkl file).
evaluation_config = copy.deepcopy(
args.config.get("evaluation_config", config.get(
"evaluation_config", {})))
config = merge_dicts(config, evaluation_config)
# Merge with command line `--config` settings.
# Merge with command line `--config` settings (if not already the same
# anyways).
config = merge_dicts(config, args.config)
if not args.env:
if not config.get("env"):
+82 -3
View File
@@ -1,8 +1,9 @@
from pathlib import Path
import os
import sys
import re
import unittest
import ray
from ray.rllib.utils.test_utils import framework_iterator
@@ -62,7 +63,73 @@ def rollout_test(algo, env="CartPole-v0", test_episode_rollout=False):
os.popen("rm -rf \"{}\"".format(tmp_dir)).read()
class TestRollout(unittest.TestCase):
def learn_test_plus_rollout(algo, env="CartPole-v0"):
for fw in framework_iterator(frameworks="tf"):
fw_ = ", \\\"framework\\\": \\\"{}\\\"".format(fw)
tmp_dir = os.popen("mktemp -d").read()[:-1]
if not os.path.exists(tmp_dir):
# Last resort: Resolve via underlying tempdir (and cut tmp_.
tmp_dir = ray.utils.tempfile.gettempdir() + tmp_dir[4:]
if not os.path.exists(tmp_dir):
sys.exit(1)
print("Saving results to {}".format(tmp_dir))
rllib_dir = str(Path(__file__).parent.parent.absolute())
print("RLlib dir = {}\nexists={}".format(rllib_dir,
os.path.exists(rllib_dir)))
os.system("python {}/train.py --local-dir={} --run={} "
"--checkpoint-freq=1 --checkpoint-at-end ".format(
rllib_dir, tmp_dir, algo) +
"--config=\"{\\\"num_gpus\\\": 0, \\\"num_workers\\\": 1, "
"\\\"evaluation_config\\\": {\\\"explore\\\": false}" + fw_ +
"}\" " + "--stop=\"{\\\"episode_reward_mean\\\": 190.0}\"" +
" --env={}".format(env))
# Find last checkpoint and use that for the rollout.
checkpoint_path = os.popen("ls {}/default/*/checkpoint_*/"
"checkpoint-*".format(tmp_dir)).read()[:-1]
checkpoints = [
cp for cp in checkpoint_path.split("\n")
if re.match(r"^.+checkpoint-\d+$", cp)
]
# Sort by number and pick last (which should be the best checkpoint).
last_checkpoint = sorted(
checkpoints,
key=lambda x: int(re.match(r".+checkpoint-(\d+)", x).group(1)))[-1]
assert re.match(r"^.+checkpoint_\d+/checkpoint-\d+$", last_checkpoint)
if not os.path.exists(last_checkpoint):
sys.exit(1)
print("Best checkpoint={} (exists)".format(last_checkpoint))
# Test rolling out n steps.
result = os.popen(
"python {}/rollout.py --run={} "
"--steps=400 "
"--out=\"{}/rollouts_n_steps.pkl\" --no-render \"{}\"".format(
rllib_dir, algo, tmp_dir, last_checkpoint)).read()[:-1]
if not os.path.exists(tmp_dir + "/rollouts_n_steps.pkl"):
sys.exit(1)
print("Rollout output exists -> Checking reward ...".format(
checkpoint_path))
episodes = result.split("\n")
mean_reward = 0.0
num_episodes = 0
for ep in episodes:
mo = re.match(r"Episode .+reward: ([\d\.\-]+)", ep)
if mo:
mean_reward += float(mo.group(1))
num_episodes += 1
mean_reward /= num_episodes
print("Rollout's mean episode reward={}".format(mean_reward))
assert mean_reward >= 190.0
# Cleanup.
os.popen("rm -rf \"{}\"".format(tmp_dir)).read()
class TestRolloutSimple(unittest.TestCase):
def test_a3c(self):
rollout_test("A3C")
@@ -85,6 +152,18 @@ class TestRollout(unittest.TestCase):
rollout_test("SAC", env="Pendulum-v0")
class TestRolloutLearntPolicy(unittest.TestCase):
def test_ppo_train_then_rollout(self):
learn_test_plus_rollout("PPO")
if __name__ == "__main__":
import sys
import pytest
sys.exit(pytest.main(["-v", __file__]))
# One can specify the specific TestCase class to run.
# None for all unittest.TestCase classes in this file.
class_ = sys.argv[1] if len(sys.argv) > 1 else None
sys.exit(
pytest.main(
["-v", __file__ + ("" if class_ is None else "::" + class_)]))