[rllib] Add more regression tests and autogenerate (#2324)

This commit is contained in:
Richard Liaw
2018-07-02 08:20:53 -07:00
committed by GitHub
parent 8aa56c12e6
commit f0ed1c1674
10 changed files with 55 additions and 49 deletions
@@ -0,0 +1,29 @@
#!/usr/bin/env python
# This script generates all the regression tests for RLlib.
import glob
import re
import os
import os.path as osp
CONFIG_DIR = osp.join(osp.dirname(osp.abspath(__file__)), "regression_tests")
TEMPLATE = """
class Test{name}(Regression):
_file = "{filename}"
def setup_cache(self):
return _evaulate_config(self._file)
"""
if __name__ == '__main__':
os.chdir(CONFIG_DIR)
with open("regression_test.py", "a") as f:
for filename in sorted(glob.glob("*.yaml")):
splits = re.findall(r"\w+", osp.splitext(filename)[0])
test_name = "".join([s.capitalize() for s in splits])
f.write(TEMPLATE.format(name=test_name, filename=filename))
@@ -1,6 +1,7 @@
cartpole-a3c:
env: CartPole-v0
run: A3C
repeat: 3
stop:
episode_reward_mean: 200
time_total_s: 600
@@ -1,6 +1,7 @@
cartpole-a3c:
env: CartPole-v0
run: A3C
repeat: 3
stop:
episode_reward_mean: 200
time_total_s: 600
@@ -1,6 +1,7 @@
cartpole-dqn:
env: CartPole-v0
run: DQN
repeat: 3
stop:
episode_reward_mean: 200
time_total_s: 600
@@ -1,6 +1,7 @@
cartpole-pg:
env: CartPole-v0
run: PG
repeat: 3
stop:
episode_reward_mean: 200
time_total_s: 300
@@ -1,6 +1,7 @@
cartpole-ppo:
env: CartPole-v0
run: PPO
repeat: 3
stop:
episode_reward_mean: 200
time_total_s: 300
@@ -1,6 +1,7 @@
pendulum-ddpg:
env: Pendulum-v0
run: DDPG
repeat: 3
stop:
episode_reward_mean: -160
time_total_s: 900
@@ -0,0 +1,19 @@
pendulum-ppo:
env: Pendulum-v0
run: PPO
repeat: 3
stop:
episode_reward_mean: -160
# expect -140 within 300-500k steps
timesteps_total: 600000
config:
timesteps_per_batch: 2048
num_workers: 4
lambda: 0.1
gamma: 0.95
sgd_stepsize: 0.0003
sgd_batchsize: 64
num_sgd_iter: 10
model:
fcnet_hiddens: [64, 64]
squash_to_range: True
@@ -51,52 +51,3 @@ class Regression():
def track_iterations(self, result):
return result["training_iteration"]
class TestCartPolePPO(Regression):
_file = "cartpole-ppo.yaml"
def setup_cache(self):
return _evaulate_config(self._file)
class TestCartPolePG(Regression):
_file = "cartpole-pg.yaml"
def setup_cache(self):
return _evaulate_config(self._file)
class TestPendulumDDPG(Regression):
_file = "pendulum-ddpg.yaml"
def setup_cache(self):
return _evaulate_config(self._file)
class TestCartPoleES(Regression):
_file = "cartpole-es.yaml"
def setup_cache(self):
return _evaulate_config(self._file)
class TestCartPoleDQN(Regression):
_file = "cartpole-dqn.yaml"
def setup_cache(self):
return _evaulate_config(self._file)
class TestCartPoleA3C(Regression):
_file = "cartpole-a3c.yaml"
def setup_cache(self):
return _evaulate_config(self._file)
class TestCartPoleA3CPyTorch(Regression):
_file = "cartpole-a3c-pytorch.yaml"
def setup_cache(self):
return _evaulate_config(self._file)