mirror of
https://github.com/wassname/ray.git
synced 2026-07-01 18:22:26 +08:00
[rllib] Port remainder of algorithms to build_trainer() pattern (#4920)
This commit is contained in:
@@ -2,15 +2,14 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from ray.rllib.agents.dqn.apex import APEX_TRAINER_PROPERTIES
|
||||
from ray.rllib.agents.ddpg.ddpg import DDPGTrainer, \
|
||||
DEFAULT_CONFIG as DDPG_CONFIG
|
||||
from ray.rllib.utils.annotations import override
|
||||
from ray.rllib.utils import merge_dicts
|
||||
|
||||
APEX_DDPG_DEFAULT_CONFIG = merge_dicts(
|
||||
DDPG_CONFIG, # see also the options in ddpg.py, which are also supported
|
||||
{
|
||||
"optimizer_class": "AsyncReplayOptimizer",
|
||||
"optimizer": merge_dicts(
|
||||
DDPG_CONFIG["optimizer"], {
|
||||
"max_weight_sync_delay": 400,
|
||||
@@ -32,23 +31,7 @@ APEX_DDPG_DEFAULT_CONFIG = merge_dicts(
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
class ApexDDPGTrainer(DDPGTrainer):
|
||||
"""DDPG variant that uses the Ape-X distributed policy optimizer.
|
||||
|
||||
By default, this is configured for a large single node (32 cores). For
|
||||
running in a large cluster, increase the `num_workers` config var.
|
||||
"""
|
||||
|
||||
_name = "APEX_DDPG"
|
||||
_default_config = APEX_DDPG_DEFAULT_CONFIG
|
||||
|
||||
@override(DDPGTrainer)
|
||||
def update_target_if_needed(self):
|
||||
# Ape-X updates based on num steps trained, not sampled
|
||||
if self.optimizer.num_steps_trained - self.last_target_update_ts > \
|
||||
self.config["target_network_update_freq"]:
|
||||
self.workers.local_worker().foreach_trainable_policy(
|
||||
lambda p, _: p.update_target())
|
||||
self.last_target_update_ts = self.optimizer.num_steps_trained
|
||||
self.num_target_updates += 1
|
||||
ApexDDPGTrainer = DDPGTrainer.with_updates(
|
||||
name="APEX_DDPG",
|
||||
default_config=APEX_DDPG_DEFAULT_CONFIG,
|
||||
**APEX_TRAINER_PROPERTIES)
|
||||
|
||||
@@ -3,9 +3,9 @@ from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from ray.rllib.agents.trainer import with_common_config
|
||||
from ray.rllib.agents.dqn.dqn import DQNTrainer
|
||||
from ray.rllib.agents.dqn.dqn import GenericOffPolicyTrainer, \
|
||||
update_worker_explorations
|
||||
from ray.rllib.agents.ddpg.ddpg_policy import DDPGTFPolicy
|
||||
from ray.rllib.utils.annotations import override
|
||||
from ray.rllib.utils.schedules import ConstantSchedule, LinearSchedule
|
||||
|
||||
# yapf: disable
|
||||
@@ -97,6 +97,11 @@ DEFAULT_CONFIG = with_common_config({
|
||||
# optimization on initial policy parameters. Note that this will be
|
||||
# disabled when the action noise scale is set to 0 (e.g during evaluation).
|
||||
"pure_exploration_steps": 1000,
|
||||
# Extra configuration that disables exploration.
|
||||
"evaluation_config": {
|
||||
"exploration_fraction": 0,
|
||||
"exploration_final_eps": 0,
|
||||
},
|
||||
|
||||
# === Replay buffer ===
|
||||
# Size of the replay buffer. Note that if async_updates is set, then
|
||||
@@ -108,6 +113,11 @@ DEFAULT_CONFIG = with_common_config({
|
||||
"prioritized_replay_alpha": 0.6,
|
||||
# Beta parameter for sampling from prioritized replay buffer.
|
||||
"prioritized_replay_beta": 0.4,
|
||||
# Fraction of entire training period over which the beta parameter is
|
||||
# annealed
|
||||
"beta_annealing_fraction": 0.2,
|
||||
# Final value of beta
|
||||
"final_prioritized_replay_beta": 0.4,
|
||||
# Epsilon to add to the TD errors when updating priorities.
|
||||
"prioritized_replay_eps": 1e-6,
|
||||
# Whether to LZ4 compress observations
|
||||
@@ -146,8 +156,6 @@ DEFAULT_CONFIG = with_common_config({
|
||||
# to increase if your environment is particularly slow to sample, or if
|
||||
# you're using the Async or Ape-X optimizers.
|
||||
"num_workers": 0,
|
||||
# Optimizer class to use.
|
||||
"optimizer_class": "SyncReplayOptimizer",
|
||||
# Whether to use a distribution of epsilons across workers for exploration.
|
||||
"per_worker_exploration": False,
|
||||
# Whether to compute priorities on workers.
|
||||
@@ -159,47 +167,56 @@ DEFAULT_CONFIG = with_common_config({
|
||||
# yapf: enable
|
||||
|
||||
|
||||
class DDPGTrainer(DQNTrainer):
|
||||
"""DDPG implementation in TensorFlow."""
|
||||
_name = "DDPG"
|
||||
_default_config = DEFAULT_CONFIG
|
||||
_policy = DDPGTFPolicy
|
||||
|
||||
@override(DQNTrainer)
|
||||
def _train(self):
|
||||
pure_expl_steps = self.config["pure_exploration_steps"]
|
||||
if pure_expl_steps:
|
||||
# tell workers whether they should do pure exploration
|
||||
only_explore = self.global_timestep < pure_expl_steps
|
||||
self.workers.local_worker().foreach_trainable_policy(
|
||||
lambda p, _: p.set_pure_exploration_phase(only_explore))
|
||||
for e in self.workers.remote_workers():
|
||||
e.foreach_trainable_policy.remote(
|
||||
lambda p, _: p.set_pure_exploration_phase(only_explore))
|
||||
return super(DDPGTrainer, self)._train()
|
||||
|
||||
@override(DQNTrainer)
|
||||
def _make_exploration_schedule(self, worker_index):
|
||||
# Override DQN's schedule to take into account
|
||||
# `exploration_ou_noise_scale`
|
||||
if self.config["per_worker_exploration"]:
|
||||
assert self.config["num_workers"] > 1, \
|
||||
"This requires multiple workers"
|
||||
if worker_index >= 0:
|
||||
# FIXME: what do magic constants mean? (0.4, 7)
|
||||
max_index = float(self.config["num_workers"] - 1)
|
||||
exponent = 1 + worker_index / max_index * 7
|
||||
return ConstantSchedule(0.4**exponent)
|
||||
else:
|
||||
# local ev should have zero exploration so that eval rollouts
|
||||
# run properly
|
||||
return ConstantSchedule(0.0)
|
||||
elif self.config["exploration_should_anneal"]:
|
||||
return LinearSchedule(
|
||||
schedule_timesteps=int(self.config["exploration_fraction"] *
|
||||
self.config["schedule_max_timesteps"]),
|
||||
initial_p=1.0,
|
||||
final_p=self.config["exploration_final_scale"])
|
||||
def make_exploration_schedule(config, worker_index):
|
||||
# Modification of DQN's schedule to take into account
|
||||
# `exploration_ou_noise_scale`
|
||||
if config["per_worker_exploration"]:
|
||||
assert config["num_workers"] > 1, "This requires multiple workers"
|
||||
if worker_index >= 0:
|
||||
# FIXME: what do magic constants mean? (0.4, 7)
|
||||
max_index = float(config["num_workers"] - 1)
|
||||
exponent = 1 + worker_index / max_index * 7
|
||||
return ConstantSchedule(0.4**exponent)
|
||||
else:
|
||||
# *always* add exploration noise
|
||||
return ConstantSchedule(1.0)
|
||||
# local ev should have zero exploration so that eval rollouts
|
||||
# run properly
|
||||
return ConstantSchedule(0.0)
|
||||
elif config["exploration_should_anneal"]:
|
||||
return LinearSchedule(
|
||||
schedule_timesteps=int(config["exploration_fraction"] *
|
||||
config["schedule_max_timesteps"]),
|
||||
initial_p=1.0,
|
||||
final_p=config["exploration_final_scale"])
|
||||
else:
|
||||
# *always* add exploration noise
|
||||
return ConstantSchedule(1.0)
|
||||
|
||||
|
||||
def setup_ddpg_exploration(trainer):
|
||||
trainer.exploration0 = make_exploration_schedule(trainer.config, -1)
|
||||
trainer.explorations = [
|
||||
make_exploration_schedule(trainer.config, i)
|
||||
for i in range(trainer.config["num_workers"])
|
||||
]
|
||||
|
||||
|
||||
def add_pure_exploration_phase(trainer):
|
||||
global_timestep = trainer.optimizer.num_steps_sampled
|
||||
pure_expl_steps = trainer.config["pure_exploration_steps"]
|
||||
if pure_expl_steps:
|
||||
# tell workers whether they should do pure exploration
|
||||
only_explore = global_timestep < pure_expl_steps
|
||||
trainer.workers.local_worker().foreach_trainable_policy(
|
||||
lambda p, _: p.set_pure_exploration_phase(only_explore))
|
||||
for e in trainer.workers.remote_workers():
|
||||
e.foreach_trainable_policy.remote(
|
||||
lambda p, _: p.set_pure_exploration_phase(only_explore))
|
||||
update_worker_explorations(trainer)
|
||||
|
||||
|
||||
DDPGTrainer = GenericOffPolicyTrainer.with_updates(
|
||||
name="DDPG",
|
||||
default_config=DEFAULT_CONFIG,
|
||||
default_policy=DDPGTFPolicy,
|
||||
before_init=setup_ddpg_exploration,
|
||||
before_train_step=add_pure_exploration_phase)
|
||||
|
||||
@@ -1,3 +1,9 @@
|
||||
"""A more stable successor to TD3.
|
||||
|
||||
By default, this uses a near-identical configuration to that reported in the
|
||||
TD3 paper.
|
||||
"""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
@@ -36,7 +42,6 @@ TD3_DEFAULT_CONFIG = merge_dicts(
|
||||
"train_batch_size": 100,
|
||||
"use_huber": False,
|
||||
"target_network_update_freq": 0,
|
||||
"optimizer_class": "SyncReplayOptimizer",
|
||||
"num_workers": 0,
|
||||
"num_gpus_per_worker": 0,
|
||||
"per_worker_exploration": False,
|
||||
@@ -48,10 +53,5 @@ TD3_DEFAULT_CONFIG = merge_dicts(
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
class TD3Trainer(DDPGTrainer):
|
||||
"""A more stable successor to TD3. By default, this uses a near-identical
|
||||
configuration to that reported in the TD3 paper."""
|
||||
|
||||
_name = "TD3"
|
||||
_default_config = TD3_DEFAULT_CONFIG
|
||||
TD3Trainer = DDPGTrainer.with_updates(
|
||||
name="TD3", default_config=TD3_DEFAULT_CONFIG)
|
||||
|
||||
@@ -3,15 +3,14 @@ from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from ray.rllib.agents.dqn.dqn import DQNTrainer, DEFAULT_CONFIG as DQN_CONFIG
|
||||
from ray.rllib.optimizers import AsyncReplayOptimizer
|
||||
from ray.rllib.utils import merge_dicts
|
||||
from ray.rllib.utils.annotations import override
|
||||
|
||||
# yapf: disable
|
||||
# __sphinx_doc_begin__
|
||||
APEX_DEFAULT_CONFIG = merge_dicts(
|
||||
DQN_CONFIG, # see also the options in dqn.py, which are also supported
|
||||
{
|
||||
"optimizer_class": "AsyncReplayOptimizer",
|
||||
"optimizer": merge_dicts(
|
||||
DQN_CONFIG["optimizer"], {
|
||||
"max_weight_sync_delay": 400,
|
||||
@@ -36,22 +35,50 @@ APEX_DEFAULT_CONFIG = merge_dicts(
|
||||
# yapf: enable
|
||||
|
||||
|
||||
class ApexTrainer(DQNTrainer):
|
||||
"""DQN variant that uses the Ape-X distributed policy optimizer.
|
||||
def defer_make_workers(trainer, env_creator, policy, config):
|
||||
# Hack to workaround https://github.com/ray-project/ray/issues/2541
|
||||
# The workers will be creatd later, after the optimizer is created
|
||||
return trainer._make_workers(env_creator, policy, config, 0)
|
||||
|
||||
By default, this is configured for a large single node (32 cores). For
|
||||
running in a large cluster, increase the `num_workers` config var.
|
||||
"""
|
||||
|
||||
_name = "APEX"
|
||||
_default_config = APEX_DEFAULT_CONFIG
|
||||
def make_async_optimizer(workers, config):
|
||||
assert len(workers.remote_workers()) == 0
|
||||
extra_config = config["optimizer"].copy()
|
||||
for key in [
|
||||
"prioritized_replay", "prioritized_replay_alpha",
|
||||
"prioritized_replay_beta", "prioritized_replay_eps"
|
||||
]:
|
||||
if key in config:
|
||||
extra_config[key] = config[key]
|
||||
opt = AsyncReplayOptimizer(
|
||||
workers,
|
||||
learning_starts=config["learning_starts"],
|
||||
buffer_size=config["buffer_size"],
|
||||
train_batch_size=config["train_batch_size"],
|
||||
sample_batch_size=config["sample_batch_size"],
|
||||
**extra_config)
|
||||
workers.add_workers(config["num_workers"])
|
||||
opt._set_workers(workers.remote_workers())
|
||||
return opt
|
||||
|
||||
@override(DQNTrainer)
|
||||
def update_target_if_needed(self):
|
||||
# Ape-X updates based on num steps trained, not sampled
|
||||
if self.optimizer.num_steps_trained - self.last_target_update_ts > \
|
||||
self.config["target_network_update_freq"]:
|
||||
self.workers.local_worker().foreach_trainable_policy(
|
||||
lambda p, _: p.update_target())
|
||||
self.last_target_update_ts = self.optimizer.num_steps_trained
|
||||
self.num_target_updates += 1
|
||||
|
||||
def update_target_based_on_num_steps_trained(trainer, fetches):
|
||||
# Ape-X updates based on num steps trained, not sampled
|
||||
if (trainer.optimizer.num_steps_trained -
|
||||
trainer.state["last_target_update_ts"] >
|
||||
trainer.config["target_network_update_freq"]):
|
||||
trainer.workers.local_worker().foreach_trainable_policy(
|
||||
lambda p, _: p.update_target())
|
||||
trainer.state["last_target_update_ts"] = (
|
||||
trainer.optimizer.num_steps_trained)
|
||||
trainer.state["num_target_updates"] += 1
|
||||
|
||||
|
||||
APEX_TRAINER_PROPERTIES = {
|
||||
"make_workers": defer_make_workers,
|
||||
"make_policy_optimizer": make_async_optimizer,
|
||||
"after_optimizer_step": update_target_based_on_num_steps_trained,
|
||||
}
|
||||
|
||||
ApexTrainer = DQNTrainer.with_updates(
|
||||
name="APEX", default_config=APEX_DEFAULT_CONFIG, **APEX_TRAINER_PROPERTIES)
|
||||
|
||||
+164
-197
@@ -3,27 +3,17 @@ from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import logging
|
||||
import time
|
||||
|
||||
from ray import tune
|
||||
from ray.rllib import optimizers
|
||||
from ray.rllib.agents.trainer import Trainer, with_common_config
|
||||
from ray.rllib.agents.trainer import with_common_config
|
||||
from ray.rllib.agents.trainer_template import build_trainer
|
||||
from ray.rllib.agents.dqn.dqn_policy import DQNTFPolicy
|
||||
from ray.rllib.evaluation.metrics import collect_metrics
|
||||
from ray.rllib.optimizers import SyncReplayOptimizer
|
||||
from ray.rllib.policy.sample_batch import DEFAULT_POLICY_ID
|
||||
from ray.rllib.utils.annotations import override
|
||||
from ray.rllib.utils.schedules import ConstantSchedule, LinearSchedule
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
OPTIMIZER_SHARED_CONFIGS = [
|
||||
"buffer_size", "prioritized_replay", "prioritized_replay_alpha",
|
||||
"prioritized_replay_beta", "schedule_max_timesteps",
|
||||
"beta_annealing_fraction", "final_prioritized_replay_beta",
|
||||
"prioritized_replay_eps", "sample_batch_size", "train_batch_size",
|
||||
"learning_starts"
|
||||
]
|
||||
|
||||
# yapf: disable
|
||||
# __sphinx_doc_begin__
|
||||
DEFAULT_CONFIG = with_common_config({
|
||||
@@ -53,7 +43,8 @@ DEFAULT_CONFIG = with_common_config({
|
||||
# 1.0 to exploration_fraction over this number of timesteps scaled by
|
||||
# exploration_fraction
|
||||
"schedule_max_timesteps": 100000,
|
||||
# Number of env steps to optimize for before returning
|
||||
# Minimum env steps to optimize for per train call. This value does
|
||||
# not affect learning, only the length of iterations.
|
||||
"timesteps_per_iteration": 1000,
|
||||
# Fraction of entire training period over which the exploration rate is
|
||||
# annealed
|
||||
@@ -70,6 +61,11 @@ DEFAULT_CONFIG = with_common_config({
|
||||
# If True parameter space noise will be used for exploration
|
||||
# See https://blog.openai.com/better-exploration-with-parameter-noise/
|
||||
"parameter_noise": False,
|
||||
# Extra configuration that disables exploration.
|
||||
"evaluation_config": {
|
||||
"exploration_fraction": 0,
|
||||
"exploration_final_eps": 0,
|
||||
},
|
||||
|
||||
# === Replay buffer ===
|
||||
# Size of the replay buffer. Note that if async_updates is set, then
|
||||
@@ -115,8 +111,6 @@ DEFAULT_CONFIG = with_common_config({
|
||||
# to increase if your environment is particularly slow to sample, or if
|
||||
# you"re using the Async or Ape-X optimizers.
|
||||
"num_workers": 0,
|
||||
# Optimizer class to use.
|
||||
"optimizer_class": "SyncReplayOptimizer",
|
||||
# Whether to use a distribution of epsilons across workers for exploration.
|
||||
"per_worker_exploration": False,
|
||||
# Whether to compute priorities on workers.
|
||||
@@ -128,202 +122,175 @@ DEFAULT_CONFIG = with_common_config({
|
||||
# yapf: enable
|
||||
|
||||
|
||||
class DQNTrainer(Trainer):
|
||||
"""DQN implementation in TensorFlow."""
|
||||
def make_optimizer(workers, config):
|
||||
return SyncReplayOptimizer(
|
||||
workers,
|
||||
learning_starts=config["learning_starts"],
|
||||
buffer_size=config["buffer_size"],
|
||||
prioritized_replay=config["prioritized_replay"],
|
||||
prioritized_replay_alpha=config["prioritized_replay_alpha"],
|
||||
prioritized_replay_beta=config["prioritized_replay_beta"],
|
||||
schedule_max_timesteps=config["schedule_max_timesteps"],
|
||||
beta_annealing_fraction=config["beta_annealing_fraction"],
|
||||
final_prioritized_replay_beta=config["final_prioritized_replay_beta"],
|
||||
prioritized_replay_eps=config["prioritized_replay_eps"],
|
||||
train_batch_size=config["train_batch_size"],
|
||||
sample_batch_size=config["sample_batch_size"],
|
||||
**config["optimizer"])
|
||||
|
||||
_name = "DQN"
|
||||
_default_config = DEFAULT_CONFIG
|
||||
_policy = DQNTFPolicy
|
||||
_optimizer_shared_configs = OPTIMIZER_SHARED_CONFIGS
|
||||
|
||||
@override(Trainer)
|
||||
def _init(self, config, env_creator):
|
||||
self._validate_config()
|
||||
def check_config_and_setup_param_noise(config):
|
||||
"""Update the config based on settings.
|
||||
|
||||
# Update effective batch size to include n-step
|
||||
adjusted_batch_size = max(config["sample_batch_size"],
|
||||
config.get("n_step", 1))
|
||||
config["sample_batch_size"] = adjusted_batch_size
|
||||
Rewrites sample_batch_size to take into account n_step truncation, and also
|
||||
adds the necessary callbacks to support parameter space noise exploration.
|
||||
"""
|
||||
|
||||
self.exploration0 = self._make_exploration_schedule(-1)
|
||||
self.explorations = [
|
||||
self._make_exploration_schedule(i)
|
||||
for i in range(config["num_workers"])
|
||||
]
|
||||
# Update effective batch size to include n-step
|
||||
adjusted_batch_size = max(config["sample_batch_size"],
|
||||
config.get("n_step", 1))
|
||||
config["sample_batch_size"] = adjusted_batch_size
|
||||
|
||||
for k in self._optimizer_shared_configs:
|
||||
if self._name != "DQN" and k in [
|
||||
"schedule_max_timesteps", "beta_annealing_fraction",
|
||||
"final_prioritized_replay_beta"
|
||||
]:
|
||||
# only Rainbow needs annealing prioritized_replay_beta
|
||||
continue
|
||||
if k not in config["optimizer"]:
|
||||
config["optimizer"][k] = config[k]
|
||||
|
||||
if config.get("parameter_noise", False):
|
||||
if config["callbacks"]["on_episode_start"]:
|
||||
start_callback = config["callbacks"]["on_episode_start"]
|
||||
else:
|
||||
start_callback = None
|
||||
|
||||
def on_episode_start(info):
|
||||
# as a callback function to sample and pose parameter space
|
||||
# noise on the parameters of network
|
||||
policies = info["policy"]
|
||||
for pi in policies.values():
|
||||
pi.add_parameter_noise()
|
||||
if start_callback:
|
||||
start_callback(info)
|
||||
|
||||
config["callbacks"]["on_episode_start"] = tune.function(
|
||||
on_episode_start)
|
||||
if config["callbacks"]["on_episode_end"]:
|
||||
end_callback = config["callbacks"]["on_episode_end"]
|
||||
else:
|
||||
end_callback = None
|
||||
|
||||
def on_episode_end(info):
|
||||
# as a callback function to monitor the distance
|
||||
# between noisy policy and original policy
|
||||
policies = info["policy"]
|
||||
episode = info["episode"]
|
||||
episode.custom_metrics["policy_distance"] = policies[
|
||||
DEFAULT_POLICY_ID].pi_distance
|
||||
if end_callback:
|
||||
end_callback(info)
|
||||
|
||||
config["callbacks"]["on_episode_end"] = tune.function(
|
||||
on_episode_end)
|
||||
|
||||
if config["optimizer_class"] != "AsyncReplayOptimizer":
|
||||
self.workers = self._make_workers(
|
||||
env_creator,
|
||||
self._policy,
|
||||
config,
|
||||
num_workers=self.config["num_workers"])
|
||||
workers_needed = 0
|
||||
if config.get("parameter_noise", False):
|
||||
if config["batch_mode"] != "complete_episodes":
|
||||
raise ValueError("Exploration with parameter space noise requires "
|
||||
"batch_mode to be complete_episodes.")
|
||||
if config.get("noisy", False):
|
||||
raise ValueError(
|
||||
"Exploration with parameter space noise and noisy network "
|
||||
"cannot be used at the same time.")
|
||||
if config["callbacks"]["on_episode_start"]:
|
||||
start_callback = config["callbacks"]["on_episode_start"]
|
||||
else:
|
||||
# Hack to workaround https://github.com/ray-project/ray/issues/2541
|
||||
self.workers = self._make_workers(
|
||||
env_creator, self._policy, config, num_workers=0)
|
||||
workers_needed = self.config["num_workers"]
|
||||
start_callback = None
|
||||
|
||||
self.optimizer = getattr(optimizers, config["optimizer_class"])(
|
||||
self.workers, **config["optimizer"])
|
||||
def on_episode_start(info):
|
||||
# as a callback function to sample and pose parameter space
|
||||
# noise on the parameters of network
|
||||
policies = info["policy"]
|
||||
for pi in policies.values():
|
||||
pi.add_parameter_noise()
|
||||
if start_callback:
|
||||
start_callback(info)
|
||||
|
||||
# Create the remote workers *after* the replay actors
|
||||
if workers_needed > 0:
|
||||
self.workers.add_workers(workers_needed)
|
||||
self.optimizer._set_workers(self.workers.remote_workers())
|
||||
|
||||
self.last_target_update_ts = 0
|
||||
self.num_target_updates = 0
|
||||
|
||||
@override(Trainer)
|
||||
def _train(self):
|
||||
start_timestep = self.global_timestep
|
||||
|
||||
# Update worker explorations
|
||||
exp_vals = [self.exploration0.value(self.global_timestep)]
|
||||
self.workers.local_worker().foreach_trainable_policy(
|
||||
lambda p, _: p.set_epsilon(exp_vals[0]))
|
||||
for i, e in enumerate(self.workers.remote_workers()):
|
||||
exp_val = self.explorations[i].value(self.global_timestep)
|
||||
e.foreach_trainable_policy.remote(
|
||||
lambda p, _: p.set_epsilon(exp_val))
|
||||
exp_vals.append(exp_val)
|
||||
|
||||
# Do optimization steps
|
||||
start = time.time()
|
||||
while (self.global_timestep - start_timestep <
|
||||
self.config["timesteps_per_iteration"]
|
||||
) or time.time() - start < self.config["min_iter_time_s"]:
|
||||
self.optimizer.step()
|
||||
self.update_target_if_needed()
|
||||
|
||||
if self.config["per_worker_exploration"]:
|
||||
# Only collect metrics from the third of workers with lowest eps
|
||||
result = self.collect_metrics(
|
||||
selected_workers=self.workers.remote_workers()[
|
||||
-len(self.workers.remote_workers()) // 3:])
|
||||
config["callbacks"]["on_episode_start"] = tune.function(
|
||||
on_episode_start)
|
||||
if config["callbacks"]["on_episode_end"]:
|
||||
end_callback = config["callbacks"]["on_episode_end"]
|
||||
else:
|
||||
result = self.collect_metrics()
|
||||
end_callback = None
|
||||
|
||||
result.update(
|
||||
timesteps_this_iter=self.global_timestep - start_timestep,
|
||||
info=dict({
|
||||
"min_exploration": min(exp_vals),
|
||||
"max_exploration": max(exp_vals),
|
||||
"num_target_updates": self.num_target_updates,
|
||||
}, **self.optimizer.stats()))
|
||||
def on_episode_end(info):
|
||||
# as a callback function to monitor the distance
|
||||
# between noisy policy and original policy
|
||||
policies = info["policy"]
|
||||
episode = info["episode"]
|
||||
episode.custom_metrics["policy_distance"] = policies[
|
||||
DEFAULT_POLICY_ID].pi_distance
|
||||
if end_callback:
|
||||
end_callback(info)
|
||||
|
||||
return result
|
||||
config["callbacks"]["on_episode_end"] = tune.function(on_episode_end)
|
||||
|
||||
def update_target_if_needed(self):
|
||||
if self.global_timestep - self.last_target_update_ts > \
|
||||
self.config["target_network_update_freq"]:
|
||||
self.workers.local_worker().foreach_trainable_policy(
|
||||
lambda p, _: p.update_target())
|
||||
self.last_target_update_ts = self.global_timestep
|
||||
self.num_target_updates += 1
|
||||
|
||||
@property
|
||||
def global_timestep(self):
|
||||
return self.optimizer.num_steps_sampled
|
||||
def get_initial_state(config):
|
||||
return {
|
||||
"last_target_update_ts": 0,
|
||||
"num_target_updates": 0,
|
||||
}
|
||||
|
||||
def _evaluate(self):
|
||||
logger.info("Evaluating current policy for {} episodes".format(
|
||||
self.config["evaluation_num_episodes"]))
|
||||
self.evaluation_workers.local_worker().restore(
|
||||
self.workers.local_worker().save())
|
||||
self.evaluation_workers.local_worker().foreach_policy(
|
||||
lambda p, _: p.set_epsilon(0))
|
||||
for _ in range(self.config["evaluation_num_episodes"]):
|
||||
self.evaluation_workers.local_worker().sample()
|
||||
metrics = collect_metrics(self.evaluation_workers.local_worker())
|
||||
return {"evaluation": metrics}
|
||||
|
||||
def _make_exploration_schedule(self, worker_index):
|
||||
# Use either a different `eps` per worker, or a linear schedule.
|
||||
if self.config["per_worker_exploration"]:
|
||||
assert self.config["num_workers"] > 1, \
|
||||
"This requires multiple workers"
|
||||
if worker_index >= 0:
|
||||
exponent = (
|
||||
1 +
|
||||
worker_index / float(self.config["num_workers"] - 1) * 7)
|
||||
return ConstantSchedule(0.4**exponent)
|
||||
else:
|
||||
# local ev should have zero exploration so that eval rollouts
|
||||
# run properly
|
||||
return ConstantSchedule(0.0)
|
||||
return LinearSchedule(
|
||||
schedule_timesteps=int(self.config["exploration_fraction"] *
|
||||
self.config["schedule_max_timesteps"]),
|
||||
initial_p=1.0,
|
||||
final_p=self.config["exploration_final_eps"])
|
||||
def make_exploration_schedule(config, worker_index):
|
||||
# Use either a different `eps` per worker, or a linear schedule.
|
||||
if config["per_worker_exploration"]:
|
||||
assert config["num_workers"] > 1, \
|
||||
"This requires multiple workers"
|
||||
if worker_index >= 0:
|
||||
exponent = (
|
||||
1 + worker_index / float(config["num_workers"] - 1) * 7)
|
||||
return ConstantSchedule(0.4**exponent)
|
||||
else:
|
||||
# local ev should have zero exploration so that eval rollouts
|
||||
# run properly
|
||||
return ConstantSchedule(0.0)
|
||||
return LinearSchedule(
|
||||
schedule_timesteps=int(
|
||||
config["exploration_fraction"] * config["schedule_max_timesteps"]),
|
||||
initial_p=1.0,
|
||||
final_p=config["exploration_final_eps"])
|
||||
|
||||
def __getstate__(self):
|
||||
state = Trainer.__getstate__(self)
|
||||
state.update({
|
||||
"num_target_updates": self.num_target_updates,
|
||||
"last_target_update_ts": self.last_target_update_ts,
|
||||
})
|
||||
return state
|
||||
|
||||
def __setstate__(self, state):
|
||||
Trainer.__setstate__(self, state)
|
||||
self.num_target_updates = state["num_target_updates"]
|
||||
self.last_target_update_ts = state["last_target_update_ts"]
|
||||
def setup_exploration(trainer):
|
||||
trainer.exploration0 = make_exploration_schedule(trainer.config, -1)
|
||||
trainer.explorations = [
|
||||
make_exploration_schedule(trainer.config, i)
|
||||
for i in range(trainer.config["num_workers"])
|
||||
]
|
||||
|
||||
def _validate_config(self):
|
||||
if self.config.get("parameter_noise", False):
|
||||
if self.config["batch_mode"] != "complete_episodes":
|
||||
raise ValueError(
|
||||
"Exploration with parameter space noise requires "
|
||||
"batch_mode to be complete_episodes.")
|
||||
if self.config.get("noisy", False):
|
||||
raise ValueError(
|
||||
"Exploration with parameter space noise and noisy network "
|
||||
"cannot be used at the same time.")
|
||||
|
||||
def update_worker_explorations(trainer):
|
||||
global_timestep = trainer.optimizer.num_steps_sampled
|
||||
exp_vals = [trainer.exploration0.value(global_timestep)]
|
||||
trainer.workers.local_worker().foreach_trainable_policy(
|
||||
lambda p, _: p.set_epsilon(exp_vals[0]))
|
||||
for i, e in enumerate(trainer.workers.remote_workers()):
|
||||
exp_val = trainer.explorations[i].value(global_timestep)
|
||||
e.foreach_trainable_policy.remote(lambda p, _: p.set_epsilon(exp_val))
|
||||
exp_vals.append(exp_val)
|
||||
trainer.train_start_timestep = global_timestep
|
||||
trainer.cur_exp_vals = exp_vals
|
||||
|
||||
|
||||
def add_trainer_metrics(trainer, result):
|
||||
global_timestep = trainer.optimizer.num_steps_sampled
|
||||
result.update(
|
||||
timesteps_this_iter=global_timestep - trainer.train_start_timestep,
|
||||
info=dict({
|
||||
"min_exploration": min(trainer.cur_exp_vals),
|
||||
"max_exploration": max(trainer.cur_exp_vals),
|
||||
"num_target_updates": trainer.state["num_target_updates"],
|
||||
}, **trainer.optimizer.stats()))
|
||||
|
||||
|
||||
def update_target_if_needed(trainer, fetches):
|
||||
global_timestep = trainer.optimizer.num_steps_sampled
|
||||
if global_timestep - trainer.state["last_target_update_ts"] > \
|
||||
trainer.config["target_network_update_freq"]:
|
||||
trainer.workers.local_worker().foreach_trainable_policy(
|
||||
lambda p, _: p.update_target())
|
||||
trainer.state["last_target_update_ts"] = global_timestep
|
||||
trainer.state["num_target_updates"] += 1
|
||||
|
||||
|
||||
def collect_metrics(trainer):
|
||||
if trainer.config["per_worker_exploration"]:
|
||||
# Only collect metrics from the third of workers with lowest eps
|
||||
result = trainer.collect_metrics(
|
||||
selected_workers=trainer.workers.remote_workers()[
|
||||
-len(trainer.workers.remote_workers()) // 3:])
|
||||
else:
|
||||
result = trainer.collect_metrics()
|
||||
return result
|
||||
|
||||
|
||||
def disable_exploration(trainer):
|
||||
trainer.evaluation_workers.local_worker().foreach_policy(
|
||||
lambda p, _: p.set_epsilon(0))
|
||||
|
||||
|
||||
GenericOffPolicyTrainer = build_trainer(
|
||||
name="GenericOffPolicyAlgorithm",
|
||||
default_policy=None,
|
||||
default_config=DEFAULT_CONFIG,
|
||||
validate_config=check_config_and_setup_param_noise,
|
||||
get_initial_state=get_initial_state,
|
||||
make_policy_optimizer=make_optimizer,
|
||||
before_init=setup_exploration,
|
||||
before_train_step=update_worker_explorations,
|
||||
after_optimizer_step=update_target_if_needed,
|
||||
after_train_result=add_trainer_metrics,
|
||||
collect_metrics_fn=collect_metrics,
|
||||
before_evaluate_fn=disable_exploration)
|
||||
|
||||
DQNTrainer = GenericOffPolicyTrainer.with_updates(
|
||||
name="DQN", default_policy=DQNTFPolicy, default_config=DEFAULT_CONFIG)
|
||||
|
||||
@@ -2,33 +2,16 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import time
|
||||
|
||||
from ray.rllib.agents.a3c.a3c_tf_policy import A3CTFPolicy
|
||||
from ray.rllib.agents.impala.vtrace_policy import VTraceTFPolicy
|
||||
from ray.rllib.agents.trainer import Trainer, with_common_config
|
||||
from ray.rllib.agents.trainer_template import build_trainer
|
||||
from ray.rllib.optimizers import AsyncSamplesOptimizer
|
||||
from ray.rllib.optimizers.aso_tree_aggregator import TreeAggregator
|
||||
from ray.rllib.utils.annotations import override
|
||||
from ray.tune.trainable import Trainable
|
||||
from ray.tune.trial import Resources
|
||||
|
||||
OPTIMIZER_SHARED_CONFIGS = [
|
||||
"lr",
|
||||
"num_envs_per_worker",
|
||||
"num_gpus",
|
||||
"sample_batch_size",
|
||||
"train_batch_size",
|
||||
"replay_buffer_num_slots",
|
||||
"replay_proportion",
|
||||
"num_data_loader_buffers",
|
||||
"max_sample_requests_in_flight_per_worker",
|
||||
"broadcast_interval",
|
||||
"num_sgd_iter",
|
||||
"minibatch_buffer_size",
|
||||
"num_aggregation_workers",
|
||||
]
|
||||
|
||||
# yapf: disable
|
||||
# __sphinx_doc_begin__
|
||||
DEFAULT_CONFIG = with_common_config({
|
||||
@@ -100,37 +83,57 @@ DEFAULT_CONFIG = with_common_config({
|
||||
# yapf: enable
|
||||
|
||||
|
||||
class ImpalaTrainer(Trainer):
|
||||
"""IMPALA implementation using DeepMind's V-trace."""
|
||||
def choose_policy(config):
|
||||
if config["vtrace"]:
|
||||
return VTraceTFPolicy
|
||||
else:
|
||||
return A3CTFPolicy
|
||||
|
||||
_name = "IMPALA"
|
||||
_default_config = DEFAULT_CONFIG
|
||||
_policy = VTraceTFPolicy
|
||||
|
||||
@override(Trainer)
|
||||
def _init(self, config, env_creator):
|
||||
for k in OPTIMIZER_SHARED_CONFIGS:
|
||||
if k not in config["optimizer"]:
|
||||
config["optimizer"][k] = config[k]
|
||||
policy_cls = self._get_policy()
|
||||
self.workers = self._make_workers(
|
||||
self.env_creator, policy_cls, self.config, num_workers=0)
|
||||
def validate_config(config):
|
||||
if config["entropy_coeff"] < 0:
|
||||
raise DeprecationWarning("entropy_coeff must be >= 0")
|
||||
|
||||
if self.config["num_aggregation_workers"] > 0:
|
||||
# Create co-located aggregator actors first for placement pref
|
||||
aggregators = TreeAggregator.precreate_aggregators(
|
||||
self.config["num_aggregation_workers"])
|
||||
|
||||
self.workers.add_workers(config["num_workers"])
|
||||
self.optimizer = AsyncSamplesOptimizer(self.workers,
|
||||
**config["optimizer"])
|
||||
if config["entropy_coeff"] < 0:
|
||||
raise DeprecationWarning("entropy_coeff must be >= 0")
|
||||
def defer_make_workers(trainer, env_creator, policy, config):
|
||||
# Defer worker creation to after the optimizer has been created.
|
||||
return trainer._make_workers(env_creator, policy, config, 0)
|
||||
|
||||
if self.config["num_aggregation_workers"] > 0:
|
||||
# Assign the pre-created aggregators to the optimizer
|
||||
self.optimizer.aggregator.init(aggregators)
|
||||
|
||||
def make_aggregators_and_optimizer(workers, config):
|
||||
if config["num_aggregation_workers"] > 0:
|
||||
# Create co-located aggregator actors first for placement pref
|
||||
aggregators = TreeAggregator.precreate_aggregators(
|
||||
config["num_aggregation_workers"])
|
||||
else:
|
||||
aggregators = None
|
||||
workers.add_workers(config["num_workers"])
|
||||
|
||||
optimizer = AsyncSamplesOptimizer(
|
||||
workers,
|
||||
lr=config["lr"],
|
||||
num_envs_per_worker=config["num_envs_per_worker"],
|
||||
num_gpus=config["num_gpus"],
|
||||
sample_batch_size=config["sample_batch_size"],
|
||||
train_batch_size=config["train_batch_size"],
|
||||
replay_buffer_num_slots=config["replay_buffer_num_slots"],
|
||||
replay_proportion=config["replay_proportion"],
|
||||
num_data_loader_buffers=config["num_data_loader_buffers"],
|
||||
max_sample_requests_in_flight_per_worker=config[
|
||||
"max_sample_requests_in_flight_per_worker"],
|
||||
broadcast_interval=config["broadcast_interval"],
|
||||
num_sgd_iter=config["num_sgd_iter"],
|
||||
minibatch_buffer_size=config["minibatch_buffer_size"],
|
||||
num_aggregation_workers=config["num_aggregation_workers"],
|
||||
**config["optimizer"])
|
||||
|
||||
if aggregators:
|
||||
# Assign the pre-created aggregators to the optimizer
|
||||
optimizer.aggregator.init(aggregators)
|
||||
return optimizer
|
||||
|
||||
|
||||
class OverrideDefaultResourceRequest(object):
|
||||
@classmethod
|
||||
@override(Trainable)
|
||||
def default_resource_request(cls, config):
|
||||
@@ -143,22 +146,13 @@ class ImpalaTrainer(Trainer):
|
||||
cf["num_aggregation_workers"],
|
||||
extra_gpu=cf["num_gpus_per_worker"] * cf["num_workers"])
|
||||
|
||||
@override(Trainer)
|
||||
def _train(self):
|
||||
prev_steps = self.optimizer.num_steps_sampled
|
||||
start = time.time()
|
||||
self.optimizer.step()
|
||||
while (time.time() - start < self.config["min_iter_time_s"]
|
||||
or self.optimizer.num_steps_sampled == prev_steps):
|
||||
self.optimizer.step()
|
||||
result = self.collect_metrics()
|
||||
result.update(timesteps_this_iter=self.optimizer.num_steps_sampled -
|
||||
prev_steps)
|
||||
return result
|
||||
|
||||
def _get_policy(self):
|
||||
if self.config["vtrace"]:
|
||||
policy_cls = self._policy
|
||||
else:
|
||||
policy_cls = A3CTFPolicy
|
||||
return policy_cls
|
||||
ImpalaTrainer = build_trainer(
|
||||
name="IMPALA",
|
||||
default_config=DEFAULT_CONFIG,
|
||||
default_policy=VTraceTFPolicy,
|
||||
validate_config=validate_config,
|
||||
get_policy_class=choose_policy,
|
||||
make_workers=defer_make_workers,
|
||||
make_policy_optimizer=make_aggregators_and_optimizer,
|
||||
mixins=[OverrideDefaultResourceRequest])
|
||||
|
||||
@@ -2,10 +2,10 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from ray.rllib.agents.trainer import Trainer, with_common_config
|
||||
from ray.rllib.agents.trainer import with_common_config
|
||||
from ray.rllib.agents.trainer_template import build_trainer
|
||||
from ray.rllib.agents.marwil.marwil_policy import MARWILPolicy
|
||||
from ray.rllib.optimizers import SyncBatchReplayOptimizer
|
||||
from ray.rllib.utils.annotations import override
|
||||
|
||||
# yapf: disable
|
||||
# __sphinx_doc_begin__
|
||||
@@ -39,30 +39,17 @@ DEFAULT_CONFIG = with_common_config({
|
||||
# yapf: enable
|
||||
|
||||
|
||||
class MARWILTrainer(Trainer):
|
||||
"""MARWIL implementation in TensorFlow."""
|
||||
def make_optimizer(workers, config):
|
||||
return SyncBatchReplayOptimizer(
|
||||
workers,
|
||||
learning_starts=config["learning_starts"],
|
||||
buffer_size=config["replay_buffer_size"],
|
||||
train_batch_size=config["train_batch_size"],
|
||||
)
|
||||
|
||||
_name = "MARWIL"
|
||||
_default_config = DEFAULT_CONFIG
|
||||
_policy = MARWILPolicy
|
||||
|
||||
@override(Trainer)
|
||||
def _init(self, config, env_creator):
|
||||
self.workers = self._make_workers(env_creator, self._policy, config,
|
||||
config["num_workers"])
|
||||
self.optimizer = SyncBatchReplayOptimizer(
|
||||
self.workers,
|
||||
learning_starts=config["learning_starts"],
|
||||
buffer_size=config["replay_buffer_size"],
|
||||
train_batch_size=config["train_batch_size"],
|
||||
)
|
||||
|
||||
@override(Trainer)
|
||||
def _train(self):
|
||||
prev_steps = self.optimizer.num_steps_sampled
|
||||
fetches = self.optimizer.step()
|
||||
res = self.collect_metrics()
|
||||
res.update(
|
||||
timesteps_this_iter=self.optimizer.num_steps_sampled - prev_steps,
|
||||
info=dict(fetches, **res.get("info", {})))
|
||||
return res
|
||||
MARWILTrainer = build_trainer(
|
||||
name="MARWIL",
|
||||
default_config=DEFAULT_CONFIG,
|
||||
default_policy=MARWILPolicy,
|
||||
make_policy_optimizer=make_optimizer)
|
||||
|
||||
@@ -5,7 +5,6 @@ from __future__ import print_function
|
||||
from ray.rllib.agents.ppo.appo_policy import AsyncPPOTFPolicy
|
||||
from ray.rllib.agents.trainer import with_base_config
|
||||
from ray.rllib.agents import impala
|
||||
from ray.rllib.utils.annotations import override
|
||||
|
||||
# yapf: disable
|
||||
# __sphinx_doc_begin__
|
||||
@@ -51,14 +50,8 @@ DEFAULT_CONFIG = with_base_config(impala.DEFAULT_CONFIG, {
|
||||
# __sphinx_doc_end__
|
||||
# yapf: enable
|
||||
|
||||
|
||||
class APPOTrainer(impala.ImpalaTrainer):
|
||||
"""PPO surrogate loss with IMPALA-architecture."""
|
||||
|
||||
_name = "APPO"
|
||||
_default_config = DEFAULT_CONFIG
|
||||
_policy = AsyncPPOTFPolicy
|
||||
|
||||
@override(impala.ImpalaTrainer)
|
||||
def _get_policy(self):
|
||||
return AsyncPPOTFPolicy
|
||||
APPOTrainer = impala.ImpalaTrainer.with_updates(
|
||||
name="APPO",
|
||||
default_config=DEFAULT_CONFIG,
|
||||
default_policy=AsyncPPOTFPolicy,
|
||||
get_policy_class=lambda _: AsyncPPOTFPolicy)
|
||||
|
||||
@@ -4,15 +4,14 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from ray.rllib.agents.dqn.apex import APEX_TRAINER_PROPERTIES
|
||||
from ray.rllib.agents.qmix.qmix import QMixTrainer, \
|
||||
DEFAULT_CONFIG as QMIX_CONFIG
|
||||
from ray.rllib.utils.annotations import override
|
||||
from ray.rllib.utils import merge_dicts
|
||||
|
||||
APEX_QMIX_DEFAULT_CONFIG = merge_dicts(
|
||||
QMIX_CONFIG, # see also the options in qmix.py, which are also supported
|
||||
{
|
||||
"optimizer_class": "AsyncReplayOptimizer",
|
||||
"optimizer": merge_dicts(
|
||||
QMIX_CONFIG["optimizer"],
|
||||
{
|
||||
@@ -34,23 +33,7 @@ APEX_QMIX_DEFAULT_CONFIG = merge_dicts(
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
class ApexQMixTrainer(QMixTrainer):
|
||||
"""QMIX variant that uses the Ape-X distributed policy optimizer.
|
||||
|
||||
By default, this is configured for a large single node (32 cores). For
|
||||
running in a large cluster, increase the `num_workers` config var.
|
||||
"""
|
||||
|
||||
_name = "APEX_QMIX"
|
||||
_default_config = APEX_QMIX_DEFAULT_CONFIG
|
||||
|
||||
@override(QMixTrainer)
|
||||
def update_target_if_needed(self):
|
||||
# Ape-X updates based on num steps trained, not sampled
|
||||
if self.optimizer.num_steps_trained - self.last_target_update_ts > \
|
||||
self.config["target_network_update_freq"]:
|
||||
self.workers.local_worker().foreach_trainable_policy(
|
||||
lambda p, _: p.update_target())
|
||||
self.last_target_update_ts = self.optimizer.num_steps_trained
|
||||
self.num_target_updates += 1
|
||||
ApexQMixTrainer = QMixTrainer.with_updates(
|
||||
name="APEX_QMIX",
|
||||
default_config=APEX_QMIX_DEFAULT_CONFIG,
|
||||
**APEX_TRAINER_PROPERTIES)
|
||||
|
||||
@@ -3,8 +3,9 @@ from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from ray.rllib.agents.trainer import with_common_config
|
||||
from ray.rllib.agents.dqn.dqn import DQNTrainer
|
||||
from ray.rllib.agents.dqn.dqn import GenericOffPolicyTrainer
|
||||
from ray.rllib.agents.qmix.qmix_policy import QMixTorchPolicy
|
||||
from ray.rllib.optimizers import SyncBatchReplayOptimizer
|
||||
|
||||
# yapf: disable
|
||||
# __sphinx_doc_begin__
|
||||
@@ -71,8 +72,6 @@ DEFAULT_CONFIG = with_common_config({
|
||||
# to increase if your environment is particularly slow to sample, or if
|
||||
# you"re using the Async or Ape-X optimizers.
|
||||
"num_workers": 0,
|
||||
# Optimizer class to use.
|
||||
"optimizer_class": "SyncBatchReplayOptimizer",
|
||||
# Whether to use a distribution of epsilons across workers for exploration.
|
||||
"per_worker_exploration": False,
|
||||
# Whether to compute priorities on workers.
|
||||
@@ -90,12 +89,16 @@ DEFAULT_CONFIG = with_common_config({
|
||||
# yapf: enable
|
||||
|
||||
|
||||
class QMixTrainer(DQNTrainer):
|
||||
"""QMix implementation in PyTorch."""
|
||||
def make_sync_batch_optimizer(workers, config):
|
||||
return SyncBatchReplayOptimizer(
|
||||
workers,
|
||||
learning_starts=config["learning_starts"],
|
||||
buffer_size=config["buffer_size"],
|
||||
train_batch_size=config["train_batch_size"])
|
||||
|
||||
_name = "QMIX"
|
||||
_default_config = DEFAULT_CONFIG
|
||||
_policy = QMixTorchPolicy
|
||||
_optimizer_shared_configs = [
|
||||
"learning_starts", "buffer_size", "train_batch_size"
|
||||
]
|
||||
|
||||
QMixTrainer = GenericOffPolicyTrainer.with_updates(
|
||||
name="QMIX",
|
||||
default_config=DEFAULT_CONFIG,
|
||||
default_policy=QMixTorchPolicy,
|
||||
make_policy_optimizer=make_sync_batch_optimizer)
|
||||
|
||||
@@ -189,6 +189,9 @@ COMMON_CONFIG = {
|
||||
"remote_env_batch_wait_ms": 0,
|
||||
# Minimum time per iteration
|
||||
"min_iter_time_s": 0,
|
||||
# Minimum env steps to optimize for per train call. This value does
|
||||
# not affect learning, only the length of iterations.
|
||||
"timesteps_per_iteration": 0,
|
||||
|
||||
# === Offline Datasets ===
|
||||
# Specify how to generate experiences:
|
||||
@@ -502,6 +505,7 @@ class Trainer(Trainable):
|
||||
|
||||
logger.info("Evaluating current policy for {} episodes".format(
|
||||
self.config["evaluation_num_episodes"]))
|
||||
self._before_evaluate()
|
||||
self.evaluation_workers.local_worker().restore(
|
||||
self.workers.local_worker().save())
|
||||
for _ in range(self.config["evaluation_num_episodes"]):
|
||||
@@ -510,6 +514,11 @@ class Trainer(Trainable):
|
||||
metrics = collect_metrics(self.evaluation_workers.local_worker())
|
||||
return {"evaluation": metrics}
|
||||
|
||||
@DeveloperAPI
|
||||
def _before_evaluate(self):
|
||||
"""Pre-evaluation callback."""
|
||||
pass
|
||||
|
||||
@PublicAPI
|
||||
def compute_action(self,
|
||||
observation,
|
||||
|
||||
@@ -6,6 +6,7 @@ import time
|
||||
|
||||
from ray.rllib.agents.trainer import Trainer, COMMON_CONFIG
|
||||
from ray.rllib.optimizers import SyncSamplesOptimizer
|
||||
from ray.rllib.utils import add_mixins
|
||||
from ray.rllib.utils.annotations import override, DeveloperAPI
|
||||
|
||||
|
||||
@@ -13,25 +14,47 @@ from ray.rllib.utils.annotations import override, DeveloperAPI
|
||||
def build_trainer(name,
|
||||
default_policy,
|
||||
default_config=None,
|
||||
make_policy_optimizer=None,
|
||||
validate_config=None,
|
||||
get_initial_state=None,
|
||||
get_policy_class=None,
|
||||
before_init=None,
|
||||
make_workers=None,
|
||||
make_policy_optimizer=None,
|
||||
after_init=None,
|
||||
before_train_step=None,
|
||||
after_optimizer_step=None,
|
||||
after_train_result=None):
|
||||
after_train_result=None,
|
||||
collect_metrics_fn=None,
|
||||
before_evaluate_fn=None,
|
||||
mixins=None):
|
||||
"""Helper function for defining a custom trainer.
|
||||
|
||||
Functions will be run in this order to initialize the trainer:
|
||||
1. Config setup: validate_config, get_initial_state, get_policy
|
||||
2. Worker setup: before_init, make_workers, make_policy_optimizer
|
||||
3. Post setup: after_init
|
||||
|
||||
Arguments:
|
||||
name (str): name of the trainer (e.g., "PPO")
|
||||
default_policy (cls): the default Policy class to use
|
||||
default_config (dict): the default config dict of the algorithm,
|
||||
otherwises uses the Trainer default config
|
||||
make_policy_optimizer (func): optional function that returns a
|
||||
PolicyOptimizer instance given (WorkerSet, config)
|
||||
validate_config (func): optional callback that checks a given config
|
||||
for correctness. It may mutate the config as needed.
|
||||
get_initial_state (func): optional function that returns the initial
|
||||
state dict given the trainer instance as an argument. The state
|
||||
dict must be serializable so that it can be checkpointed, and will
|
||||
be available as the `trainer.state` variable.
|
||||
get_policy_class (func): optional callback that takes a config and
|
||||
returns the policy class to override the default with
|
||||
before_init (func): optional function to run at the start of trainer
|
||||
init that takes the trainer instance as argument
|
||||
make_workers (func): override the method that creates rollout workers.
|
||||
This takes in (trainer, env_creator, policy, config) as args.
|
||||
make_policy_optimizer (func): optional function that returns a
|
||||
PolicyOptimizer instance given (WorkerSet, config)
|
||||
after_init (func): optional function to run at the end of trainer init
|
||||
that takes the trainer instance as argument
|
||||
before_train_step (func): optional callback to run before each train()
|
||||
call. It takes the trainer instance as an argument.
|
||||
after_optimizer_step (func): optional callback to run after each
|
||||
@@ -40,27 +63,47 @@ def build_trainer(name,
|
||||
after_train_result (func): optional callback to run at the end of each
|
||||
train() call. It takes the trainer instance and result dict as
|
||||
arguments, and may mutate the result dict as needed.
|
||||
collect_metrics_fn (func): override the method used to collect metrics.
|
||||
It takes the trainer instance as argumnt.
|
||||
before_evaluate_fn (func): callback to run before evaluation. This
|
||||
takes the trainer instance as argument.
|
||||
mixins (list): list of any class mixins for the returned trainer class.
|
||||
These mixins will be applied in order and will have higher
|
||||
precedence than the Trainer class
|
||||
|
||||
Returns:
|
||||
a Trainer instance that uses the specified args.
|
||||
"""
|
||||
|
||||
original_kwargs = locals().copy()
|
||||
base = add_mixins(Trainer, mixins)
|
||||
|
||||
class trainer_cls(Trainer):
|
||||
class trainer_cls(base):
|
||||
_name = name
|
||||
_default_config = default_config or COMMON_CONFIG
|
||||
_policy = default_policy
|
||||
|
||||
def __init__(self, config=None, env=None, logger_creator=None):
|
||||
Trainer.__init__(self, config, env, logger_creator)
|
||||
|
||||
def _init(self, config, env_creator):
|
||||
if validate_config:
|
||||
validate_config(config)
|
||||
if get_initial_state:
|
||||
self.state = get_initial_state(self)
|
||||
else:
|
||||
self.state = {}
|
||||
if get_policy_class is None:
|
||||
policy = default_policy
|
||||
else:
|
||||
policy = get_policy_class(config)
|
||||
self.workers = self._make_workers(env_creator, policy, config,
|
||||
self.config["num_workers"])
|
||||
if before_init:
|
||||
before_init(self)
|
||||
if make_workers:
|
||||
self.workers = make_workers(self, env_creator, policy, config)
|
||||
else:
|
||||
self.workers = self._make_workers(env_creator, policy, config,
|
||||
self.config["num_workers"])
|
||||
if make_policy_optimizer:
|
||||
self.optimizer = make_policy_optimizer(self.workers, config)
|
||||
else:
|
||||
@@ -69,6 +112,8 @@ def build_trainer(name,
|
||||
**{"train_batch_size": config["train_batch_size"]})
|
||||
self.optimizer = SyncSamplesOptimizer(self.workers,
|
||||
**optimizer_config)
|
||||
if after_init:
|
||||
after_init(self)
|
||||
|
||||
@override(Trainer)
|
||||
def _train(self):
|
||||
@@ -81,20 +126,46 @@ def build_trainer(name,
|
||||
fetches = self.optimizer.step()
|
||||
if after_optimizer_step:
|
||||
after_optimizer_step(self, fetches)
|
||||
if time.time() - start > self.config["min_iter_time_s"]:
|
||||
if (time.time() - start >= self.config["min_iter_time_s"]
|
||||
and self.optimizer.num_steps_sampled - prev_steps >=
|
||||
self.config["timesteps_per_iteration"]):
|
||||
break
|
||||
|
||||
res = self.collect_metrics()
|
||||
if collect_metrics_fn:
|
||||
res = collect_metrics_fn(self)
|
||||
else:
|
||||
res = self.collect_metrics()
|
||||
res.update(
|
||||
timesteps_this_iter=self.optimizer.num_steps_sampled -
|
||||
prev_steps,
|
||||
info=res.get("info", {}))
|
||||
|
||||
if after_train_result:
|
||||
after_train_result(self, res)
|
||||
return res
|
||||
|
||||
@override(Trainer)
|
||||
def _before_evaluate(self):
|
||||
if before_evaluate_fn:
|
||||
before_evaluate_fn(self)
|
||||
|
||||
def __getstate__(self):
|
||||
state = Trainer.__getstate__(self)
|
||||
state.update(self.state)
|
||||
return state
|
||||
|
||||
def __setstate__(self, state):
|
||||
Trainer.__setstate__(self, state)
|
||||
self.state = state
|
||||
|
||||
@staticmethod
|
||||
def with_updates(**overrides):
|
||||
"""Build a copy of this trainer with the specified overrides.
|
||||
|
||||
Arguments:
|
||||
overrides (dict): use this to override any of the arguments
|
||||
originally passed to build_trainer() for this policy.
|
||||
"""
|
||||
return build_trainer(**dict(original_kwargs, **overrides))
|
||||
|
||||
trainer_cls.with_updates = with_updates
|
||||
|
||||
@@ -5,6 +5,7 @@ from __future__ import print_function
|
||||
from ray.rllib.policy.dynamic_tf_policy import DynamicTFPolicy
|
||||
from ray.rllib.policy.policy import Policy, LEARNER_STATS_KEY
|
||||
from ray.rllib.policy.tf_policy import TFPolicy
|
||||
from ray.rllib.utils import add_mixins
|
||||
from ray.rllib.utils.annotations import override, DeveloperAPI
|
||||
|
||||
|
||||
@@ -89,13 +90,7 @@ def build_tf_policy(name,
|
||||
"""
|
||||
|
||||
original_kwargs = locals().copy()
|
||||
base = DynamicTFPolicy
|
||||
while mixins:
|
||||
|
||||
class new_base(mixins.pop(), base):
|
||||
pass
|
||||
|
||||
base = new_base
|
||||
base = add_mixins(DynamicTFPolicy, mixins)
|
||||
|
||||
class policy_cls(base):
|
||||
def __init__(self,
|
||||
|
||||
@@ -5,6 +5,7 @@ from __future__ import print_function
|
||||
from ray.rllib.policy.policy import Policy
|
||||
from ray.rllib.policy.torch_policy import TorchPolicy
|
||||
from ray.rllib.models.catalog import ModelCatalog
|
||||
from ray.rllib.utils import add_mixins
|
||||
from ray.rllib.utils.annotations import override, DeveloperAPI
|
||||
|
||||
|
||||
@@ -56,13 +57,7 @@ def build_torch_policy(name,
|
||||
"""
|
||||
|
||||
original_kwargs = locals().copy()
|
||||
base = TorchPolicy
|
||||
while mixins:
|
||||
|
||||
class new_base(mixins.pop(), base):
|
||||
pass
|
||||
|
||||
base = new_base
|
||||
base = add_mixins(TorchPolicy, mixins)
|
||||
|
||||
class policy_cls(base):
|
||||
def __init__(self, obs_space, action_space, config):
|
||||
|
||||
@@ -27,6 +27,21 @@ def renamed_class(cls, old_name):
|
||||
return DeprecationWrapper
|
||||
|
||||
|
||||
def add_mixins(base, mixins):
|
||||
"""Returns a new class with mixins applied in priority order."""
|
||||
|
||||
mixins = list(mixins or [])
|
||||
|
||||
while mixins:
|
||||
|
||||
class new_base(mixins.pop(), base):
|
||||
pass
|
||||
|
||||
base = new_base
|
||||
|
||||
return base
|
||||
|
||||
|
||||
def renamed_agent(cls):
|
||||
"""Helper class for renaming Agent => Trainer with a warning."""
|
||||
|
||||
|
||||
Reference in New Issue
Block a user