[rllib] Port remainder of algorithms to build_trainer() pattern (#4920)

This commit is contained in:
Eric Liang
2019-06-07 16:45:36 -07:00
committed by GitHub
parent 9e328fbe6f
commit 77689d1116
16 changed files with 489 additions and 464 deletions
+5 -22
View File
@@ -2,15 +2,14 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from ray.rllib.agents.dqn.apex import APEX_TRAINER_PROPERTIES
from ray.rllib.agents.ddpg.ddpg import DDPGTrainer, \
DEFAULT_CONFIG as DDPG_CONFIG
from ray.rllib.utils.annotations import override
from ray.rllib.utils import merge_dicts
APEX_DDPG_DEFAULT_CONFIG = merge_dicts(
DDPG_CONFIG, # see also the options in ddpg.py, which are also supported
{
"optimizer_class": "AsyncReplayOptimizer",
"optimizer": merge_dicts(
DDPG_CONFIG["optimizer"], {
"max_weight_sync_delay": 400,
@@ -32,23 +31,7 @@ APEX_DDPG_DEFAULT_CONFIG = merge_dicts(
},
)
class ApexDDPGTrainer(DDPGTrainer):
"""DDPG variant that uses the Ape-X distributed policy optimizer.
By default, this is configured for a large single node (32 cores). For
running in a large cluster, increase the `num_workers` config var.
"""
_name = "APEX_DDPG"
_default_config = APEX_DDPG_DEFAULT_CONFIG
@override(DDPGTrainer)
def update_target_if_needed(self):
# Ape-X updates based on num steps trained, not sampled
if self.optimizer.num_steps_trained - self.last_target_update_ts > \
self.config["target_network_update_freq"]:
self.workers.local_worker().foreach_trainable_policy(
lambda p, _: p.update_target())
self.last_target_update_ts = self.optimizer.num_steps_trained
self.num_target_updates += 1
ApexDDPGTrainer = DDPGTrainer.with_updates(
name="APEX_DDPG",
default_config=APEX_DDPG_DEFAULT_CONFIG,
**APEX_TRAINER_PROPERTIES)
+64 -47
View File
@@ -3,9 +3,9 @@ from __future__ import division
from __future__ import print_function
from ray.rllib.agents.trainer import with_common_config
from ray.rllib.agents.dqn.dqn import DQNTrainer
from ray.rllib.agents.dqn.dqn import GenericOffPolicyTrainer, \
update_worker_explorations
from ray.rllib.agents.ddpg.ddpg_policy import DDPGTFPolicy
from ray.rllib.utils.annotations import override
from ray.rllib.utils.schedules import ConstantSchedule, LinearSchedule
# yapf: disable
@@ -97,6 +97,11 @@ DEFAULT_CONFIG = with_common_config({
# optimization on initial policy parameters. Note that this will be
# disabled when the action noise scale is set to 0 (e.g during evaluation).
"pure_exploration_steps": 1000,
# Extra configuration that disables exploration.
"evaluation_config": {
"exploration_fraction": 0,
"exploration_final_eps": 0,
},
# === Replay buffer ===
# Size of the replay buffer. Note that if async_updates is set, then
@@ -108,6 +113,11 @@ DEFAULT_CONFIG = with_common_config({
"prioritized_replay_alpha": 0.6,
# Beta parameter for sampling from prioritized replay buffer.
"prioritized_replay_beta": 0.4,
# Fraction of entire training period over which the beta parameter is
# annealed
"beta_annealing_fraction": 0.2,
# Final value of beta
"final_prioritized_replay_beta": 0.4,
# Epsilon to add to the TD errors when updating priorities.
"prioritized_replay_eps": 1e-6,
# Whether to LZ4 compress observations
@@ -146,8 +156,6 @@ DEFAULT_CONFIG = with_common_config({
# to increase if your environment is particularly slow to sample, or if
# you're using the Async or Ape-X optimizers.
"num_workers": 0,
# Optimizer class to use.
"optimizer_class": "SyncReplayOptimizer",
# Whether to use a distribution of epsilons across workers for exploration.
"per_worker_exploration": False,
# Whether to compute priorities on workers.
@@ -159,47 +167,56 @@ DEFAULT_CONFIG = with_common_config({
# yapf: enable
class DDPGTrainer(DQNTrainer):
"""DDPG implementation in TensorFlow."""
_name = "DDPG"
_default_config = DEFAULT_CONFIG
_policy = DDPGTFPolicy
@override(DQNTrainer)
def _train(self):
pure_expl_steps = self.config["pure_exploration_steps"]
if pure_expl_steps:
# tell workers whether they should do pure exploration
only_explore = self.global_timestep < pure_expl_steps
self.workers.local_worker().foreach_trainable_policy(
lambda p, _: p.set_pure_exploration_phase(only_explore))
for e in self.workers.remote_workers():
e.foreach_trainable_policy.remote(
lambda p, _: p.set_pure_exploration_phase(only_explore))
return super(DDPGTrainer, self)._train()
@override(DQNTrainer)
def _make_exploration_schedule(self, worker_index):
# Override DQN's schedule to take into account
# `exploration_ou_noise_scale`
if self.config["per_worker_exploration"]:
assert self.config["num_workers"] > 1, \
"This requires multiple workers"
if worker_index >= 0:
# FIXME: what do magic constants mean? (0.4, 7)
max_index = float(self.config["num_workers"] - 1)
exponent = 1 + worker_index / max_index * 7
return ConstantSchedule(0.4**exponent)
else:
# local ev should have zero exploration so that eval rollouts
# run properly
return ConstantSchedule(0.0)
elif self.config["exploration_should_anneal"]:
return LinearSchedule(
schedule_timesteps=int(self.config["exploration_fraction"] *
self.config["schedule_max_timesteps"]),
initial_p=1.0,
final_p=self.config["exploration_final_scale"])
def make_exploration_schedule(config, worker_index):
# Modification of DQN's schedule to take into account
# `exploration_ou_noise_scale`
if config["per_worker_exploration"]:
assert config["num_workers"] > 1, "This requires multiple workers"
if worker_index >= 0:
# FIXME: what do magic constants mean? (0.4, 7)
max_index = float(config["num_workers"] - 1)
exponent = 1 + worker_index / max_index * 7
return ConstantSchedule(0.4**exponent)
else:
# *always* add exploration noise
return ConstantSchedule(1.0)
# local ev should have zero exploration so that eval rollouts
# run properly
return ConstantSchedule(0.0)
elif config["exploration_should_anneal"]:
return LinearSchedule(
schedule_timesteps=int(config["exploration_fraction"] *
config["schedule_max_timesteps"]),
initial_p=1.0,
final_p=config["exploration_final_scale"])
else:
# *always* add exploration noise
return ConstantSchedule(1.0)
def setup_ddpg_exploration(trainer):
trainer.exploration0 = make_exploration_schedule(trainer.config, -1)
trainer.explorations = [
make_exploration_schedule(trainer.config, i)
for i in range(trainer.config["num_workers"])
]
def add_pure_exploration_phase(trainer):
global_timestep = trainer.optimizer.num_steps_sampled
pure_expl_steps = trainer.config["pure_exploration_steps"]
if pure_expl_steps:
# tell workers whether they should do pure exploration
only_explore = global_timestep < pure_expl_steps
trainer.workers.local_worker().foreach_trainable_policy(
lambda p, _: p.set_pure_exploration_phase(only_explore))
for e in trainer.workers.remote_workers():
e.foreach_trainable_policy.remote(
lambda p, _: p.set_pure_exploration_phase(only_explore))
update_worker_explorations(trainer)
DDPGTrainer = GenericOffPolicyTrainer.with_updates(
name="DDPG",
default_config=DEFAULT_CONFIG,
default_policy=DDPGTFPolicy,
before_init=setup_ddpg_exploration,
before_train_step=add_pure_exploration_phase)
+8 -8
View File
@@ -1,3 +1,9 @@
"""A more stable successor to TD3.
By default, this uses a near-identical configuration to that reported in the
TD3 paper.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
@@ -36,7 +42,6 @@ TD3_DEFAULT_CONFIG = merge_dicts(
"train_batch_size": 100,
"use_huber": False,
"target_network_update_freq": 0,
"optimizer_class": "SyncReplayOptimizer",
"num_workers": 0,
"num_gpus_per_worker": 0,
"per_worker_exploration": False,
@@ -48,10 +53,5 @@ TD3_DEFAULT_CONFIG = merge_dicts(
},
)
class TD3Trainer(DDPGTrainer):
"""A more stable successor to TD3. By default, this uses a near-identical
configuration to that reported in the TD3 paper."""
_name = "TD3"
_default_config = TD3_DEFAULT_CONFIG
TD3Trainer = DDPGTrainer.with_updates(
name="TD3", default_config=TD3_DEFAULT_CONFIG)
+45 -18
View File
@@ -3,15 +3,14 @@ from __future__ import division
from __future__ import print_function
from ray.rllib.agents.dqn.dqn import DQNTrainer, DEFAULT_CONFIG as DQN_CONFIG
from ray.rllib.optimizers import AsyncReplayOptimizer
from ray.rllib.utils import merge_dicts
from ray.rllib.utils.annotations import override
# yapf: disable
# __sphinx_doc_begin__
APEX_DEFAULT_CONFIG = merge_dicts(
DQN_CONFIG, # see also the options in dqn.py, which are also supported
{
"optimizer_class": "AsyncReplayOptimizer",
"optimizer": merge_dicts(
DQN_CONFIG["optimizer"], {
"max_weight_sync_delay": 400,
@@ -36,22 +35,50 @@ APEX_DEFAULT_CONFIG = merge_dicts(
# yapf: enable
class ApexTrainer(DQNTrainer):
"""DQN variant that uses the Ape-X distributed policy optimizer.
def defer_make_workers(trainer, env_creator, policy, config):
# Hack to workaround https://github.com/ray-project/ray/issues/2541
# The workers will be creatd later, after the optimizer is created
return trainer._make_workers(env_creator, policy, config, 0)
By default, this is configured for a large single node (32 cores). For
running in a large cluster, increase the `num_workers` config var.
"""
_name = "APEX"
_default_config = APEX_DEFAULT_CONFIG
def make_async_optimizer(workers, config):
assert len(workers.remote_workers()) == 0
extra_config = config["optimizer"].copy()
for key in [
"prioritized_replay", "prioritized_replay_alpha",
"prioritized_replay_beta", "prioritized_replay_eps"
]:
if key in config:
extra_config[key] = config[key]
opt = AsyncReplayOptimizer(
workers,
learning_starts=config["learning_starts"],
buffer_size=config["buffer_size"],
train_batch_size=config["train_batch_size"],
sample_batch_size=config["sample_batch_size"],
**extra_config)
workers.add_workers(config["num_workers"])
opt._set_workers(workers.remote_workers())
return opt
@override(DQNTrainer)
def update_target_if_needed(self):
# Ape-X updates based on num steps trained, not sampled
if self.optimizer.num_steps_trained - self.last_target_update_ts > \
self.config["target_network_update_freq"]:
self.workers.local_worker().foreach_trainable_policy(
lambda p, _: p.update_target())
self.last_target_update_ts = self.optimizer.num_steps_trained
self.num_target_updates += 1
def update_target_based_on_num_steps_trained(trainer, fetches):
# Ape-X updates based on num steps trained, not sampled
if (trainer.optimizer.num_steps_trained -
trainer.state["last_target_update_ts"] >
trainer.config["target_network_update_freq"]):
trainer.workers.local_worker().foreach_trainable_policy(
lambda p, _: p.update_target())
trainer.state["last_target_update_ts"] = (
trainer.optimizer.num_steps_trained)
trainer.state["num_target_updates"] += 1
APEX_TRAINER_PROPERTIES = {
"make_workers": defer_make_workers,
"make_policy_optimizer": make_async_optimizer,
"after_optimizer_step": update_target_based_on_num_steps_trained,
}
ApexTrainer = DQNTrainer.with_updates(
name="APEX", default_config=APEX_DEFAULT_CONFIG, **APEX_TRAINER_PROPERTIES)
+164 -197
View File
@@ -3,27 +3,17 @@ from __future__ import division
from __future__ import print_function
import logging
import time
from ray import tune
from ray.rllib import optimizers
from ray.rllib.agents.trainer import Trainer, with_common_config
from ray.rllib.agents.trainer import with_common_config
from ray.rllib.agents.trainer_template import build_trainer
from ray.rllib.agents.dqn.dqn_policy import DQNTFPolicy
from ray.rllib.evaluation.metrics import collect_metrics
from ray.rllib.optimizers import SyncReplayOptimizer
from ray.rllib.policy.sample_batch import DEFAULT_POLICY_ID
from ray.rllib.utils.annotations import override
from ray.rllib.utils.schedules import ConstantSchedule, LinearSchedule
logger = logging.getLogger(__name__)
OPTIMIZER_SHARED_CONFIGS = [
"buffer_size", "prioritized_replay", "prioritized_replay_alpha",
"prioritized_replay_beta", "schedule_max_timesteps",
"beta_annealing_fraction", "final_prioritized_replay_beta",
"prioritized_replay_eps", "sample_batch_size", "train_batch_size",
"learning_starts"
]
# yapf: disable
# __sphinx_doc_begin__
DEFAULT_CONFIG = with_common_config({
@@ -53,7 +43,8 @@ DEFAULT_CONFIG = with_common_config({
# 1.0 to exploration_fraction over this number of timesteps scaled by
# exploration_fraction
"schedule_max_timesteps": 100000,
# Number of env steps to optimize for before returning
# Minimum env steps to optimize for per train call. This value does
# not affect learning, only the length of iterations.
"timesteps_per_iteration": 1000,
# Fraction of entire training period over which the exploration rate is
# annealed
@@ -70,6 +61,11 @@ DEFAULT_CONFIG = with_common_config({
# If True parameter space noise will be used for exploration
# See https://blog.openai.com/better-exploration-with-parameter-noise/
"parameter_noise": False,
# Extra configuration that disables exploration.
"evaluation_config": {
"exploration_fraction": 0,
"exploration_final_eps": 0,
},
# === Replay buffer ===
# Size of the replay buffer. Note that if async_updates is set, then
@@ -115,8 +111,6 @@ DEFAULT_CONFIG = with_common_config({
# to increase if your environment is particularly slow to sample, or if
# you"re using the Async or Ape-X optimizers.
"num_workers": 0,
# Optimizer class to use.
"optimizer_class": "SyncReplayOptimizer",
# Whether to use a distribution of epsilons across workers for exploration.
"per_worker_exploration": False,
# Whether to compute priorities on workers.
@@ -128,202 +122,175 @@ DEFAULT_CONFIG = with_common_config({
# yapf: enable
class DQNTrainer(Trainer):
"""DQN implementation in TensorFlow."""
def make_optimizer(workers, config):
return SyncReplayOptimizer(
workers,
learning_starts=config["learning_starts"],
buffer_size=config["buffer_size"],
prioritized_replay=config["prioritized_replay"],
prioritized_replay_alpha=config["prioritized_replay_alpha"],
prioritized_replay_beta=config["prioritized_replay_beta"],
schedule_max_timesteps=config["schedule_max_timesteps"],
beta_annealing_fraction=config["beta_annealing_fraction"],
final_prioritized_replay_beta=config["final_prioritized_replay_beta"],
prioritized_replay_eps=config["prioritized_replay_eps"],
train_batch_size=config["train_batch_size"],
sample_batch_size=config["sample_batch_size"],
**config["optimizer"])
_name = "DQN"
_default_config = DEFAULT_CONFIG
_policy = DQNTFPolicy
_optimizer_shared_configs = OPTIMIZER_SHARED_CONFIGS
@override(Trainer)
def _init(self, config, env_creator):
self._validate_config()
def check_config_and_setup_param_noise(config):
"""Update the config based on settings.
# Update effective batch size to include n-step
adjusted_batch_size = max(config["sample_batch_size"],
config.get("n_step", 1))
config["sample_batch_size"] = adjusted_batch_size
Rewrites sample_batch_size to take into account n_step truncation, and also
adds the necessary callbacks to support parameter space noise exploration.
"""
self.exploration0 = self._make_exploration_schedule(-1)
self.explorations = [
self._make_exploration_schedule(i)
for i in range(config["num_workers"])
]
# Update effective batch size to include n-step
adjusted_batch_size = max(config["sample_batch_size"],
config.get("n_step", 1))
config["sample_batch_size"] = adjusted_batch_size
for k in self._optimizer_shared_configs:
if self._name != "DQN" and k in [
"schedule_max_timesteps", "beta_annealing_fraction",
"final_prioritized_replay_beta"
]:
# only Rainbow needs annealing prioritized_replay_beta
continue
if k not in config["optimizer"]:
config["optimizer"][k] = config[k]
if config.get("parameter_noise", False):
if config["callbacks"]["on_episode_start"]:
start_callback = config["callbacks"]["on_episode_start"]
else:
start_callback = None
def on_episode_start(info):
# as a callback function to sample and pose parameter space
# noise on the parameters of network
policies = info["policy"]
for pi in policies.values():
pi.add_parameter_noise()
if start_callback:
start_callback(info)
config["callbacks"]["on_episode_start"] = tune.function(
on_episode_start)
if config["callbacks"]["on_episode_end"]:
end_callback = config["callbacks"]["on_episode_end"]
else:
end_callback = None
def on_episode_end(info):
# as a callback function to monitor the distance
# between noisy policy and original policy
policies = info["policy"]
episode = info["episode"]
episode.custom_metrics["policy_distance"] = policies[
DEFAULT_POLICY_ID].pi_distance
if end_callback:
end_callback(info)
config["callbacks"]["on_episode_end"] = tune.function(
on_episode_end)
if config["optimizer_class"] != "AsyncReplayOptimizer":
self.workers = self._make_workers(
env_creator,
self._policy,
config,
num_workers=self.config["num_workers"])
workers_needed = 0
if config.get("parameter_noise", False):
if config["batch_mode"] != "complete_episodes":
raise ValueError("Exploration with parameter space noise requires "
"batch_mode to be complete_episodes.")
if config.get("noisy", False):
raise ValueError(
"Exploration with parameter space noise and noisy network "
"cannot be used at the same time.")
if config["callbacks"]["on_episode_start"]:
start_callback = config["callbacks"]["on_episode_start"]
else:
# Hack to workaround https://github.com/ray-project/ray/issues/2541
self.workers = self._make_workers(
env_creator, self._policy, config, num_workers=0)
workers_needed = self.config["num_workers"]
start_callback = None
self.optimizer = getattr(optimizers, config["optimizer_class"])(
self.workers, **config["optimizer"])
def on_episode_start(info):
# as a callback function to sample and pose parameter space
# noise on the parameters of network
policies = info["policy"]
for pi in policies.values():
pi.add_parameter_noise()
if start_callback:
start_callback(info)
# Create the remote workers *after* the replay actors
if workers_needed > 0:
self.workers.add_workers(workers_needed)
self.optimizer._set_workers(self.workers.remote_workers())
self.last_target_update_ts = 0
self.num_target_updates = 0
@override(Trainer)
def _train(self):
start_timestep = self.global_timestep
# Update worker explorations
exp_vals = [self.exploration0.value(self.global_timestep)]
self.workers.local_worker().foreach_trainable_policy(
lambda p, _: p.set_epsilon(exp_vals[0]))
for i, e in enumerate(self.workers.remote_workers()):
exp_val = self.explorations[i].value(self.global_timestep)
e.foreach_trainable_policy.remote(
lambda p, _: p.set_epsilon(exp_val))
exp_vals.append(exp_val)
# Do optimization steps
start = time.time()
while (self.global_timestep - start_timestep <
self.config["timesteps_per_iteration"]
) or time.time() - start < self.config["min_iter_time_s"]:
self.optimizer.step()
self.update_target_if_needed()
if self.config["per_worker_exploration"]:
# Only collect metrics from the third of workers with lowest eps
result = self.collect_metrics(
selected_workers=self.workers.remote_workers()[
-len(self.workers.remote_workers()) // 3:])
config["callbacks"]["on_episode_start"] = tune.function(
on_episode_start)
if config["callbacks"]["on_episode_end"]:
end_callback = config["callbacks"]["on_episode_end"]
else:
result = self.collect_metrics()
end_callback = None
result.update(
timesteps_this_iter=self.global_timestep - start_timestep,
info=dict({
"min_exploration": min(exp_vals),
"max_exploration": max(exp_vals),
"num_target_updates": self.num_target_updates,
}, **self.optimizer.stats()))
def on_episode_end(info):
# as a callback function to monitor the distance
# between noisy policy and original policy
policies = info["policy"]
episode = info["episode"]
episode.custom_metrics["policy_distance"] = policies[
DEFAULT_POLICY_ID].pi_distance
if end_callback:
end_callback(info)
return result
config["callbacks"]["on_episode_end"] = tune.function(on_episode_end)
def update_target_if_needed(self):
if self.global_timestep - self.last_target_update_ts > \
self.config["target_network_update_freq"]:
self.workers.local_worker().foreach_trainable_policy(
lambda p, _: p.update_target())
self.last_target_update_ts = self.global_timestep
self.num_target_updates += 1
@property
def global_timestep(self):
return self.optimizer.num_steps_sampled
def get_initial_state(config):
return {
"last_target_update_ts": 0,
"num_target_updates": 0,
}
def _evaluate(self):
logger.info("Evaluating current policy for {} episodes".format(
self.config["evaluation_num_episodes"]))
self.evaluation_workers.local_worker().restore(
self.workers.local_worker().save())
self.evaluation_workers.local_worker().foreach_policy(
lambda p, _: p.set_epsilon(0))
for _ in range(self.config["evaluation_num_episodes"]):
self.evaluation_workers.local_worker().sample()
metrics = collect_metrics(self.evaluation_workers.local_worker())
return {"evaluation": metrics}
def _make_exploration_schedule(self, worker_index):
# Use either a different `eps` per worker, or a linear schedule.
if self.config["per_worker_exploration"]:
assert self.config["num_workers"] > 1, \
"This requires multiple workers"
if worker_index >= 0:
exponent = (
1 +
worker_index / float(self.config["num_workers"] - 1) * 7)
return ConstantSchedule(0.4**exponent)
else:
# local ev should have zero exploration so that eval rollouts
# run properly
return ConstantSchedule(0.0)
return LinearSchedule(
schedule_timesteps=int(self.config["exploration_fraction"] *
self.config["schedule_max_timesteps"]),
initial_p=1.0,
final_p=self.config["exploration_final_eps"])
def make_exploration_schedule(config, worker_index):
# Use either a different `eps` per worker, or a linear schedule.
if config["per_worker_exploration"]:
assert config["num_workers"] > 1, \
"This requires multiple workers"
if worker_index >= 0:
exponent = (
1 + worker_index / float(config["num_workers"] - 1) * 7)
return ConstantSchedule(0.4**exponent)
else:
# local ev should have zero exploration so that eval rollouts
# run properly
return ConstantSchedule(0.0)
return LinearSchedule(
schedule_timesteps=int(
config["exploration_fraction"] * config["schedule_max_timesteps"]),
initial_p=1.0,
final_p=config["exploration_final_eps"])
def __getstate__(self):
state = Trainer.__getstate__(self)
state.update({
"num_target_updates": self.num_target_updates,
"last_target_update_ts": self.last_target_update_ts,
})
return state
def __setstate__(self, state):
Trainer.__setstate__(self, state)
self.num_target_updates = state["num_target_updates"]
self.last_target_update_ts = state["last_target_update_ts"]
def setup_exploration(trainer):
trainer.exploration0 = make_exploration_schedule(trainer.config, -1)
trainer.explorations = [
make_exploration_schedule(trainer.config, i)
for i in range(trainer.config["num_workers"])
]
def _validate_config(self):
if self.config.get("parameter_noise", False):
if self.config["batch_mode"] != "complete_episodes":
raise ValueError(
"Exploration with parameter space noise requires "
"batch_mode to be complete_episodes.")
if self.config.get("noisy", False):
raise ValueError(
"Exploration with parameter space noise and noisy network "
"cannot be used at the same time.")
def update_worker_explorations(trainer):
global_timestep = trainer.optimizer.num_steps_sampled
exp_vals = [trainer.exploration0.value(global_timestep)]
trainer.workers.local_worker().foreach_trainable_policy(
lambda p, _: p.set_epsilon(exp_vals[0]))
for i, e in enumerate(trainer.workers.remote_workers()):
exp_val = trainer.explorations[i].value(global_timestep)
e.foreach_trainable_policy.remote(lambda p, _: p.set_epsilon(exp_val))
exp_vals.append(exp_val)
trainer.train_start_timestep = global_timestep
trainer.cur_exp_vals = exp_vals
def add_trainer_metrics(trainer, result):
global_timestep = trainer.optimizer.num_steps_sampled
result.update(
timesteps_this_iter=global_timestep - trainer.train_start_timestep,
info=dict({
"min_exploration": min(trainer.cur_exp_vals),
"max_exploration": max(trainer.cur_exp_vals),
"num_target_updates": trainer.state["num_target_updates"],
}, **trainer.optimizer.stats()))
def update_target_if_needed(trainer, fetches):
global_timestep = trainer.optimizer.num_steps_sampled
if global_timestep - trainer.state["last_target_update_ts"] > \
trainer.config["target_network_update_freq"]:
trainer.workers.local_worker().foreach_trainable_policy(
lambda p, _: p.update_target())
trainer.state["last_target_update_ts"] = global_timestep
trainer.state["num_target_updates"] += 1
def collect_metrics(trainer):
if trainer.config["per_worker_exploration"]:
# Only collect metrics from the third of workers with lowest eps
result = trainer.collect_metrics(
selected_workers=trainer.workers.remote_workers()[
-len(trainer.workers.remote_workers()) // 3:])
else:
result = trainer.collect_metrics()
return result
def disable_exploration(trainer):
trainer.evaluation_workers.local_worker().foreach_policy(
lambda p, _: p.set_epsilon(0))
GenericOffPolicyTrainer = build_trainer(
name="GenericOffPolicyAlgorithm",
default_policy=None,
default_config=DEFAULT_CONFIG,
validate_config=check_config_and_setup_param_noise,
get_initial_state=get_initial_state,
make_policy_optimizer=make_optimizer,
before_init=setup_exploration,
before_train_step=update_worker_explorations,
after_optimizer_step=update_target_if_needed,
after_train_result=add_trainer_metrics,
collect_metrics_fn=collect_metrics,
before_evaluate_fn=disable_exploration)
DQNTrainer = GenericOffPolicyTrainer.with_updates(
name="DQN", default_policy=DQNTFPolicy, default_config=DEFAULT_CONFIG)
+55 -61
View File
@@ -2,33 +2,16 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import time
from ray.rllib.agents.a3c.a3c_tf_policy import A3CTFPolicy
from ray.rllib.agents.impala.vtrace_policy import VTraceTFPolicy
from ray.rllib.agents.trainer import Trainer, with_common_config
from ray.rllib.agents.trainer_template import build_trainer
from ray.rllib.optimizers import AsyncSamplesOptimizer
from ray.rllib.optimizers.aso_tree_aggregator import TreeAggregator
from ray.rllib.utils.annotations import override
from ray.tune.trainable import Trainable
from ray.tune.trial import Resources
OPTIMIZER_SHARED_CONFIGS = [
"lr",
"num_envs_per_worker",
"num_gpus",
"sample_batch_size",
"train_batch_size",
"replay_buffer_num_slots",
"replay_proportion",
"num_data_loader_buffers",
"max_sample_requests_in_flight_per_worker",
"broadcast_interval",
"num_sgd_iter",
"minibatch_buffer_size",
"num_aggregation_workers",
]
# yapf: disable
# __sphinx_doc_begin__
DEFAULT_CONFIG = with_common_config({
@@ -100,37 +83,57 @@ DEFAULT_CONFIG = with_common_config({
# yapf: enable
class ImpalaTrainer(Trainer):
"""IMPALA implementation using DeepMind's V-trace."""
def choose_policy(config):
if config["vtrace"]:
return VTraceTFPolicy
else:
return A3CTFPolicy
_name = "IMPALA"
_default_config = DEFAULT_CONFIG
_policy = VTraceTFPolicy
@override(Trainer)
def _init(self, config, env_creator):
for k in OPTIMIZER_SHARED_CONFIGS:
if k not in config["optimizer"]:
config["optimizer"][k] = config[k]
policy_cls = self._get_policy()
self.workers = self._make_workers(
self.env_creator, policy_cls, self.config, num_workers=0)
def validate_config(config):
if config["entropy_coeff"] < 0:
raise DeprecationWarning("entropy_coeff must be >= 0")
if self.config["num_aggregation_workers"] > 0:
# Create co-located aggregator actors first for placement pref
aggregators = TreeAggregator.precreate_aggregators(
self.config["num_aggregation_workers"])
self.workers.add_workers(config["num_workers"])
self.optimizer = AsyncSamplesOptimizer(self.workers,
**config["optimizer"])
if config["entropy_coeff"] < 0:
raise DeprecationWarning("entropy_coeff must be >= 0")
def defer_make_workers(trainer, env_creator, policy, config):
# Defer worker creation to after the optimizer has been created.
return trainer._make_workers(env_creator, policy, config, 0)
if self.config["num_aggregation_workers"] > 0:
# Assign the pre-created aggregators to the optimizer
self.optimizer.aggregator.init(aggregators)
def make_aggregators_and_optimizer(workers, config):
if config["num_aggregation_workers"] > 0:
# Create co-located aggregator actors first for placement pref
aggregators = TreeAggregator.precreate_aggregators(
config["num_aggregation_workers"])
else:
aggregators = None
workers.add_workers(config["num_workers"])
optimizer = AsyncSamplesOptimizer(
workers,
lr=config["lr"],
num_envs_per_worker=config["num_envs_per_worker"],
num_gpus=config["num_gpus"],
sample_batch_size=config["sample_batch_size"],
train_batch_size=config["train_batch_size"],
replay_buffer_num_slots=config["replay_buffer_num_slots"],
replay_proportion=config["replay_proportion"],
num_data_loader_buffers=config["num_data_loader_buffers"],
max_sample_requests_in_flight_per_worker=config[
"max_sample_requests_in_flight_per_worker"],
broadcast_interval=config["broadcast_interval"],
num_sgd_iter=config["num_sgd_iter"],
minibatch_buffer_size=config["minibatch_buffer_size"],
num_aggregation_workers=config["num_aggregation_workers"],
**config["optimizer"])
if aggregators:
# Assign the pre-created aggregators to the optimizer
optimizer.aggregator.init(aggregators)
return optimizer
class OverrideDefaultResourceRequest(object):
@classmethod
@override(Trainable)
def default_resource_request(cls, config):
@@ -143,22 +146,13 @@ class ImpalaTrainer(Trainer):
cf["num_aggregation_workers"],
extra_gpu=cf["num_gpus_per_worker"] * cf["num_workers"])
@override(Trainer)
def _train(self):
prev_steps = self.optimizer.num_steps_sampled
start = time.time()
self.optimizer.step()
while (time.time() - start < self.config["min_iter_time_s"]
or self.optimizer.num_steps_sampled == prev_steps):
self.optimizer.step()
result = self.collect_metrics()
result.update(timesteps_this_iter=self.optimizer.num_steps_sampled -
prev_steps)
return result
def _get_policy(self):
if self.config["vtrace"]:
policy_cls = self._policy
else:
policy_cls = A3CTFPolicy
return policy_cls
ImpalaTrainer = build_trainer(
name="IMPALA",
default_config=DEFAULT_CONFIG,
default_policy=VTraceTFPolicy,
validate_config=validate_config,
get_policy_class=choose_policy,
make_workers=defer_make_workers,
make_policy_optimizer=make_aggregators_and_optimizer,
mixins=[OverrideDefaultResourceRequest])
+14 -27
View File
@@ -2,10 +2,10 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from ray.rllib.agents.trainer import Trainer, with_common_config
from ray.rllib.agents.trainer import with_common_config
from ray.rllib.agents.trainer_template import build_trainer
from ray.rllib.agents.marwil.marwil_policy import MARWILPolicy
from ray.rllib.optimizers import SyncBatchReplayOptimizer
from ray.rllib.utils.annotations import override
# yapf: disable
# __sphinx_doc_begin__
@@ -39,30 +39,17 @@ DEFAULT_CONFIG = with_common_config({
# yapf: enable
class MARWILTrainer(Trainer):
"""MARWIL implementation in TensorFlow."""
def make_optimizer(workers, config):
return SyncBatchReplayOptimizer(
workers,
learning_starts=config["learning_starts"],
buffer_size=config["replay_buffer_size"],
train_batch_size=config["train_batch_size"],
)
_name = "MARWIL"
_default_config = DEFAULT_CONFIG
_policy = MARWILPolicy
@override(Trainer)
def _init(self, config, env_creator):
self.workers = self._make_workers(env_creator, self._policy, config,
config["num_workers"])
self.optimizer = SyncBatchReplayOptimizer(
self.workers,
learning_starts=config["learning_starts"],
buffer_size=config["replay_buffer_size"],
train_batch_size=config["train_batch_size"],
)
@override(Trainer)
def _train(self):
prev_steps = self.optimizer.num_steps_sampled
fetches = self.optimizer.step()
res = self.collect_metrics()
res.update(
timesteps_this_iter=self.optimizer.num_steps_sampled - prev_steps,
info=dict(fetches, **res.get("info", {})))
return res
MARWILTrainer = build_trainer(
name="MARWIL",
default_config=DEFAULT_CONFIG,
default_policy=MARWILPolicy,
make_policy_optimizer=make_optimizer)
+5 -12
View File
@@ -5,7 +5,6 @@ from __future__ import print_function
from ray.rllib.agents.ppo.appo_policy import AsyncPPOTFPolicy
from ray.rllib.agents.trainer import with_base_config
from ray.rllib.agents import impala
from ray.rllib.utils.annotations import override
# yapf: disable
# __sphinx_doc_begin__
@@ -51,14 +50,8 @@ DEFAULT_CONFIG = with_base_config(impala.DEFAULT_CONFIG, {
# __sphinx_doc_end__
# yapf: enable
class APPOTrainer(impala.ImpalaTrainer):
"""PPO surrogate loss with IMPALA-architecture."""
_name = "APPO"
_default_config = DEFAULT_CONFIG
_policy = AsyncPPOTFPolicy
@override(impala.ImpalaTrainer)
def _get_policy(self):
return AsyncPPOTFPolicy
APPOTrainer = impala.ImpalaTrainer.with_updates(
name="APPO",
default_config=DEFAULT_CONFIG,
default_policy=AsyncPPOTFPolicy,
get_policy_class=lambda _: AsyncPPOTFPolicy)
+5 -22
View File
@@ -4,15 +4,14 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from ray.rllib.agents.dqn.apex import APEX_TRAINER_PROPERTIES
from ray.rllib.agents.qmix.qmix import QMixTrainer, \
DEFAULT_CONFIG as QMIX_CONFIG
from ray.rllib.utils.annotations import override
from ray.rllib.utils import merge_dicts
APEX_QMIX_DEFAULT_CONFIG = merge_dicts(
QMIX_CONFIG, # see also the options in qmix.py, which are also supported
{
"optimizer_class": "AsyncReplayOptimizer",
"optimizer": merge_dicts(
QMIX_CONFIG["optimizer"],
{
@@ -34,23 +33,7 @@ APEX_QMIX_DEFAULT_CONFIG = merge_dicts(
},
)
class ApexQMixTrainer(QMixTrainer):
"""QMIX variant that uses the Ape-X distributed policy optimizer.
By default, this is configured for a large single node (32 cores). For
running in a large cluster, increase the `num_workers` config var.
"""
_name = "APEX_QMIX"
_default_config = APEX_QMIX_DEFAULT_CONFIG
@override(QMixTrainer)
def update_target_if_needed(self):
# Ape-X updates based on num steps trained, not sampled
if self.optimizer.num_steps_trained - self.last_target_update_ts > \
self.config["target_network_update_freq"]:
self.workers.local_worker().foreach_trainable_policy(
lambda p, _: p.update_target())
self.last_target_update_ts = self.optimizer.num_steps_trained
self.num_target_updates += 1
ApexQMixTrainer = QMixTrainer.with_updates(
name="APEX_QMIX",
default_config=APEX_QMIX_DEFAULT_CONFIG,
**APEX_TRAINER_PROPERTIES)
+14 -11
View File
@@ -3,8 +3,9 @@ from __future__ import division
from __future__ import print_function
from ray.rllib.agents.trainer import with_common_config
from ray.rllib.agents.dqn.dqn import DQNTrainer
from ray.rllib.agents.dqn.dqn import GenericOffPolicyTrainer
from ray.rllib.agents.qmix.qmix_policy import QMixTorchPolicy
from ray.rllib.optimizers import SyncBatchReplayOptimizer
# yapf: disable
# __sphinx_doc_begin__
@@ -71,8 +72,6 @@ DEFAULT_CONFIG = with_common_config({
# to increase if your environment is particularly slow to sample, or if
# you"re using the Async or Ape-X optimizers.
"num_workers": 0,
# Optimizer class to use.
"optimizer_class": "SyncBatchReplayOptimizer",
# Whether to use a distribution of epsilons across workers for exploration.
"per_worker_exploration": False,
# Whether to compute priorities on workers.
@@ -90,12 +89,16 @@ DEFAULT_CONFIG = with_common_config({
# yapf: enable
class QMixTrainer(DQNTrainer):
"""QMix implementation in PyTorch."""
def make_sync_batch_optimizer(workers, config):
return SyncBatchReplayOptimizer(
workers,
learning_starts=config["learning_starts"],
buffer_size=config["buffer_size"],
train_batch_size=config["train_batch_size"])
_name = "QMIX"
_default_config = DEFAULT_CONFIG
_policy = QMixTorchPolicy
_optimizer_shared_configs = [
"learning_starts", "buffer_size", "train_batch_size"
]
QMixTrainer = GenericOffPolicyTrainer.with_updates(
name="QMIX",
default_config=DEFAULT_CONFIG,
default_policy=QMixTorchPolicy,
make_policy_optimizer=make_sync_batch_optimizer)
+9
View File
@@ -189,6 +189,9 @@ COMMON_CONFIG = {
"remote_env_batch_wait_ms": 0,
# Minimum time per iteration
"min_iter_time_s": 0,
# Minimum env steps to optimize for per train call. This value does
# not affect learning, only the length of iterations.
"timesteps_per_iteration": 0,
# === Offline Datasets ===
# Specify how to generate experiences:
@@ -502,6 +505,7 @@ class Trainer(Trainable):
logger.info("Evaluating current policy for {} episodes".format(
self.config["evaluation_num_episodes"]))
self._before_evaluate()
self.evaluation_workers.local_worker().restore(
self.workers.local_worker().save())
for _ in range(self.config["evaluation_num_episodes"]):
@@ -510,6 +514,11 @@ class Trainer(Trainable):
metrics = collect_metrics(self.evaluation_workers.local_worker())
return {"evaluation": metrics}
@DeveloperAPI
def _before_evaluate(self):
"""Pre-evaluation callback."""
pass
@PublicAPI
def compute_action(self,
observation,
+80 -9
View File
@@ -6,6 +6,7 @@ import time
from ray.rllib.agents.trainer import Trainer, COMMON_CONFIG
from ray.rllib.optimizers import SyncSamplesOptimizer
from ray.rllib.utils import add_mixins
from ray.rllib.utils.annotations import override, DeveloperAPI
@@ -13,25 +14,47 @@ from ray.rllib.utils.annotations import override, DeveloperAPI
def build_trainer(name,
default_policy,
default_config=None,
make_policy_optimizer=None,
validate_config=None,
get_initial_state=None,
get_policy_class=None,
before_init=None,
make_workers=None,
make_policy_optimizer=None,
after_init=None,
before_train_step=None,
after_optimizer_step=None,
after_train_result=None):
after_train_result=None,
collect_metrics_fn=None,
before_evaluate_fn=None,
mixins=None):
"""Helper function for defining a custom trainer.
Functions will be run in this order to initialize the trainer:
1. Config setup: validate_config, get_initial_state, get_policy
2. Worker setup: before_init, make_workers, make_policy_optimizer
3. Post setup: after_init
Arguments:
name (str): name of the trainer (e.g., "PPO")
default_policy (cls): the default Policy class to use
default_config (dict): the default config dict of the algorithm,
otherwises uses the Trainer default config
make_policy_optimizer (func): optional function that returns a
PolicyOptimizer instance given (WorkerSet, config)
validate_config (func): optional callback that checks a given config
for correctness. It may mutate the config as needed.
get_initial_state (func): optional function that returns the initial
state dict given the trainer instance as an argument. The state
dict must be serializable so that it can be checkpointed, and will
be available as the `trainer.state` variable.
get_policy_class (func): optional callback that takes a config and
returns the policy class to override the default with
before_init (func): optional function to run at the start of trainer
init that takes the trainer instance as argument
make_workers (func): override the method that creates rollout workers.
This takes in (trainer, env_creator, policy, config) as args.
make_policy_optimizer (func): optional function that returns a
PolicyOptimizer instance given (WorkerSet, config)
after_init (func): optional function to run at the end of trainer init
that takes the trainer instance as argument
before_train_step (func): optional callback to run before each train()
call. It takes the trainer instance as an argument.
after_optimizer_step (func): optional callback to run after each
@@ -40,27 +63,47 @@ def build_trainer(name,
after_train_result (func): optional callback to run at the end of each
train() call. It takes the trainer instance and result dict as
arguments, and may mutate the result dict as needed.
collect_metrics_fn (func): override the method used to collect metrics.
It takes the trainer instance as argumnt.
before_evaluate_fn (func): callback to run before evaluation. This
takes the trainer instance as argument.
mixins (list): list of any class mixins for the returned trainer class.
These mixins will be applied in order and will have higher
precedence than the Trainer class
Returns:
a Trainer instance that uses the specified args.
"""
original_kwargs = locals().copy()
base = add_mixins(Trainer, mixins)
class trainer_cls(Trainer):
class trainer_cls(base):
_name = name
_default_config = default_config or COMMON_CONFIG
_policy = default_policy
def __init__(self, config=None, env=None, logger_creator=None):
Trainer.__init__(self, config, env, logger_creator)
def _init(self, config, env_creator):
if validate_config:
validate_config(config)
if get_initial_state:
self.state = get_initial_state(self)
else:
self.state = {}
if get_policy_class is None:
policy = default_policy
else:
policy = get_policy_class(config)
self.workers = self._make_workers(env_creator, policy, config,
self.config["num_workers"])
if before_init:
before_init(self)
if make_workers:
self.workers = make_workers(self, env_creator, policy, config)
else:
self.workers = self._make_workers(env_creator, policy, config,
self.config["num_workers"])
if make_policy_optimizer:
self.optimizer = make_policy_optimizer(self.workers, config)
else:
@@ -69,6 +112,8 @@ def build_trainer(name,
**{"train_batch_size": config["train_batch_size"]})
self.optimizer = SyncSamplesOptimizer(self.workers,
**optimizer_config)
if after_init:
after_init(self)
@override(Trainer)
def _train(self):
@@ -81,20 +126,46 @@ def build_trainer(name,
fetches = self.optimizer.step()
if after_optimizer_step:
after_optimizer_step(self, fetches)
if time.time() - start > self.config["min_iter_time_s"]:
if (time.time() - start >= self.config["min_iter_time_s"]
and self.optimizer.num_steps_sampled - prev_steps >=
self.config["timesteps_per_iteration"]):
break
res = self.collect_metrics()
if collect_metrics_fn:
res = collect_metrics_fn(self)
else:
res = self.collect_metrics()
res.update(
timesteps_this_iter=self.optimizer.num_steps_sampled -
prev_steps,
info=res.get("info", {}))
if after_train_result:
after_train_result(self, res)
return res
@override(Trainer)
def _before_evaluate(self):
if before_evaluate_fn:
before_evaluate_fn(self)
def __getstate__(self):
state = Trainer.__getstate__(self)
state.update(self.state)
return state
def __setstate__(self, state):
Trainer.__setstate__(self, state)
self.state = state
@staticmethod
def with_updates(**overrides):
"""Build a copy of this trainer with the specified overrides.
Arguments:
overrides (dict): use this to override any of the arguments
originally passed to build_trainer() for this policy.
"""
return build_trainer(**dict(original_kwargs, **overrides))
trainer_cls.with_updates = with_updates
@@ -5,6 +5,7 @@ from __future__ import print_function
from ray.rllib.policy.dynamic_tf_policy import DynamicTFPolicy
from ray.rllib.policy.policy import Policy, LEARNER_STATS_KEY
from ray.rllib.policy.tf_policy import TFPolicy
from ray.rllib.utils import add_mixins
from ray.rllib.utils.annotations import override, DeveloperAPI
@@ -89,13 +90,7 @@ def build_tf_policy(name,
"""
original_kwargs = locals().copy()
base = DynamicTFPolicy
while mixins:
class new_base(mixins.pop(), base):
pass
base = new_base
base = add_mixins(DynamicTFPolicy, mixins)
class policy_cls(base):
def __init__(self,
@@ -5,6 +5,7 @@ from __future__ import print_function
from ray.rllib.policy.policy import Policy
from ray.rllib.policy.torch_policy import TorchPolicy
from ray.rllib.models.catalog import ModelCatalog
from ray.rllib.utils import add_mixins
from ray.rllib.utils.annotations import override, DeveloperAPI
@@ -56,13 +57,7 @@ def build_torch_policy(name,
"""
original_kwargs = locals().copy()
base = TorchPolicy
while mixins:
class new_base(mixins.pop(), base):
pass
base = new_base
base = add_mixins(TorchPolicy, mixins)
class policy_cls(base):
def __init__(self, obs_space, action_space, config):
+15
View File
@@ -27,6 +27,21 @@ def renamed_class(cls, old_name):
return DeprecationWrapper
def add_mixins(base, mixins):
"""Returns a new class with mixins applied in priority order."""
mixins = list(mixins or [])
while mixins:
class new_base(mixins.pop(), base):
pass
base = new_base
return base
def renamed_agent(cls):
"""Helper class for renaming Agent => Trainer with a warning."""