diff --git a/python/ray/rllib/dqn/dqn.py b/python/ray/rllib/dqn/dqn.py index a0a5dbc01..51e501c53 100644 --- a/python/ray/rllib/dqn/dqn.py +++ b/python/ray/rllib/dqn/dqn.py @@ -84,6 +84,8 @@ DEFAULT_CONFIG = dict( grad_norm_clipping=10, # Arguments to pass to the rllib optimizer optimizer={}, + # Smooth the current average reward over this many previous episodes. + smoothing_num_episodes=100, # === Tensorflow === # Arguments to pass to tensorflow diff --git a/python/ray/rllib/dqn/dqn_evaluator.py b/python/ray/rllib/dqn/dqn_evaluator.py index 20a269cbf..2bae4aed8 100644 --- a/python/ray/rllib/dqn/dqn_evaluator.py +++ b/python/ray/rllib/dqn/dqn_evaluator.py @@ -157,8 +157,9 @@ class DQNEvaluator(TFMultiGPUSupport): return ret def stats(self): - mean_100ep_reward = round(np.mean(self.episode_rewards[-101:-1]), 5) - mean_100ep_length = round(np.mean(self.episode_lengths[-101:-1]), 5) + n = self.config["smoothing_num_episodes"] + 1 + mean_100ep_reward = round(np.mean(self.episode_rewards[-n:-1]), 5) + mean_100ep_length = round(np.mean(self.episode_lengths[-n:-1]), 5) exploration = self.exploration.value(self.global_timestep) return { "mean_100ep_reward": mean_100ep_reward, diff --git a/python/ray/rllib/tuned_examples/regression_tests/cartpole-a3c.yaml b/python/ray/rllib/tuned_examples/regression_tests/cartpole-a3c.yaml new file mode 100644 index 000000000..2b16e48e0 --- /dev/null +++ b/python/ray/rllib/tuned_examples/regression_tests/cartpole-a3c.yaml @@ -0,0 +1,11 @@ +cartpole-a3c: + env: CartPole-v0 + run: A3C + stop: + episode_reward_mean: 200 + time_total_s: 600 + resources: + cpu: 2 + config: + num_workers: 4 + gamma: 0.95 diff --git a/python/ray/rllib/tuned_examples/regression_tests/cartpole-dqn.yaml b/python/ray/rllib/tuned_examples/regression_tests/cartpole-dqn.yaml new file mode 100644 index 000000000..da916312f --- /dev/null +++ b/python/ray/rllib/tuned_examples/regression_tests/cartpole-dqn.yaml @@ -0,0 +1,12 @@ +cartpole-dqn: + env: CartPole-v0 + run: DQN + stop: + episode_reward_mean: 200 + time_total_s: 600 + resources: + cpu: 1 + config: + n_step: 3 + gamma: 0.95 + smoothing_num_episodes: 10 diff --git a/python/ray/rllib/tuned_examples/regression_tests/cartpole-es.yaml b/python/ray/rllib/tuned_examples/regression_tests/cartpole-es.yaml new file mode 100644 index 000000000..45392e156 --- /dev/null +++ b/python/ray/rllib/tuned_examples/regression_tests/cartpole-es.yaml @@ -0,0 +1,12 @@ +cartpole-es: + env: CartPole-v0 + run: ES + stop: + episode_reward_mean: 200 + time_total_s: 300 + resources: + cpu: 2 + config: + num_workers: 2 + noise_size: 25000000 + episodes_per_batch: 50 diff --git a/python/ray/rllib/tuned_examples/regression_tests/cartpole-ppo.yaml b/python/ray/rllib/tuned_examples/regression_tests/cartpole-ppo.yaml new file mode 100644 index 000000000..0f7c23ff2 --- /dev/null +++ b/python/ray/rllib/tuned_examples/regression_tests/cartpole-ppo.yaml @@ -0,0 +1,10 @@ +cartpole-ppo: + env: CartPole-v0 + run: PPO + stop: + episode_reward_mean: 200 + time_total_s: 300 + resources: + cpu: 1 + config: + num_workers: 1 diff --git a/python/ray/rllib/tuned_examples/run_regression_tests.py b/python/ray/rllib/tuned_examples/run_regression_tests.py new file mode 100755 index 000000000..3bb7d5224 --- /dev/null +++ b/python/ray/rllib/tuned_examples/run_regression_tests.py @@ -0,0 +1,33 @@ +#!/usr/bin/env python +# This script runs all the integration tests for RLlib. +# TODO(ekl) add large-scale tests on different envs here. + +import glob +import yaml + +import ray +from ray.tune import run_experiments + + +if __name__ == '__main__': + experiments = {} + + for test in glob.glob("regression_tests/*.yaml"): + config = yaml.load(open(test).read()) + experiments.update(config) + + print("== Test config ==") + print(yaml.dump(experiments)) + + ray.init() + trials = run_experiments(experiments) + + num_failures = 0 + for t in trials: + if (t.last_result.episode_reward_mean < + t.stopping_criterion["episode_reward_mean"]): + num_failures += 1 + + if num_failures: + raise Exception( + "{} trials did not converge".format(num_failures))