[rllib] Switch to tune.run() instead of run_experiments() (#4515)

This commit is contained in:
Eric Liang
2019-03-30 14:07:50 -07:00
committed by GitHub
parent 5efb21e1d0
commit fce0062380
13 changed files with 174 additions and 208 deletions
+9 -13
View File
@@ -9,9 +9,9 @@ import tensorflow as tf
import tensorflow.contrib.slim as slim
import ray
from ray import tune
from ray.rllib.models import Model, ModelCatalog
from ray.rllib.models.misc import normc_initializer
from ray.tune import run_experiments
parser = argparse.ArgumentParser()
parser.add_argument("--num-iters", type=int, default=200)
@@ -47,18 +47,14 @@ if __name__ == "__main__":
ray.init()
ModelCatalog.register_custom_model("bn_model", BatchNormModel)
run_experiments({
"batch_norm_demo": {
"run": args.run,
tune.run(
args.run,
stop={"training_iteration": args.num_iters},
config={
"env": "Pendulum-v0" if args.run == "DDPG" else "CartPole-v0",
"stop": {
"training_iteration": args.num_iters
},
"config": {
"model": {
"custom_model": "bn_model",
},
"num_workers": 0,
"model": {
"custom_model": "bn_model",
},
"num_workers": 0,
},
})
)
+12 -17
View File
@@ -179,20 +179,15 @@ if __name__ == "__main__":
},
}
tune.run_experiments({
"test": {
"env": "cartpole_stateless",
"run": args.run,
"stop": {
"episode_reward_mean": args.stop
},
"config": dict(
configs[args.run], **{
"model": {
"use_lstm": True,
"lstm_use_prev_action_reward": args.
use_prev_action_reward,
},
}),
}
})
tune.run(
args.run,
stop={"episode_reward_mean": args.stop},
config=dict(
configs[args.run], **{
"env": "cartpole_stateless",
"model": {
"use_lstm": True,
"lstm_use_prev_action_reward": args.use_prev_action_reward,
},
}),
)
+15 -16
View File
@@ -18,7 +18,8 @@ from ray.rllib.models import FullyConnectedNetwork, Model, ModelCatalog
from gym.spaces import Discrete, Box
import ray
from ray.tune import run_experiments, grid_search
from ray import tune
from ray.tune import grid_search
class SimpleCorridor(gym.Env):
@@ -66,22 +67,20 @@ if __name__ == "__main__":
# register_env("corridor", lambda config: SimpleCorridor(config))
ray.init()
ModelCatalog.register_custom_model("my_model", CustomModel)
run_experiments({
"demo": {
"run": "PPO",
tune.run(
"PPO",
stop={
"timesteps_total": 10000,
},
config={
"env": SimpleCorridor, # or "corridor" if registered above
"stop": {
"timesteps_total": 10000,
"model": {
"custom_model": "my_model",
},
"config": {
"model": {
"custom_model": "my_model",
},
"lr": grid_search([1e-2, 1e-4, 1e-6]), # try different lrs
"num_workers": 1, # parallelism
"env_config": {
"corridor_length": 5,
},
"lr": grid_search([1e-2, 1e-4, 1e-6]), # try different lrs
"num_workers": 1, # parallelism
"env_config": {
"corridor_length": 5,
},
},
})
)
+13 -15
View File
@@ -18,7 +18,7 @@ import os
import tensorflow as tf
import ray
from ray.tune import run_experiments
from ray import tune
from ray.rllib.models import (Categorical, FullyConnectedNetwork, Model,
ModelCatalog)
from ray.rllib.models.model import restore_original_dimensions
@@ -82,21 +82,19 @@ if __name__ == "__main__":
args = parser.parse_args()
ModelCatalog.register_custom_model("custom_loss", CustomLossModel)
run_experiments({
"custom_loss": {
"run": "PG",
tune.run(
"PG",
stop={
"training_iteration": args.iters,
},
config={
"env": "CartPole-v0",
"stop": {
"training_iteration": args.iters,
},
"config": {
"num_workers": 0,
"model": {
"custom_model": "custom_loss",
"custom_options": {
"input_files": args.input_files,
},
"num_workers": 0,
"model": {
"custom_model": "custom_loss",
"custom_options": {
"input_files": args.input_files,
},
},
},
})
)
@@ -50,24 +50,22 @@ if __name__ == "__main__":
args = parser.parse_args()
ray.init()
trials = tune.run_experiments({
"test": {
trials = tune.run(
"PG",
stop={
"training_iteration": args.num_iters,
},
config={
"env": "CartPole-v0",
"run": "PG",
"stop": {
"training_iteration": args.num_iters,
"callbacks": {
"on_episode_start": tune.function(on_episode_start),
"on_episode_step": tune.function(on_episode_step),
"on_episode_end": tune.function(on_episode_end),
"on_sample_end": tune.function(on_sample_end),
"on_train_result": tune.function(on_train_result),
},
"config": {
"callbacks": {
"on_episode_start": tune.function(on_episode_start),
"on_episode_step": tune.function(on_episode_step),
"on_episode_end": tune.function(on_episode_end),
"on_sample_end": tune.function(on_sample_end),
"on_train_result": tune.function(on_train_result),
},
},
}
})
},
)
# verify custom metrics for integration tests
custom_metrics = trials[0].last_result["custom_metrics"]
+10 -12
View File
@@ -11,8 +11,8 @@ from __future__ import division
from __future__ import print_function
import ray
from ray import tune
from ray.rllib.agents.ppo import PPOAgent
from ray.tune import run_experiments
def my_train_fn(config, reporter):
@@ -40,15 +40,13 @@ def my_train_fn(config, reporter):
if __name__ == "__main__":
ray.init()
run_experiments({
"demo": {
"run": my_train_fn,
"resources_per_trial": {
"cpu": 1,
},
"config": {
"lr": 0.01,
"num_workers": 0,
},
tune.run(
my_train_fn,
resources_per_trial={
"cpu": 1,
},
})
config={
"lr": 0.01,
"num_workers": 0,
},
)
@@ -33,7 +33,8 @@ from gym.spaces import Box, Discrete, Tuple
import logging
import ray
from ray.tune import run_experiments, function
from ray import tune
from ray.tune import function
from ray.rllib.env import MultiAgentEnv
parser = argparse.ArgumentParser()
@@ -184,15 +185,13 @@ if __name__ == "__main__":
args = parser.parse_args()
ray.init()
if args.flat:
run_experiments({
"maze_single": {
"run": "PPO",
tune.run(
"PPO",
config={
"env": WindyMazeEnv,
"config": {
"num_workers": 0,
},
"num_workers": 0,
},
})
)
else:
maze = WindyMazeEnv(None)
@@ -202,30 +201,28 @@ if __name__ == "__main__":
else:
return "high_level_policy"
run_experiments({
"maze_hier": {
"run": "PPO",
tune.run(
"PPO",
config={
"env": HierarchicalWindyMazeEnv,
"config": {
"num_workers": 0,
"log_level": "INFO",
"entropy_coeff": 0.01,
"multiagent": {
"policy_graphs": {
"high_level_policy": (None, maze.observation_space,
Discrete(4), {
"gamma": 0.9
}),
"low_level_policy": (None,
Tuple([
maze.observation_space,
Discrete(4)
]), maze.action_space, {
"gamma": 0.0
}),
},
"policy_mapping_fn": function(policy_mapping_fn),
"num_workers": 0,
"log_level": "INFO",
"entropy_coeff": 0.01,
"multiagent": {
"policy_graphs": {
"high_level_policy": (None, maze.observation_space,
Discrete(4), {
"gamma": 0.9
}),
"low_level_policy": (None,
Tuple([
maze.observation_space,
Discrete(4)
]), maze.action_space, {
"gamma": 0.0
}),
},
"policy_mapping_fn": function(policy_mapping_fn),
},
},
})
)
@@ -23,7 +23,6 @@ import ray
from ray import tune
from ray.rllib.models import Model, ModelCatalog
from ray.rllib.tests.test_multi_agent_env import MultiCartpole
from ray.tune import run_experiments
from ray.tune.registry import register_env
parser = argparse.ArgumentParser()
@@ -98,21 +97,17 @@ if __name__ == "__main__":
}
policy_ids = list(policy_graphs.keys())
run_experiments({
"test": {
"run": "PPO",
tune.run(
"PPO",
stop={"training_iteration": args.num_iters},
config={
"env": "multi_cartpole",
"stop": {
"training_iteration": args.num_iters
"log_level": "DEBUG",
"num_sgd_iter": 10,
"multiagent": {
"policy_graphs": policy_graphs,
"policy_mapping_fn": tune.function(
lambda agent_id: random.choice(policy_ids)),
},
"config": {
"log_level": "DEBUG",
"num_sgd_iter": 10,
"multiagent": {
"policy_graphs": policy_graphs,
"policy_mapping_fn": tune.function(
lambda agent_id: random.choice(policy_ids)),
},
},
}
})
},
)
@@ -27,9 +27,9 @@ import tensorflow as tf
import tensorflow.contrib.slim as slim
import ray
from ray import tune
from ray.rllib.models import Model, ModelCatalog
from ray.rllib.models.misc import normc_initializer
from ray.tune import run_experiments
from ray.tune.registry import register_env
parser = argparse.ArgumentParser()
@@ -178,18 +178,16 @@ if __name__ == "__main__":
}
else:
cfg = {} # PG, IMPALA, A2C, etc.
run_experiments({
"parametric_cartpole": {
"run": args.run,
"env": "pa_cartpole",
"stop": {
"episode_reward_mean": args.stop,
},
"config": dict({
"model": {
"custom_model": "pa_model",
},
"num_workers": 0,
}, **cfg),
tune.run(
args.run,
stop={
"episode_reward_mean": args.stop,
},
})
config=dict({
"env": "pa_cartpole",
"model": {
"custom_model": "pa_model",
},
"num_workers": 0,
}, **cfg),
)
+10 -10
View File
@@ -8,7 +8,8 @@ import argparse
from gym.spaces import Tuple, Discrete
import ray
from ray.tune import register_env, run_experiments, grid_search
from ray import tune
from ray.tune import register_env, grid_search
from ray.rllib.env.multi_agent_env import MultiAgentEnv
parser = argparse.ArgumentParser()
@@ -108,13 +109,12 @@ if __name__ == "__main__":
group = False
ray.init()
run_experiments({
"two_step": {
"run": args.run,
"env": "grouped_twostep" if group else TwoStepGame,
"stop": {
"timesteps_total": args.stop,
},
"config": config,
tune.run(
args.run,
stop={
"timesteps_total": args.stop,
},
})
config=dict(config, **{
"env": "grouped_twostep" if group else TwoStepGame,
}),
)
+1 -1
View File
@@ -10,7 +10,7 @@ from ray.rllib.evaluation.sample_batch import SampleBatch, MultiAgentBatch
_logged = set()
_disabled = False
_periodic_log = True
_periodic_log = False
_last_logged = 0.0
_printer = pprint.PrettyPrinter(indent=2, width=60)
+1 -1
View File
@@ -95,7 +95,7 @@ def run(run_or_experiment,
Args:
run_or_experiment (function|class|str|Experiment): If
function|class|str, this is the algorithm or model to train.
function|class|str, this is the algorithm or model to train.
This may refer to the name of a built-on algorithm
(e.g. RLLib's DQN or PPO), a user-defined trainable
function or class, or the string identifier of a