mirror of
https://github.com/wassname/ray.git
synced 2026-06-28 03:34:48 +08:00
[rllib] Basic regression tests on CartPole (#1608)
* Sun Feb 25 21:36:22 PST 2018 * Sun Feb 25 21:42:09 PST 2018 * Sun Feb 25 21:44:30 PST 2018 * fix lint * Wed Feb 28 12:41:49 PST 2018
This commit is contained in:
@@ -84,6 +84,8 @@ DEFAULT_CONFIG = dict(
|
||||
grad_norm_clipping=10,
|
||||
# Arguments to pass to the rllib optimizer
|
||||
optimizer={},
|
||||
# Smooth the current average reward over this many previous episodes.
|
||||
smoothing_num_episodes=100,
|
||||
|
||||
# === Tensorflow ===
|
||||
# Arguments to pass to tensorflow
|
||||
|
||||
@@ -157,8 +157,9 @@ class DQNEvaluator(TFMultiGPUSupport):
|
||||
return ret
|
||||
|
||||
def stats(self):
|
||||
mean_100ep_reward = round(np.mean(self.episode_rewards[-101:-1]), 5)
|
||||
mean_100ep_length = round(np.mean(self.episode_lengths[-101:-1]), 5)
|
||||
n = self.config["smoothing_num_episodes"] + 1
|
||||
mean_100ep_reward = round(np.mean(self.episode_rewards[-n:-1]), 5)
|
||||
mean_100ep_length = round(np.mean(self.episode_lengths[-n:-1]), 5)
|
||||
exploration = self.exploration.value(self.global_timestep)
|
||||
return {
|
||||
"mean_100ep_reward": mean_100ep_reward,
|
||||
|
||||
@@ -0,0 +1,11 @@
|
||||
cartpole-a3c:
|
||||
env: CartPole-v0
|
||||
run: A3C
|
||||
stop:
|
||||
episode_reward_mean: 200
|
||||
time_total_s: 600
|
||||
resources:
|
||||
cpu: 2
|
||||
config:
|
||||
num_workers: 4
|
||||
gamma: 0.95
|
||||
@@ -0,0 +1,12 @@
|
||||
cartpole-dqn:
|
||||
env: CartPole-v0
|
||||
run: DQN
|
||||
stop:
|
||||
episode_reward_mean: 200
|
||||
time_total_s: 600
|
||||
resources:
|
||||
cpu: 1
|
||||
config:
|
||||
n_step: 3
|
||||
gamma: 0.95
|
||||
smoothing_num_episodes: 10
|
||||
@@ -0,0 +1,12 @@
|
||||
cartpole-es:
|
||||
env: CartPole-v0
|
||||
run: ES
|
||||
stop:
|
||||
episode_reward_mean: 200
|
||||
time_total_s: 300
|
||||
resources:
|
||||
cpu: 2
|
||||
config:
|
||||
num_workers: 2
|
||||
noise_size: 25000000
|
||||
episodes_per_batch: 50
|
||||
@@ -0,0 +1,10 @@
|
||||
cartpole-ppo:
|
||||
env: CartPole-v0
|
||||
run: PPO
|
||||
stop:
|
||||
episode_reward_mean: 200
|
||||
time_total_s: 300
|
||||
resources:
|
||||
cpu: 1
|
||||
config:
|
||||
num_workers: 1
|
||||
+33
@@ -0,0 +1,33 @@
|
||||
#!/usr/bin/env python
|
||||
# This script runs all the integration tests for RLlib.
|
||||
# TODO(ekl) add large-scale tests on different envs here.
|
||||
|
||||
import glob
|
||||
import yaml
|
||||
|
||||
import ray
|
||||
from ray.tune import run_experiments
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
experiments = {}
|
||||
|
||||
for test in glob.glob("regression_tests/*.yaml"):
|
||||
config = yaml.load(open(test).read())
|
||||
experiments.update(config)
|
||||
|
||||
print("== Test config ==")
|
||||
print(yaml.dump(experiments))
|
||||
|
||||
ray.init()
|
||||
trials = run_experiments(experiments)
|
||||
|
||||
num_failures = 0
|
||||
for t in trials:
|
||||
if (t.last_result.episode_reward_mean <
|
||||
t.stopping_criterion["episode_reward_mean"]):
|
||||
num_failures += 1
|
||||
|
||||
if num_failures:
|
||||
raise Exception(
|
||||
"{} trials did not converge".format(num_failures))
|
||||
Reference in New Issue
Block a user