mirror of
https://github.com/wassname/ray.git
synced 2026-07-01 22:06:00 +08:00
[rllib] Add more regression tests and autogenerate (#2324)
This commit is contained in:
@@ -0,0 +1,29 @@
|
||||
#!/usr/bin/env python
|
||||
# This script generates all the regression tests for RLlib.
|
||||
|
||||
import glob
|
||||
import re
|
||||
import os
|
||||
import os.path as osp
|
||||
|
||||
|
||||
CONFIG_DIR = osp.join(osp.dirname(osp.abspath(__file__)), "regression_tests")
|
||||
|
||||
|
||||
TEMPLATE = """
|
||||
class Test{name}(Regression):
|
||||
_file = "{filename}"
|
||||
|
||||
def setup_cache(self):
|
||||
return _evaulate_config(self._file)
|
||||
|
||||
"""
|
||||
|
||||
if __name__ == '__main__':
|
||||
os.chdir(CONFIG_DIR)
|
||||
|
||||
with open("regression_test.py", "a") as f:
|
||||
for filename in sorted(glob.glob("*.yaml")):
|
||||
splits = re.findall(r"\w+", osp.splitext(filename)[0])
|
||||
test_name = "".join([s.capitalize() for s in splits])
|
||||
f.write(TEMPLATE.format(name=test_name, filename=filename))
|
||||
@@ -1,6 +1,7 @@
|
||||
cartpole-a3c:
|
||||
env: CartPole-v0
|
||||
run: A3C
|
||||
repeat: 3
|
||||
stop:
|
||||
episode_reward_mean: 200
|
||||
time_total_s: 600
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
cartpole-a3c:
|
||||
env: CartPole-v0
|
||||
run: A3C
|
||||
repeat: 3
|
||||
stop:
|
||||
episode_reward_mean: 200
|
||||
time_total_s: 600
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
cartpole-dqn:
|
||||
env: CartPole-v0
|
||||
run: DQN
|
||||
repeat: 3
|
||||
stop:
|
||||
episode_reward_mean: 200
|
||||
time_total_s: 600
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
cartpole-pg:
|
||||
env: CartPole-v0
|
||||
run: PG
|
||||
repeat: 3
|
||||
stop:
|
||||
episode_reward_mean: 200
|
||||
time_total_s: 300
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
cartpole-ppo:
|
||||
env: CartPole-v0
|
||||
run: PPO
|
||||
repeat: 3
|
||||
stop:
|
||||
episode_reward_mean: 200
|
||||
time_total_s: 300
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
pendulum-ddpg:
|
||||
env: Pendulum-v0
|
||||
run: DDPG
|
||||
repeat: 3
|
||||
stop:
|
||||
episode_reward_mean: -160
|
||||
time_total_s: 900
|
||||
|
||||
@@ -0,0 +1,19 @@
|
||||
pendulum-ppo:
|
||||
env: Pendulum-v0
|
||||
run: PPO
|
||||
repeat: 3
|
||||
stop:
|
||||
episode_reward_mean: -160
|
||||
# expect -140 within 300-500k steps
|
||||
timesteps_total: 600000
|
||||
config:
|
||||
timesteps_per_batch: 2048
|
||||
num_workers: 4
|
||||
lambda: 0.1
|
||||
gamma: 0.95
|
||||
sgd_stepsize: 0.0003
|
||||
sgd_batchsize: 64
|
||||
num_sgd_iter: 10
|
||||
model:
|
||||
fcnet_hiddens: [64, 64]
|
||||
squash_to_range: True
|
||||
@@ -51,52 +51,3 @@ class Regression():
|
||||
|
||||
def track_iterations(self, result):
|
||||
return result["training_iteration"]
|
||||
|
||||
|
||||
class TestCartPolePPO(Regression):
|
||||
_file = "cartpole-ppo.yaml"
|
||||
|
||||
def setup_cache(self):
|
||||
return _evaulate_config(self._file)
|
||||
|
||||
|
||||
class TestCartPolePG(Regression):
|
||||
_file = "cartpole-pg.yaml"
|
||||
|
||||
def setup_cache(self):
|
||||
return _evaulate_config(self._file)
|
||||
|
||||
|
||||
class TestPendulumDDPG(Regression):
|
||||
_file = "pendulum-ddpg.yaml"
|
||||
|
||||
def setup_cache(self):
|
||||
return _evaulate_config(self._file)
|
||||
|
||||
|
||||
class TestCartPoleES(Regression):
|
||||
_file = "cartpole-es.yaml"
|
||||
|
||||
def setup_cache(self):
|
||||
return _evaulate_config(self._file)
|
||||
|
||||
|
||||
class TestCartPoleDQN(Regression):
|
||||
_file = "cartpole-dqn.yaml"
|
||||
|
||||
def setup_cache(self):
|
||||
return _evaulate_config(self._file)
|
||||
|
||||
|
||||
class TestCartPoleA3C(Regression):
|
||||
_file = "cartpole-a3c.yaml"
|
||||
|
||||
def setup_cache(self):
|
||||
return _evaulate_config(self._file)
|
||||
|
||||
|
||||
class TestCartPoleA3CPyTorch(Regression):
|
||||
_file = "cartpole-a3c-pytorch.yaml"
|
||||
|
||||
def setup_cache(self):
|
||||
return _evaulate_config(self._file)
|
||||
|
||||
Reference in New Issue
Block a user