mirror of
https://github.com/wassname/ray.git
synced 2026-06-29 21:26:08 +08:00
[rllib] Switch to tune.run() instead of run_experiments() (#4515)
This commit is contained in:
@@ -9,9 +9,9 @@ import tensorflow as tf
|
||||
import tensorflow.contrib.slim as slim
|
||||
|
||||
import ray
|
||||
from ray import tune
|
||||
from ray.rllib.models import Model, ModelCatalog
|
||||
from ray.rllib.models.misc import normc_initializer
|
||||
from ray.tune import run_experiments
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--num-iters", type=int, default=200)
|
||||
@@ -47,18 +47,14 @@ if __name__ == "__main__":
|
||||
ray.init()
|
||||
|
||||
ModelCatalog.register_custom_model("bn_model", BatchNormModel)
|
||||
run_experiments({
|
||||
"batch_norm_demo": {
|
||||
"run": args.run,
|
||||
tune.run(
|
||||
args.run,
|
||||
stop={"training_iteration": args.num_iters},
|
||||
config={
|
||||
"env": "Pendulum-v0" if args.run == "DDPG" else "CartPole-v0",
|
||||
"stop": {
|
||||
"training_iteration": args.num_iters
|
||||
},
|
||||
"config": {
|
||||
"model": {
|
||||
"custom_model": "bn_model",
|
||||
},
|
||||
"num_workers": 0,
|
||||
"model": {
|
||||
"custom_model": "bn_model",
|
||||
},
|
||||
"num_workers": 0,
|
||||
},
|
||||
})
|
||||
)
|
||||
|
||||
@@ -179,20 +179,15 @@ if __name__ == "__main__":
|
||||
},
|
||||
}
|
||||
|
||||
tune.run_experiments({
|
||||
"test": {
|
||||
"env": "cartpole_stateless",
|
||||
"run": args.run,
|
||||
"stop": {
|
||||
"episode_reward_mean": args.stop
|
||||
},
|
||||
"config": dict(
|
||||
configs[args.run], **{
|
||||
"model": {
|
||||
"use_lstm": True,
|
||||
"lstm_use_prev_action_reward": args.
|
||||
use_prev_action_reward,
|
||||
},
|
||||
}),
|
||||
}
|
||||
})
|
||||
tune.run(
|
||||
args.run,
|
||||
stop={"episode_reward_mean": args.stop},
|
||||
config=dict(
|
||||
configs[args.run], **{
|
||||
"env": "cartpole_stateless",
|
||||
"model": {
|
||||
"use_lstm": True,
|
||||
"lstm_use_prev_action_reward": args.use_prev_action_reward,
|
||||
},
|
||||
}),
|
||||
)
|
||||
|
||||
@@ -18,7 +18,8 @@ from ray.rllib.models import FullyConnectedNetwork, Model, ModelCatalog
|
||||
from gym.spaces import Discrete, Box
|
||||
|
||||
import ray
|
||||
from ray.tune import run_experiments, grid_search
|
||||
from ray import tune
|
||||
from ray.tune import grid_search
|
||||
|
||||
|
||||
class SimpleCorridor(gym.Env):
|
||||
@@ -66,22 +67,20 @@ if __name__ == "__main__":
|
||||
# register_env("corridor", lambda config: SimpleCorridor(config))
|
||||
ray.init()
|
||||
ModelCatalog.register_custom_model("my_model", CustomModel)
|
||||
run_experiments({
|
||||
"demo": {
|
||||
"run": "PPO",
|
||||
tune.run(
|
||||
"PPO",
|
||||
stop={
|
||||
"timesteps_total": 10000,
|
||||
},
|
||||
config={
|
||||
"env": SimpleCorridor, # or "corridor" if registered above
|
||||
"stop": {
|
||||
"timesteps_total": 10000,
|
||||
"model": {
|
||||
"custom_model": "my_model",
|
||||
},
|
||||
"config": {
|
||||
"model": {
|
||||
"custom_model": "my_model",
|
||||
},
|
||||
"lr": grid_search([1e-2, 1e-4, 1e-6]), # try different lrs
|
||||
"num_workers": 1, # parallelism
|
||||
"env_config": {
|
||||
"corridor_length": 5,
|
||||
},
|
||||
"lr": grid_search([1e-2, 1e-4, 1e-6]), # try different lrs
|
||||
"num_workers": 1, # parallelism
|
||||
"env_config": {
|
||||
"corridor_length": 5,
|
||||
},
|
||||
},
|
||||
})
|
||||
)
|
||||
|
||||
@@ -18,7 +18,7 @@ import os
|
||||
import tensorflow as tf
|
||||
|
||||
import ray
|
||||
from ray.tune import run_experiments
|
||||
from ray import tune
|
||||
from ray.rllib.models import (Categorical, FullyConnectedNetwork, Model,
|
||||
ModelCatalog)
|
||||
from ray.rllib.models.model import restore_original_dimensions
|
||||
@@ -82,21 +82,19 @@ if __name__ == "__main__":
|
||||
args = parser.parse_args()
|
||||
|
||||
ModelCatalog.register_custom_model("custom_loss", CustomLossModel)
|
||||
run_experiments({
|
||||
"custom_loss": {
|
||||
"run": "PG",
|
||||
tune.run(
|
||||
"PG",
|
||||
stop={
|
||||
"training_iteration": args.iters,
|
||||
},
|
||||
config={
|
||||
"env": "CartPole-v0",
|
||||
"stop": {
|
||||
"training_iteration": args.iters,
|
||||
},
|
||||
"config": {
|
||||
"num_workers": 0,
|
||||
"model": {
|
||||
"custom_model": "custom_loss",
|
||||
"custom_options": {
|
||||
"input_files": args.input_files,
|
||||
},
|
||||
"num_workers": 0,
|
||||
"model": {
|
||||
"custom_model": "custom_loss",
|
||||
"custom_options": {
|
||||
"input_files": args.input_files,
|
||||
},
|
||||
},
|
||||
},
|
||||
})
|
||||
)
|
||||
|
||||
@@ -50,24 +50,22 @@ if __name__ == "__main__":
|
||||
args = parser.parse_args()
|
||||
|
||||
ray.init()
|
||||
trials = tune.run_experiments({
|
||||
"test": {
|
||||
trials = tune.run(
|
||||
"PG",
|
||||
stop={
|
||||
"training_iteration": args.num_iters,
|
||||
},
|
||||
config={
|
||||
"env": "CartPole-v0",
|
||||
"run": "PG",
|
||||
"stop": {
|
||||
"training_iteration": args.num_iters,
|
||||
"callbacks": {
|
||||
"on_episode_start": tune.function(on_episode_start),
|
||||
"on_episode_step": tune.function(on_episode_step),
|
||||
"on_episode_end": tune.function(on_episode_end),
|
||||
"on_sample_end": tune.function(on_sample_end),
|
||||
"on_train_result": tune.function(on_train_result),
|
||||
},
|
||||
"config": {
|
||||
"callbacks": {
|
||||
"on_episode_start": tune.function(on_episode_start),
|
||||
"on_episode_step": tune.function(on_episode_step),
|
||||
"on_episode_end": tune.function(on_episode_end),
|
||||
"on_sample_end": tune.function(on_sample_end),
|
||||
"on_train_result": tune.function(on_train_result),
|
||||
},
|
||||
},
|
||||
}
|
||||
})
|
||||
},
|
||||
)
|
||||
|
||||
# verify custom metrics for integration tests
|
||||
custom_metrics = trials[0].last_result["custom_metrics"]
|
||||
|
||||
@@ -11,8 +11,8 @@ from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import ray
|
||||
from ray import tune
|
||||
from ray.rllib.agents.ppo import PPOAgent
|
||||
from ray.tune import run_experiments
|
||||
|
||||
|
||||
def my_train_fn(config, reporter):
|
||||
@@ -40,15 +40,13 @@ def my_train_fn(config, reporter):
|
||||
|
||||
if __name__ == "__main__":
|
||||
ray.init()
|
||||
run_experiments({
|
||||
"demo": {
|
||||
"run": my_train_fn,
|
||||
"resources_per_trial": {
|
||||
"cpu": 1,
|
||||
},
|
||||
"config": {
|
||||
"lr": 0.01,
|
||||
"num_workers": 0,
|
||||
},
|
||||
tune.run(
|
||||
my_train_fn,
|
||||
resources_per_trial={
|
||||
"cpu": 1,
|
||||
},
|
||||
})
|
||||
config={
|
||||
"lr": 0.01,
|
||||
"num_workers": 0,
|
||||
},
|
||||
)
|
||||
|
||||
@@ -33,7 +33,8 @@ from gym.spaces import Box, Discrete, Tuple
|
||||
import logging
|
||||
|
||||
import ray
|
||||
from ray.tune import run_experiments, function
|
||||
from ray import tune
|
||||
from ray.tune import function
|
||||
from ray.rllib.env import MultiAgentEnv
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
@@ -184,15 +185,13 @@ if __name__ == "__main__":
|
||||
args = parser.parse_args()
|
||||
ray.init()
|
||||
if args.flat:
|
||||
run_experiments({
|
||||
"maze_single": {
|
||||
"run": "PPO",
|
||||
tune.run(
|
||||
"PPO",
|
||||
config={
|
||||
"env": WindyMazeEnv,
|
||||
"config": {
|
||||
"num_workers": 0,
|
||||
},
|
||||
"num_workers": 0,
|
||||
},
|
||||
})
|
||||
)
|
||||
else:
|
||||
maze = WindyMazeEnv(None)
|
||||
|
||||
@@ -202,30 +201,28 @@ if __name__ == "__main__":
|
||||
else:
|
||||
return "high_level_policy"
|
||||
|
||||
run_experiments({
|
||||
"maze_hier": {
|
||||
"run": "PPO",
|
||||
tune.run(
|
||||
"PPO",
|
||||
config={
|
||||
"env": HierarchicalWindyMazeEnv,
|
||||
"config": {
|
||||
"num_workers": 0,
|
||||
"log_level": "INFO",
|
||||
"entropy_coeff": 0.01,
|
||||
"multiagent": {
|
||||
"policy_graphs": {
|
||||
"high_level_policy": (None, maze.observation_space,
|
||||
Discrete(4), {
|
||||
"gamma": 0.9
|
||||
}),
|
||||
"low_level_policy": (None,
|
||||
Tuple([
|
||||
maze.observation_space,
|
||||
Discrete(4)
|
||||
]), maze.action_space, {
|
||||
"gamma": 0.0
|
||||
}),
|
||||
},
|
||||
"policy_mapping_fn": function(policy_mapping_fn),
|
||||
"num_workers": 0,
|
||||
"log_level": "INFO",
|
||||
"entropy_coeff": 0.01,
|
||||
"multiagent": {
|
||||
"policy_graphs": {
|
||||
"high_level_policy": (None, maze.observation_space,
|
||||
Discrete(4), {
|
||||
"gamma": 0.9
|
||||
}),
|
||||
"low_level_policy": (None,
|
||||
Tuple([
|
||||
maze.observation_space,
|
||||
Discrete(4)
|
||||
]), maze.action_space, {
|
||||
"gamma": 0.0
|
||||
}),
|
||||
},
|
||||
"policy_mapping_fn": function(policy_mapping_fn),
|
||||
},
|
||||
},
|
||||
})
|
||||
)
|
||||
|
||||
@@ -23,7 +23,6 @@ import ray
|
||||
from ray import tune
|
||||
from ray.rllib.models import Model, ModelCatalog
|
||||
from ray.rllib.tests.test_multi_agent_env import MultiCartpole
|
||||
from ray.tune import run_experiments
|
||||
from ray.tune.registry import register_env
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
@@ -98,21 +97,17 @@ if __name__ == "__main__":
|
||||
}
|
||||
policy_ids = list(policy_graphs.keys())
|
||||
|
||||
run_experiments({
|
||||
"test": {
|
||||
"run": "PPO",
|
||||
tune.run(
|
||||
"PPO",
|
||||
stop={"training_iteration": args.num_iters},
|
||||
config={
|
||||
"env": "multi_cartpole",
|
||||
"stop": {
|
||||
"training_iteration": args.num_iters
|
||||
"log_level": "DEBUG",
|
||||
"num_sgd_iter": 10,
|
||||
"multiagent": {
|
||||
"policy_graphs": policy_graphs,
|
||||
"policy_mapping_fn": tune.function(
|
||||
lambda agent_id: random.choice(policy_ids)),
|
||||
},
|
||||
"config": {
|
||||
"log_level": "DEBUG",
|
||||
"num_sgd_iter": 10,
|
||||
"multiagent": {
|
||||
"policy_graphs": policy_graphs,
|
||||
"policy_mapping_fn": tune.function(
|
||||
lambda agent_id: random.choice(policy_ids)),
|
||||
},
|
||||
},
|
||||
}
|
||||
})
|
||||
},
|
||||
)
|
||||
|
||||
@@ -27,9 +27,9 @@ import tensorflow as tf
|
||||
import tensorflow.contrib.slim as slim
|
||||
|
||||
import ray
|
||||
from ray import tune
|
||||
from ray.rllib.models import Model, ModelCatalog
|
||||
from ray.rllib.models.misc import normc_initializer
|
||||
from ray.tune import run_experiments
|
||||
from ray.tune.registry import register_env
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
@@ -178,18 +178,16 @@ if __name__ == "__main__":
|
||||
}
|
||||
else:
|
||||
cfg = {} # PG, IMPALA, A2C, etc.
|
||||
run_experiments({
|
||||
"parametric_cartpole": {
|
||||
"run": args.run,
|
||||
"env": "pa_cartpole",
|
||||
"stop": {
|
||||
"episode_reward_mean": args.stop,
|
||||
},
|
||||
"config": dict({
|
||||
"model": {
|
||||
"custom_model": "pa_model",
|
||||
},
|
||||
"num_workers": 0,
|
||||
}, **cfg),
|
||||
tune.run(
|
||||
args.run,
|
||||
stop={
|
||||
"episode_reward_mean": args.stop,
|
||||
},
|
||||
})
|
||||
config=dict({
|
||||
"env": "pa_cartpole",
|
||||
"model": {
|
||||
"custom_model": "pa_model",
|
||||
},
|
||||
"num_workers": 0,
|
||||
}, **cfg),
|
||||
)
|
||||
|
||||
@@ -8,7 +8,8 @@ import argparse
|
||||
from gym.spaces import Tuple, Discrete
|
||||
|
||||
import ray
|
||||
from ray.tune import register_env, run_experiments, grid_search
|
||||
from ray import tune
|
||||
from ray.tune import register_env, grid_search
|
||||
from ray.rllib.env.multi_agent_env import MultiAgentEnv
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
@@ -108,13 +109,12 @@ if __name__ == "__main__":
|
||||
group = False
|
||||
|
||||
ray.init()
|
||||
run_experiments({
|
||||
"two_step": {
|
||||
"run": args.run,
|
||||
"env": "grouped_twostep" if group else TwoStepGame,
|
||||
"stop": {
|
||||
"timesteps_total": args.stop,
|
||||
},
|
||||
"config": config,
|
||||
tune.run(
|
||||
args.run,
|
||||
stop={
|
||||
"timesteps_total": args.stop,
|
||||
},
|
||||
})
|
||||
config=dict(config, **{
|
||||
"env": "grouped_twostep" if group else TwoStepGame,
|
||||
}),
|
||||
)
|
||||
|
||||
@@ -10,7 +10,7 @@ from ray.rllib.evaluation.sample_batch import SampleBatch, MultiAgentBatch
|
||||
|
||||
_logged = set()
|
||||
_disabled = False
|
||||
_periodic_log = True
|
||||
_periodic_log = False
|
||||
_last_logged = 0.0
|
||||
_printer = pprint.PrettyPrinter(indent=2, width=60)
|
||||
|
||||
|
||||
@@ -95,7 +95,7 @@ def run(run_or_experiment,
|
||||
|
||||
Args:
|
||||
run_or_experiment (function|class|str|Experiment): If
|
||||
function|class|str, this is the algorithm or model to train.
|
||||
function|class|str, this is the algorithm or model to train.
|
||||
This may refer to the name of a built-on algorithm
|
||||
(e.g. RLLib's DQN or PPO), a user-defined trainable
|
||||
function or class, or the string identifier of a
|
||||
|
||||
Reference in New Issue
Block a user