[rllib] Rename PolicyEvaluator => RolloutWorker (#4820)

This commit is contained in:
Eric Liang
2019-06-03 06:49:24 +08:00
committed by GitHub
parent 99eae05cf6
commit 7501ee51db
59 changed files with 1538 additions and 1474 deletions
+2 -1
View File
@@ -11,7 +11,7 @@ from ray.tune.registry import register_trainable
from ray.rllib.evaluation.policy_graph import PolicyGraph
from ray.rllib.evaluation.tf_policy_graph import TFPolicyGraph
from ray.rllib.evaluation.policy_evaluator import PolicyEvaluator
from ray.rllib.evaluation.rollout_worker import RolloutWorker
from ray.rllib.env.base_env import BaseEnv
from ray.rllib.env.multi_agent_env import MultiAgentEnv
from ray.rllib.env.vector_env import VectorEnv
@@ -55,6 +55,7 @@ __all__ = [
"PolicyGraph",
"TFPolicy",
"TFPolicyGraph",
"RolloutWorker",
"PolicyEvaluator",
"SampleBatch",
"BaseEnv",
+10 -16
View File
@@ -2,9 +2,10 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from ray.rllib.agents.a3c.a3c import A3CTrainer, DEFAULT_CONFIG as A3C_CONFIG
from ray.rllib.optimizers import SyncSamplesOptimizer
from ray.rllib.utils.annotations import override
from ray.rllib.agents.a3c.a3c import DEFAULT_CONFIG as A3C_CONFIG, \
validate_config, get_policy_class
from ray.rllib.agents.a3c.a3c_tf_policy import A3CTFPolicy
from ray.rllib.agents.trainer_template import build_trainer
from ray.rllib.utils import merge_dicts
A2C_DEFAULT_CONFIG = merge_dicts(
@@ -16,16 +17,9 @@ A2C_DEFAULT_CONFIG = merge_dicts(
},
)
class A2CTrainer(A3CTrainer):
"""Synchronous variant of the A3CTrainer."""
_name = "A2C"
_default_config = A2C_DEFAULT_CONFIG
@override(A3CTrainer)
def _make_optimizer(self):
return SyncSamplesOptimizer(
self.local_evaluator,
self.remote_evaluators,
train_batch_size=self.config["train_batch_size"])
A2CTrainer = build_trainer(
name="A2C",
default_config=A2C_DEFAULT_CONFIG,
default_policy=A3CTFPolicy,
get_policy_class=get_policy_class,
validate_config=validate_config)
+21 -38
View File
@@ -2,12 +2,10 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import time
from ray.rllib.agents.a3c.a3c_tf_policy import A3CTFPolicy
from ray.rllib.agents.trainer import Trainer, with_common_config
from ray.rllib.agents.trainer import with_common_config
from ray.rllib.agents.trainer_template import build_trainer
from ray.rllib.optimizers import AsyncGradientsOptimizer
from ray.rllib.utils.annotations import override
# yapf: disable
# __sphinx_doc_begin__
@@ -38,43 +36,28 @@ DEFAULT_CONFIG = with_common_config({
# yapf: enable
class A3CTrainer(Trainer):
"""A3C implementations in TensorFlow and PyTorch."""
def get_policy_class(config):
if config["use_pytorch"]:
from ray.rllib.agents.a3c.a3c_torch_policy import \
A3CTorchPolicy
return A3CTorchPolicy
else:
return A3CTFPolicy
_name = "A3C"
_default_config = DEFAULT_CONFIG
_policy = A3CTFPolicy
@override(Trainer)
def _init(self, config, env_creator):
if config["use_pytorch"]:
from ray.rllib.agents.a3c.a3c_torch_policy import \
A3CTorchPolicy
policy_cls = A3CTorchPolicy
else:
policy_cls = self._policy
def validate_config(config):
if config["entropy_coeff"] < 0:
raise DeprecationWarning("entropy_coeff must be >= 0")
if config["entropy_coeff"] < 0:
raise DeprecationWarning("entropy_coeff must be >= 0")
self.local_evaluator = self.make_local_evaluator(
env_creator, policy_cls)
self.remote_evaluators = self.make_remote_evaluators(
env_creator, policy_cls, config["num_workers"])
self.optimizer = self._make_optimizer()
def make_async_optimizer(workers, config):
return AsyncGradientsOptimizer(workers, **config["optimizer"])
@override(Trainer)
def _train(self):
prev_steps = self.optimizer.num_steps_sampled
start = time.time()
while time.time() - start < self.config["min_iter_time_s"]:
self.optimizer.step()
result = self.collect_metrics()
result.update(timesteps_this_iter=self.optimizer.num_steps_sampled -
prev_steps)
return result
def _make_optimizer(self):
return AsyncGradientsOptimizer(self.local_evaluator,
self.remote_evaluators,
**self.config["optimizer"])
A3CTrainer = build_trainer(
name="A3C",
default_config=DEFAULT_CONFIG,
default_policy=A3CTFPolicy,
get_policy_class=get_policy_class,
validate_config=validate_config,
make_policy_optimizer=make_async_optimizer)
+1 -1
View File
@@ -48,7 +48,7 @@ class ApexDDPGTrainer(DDPGTrainer):
# Ape-X updates based on num steps trained, not sampled
if self.optimizer.num_steps_trained - self.last_target_update_ts > \
self.config["target_network_update_freq"]:
self.local_evaluator.foreach_trainable_policy(
self.workers.local_worker().foreach_trainable_policy(
lambda p, _: p.update_target())
self.last_target_update_ts = self.optimizer.num_steps_trained
self.num_target_updates += 1
+2 -2
View File
@@ -171,9 +171,9 @@ class DDPGTrainer(DQNTrainer):
if pure_expl_steps:
# tell workers whether they should do pure exploration
only_explore = self.global_timestep < pure_expl_steps
self.local_evaluator.foreach_trainable_policy(
self.workers.local_worker().foreach_trainable_policy(
lambda p, _: p.set_pure_exploration_phase(only_explore))
for e in self.remote_evaluators:
for e in self.workers.remote_workers():
e.foreach_trainable_policy.remote(
lambda p, _: p.set_pure_exploration_phase(only_explore))
return super(DDPGTrainer, self)._train()
+1 -1
View File
@@ -515,7 +515,7 @@ class DDPGTFPolicy(DDPGPostprocessing, TFPolicy):
stochastic_actions = tf.cond(
# need to condition on noise_scale > 0 because zeroing
# noise_scale is how evaluator signals no noise should be used
# noise_scale is how a worker signals no noise should be used
# (this is ugly and should be fixed by adding an "eval_mode"
# config flag or something)
tf.logical_and(enable_pure_exploration, noise_scale > 0),
+1 -1
View File
@@ -51,7 +51,7 @@ class ApexTrainer(DQNTrainer):
# Ape-X updates based on num steps trained, not sampled
if self.optimizer.num_steps_trained - self.last_target_update_ts > \
self.config["target_network_update_freq"]:
self.local_evaluator.foreach_trainable_policy(
self.workers.local_worker().foreach_trainable_policy(
lambda p, _: p.update_target())
self.last_target_update_ts = self.optimizer.num_steps_trained
self.num_target_updates += 1
+26 -24
View File
@@ -196,26 +196,26 @@ class DQNTrainer(Trainer):
config["callbacks"]["on_episode_end"] = tune.function(
on_episode_end)
self.local_evaluator = self.make_local_evaluator(
env_creator, self._policy)
def create_remote_evaluators():
return self.make_remote_evaluators(env_creator, self._policy,
config["num_workers"])
if config["optimizer_class"] != "AsyncReplayOptimizer":
self.remote_evaluators = create_remote_evaluators()
self.workers = self._make_workers(
env_creator,
self._policy,
config,
num_workers=self.config["num_workers"])
workers_needed = 0
else:
# Hack to workaround https://github.com/ray-project/ray/issues/2541
self.remote_evaluators = None
self.workers = self._make_workers(
env_creator, self._policy, config, num_workers=0)
workers_needed = self.config["num_workers"]
self.optimizer = getattr(optimizers, config["optimizer_class"])(
self.local_evaluator, self.remote_evaluators,
**config["optimizer"])
# Create the remote evaluators *after* the replay actors
if self.remote_evaluators is None:
self.remote_evaluators = create_remote_evaluators()
self.optimizer._set_evaluators(self.remote_evaluators)
self.workers, **config["optimizer"])
# Create the remote workers *after* the replay actors
if workers_needed > 0:
self.workers.add_workers(workers_needed)
self.optimizer._set_workers(self.workers.remote_workers())
self.last_target_update_ts = 0
self.num_target_updates = 0
@@ -226,9 +226,9 @@ class DQNTrainer(Trainer):
# Update worker explorations
exp_vals = [self.exploration0.value(self.global_timestep)]
self.local_evaluator.foreach_trainable_policy(
self.workers.local_worker().foreach_trainable_policy(
lambda p, _: p.set_epsilon(exp_vals[0]))
for i, e in enumerate(self.remote_evaluators):
for i, e in enumerate(self.workers.remote_workers()):
exp_val = self.explorations[i].value(self.global_timestep)
e.foreach_trainable_policy.remote(
lambda p, _: p.set_epsilon(exp_val))
@@ -245,8 +245,8 @@ class DQNTrainer(Trainer):
if self.config["per_worker_exploration"]:
# Only collect metrics from the third of workers with lowest eps
result = self.collect_metrics(
selected_evaluators=self.remote_evaluators[
-len(self.remote_evaluators) // 3:])
selected_workers=self.workers.remote_workers()[
-len(self.workers.remote_workers()) // 3:])
else:
result = self.collect_metrics()
@@ -263,7 +263,7 @@ class DQNTrainer(Trainer):
def update_target_if_needed(self):
if self.global_timestep - self.last_target_update_ts > \
self.config["target_network_update_freq"]:
self.local_evaluator.foreach_trainable_policy(
self.workers.local_worker().foreach_trainable_policy(
lambda p, _: p.update_target())
self.last_target_update_ts = self.global_timestep
self.num_target_updates += 1
@@ -275,11 +275,13 @@ class DQNTrainer(Trainer):
def _evaluate(self):
logger.info("Evaluating current policy for {} episodes".format(
self.config["evaluation_num_episodes"]))
self.evaluation_ev.restore(self.local_evaluator.save())
self.evaluation_ev.foreach_policy(lambda p, _: p.set_epsilon(0))
self.evaluation_workers.local_worker().restore(
self.workers.local_worker().save())
self.evaluation_workers.local_worker().foreach_policy(
lambda p, _: p.set_epsilon(0))
for _ in range(self.config["evaluation_num_episodes"]):
self.evaluation_ev.sample()
metrics = collect_metrics(self.evaluation_ev)
self.evaluation_workers.local_worker().sample()
metrics = collect_metrics(self.evaluation_workers.local_worker())
return {"evaluation": metrics}
def _make_exploration_schedule(self, worker_index):
+5 -5
View File
@@ -192,7 +192,7 @@ class ESTrainer(Trainer):
# Create the actors.
logger.info("Creating actors.")
self.workers = [
self._workers = [
Worker.remote(config, policy_params, env_creator, noise_id)
for _ in range(config["num_workers"])
]
@@ -270,7 +270,7 @@ class ESTrainer(Trainer):
# Now sync the filters
FilterManager.synchronize({
DEFAULT_POLICY_ID: self.policy.get_filter()
}, self.workers)
}, self._workers)
info = {
"weights_norm": np.square(theta).sum(),
@@ -296,7 +296,7 @@ class ESTrainer(Trainer):
@override(Trainer)
def _stop(self):
# workaround for https://github.com/ray-project/ray/issues/1516
for w in self.workers:
for w in self._workers:
w.__ray_terminate__.remote()
def _collect_results(self, theta_id, min_episodes, min_timesteps):
@@ -307,7 +307,7 @@ class ESTrainer(Trainer):
"Collected {} episodes {} timesteps so far this iter".format(
num_episodes, num_timesteps))
rollout_ids = [
worker.do_rollouts.remote(theta_id) for worker in self.workers
worker.do_rollouts.remote(theta_id) for worker in self._workers
]
# Get the results of the rollouts.
for result in ray_get_and_free(rollout_ids):
@@ -334,4 +334,4 @@ class ESTrainer(Trainer):
self.policy.set_filter(state["filter"])
FilterManager.synchronize({
DEFAULT_POLICY_ID: self.policy.get_filter()
}, self.workers)
}, self._workers)
+4 -6
View File
@@ -113,18 +113,16 @@ class ImpalaTrainer(Trainer):
if k not in config["optimizer"]:
config["optimizer"][k] = config[k]
policy_cls = self._get_policy()
self.local_evaluator = self.make_local_evaluator(
self.env_creator, policy_cls)
self.workers = self._make_workers(
self.env_creator, policy_cls, self.config, num_workers=0)
if self.config["num_aggregation_workers"] > 0:
# Create co-located aggregator actors first for placement pref
aggregators = TreeAggregator.precreate_aggregators(
self.config["num_aggregation_workers"])
self.remote_evaluators = self.make_remote_evaluators(
env_creator, policy_cls, config["num_workers"])
self.optimizer = AsyncSamplesOptimizer(self.local_evaluator,
self.remote_evaluators,
self.workers.add_workers(config["num_workers"])
self.optimizer = AsyncSamplesOptimizer(self.workers,
**config["optimizer"])
if config["entropy_coeff"] < 0:
raise DeprecationWarning("entropy_coeff must be >= 0")
+3 -6
View File
@@ -48,13 +48,10 @@ class MARWILTrainer(Trainer):
@override(Trainer)
def _init(self, config, env_creator):
self.local_evaluator = self.make_local_evaluator(
env_creator, self._policy)
self.remote_evaluators = self.make_remote_evaluators(
env_creator, self._policy, config["num_workers"])
self.workers = self._make_workers(env_creator, self._policy, config,
config["num_workers"])
self.optimizer = SyncBatchReplayOptimizer(
self.local_evaluator,
self.remote_evaluators,
self.workers,
learning_starts=config["learning_starts"],
buffer_size=config["replay_buffer_size"],
train_batch_size=config["train_batch_size"],
+1 -1
View File
@@ -29,7 +29,7 @@ def get_policy_class(config):
PGTrainer = build_trainer(
name="PGTrainer",
name="PG",
default_config=DEFAULT_CONFIG,
default_policy=PGTFPolicy,
get_policy_class=get_policy_class)
+6 -8
View File
@@ -63,17 +63,15 @@ DEFAULT_CONFIG = with_common_config({
# yapf: enable
def choose_policy_optimizer(local_evaluator, remote_evaluators, config):
def choose_policy_optimizer(workers, config):
if config["simple_optimizer"]:
return SyncSamplesOptimizer(
local_evaluator,
remote_evaluators,
workers,
num_sgd_iter=config["num_sgd_iter"],
train_batch_size=config["train_batch_size"])
return LocalMultiGPUOptimizer(
local_evaluator,
remote_evaluators,
workers,
sgd_batch_size=config["sgd_minibatch_size"],
num_sgd_iter=config["num_sgd_iter"],
num_gpus=config["num_gpus"],
@@ -87,7 +85,7 @@ def choose_policy_optimizer(local_evaluator, remote_evaluators, config):
def update_kl(trainer, fetches):
if "kl" in fetches:
# single-agent
trainer.local_evaluator.for_policy(
trainer.workers.local_worker().for_policy(
lambda pi: pi.update_kl(fetches["kl"]))
else:
@@ -98,7 +96,7 @@ def update_kl(trainer, fetches):
logger.debug("No data for {}, not updating kl".format(pi_id))
# multi-agent
trainer.local_evaluator.foreach_trainable_policy(update)
trainer.workers.local_worker().foreach_trainable_policy(update)
def warn_about_obs_filter(trainer):
@@ -155,7 +153,7 @@ def validate_config(config):
PPOTrainer = build_trainer(
name="PPOTrainer",
name="PPO",
default_config=DEFAULT_CONFIG,
default_policy=PPOTFPolicy,
make_policy_optimizer=choose_policy_optimizer,
+1 -1
View File
@@ -50,7 +50,7 @@ class ApexQMixTrainer(QMixTrainer):
# Ape-X updates based on num steps trained, not sampled
if self.optimizer.num_steps_trained - self.last_target_update_ts > \
self.config["target_network_update_freq"]:
self.local_evaluator.foreach_trainable_policy(
self.workers.local_worker().foreach_trainable_policy(
lambda p, _: p.update_target())
self.last_target_update_ts = self.optimizer.num_steps_trained
self.num_target_updates += 1
+64 -181
View File
@@ -10,18 +10,14 @@ import pickle
import six
import time
import tempfile
from types import FunctionType
import ray
from ray.exceptions import RayError
from ray.rllib.offline import NoopOutput, JsonReader, MixedInput, JsonWriter, \
ShuffledInput
from ray.rllib.models import MODEL_DEFAULTS
from ray.rllib.evaluation.policy_evaluator import PolicyEvaluator, \
_validate_multiagent_config
from ray.rllib.policy.sample_batch import DEFAULT_POLICY_ID
from ray.rllib.evaluation.metrics import collect_metrics
from ray.rllib.optimizers.policy_optimizer import PolicyOptimizer
from ray.rllib.evaluation.worker_set import WorkerSet
from ray.rllib.utils.annotations import override, PublicAPI, DeveloperAPI
from ray.rllib.utils import FilterManager, deep_update, merge_dicts
from ray.rllib.utils.memory import ray_get_and_free
@@ -46,7 +42,7 @@ COMMON_CONFIG = {
# === Debugging ===
# Whether to write episode stats and videos to the agent log dir
"monitor": False,
# Set the ray.rllib.* log level for the agent process and its evaluators.
# Set the ray.rllib.* log level for the agent process and its workers.
# Should be one of DEBUG, INFO, WARN, or ERROR. The DEBUG level will also
# periodically print out summaries of relevant internal dataflow (this is
# also printed out once at startup at the INFO level).
@@ -60,7 +56,7 @@ COMMON_CONFIG = {
"on_episode_start": None, # arg: {"env": .., "episode": ...}
"on_episode_step": None, # arg: {"env": .., "episode": ...}
"on_episode_end": None, # arg: {"env": .., "episode": ...}
"on_sample_end": None, # arg: {"samples": .., "evaluator": ...}
"on_sample_end": None, # arg: {"samples": .., "worker": ...}
"on_train_result": None, # arg: {"trainer": ..., "result": ...}
"on_postprocess_traj": None, # arg: {
# "agent_id": ..., "episode": ...,
@@ -153,7 +149,7 @@ COMMON_CONFIG = {
"synchronize_filters": True,
# Configure TF for single-process operation by default
"tf_session_args": {
# note: overriden by `local_evaluator_tf_session_args`
# note: overriden by `local_tf_session_args`
"intra_op_parallelism_threads": 2,
"inter_op_parallelism_threads": 2,
"gpu_options": {
@@ -165,8 +161,8 @@ COMMON_CONFIG = {
},
"allow_soft_placement": True, # required by PPO multi-gpu
},
# Override the following tf session args on the local evaluator
"local_evaluator_tf_session_args": {
# Override the following tf session args on the local worker
"local_tf_session_args": {
# Allow a higher level of parallelism by default, but not unlimited
# since that can cause crashes with many concurrent drivers.
"intra_op_parallelism_threads": 8,
@@ -188,6 +184,8 @@ COMMON_CONFIG = {
# but optimal value could be obtained by measuring your environment
# step / reset and model inference perf.
"remote_env_batch_wait_ms": 0,
# Minimum time per iteration
"min_iter_time_s": 0,
# === Offline Datasets ===
# Specify how to generate experiences:
@@ -229,7 +227,7 @@ COMMON_CONFIG = {
# === Multiagent ===
"multiagent": {
# Map from policy ids to tuples of (policy_cls, obs_space,
# act_space, config). See policy_evaluator.py for more info.
# act_space, config). See rollout_worker.py for more info.
"policies": {},
# Function mapping agent ids to policy ids.
"policy_mapping_fn": None,
@@ -292,7 +290,7 @@ class Trainer(Trainable):
config = config or {}
# Vars to synchronize to evaluators on each train call
# Vars to synchronize to workers on each train call
self.global_vars = {"timestep": 0}
# Trainers allow env ids to be passed directly to the constructor.
@@ -337,9 +335,10 @@ class Trainer(Trainable):
if self._has_policy_optimizer():
self.global_vars["timestep"] = self.optimizer.num_steps_sampled
self.optimizer.local_evaluator.set_global_vars(self.global_vars)
for ev in self.optimizer.remote_evaluators:
ev.set_global_vars.remote(self.global_vars)
self.optimizer.workers.local_worker().set_global_vars(
self.global_vars)
for w in self.optimizer.workers.remote_workers():
w.set_global_vars.remote(self.global_vars)
logger.debug("updated global vars: {}".format(self.global_vars))
result = None
@@ -366,17 +365,18 @@ class Trainer(Trainable):
raise RuntimeError("Failed to recover from worker crash")
if (self.config.get("observation_filter", "NoFilter") != "NoFilter"
and hasattr(self, "local_evaluator")):
and hasattr(self, "workers")
and isinstance(self.workers, WorkerSet)):
FilterManager.synchronize(
self.local_evaluator.filters,
self.remote_evaluators,
self.workers.local_worker().filters,
self.workers.remote_workers(),
update_remote=self.config["synchronize_filters"])
logger.debug("synchronized filters: {}".format(
self.local_evaluator.filters))
self.workers.local_worker().filters))
if self._has_policy_optimizer():
result["num_healthy_workers"] = len(
self.optimizer.remote_evaluators)
self.optimizer.workers.remote_workers())
if self.config["evaluation_interval"]:
if self._iteration % self.config["evaluation_interval"] == 0:
@@ -441,25 +441,17 @@ class Trainer(Trainable):
})
logger.debug(
"using evaluation_config: {}".format(extra_config))
# Make local evaluation evaluators
self.evaluation_ev = self.make_local_evaluator(
self.env_creator, self._policy, extra_config=extra_config)
self.evaluation_workers = self._make_workers(
self.env_creator,
self._policy,
merge_dicts(self.config, extra_config),
num_workers=0)
self.evaluation_metrics = self._evaluate()
@override(Trainable)
def _stop(self):
# Call stop on all evaluators to release resources
if hasattr(self, "local_evaluator"):
self.local_evaluator.stop()
if hasattr(self, "remote_evaluators"):
for ev in self.remote_evaluators:
ev.stop.remote()
# workaround for https://github.com/ray-project/ray/issues/1516
if hasattr(self, "remote_evaluators"):
for ev in self.remote_evaluators:
ev.__ray_terminate__.remote()
if hasattr(self, "workers"):
self.workers.stop()
if hasattr(self, "optimizer"):
self.optimizer.stop()
@@ -475,6 +467,15 @@ class Trainer(Trainable):
extra_data = pickle.load(open(checkpoint_path, "rb"))
self.__setstate__(extra_data)
@DeveloperAPI
def _make_workers(self, env_creator, policy, config, num_workers):
return WorkerSet(
env_creator,
policy,
config,
num_workers=num_workers,
logdir=self.logdir)
@DeveloperAPI
def _init(self, config, env_creator):
"""Subclasses should override this for custom initialization."""
@@ -498,11 +499,12 @@ class Trainer(Trainable):
logger.info("Evaluating current policy for {} episodes".format(
self.config["evaluation_num_episodes"]))
self.evaluation_ev.restore(self.local_evaluator.save())
self.evaluation_workers.local_worker().restore(
self.workers.local_worker().save())
for _ in range(self.config["evaluation_num_episodes"]):
self.evaluation_ev.sample()
self.evaluation_workers.local_worker().sample()
metrics = collect_metrics(self.evaluation_ev)
metrics = collect_metrics(self.evaluation_workers.local_worker())
return {"evaluation": metrics}
@PublicAPI
@@ -540,9 +542,9 @@ class Trainer(Trainable):
if state is None:
state = []
preprocessed = self.local_evaluator.preprocessors[policy_id].transform(
observation)
filtered_obs = self.local_evaluator.filters[policy_id](
preprocessed = self.workers.local_worker().preprocessors[
policy_id].transform(observation)
filtered_obs = self.workers.local_worker().filters[policy_id](
preprocessed, update=False)
if state:
return self.get_policy(policy_id).compute_single_action(
@@ -590,7 +592,7 @@ class Trainer(Trainable):
policy_id (str): id of policy to return.
"""
return self.local_evaluator.get_policy(policy_id)
return self.workers.local_worker().get_policy(policy_id)
@PublicAPI
def get_weights(self, policies=None):
@@ -600,7 +602,7 @@ class Trainer(Trainable):
policies (list): Optional list of policies to return weights for,
or None for all policies.
"""
return self.local_evaluator.get_weights(policies)
return self.workers.local_worker().get_weights(policies)
@PublicAPI
def set_weights(self, weights):
@@ -609,42 +611,7 @@ class Trainer(Trainable):
Arguments:
weights (dict): Map of policy ids to weights to set.
"""
self.local_evaluator.set_weights(weights)
@DeveloperAPI
def make_local_evaluator(self, env_creator, policy, extra_config=None):
"""Convenience method to return configured local evaluator."""
return self._make_evaluator(
PolicyEvaluator,
env_creator,
policy,
0,
merge_dicts(
# important: allow local tf to use more CPUs for optimization
merge_dicts(
self.config, {
"tf_session_args": self.
config["local_evaluator_tf_session_args"]
}),
extra_config or {}))
@DeveloperAPI
def make_remote_evaluators(self, env_creator, policy, count):
"""Convenience method to return a number of remote evaluators."""
remote_args = {
"num_cpus": self.config["num_cpus_per_worker"],
"num_gpus": self.config["num_gpus_per_worker"],
"resources": self.config["custom_resources_per_worker"],
}
cls = PolicyEvaluator.as_remote(**remote_args).remote
return [
self._make_evaluator(cls, env_creator, policy, i + 1, self.config)
for i in range(count)
]
self.workers.local_worker().set_weights(weights)
@DeveloperAPI
def export_policy_model(self, export_dir, policy_id=DEFAULT_POLICY_ID):
@@ -660,7 +627,7 @@ class Trainer(Trainable):
>>> trainer.train()
>>> trainer.export_policy_model("/tmp/export_dir")
"""
self.local_evaluator.export_policy_model(export_dir, policy_id)
self.workers.local_worker().export_policy_model(export_dir, policy_id)
@DeveloperAPI
def export_policy_checkpoint(self,
@@ -680,19 +647,19 @@ class Trainer(Trainable):
>>> trainer.train()
>>> trainer.export_policy_checkpoint("/tmp/export_dir")
"""
self.local_evaluator.export_policy_checkpoint(
self.workers.local_worker().export_policy_checkpoint(
export_dir, filename_prefix, policy_id)
@DeveloperAPI
def collect_metrics(self, selected_evaluators=None):
"""Collects metrics from the remote evaluators of this agent.
def collect_metrics(self, selected_workers=None):
"""Collects metrics from the remote workers of this agent.
This is the same data as returned by a call to train().
"""
return self.optimizer.collect_metrics(
self.config["collect_metrics_timeout"],
min_history=self.config["metrics_smoothing_episodes"],
selected_evaluators=selected_evaluators)
selected_workers=selected_workers)
@classmethod
def resource_help(cls, config):
@@ -742,118 +709,34 @@ class Trainer(Trainable):
logger.info("Health checking all workers...")
checks = []
for ev in self.optimizer.remote_evaluators:
for ev in self.optimizer.workers.remote_workers():
_, obj_id = ev.sample_with_count.remote()
checks.append(obj_id)
healthy_evaluators = []
healthy_workers = []
for i, obj_id in enumerate(checks):
ev = self.optimizer.remote_evaluators[i]
w = self.optimizer.workers.remote_workers()[i]
try:
ray_get_and_free(obj_id)
healthy_evaluators.append(ev)
healthy_workers.append(w)
logger.info("Worker {} looks healthy".format(i + 1))
except RayError:
logger.exception("Blacklisting worker {}".format(i + 1))
try:
ev.__ray_terminate__.remote()
w.__ray_terminate__.remote()
except Exception:
logger.exception("Error terminating unhealthy worker")
if len(healthy_evaluators) < 1:
if len(healthy_workers) < 1:
raise RuntimeError(
"Not enough healthy workers remain to continue.")
self.optimizer.reset(healthy_evaluators)
self.optimizer.reset(healthy_workers)
def _has_policy_optimizer(self):
return hasattr(self, "optimizer") and isinstance(
self.optimizer, PolicyOptimizer)
def _make_evaluator(self, cls, env_creator, policy, worker_index, config):
def session_creator():
logger.debug("Creating TF session {}".format(
config["tf_session_args"]))
return tf.Session(
config=tf.ConfigProto(**config["tf_session_args"]))
if isinstance(config["input"], FunctionType):
input_creator = config["input"]
elif config["input"] == "sampler":
input_creator = (lambda ioctx: ioctx.default_sampler_input())
elif isinstance(config["input"], dict):
input_creator = (lambda ioctx: ShuffledInput(
MixedInput(config["input"], ioctx), config[
"shuffle_buffer_size"]))
else:
input_creator = (lambda ioctx: ShuffledInput(
JsonReader(config["input"], ioctx), config[
"shuffle_buffer_size"]))
if isinstance(config["output"], FunctionType):
output_creator = config["output"]
elif config["output"] is None:
output_creator = (lambda ioctx: NoopOutput())
elif config["output"] == "logdir":
output_creator = (lambda ioctx: JsonWriter(
ioctx.log_dir,
ioctx,
max_file_size=config["output_max_file_size"],
compress_columns=config["output_compress_columns"]))
else:
output_creator = (lambda ioctx: JsonWriter(
config["output"],
ioctx,
max_file_size=config["output_max_file_size"],
compress_columns=config["output_compress_columns"]))
if config["input"] == "sampler":
input_evaluation = []
else:
input_evaluation = config["input_evaluation"]
# Fill in the default policy if 'None' is specified in multiagent
if self.config["multiagent"]["policies"]:
tmp = self.config["multiagent"]["policies"]
_validate_multiagent_config(tmp, allow_none_graph=True)
for k, v in tmp.items():
if v[0] is None:
tmp[k] = (policy, v[1], v[2], v[3])
policy = tmp
return cls(
env_creator,
policy,
policy_mapping_fn=self.config["multiagent"]["policy_mapping_fn"],
policies_to_train=self.config["multiagent"]["policies_to_train"],
tf_session_creator=(session_creator
if config["tf_session_args"] else None),
batch_steps=config["sample_batch_size"],
batch_mode=config["batch_mode"],
episode_horizon=config["horizon"],
preprocessor_pref=config["preprocessor_pref"],
sample_async=config["sample_async"],
compress_observations=config["compress_observations"],
num_envs=config["num_envs_per_worker"],
observation_filter=config["observation_filter"],
clip_rewards=config["clip_rewards"],
clip_actions=config["clip_actions"],
env_config=config["env_config"],
model_config=config["model"],
policy_config=config,
worker_index=worker_index,
monitor_path=self.logdir if config["monitor"] else None,
log_dir=self.logdir,
log_level=config["log_level"],
callbacks=config["callbacks"],
input_creator=input_creator,
input_evaluation=input_evaluation,
output_creator=output_creator,
remote_worker_envs=config["remote_worker_envs"],
remote_env_batch_wait_ms=config["remote_env_batch_wait_ms"],
soft_horizon=config["soft_horizon"],
_fake_sampler=config.get("_fake_sampler", False))
@override(Trainable)
def _export_model(self, export_formats, export_dir):
ExportFormat.validate(export_formats)
@@ -870,17 +753,17 @@ class Trainer(Trainable):
def __getstate__(self):
state = {}
if hasattr(self, "local_evaluator"):
state["evaluator"] = self.local_evaluator.save()
if hasattr(self, "workers"):
state["worker"] = self.workers.local_worker().save()
if hasattr(self, "optimizer") and hasattr(self.optimizer, "save"):
state["optimizer"] = self.optimizer.save()
return state
def __setstate__(self, state):
if "evaluator" in state:
self.local_evaluator.restore(state["evaluator"])
remote_state = ray.put(state["evaluator"])
for r in self.remote_evaluators:
if "worker" in state:
self.workers.local_worker().restore(state["worker"])
remote_state = ray.put(state["worker"])
for r in self.workers.remote_workers():
r.restore.remote(remote_state)
if "optimizer" in state:
self.optimizer.restore(state["optimizer"])
+22 -15
View File
@@ -2,6 +2,8 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import time
from ray.rllib.agents.trainer import Trainer, COMMON_CONFIG
from ray.rllib.optimizers import SyncSamplesOptimizer
from ray.rllib.utils.annotations import override, DeveloperAPI
@@ -25,8 +27,7 @@ def build_trainer(name,
default_config (dict): the default config dict of the algorithm,
otherwises uses the Trainer default config
make_policy_optimizer (func): optional function that returns a
PolicyOptimizer instance given
(local_evaluator, remote_evaluators, config)
PolicyOptimizer instance given (WorkerSet, config)
validate_config (func): optional callback that checks a given config
for correctness. It may mutate the config as needed.
get_policy_class (func): optional callback that takes a config and
@@ -44,8 +45,7 @@ def build_trainer(name,
a Trainer instance that uses the specified args.
"""
if not name.endswith("Trainer"):
raise ValueError("Algorithm name should have *Trainer suffix", name)
original_kwargs = locals().copy()
class trainer_cls(Trainer):
_name = name
@@ -59,19 +59,15 @@ def build_trainer(name,
policy = default_policy
else:
policy = get_policy_class(config)
self.local_evaluator = self.make_local_evaluator(
env_creator, policy)
self.remote_evaluators = self.make_remote_evaluators(
env_creator, policy, config["num_workers"])
self.workers = self._make_workers(env_creator, policy, config,
self.config["num_workers"])
if make_policy_optimizer:
self.optimizer = make_policy_optimizer(
self.local_evaluator, self.remote_evaluators, config)
self.optimizer = make_policy_optimizer(self.workers, config)
else:
optimizer_config = dict(
config["optimizer"],
**{"train_batch_size": config["train_batch_size"]})
self.optimizer = SyncSamplesOptimizer(self.local_evaluator,
self.remote_evaluators,
self.optimizer = SyncSamplesOptimizer(self.workers,
**optimizer_config)
@override(Trainer)
@@ -79,9 +75,15 @@ def build_trainer(name,
if before_train_step:
before_train_step(self)
prev_steps = self.optimizer.num_steps_sampled
fetches = self.optimizer.step()
if after_optimizer_step:
after_optimizer_step(self, fetches)
start = time.time()
while True:
fetches = self.optimizer.step()
if after_optimizer_step:
after_optimizer_step(self, fetches)
if time.time() - start > self.config["min_iter_time_s"]:
break
res = self.collect_metrics()
res.update(
timesteps_this_iter=self.optimizer.num_steps_sampled -
@@ -91,6 +93,11 @@ def build_trainer(name,
after_train_result(self, res)
return res
@staticmethod
def with_updates(**overrides):
return build_trainer(**dict(original_kwargs, **overrides))
trainer_cls.with_updates = with_updates
trainer_cls.__name__ = name
trainer_cls.__qualname__ = name
return trainer_cls
+1 -1
View File
@@ -21,7 +21,7 @@ class BaseEnv(object):
can be sent back via send_actions().
All other env types can be adapted to BaseEnv. RLlib handles these
conversions internally in PolicyEvaluator, for example:
conversions internally in RolloutWorker, for example:
gym.Env => rllib.VectorEnv => rllib.BaseEnv
rllib.MultiAgentEnv => rllib.BaseEnv
+16 -4
View File
@@ -1,4 +1,5 @@
from ray.rllib.evaluation.episode import MultiAgentEpisode
from ray.rllib.evaluation.rollout_worker import RolloutWorker
from ray.rllib.evaluation.policy_evaluator import PolicyEvaluator
from ray.rllib.evaluation.interface import EvaluatorInterface
from ray.rllib.evaluation.policy_graph import PolicyGraph
@@ -12,8 +13,19 @@ from ray.rllib.evaluation.postprocessing import compute_advantages
from ray.rllib.evaluation.metrics import collect_metrics
__all__ = [
"EvaluatorInterface", "PolicyEvaluator", "PolicyGraph", "TFPolicyGraph",
"TorchPolicyGraph", "SampleBatch", "MultiAgentBatch", "SampleBatchBuilder",
"MultiAgentSampleBatchBuilder", "SyncSampler", "AsyncSampler",
"compute_advantages", "collect_metrics", "MultiAgentEpisode"
"EvaluatorInterface",
"RolloutWorker",
"PolicyGraph",
"TFPolicyGraph",
"TorchPolicyGraph",
"SampleBatch",
"MultiAgentBatch",
"SampleBatchBuilder",
"MultiAgentSampleBatchBuilder",
"SyncSampler",
"AsyncSampler",
"compute_advantages",
"collect_metrics",
"MultiAgentEpisode",
"PolicyEvaluator",
]
+1 -1
View File
@@ -11,7 +11,7 @@ from ray.rllib.utils.annotations import DeveloperAPI
class EvaluatorInterface(object):
"""This is the interface between policy optimizers and policy evaluation.
See also: PolicyEvaluator
See also: RolloutWorker
"""
@DeveloperAPI
+8 -12
View File
@@ -39,27 +39,23 @@ def get_learner_stats(grad_info):
@DeveloperAPI
def collect_metrics(local_evaluator=None,
remote_evaluators=[],
timeout_seconds=180):
"""Gathers episode metrics from PolicyEvaluator instances."""
def collect_metrics(local_worker=None, remote_workers=[], timeout_seconds=180):
"""Gathers episode metrics from RolloutWorker instances."""
episodes, num_dropped = collect_episodes(
local_evaluator, remote_evaluators, timeout_seconds=timeout_seconds)
local_worker, remote_workers, timeout_seconds=timeout_seconds)
metrics = summarize_episodes(episodes, episodes, num_dropped)
return metrics
@DeveloperAPI
def collect_episodes(local_evaluator=None,
remote_evaluators=[],
def collect_episodes(local_worker=None, remote_workers=[],
timeout_seconds=180):
"""Gathers new episodes metrics tuples from the given evaluators."""
if remote_evaluators:
if remote_workers:
pending = [
a.apply.remote(lambda ev: ev.get_metrics())
for a in remote_evaluators
a.apply.remote(lambda ev: ev.get_metrics()) for a in remote_workers
]
collected, _ = ray.wait(
pending, num_returns=len(pending), timeout=timeout_seconds * 1.0)
@@ -73,8 +69,8 @@ def collect_episodes(local_evaluator=None,
metric_lists = []
num_metric_batches_dropped = 0
if local_evaluator:
metric_lists.append(local_evaluator.get_metrics())
if local_worker:
metric_lists.append(local_worker.get_metrics())
episodes = []
for metrics in metric_lists:
episodes.extend(metrics)
+4 -801
View File
@@ -2,805 +2,8 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import gym
import logging
import pickle
from ray.rllib.utils import renamed_class
from ray.rllib.evaluation import RolloutWorker
import ray
from ray.rllib.env.atari_wrappers import wrap_deepmind, is_atari
from ray.rllib.env.base_env import BaseEnv
from ray.rllib.env.env_context import EnvContext
from ray.rllib.env.external_env import ExternalEnv
from ray.rllib.env.multi_agent_env import MultiAgentEnv
from ray.rllib.env.external_multi_agent_env import ExternalMultiAgentEnv
from ray.rllib.env.vector_env import VectorEnv
from ray.rllib.evaluation.interface import EvaluatorInterface
from ray.rllib.evaluation.sampler import AsyncSampler, SyncSampler
from ray.rllib.policy.sample_batch import MultiAgentBatch, DEFAULT_POLICY_ID
from ray.rllib.policy.policy import Policy
from ray.rllib.policy.tf_policy import TFPolicy
from ray.rllib.offline import NoopOutput, IOContext, OutputWriter, InputReader
from ray.rllib.offline.is_estimator import ImportanceSamplingEstimator
from ray.rllib.offline.wis_estimator import WeightedImportanceSamplingEstimator
from ray.rllib.models import ModelCatalog
from ray.rllib.models.preprocessors import NoPreprocessor
from ray.rllib.utils import merge_dicts
from ray.rllib.utils.annotations import override, DeveloperAPI
from ray.rllib.utils.debug import disable_log_once_globally, log_once, \
summarize, enable_periodic_logging
from ray.rllib.utils.filter import get_filter
from ray.rllib.utils.tf_run_builder import TFRunBuilder
from ray.rllib.utils import try_import_tf
tf = try_import_tf()
logger = logging.getLogger(__name__)
# Handle to the current evaluator, which will be set to the most recently
# created PolicyEvaluator in this process. This can be helpful to access in
# custom env or policy classes for debugging or advanced use cases.
_global_evaluator = None
@DeveloperAPI
def get_global_evaluator():
"""Returns a handle to the active policy evaluator in this process."""
global _global_evaluator
return _global_evaluator
@DeveloperAPI
class PolicyEvaluator(EvaluatorInterface):
"""Common ``PolicyEvaluator`` implementation that wraps a ``Policy``.
This class wraps a policy instance and an environment class to
collect experiences from the environment. You can create many replicas of
this class as Ray actors to scale RL training.
This class supports vectorized and multi-agent policy evaluation (e.g.,
VectorEnv, MultiAgentEnv, etc.)
Examples:
>>> # Create a policy evaluator and using it to collect experiences.
>>> evaluator = PolicyEvaluator(
... env_creator=lambda _: gym.make("CartPole-v0"),
... policy=PGTFPolicy)
>>> print(evaluator.sample())
SampleBatch({
"obs": [[...]], "actions": [[...]], "rewards": [[...]],
"dones": [[...]], "new_obs": [[...]]})
>>> # Creating policy evaluators using optimizer_cls.make().
>>> optimizer = SyncSamplesOptimizer.make(
... evaluator_cls=PolicyEvaluator,
... evaluator_args={
... "env_creator": lambda _: gym.make("CartPole-v0"),
... "policy": PGTFPolicy,
... },
... num_workers=10)
>>> for _ in range(10): optimizer.step()
>>> # Creating a multi-agent policy evaluator
>>> evaluator = PolicyEvaluator(
... env_creator=lambda _: MultiAgentTrafficGrid(num_cars=25),
... policies={
... # Use an ensemble of two policies for car agents
... "car_policy1":
... (PGTFPolicy, Box(...), Discrete(...), {"gamma": 0.99}),
... "car_policy2":
... (PGTFPolicy, Box(...), Discrete(...), {"gamma": 0.95}),
... # Use a single shared policy for all traffic lights
... "traffic_light_policy":
... (PGTFPolicy, Box(...), Discrete(...), {}),
... },
... policy_mapping_fn=lambda agent_id:
... random.choice(["car_policy1", "car_policy2"])
... if agent_id.startswith("car_") else "traffic_light_policy")
>>> print(evaluator.sample())
MultiAgentBatch({
"car_policy1": SampleBatch(...),
"car_policy2": SampleBatch(...),
"traffic_light_policy": SampleBatch(...)})
"""
@DeveloperAPI
@classmethod
def as_remote(cls, num_cpus=None, num_gpus=None, resources=None):
return ray.remote(
num_cpus=num_cpus, num_gpus=num_gpus, resources=resources)(cls)
@DeveloperAPI
def __init__(self,
env_creator,
policy,
policy_mapping_fn=None,
policies_to_train=None,
tf_session_creator=None,
batch_steps=100,
batch_mode="truncate_episodes",
episode_horizon=None,
preprocessor_pref="deepmind",
sample_async=False,
compress_observations=False,
num_envs=1,
observation_filter="NoFilter",
clip_rewards=None,
clip_actions=True,
env_config=None,
model_config=None,
policy_config=None,
worker_index=0,
monitor_path=None,
log_dir=None,
log_level=None,
callbacks=None,
input_creator=lambda ioctx: ioctx.default_sampler_input(),
input_evaluation=frozenset([]),
output_creator=lambda ioctx: NoopOutput(),
remote_worker_envs=False,
remote_env_batch_wait_ms=0,
soft_horizon=False,
_fake_sampler=False):
"""Initialize a policy evaluator.
Arguments:
env_creator (func): Function that returns a gym.Env given an
EnvContext wrapped configuration.
policy (class|dict): Either a class implementing
Policy, or a dictionary of policy id strings to
(Policy, obs_space, action_space, config) tuples. If a
dict is specified, then we are in multi-agent mode and a
policy_mapping_fn should also be set.
policy_mapping_fn (func): A function that maps agent ids to
policy ids in multi-agent mode. This function will be called
each time a new agent appears in an episode, to bind that agent
to a policy for the duration of the episode.
policies_to_train (list): Optional whitelist of policies to train,
or None for all policies.
tf_session_creator (func): A function that returns a TF session.
This is optional and only useful with TFPolicy.
batch_steps (int): The target number of env transitions to include
in each sample batch returned from this evaluator.
batch_mode (str): One of the following batch modes:
"truncate_episodes": Each call to sample() will return a batch
of at most `batch_steps * num_envs` in size. The batch will
be exactly `batch_steps * num_envs` in size if
postprocessing does not change batch sizes. Episodes may be
truncated in order to meet this size requirement.
"complete_episodes": Each call to sample() will return a batch
of at least `batch_steps * num_envs` in size. Episodes will
not be truncated, but multiple episodes may be packed
within one batch to meet the batch size. Note that when
`num_envs > 1`, episode steps will be buffered until the
episode completes, and hence batches may contain
significant amounts of off-policy data.
episode_horizon (int): Whether to stop episodes at this horizon.
preprocessor_pref (str): Whether to prefer RLlib preprocessors
("rllib") or deepmind ("deepmind") when applicable.
sample_async (bool): Whether to compute samples asynchronously in
the background, which improves throughput but can cause samples
to be slightly off-policy.
compress_observations (bool): If true, compress the observations.
They can be decompressed with rllib/utils/compression.
num_envs (int): If more than one, will create multiple envs
and vectorize the computation of actions. This has no effect if
if the env already implements VectorEnv.
observation_filter (str): Name of observation filter to use.
clip_rewards (bool): Whether to clip rewards to [-1, 1] prior to
experience postprocessing. Setting to None means clip for Atari
only.
clip_actions (bool): Whether to clip action values to the range
specified by the policy action space.
env_config (dict): Config to pass to the env creator.
model_config (dict): Config to use when creating the policy model.
policy_config (dict): Config to pass to the policy. In the
multi-agent case, this config will be merged with the
per-policy configs specified by `policy`.
worker_index (int): For remote evaluators, this should be set to a
non-zero and unique value. This index is passed to created envs
through EnvContext so that envs can be configured per worker.
monitor_path (str): Write out episode stats and videos to this
directory if specified.
log_dir (str): Directory where logs can be placed.
log_level (str): Set the root log level on creation.
callbacks (dict): Dict of custom debug callbacks.
input_creator (func): Function that returns an InputReader object
for loading previous generated experiences.
input_evaluation (list): How to evaluate the policy performance.
This only makes sense to set when the input is reading offline
data. The possible values include:
- "is": the step-wise importance sampling estimator.
- "wis": the weighted step-wise is estimator.
- "simulation": run the environment in the background, but
use this data for evaluation only and never for learning.
output_creator (func): Function that returns an OutputWriter object
for saving generated experiences.
remote_worker_envs (bool): If using num_envs > 1, whether to create
those new envs in remote processes instead of in the current
process. This adds overheads, but can make sense if your envs
remote_env_batch_wait_ms (float): Timeout that remote workers
are waiting when polling environments. 0 (continue when at
least one env is ready) is a reasonable default, but optimal
value could be obtained by measuring your environment
step / reset and model inference perf.
soft_horizon (bool): Calculate rewards but don't reset the
environment when the horizon is hit.
_fake_sampler (bool): Use a fake (inf speed) sampler for testing.
"""
global _global_evaluator
_global_evaluator = self
if log_level:
logging.getLogger("ray.rllib").setLevel(log_level)
if worker_index > 1:
disable_log_once_globally() # only need 1 evaluator to log
elif log_level == "DEBUG":
enable_periodic_logging()
env_context = EnvContext(env_config or {}, worker_index)
policy_config = policy_config or {}
self.policy_config = policy_config
self.callbacks = callbacks or {}
self.worker_index = worker_index
model_config = model_config or {}
policy_mapping_fn = (policy_mapping_fn
or (lambda agent_id: DEFAULT_POLICY_ID))
if not callable(policy_mapping_fn):
raise ValueError(
"Policy mapping function not callable. If you're using Tune, "
"make sure to escape the function with tune.function() "
"to prevent it from being evaluated as an expression.")
self.env_creator = env_creator
self.sample_batch_size = batch_steps * num_envs
self.batch_mode = batch_mode
self.compress_observations = compress_observations
self.preprocessing_enabled = True
self.last_batch = None
self._fake_sampler = _fake_sampler
self.env = _validate_env(env_creator(env_context))
if isinstance(self.env, MultiAgentEnv) or \
isinstance(self.env, BaseEnv):
def wrap(env):
return env # we can't auto-wrap these env types
elif is_atari(self.env) and \
not model_config.get("custom_preprocessor") and \
preprocessor_pref == "deepmind":
# Deepmind wrappers already handle all preprocessing
self.preprocessing_enabled = False
if clip_rewards is None:
clip_rewards = True
def wrap(env):
env = wrap_deepmind(
env,
dim=model_config.get("dim"),
framestack=model_config.get("framestack"))
if monitor_path:
env = _monitor(env, monitor_path)
return env
else:
def wrap(env):
if monitor_path:
env = _monitor(env, monitor_path)
return env
self.env = wrap(self.env)
def make_env(vector_index):
return wrap(
env_creator(
env_context.copy_with_overrides(
vector_index=vector_index, remote=remote_worker_envs)))
self.tf_sess = None
policy_dict = _validate_and_canonicalize(policy, self.env)
self.policies_to_train = policies_to_train or list(policy_dict.keys())
if _has_tensorflow_graph(policy_dict):
if (ray.is_initialized()
and ray.worker._mode() != ray.worker.LOCAL_MODE
and not ray.get_gpu_ids()):
logger.info("Creating policy evaluation worker {}".format(
worker_index) +
" on CPU (please ignore any CUDA init errors)")
with tf.Graph().as_default():
if tf_session_creator:
self.tf_sess = tf_session_creator()
else:
self.tf_sess = tf.Session(
config=tf.ConfigProto(
gpu_options=tf.GPUOptions(allow_growth=True)))
with self.tf_sess.as_default():
self.policy_map, self.preprocessors = \
self._build_policy_map(policy_dict, policy_config)
else:
self.policy_map, self.preprocessors = self._build_policy_map(
policy_dict, policy_config)
self.multiagent = set(self.policy_map.keys()) != {DEFAULT_POLICY_ID}
if self.multiagent:
if not ((isinstance(self.env, MultiAgentEnv)
or isinstance(self.env, ExternalMultiAgentEnv))
or isinstance(self.env, BaseEnv)):
raise ValueError(
"Have multiple policies {}, but the env ".format(
self.policy_map) +
"{} is not a subclass of BaseEnv, MultiAgentEnv or "
"ExternalMultiAgentEnv?".format(self.env))
self.filters = {
policy_id: get_filter(observation_filter,
policy.observation_space.shape)
for (policy_id, policy) in self.policy_map.items()
}
if self.worker_index == 0:
logger.info("Built filter map: {}".format(self.filters))
# Always use vector env for consistency even if num_envs = 1
self.async_env = BaseEnv.to_base_env(
self.env,
make_env=make_env,
num_envs=num_envs,
remote_envs=remote_worker_envs,
remote_env_batch_wait_ms=remote_env_batch_wait_ms)
self.num_envs = num_envs
if self.batch_mode == "truncate_episodes":
unroll_length = batch_steps
pack_episodes = True
elif self.batch_mode == "complete_episodes":
unroll_length = float("inf") # never cut episodes
pack_episodes = False # sampler will return 1 episode per poll
else:
raise ValueError("Unsupported batch mode: {}".format(
self.batch_mode))
self.io_context = IOContext(log_dir, policy_config, worker_index, self)
self.reward_estimators = []
for method in input_evaluation:
if method == "simulation":
logger.warning(
"Requested 'simulation' input evaluation method: "
"will discard all sampler outputs and keep only metrics.")
sample_async = True
elif method == "is":
ise = ImportanceSamplingEstimator.create(self.io_context)
self.reward_estimators.append(ise)
elif method == "wis":
wise = WeightedImportanceSamplingEstimator.create(
self.io_context)
self.reward_estimators.append(wise)
else:
raise ValueError(
"Unknown evaluation method: {}".format(method))
if sample_async:
self.sampler = AsyncSampler(
self.async_env,
self.policy_map,
policy_mapping_fn,
self.preprocessors,
self.filters,
clip_rewards,
unroll_length,
self.callbacks,
horizon=episode_horizon,
pack=pack_episodes,
tf_sess=self.tf_sess,
clip_actions=clip_actions,
blackhole_outputs="simulation" in input_evaluation,
soft_horizon=soft_horizon)
self.sampler.start()
else:
self.sampler = SyncSampler(
self.async_env,
self.policy_map,
policy_mapping_fn,
self.preprocessors,
self.filters,
clip_rewards,
unroll_length,
self.callbacks,
horizon=episode_horizon,
pack=pack_episodes,
tf_sess=self.tf_sess,
clip_actions=clip_actions,
soft_horizon=soft_horizon)
self.input_reader = input_creator(self.io_context)
assert isinstance(self.input_reader, InputReader), self.input_reader
self.output_writer = output_creator(self.io_context)
assert isinstance(self.output_writer, OutputWriter), self.output_writer
logger.debug("Created evaluator with env {} ({}), policies {}".format(
self.async_env, self.env, self.policy_map))
@override(EvaluatorInterface)
def sample(self):
"""Evaluate the current policies and return a batch of experiences.
Return:
SampleBatch|MultiAgentBatch from evaluating the current policies.
"""
if self._fake_sampler and self.last_batch is not None:
return self.last_batch
if log_once("sample_start"):
logger.info("Generating sample batch of size {}".format(
self.sample_batch_size))
batches = [self.input_reader.next()]
steps_so_far = batches[0].count
# In truncate_episodes mode, never pull more than 1 batch per env.
# This avoids over-running the target batch size.
if self.batch_mode == "truncate_episodes":
max_batches = self.num_envs
else:
max_batches = float("inf")
while steps_so_far < self.sample_batch_size and len(
batches) < max_batches:
batch = self.input_reader.next()
steps_so_far += batch.count
batches.append(batch)
batch = batches[0].concat_samples(batches)
if self.callbacks.get("on_sample_end"):
self.callbacks["on_sample_end"]({
"evaluator": self,
"samples": batch
})
# Always do writes prior to compression for consistency and to allow
# for better compression inside the writer.
self.output_writer.write(batch)
# Do off-policy estimation if needed
if self.reward_estimators:
for sub_batch in batch.split_by_episode():
for estimator in self.reward_estimators:
estimator.process(sub_batch)
if log_once("sample_end"):
logger.info("Completed sample batch:\n\n{}\n".format(
summarize(batch)))
if self.compress_observations == "bulk":
batch.compress(bulk=True)
elif self.compress_observations:
batch.compress()
if self._fake_sampler:
self.last_batch = batch
return batch
@DeveloperAPI
@ray.method(num_return_vals=2)
def sample_with_count(self):
"""Same as sample() but returns the count as a separate future."""
batch = self.sample()
return batch, batch.count
@override(EvaluatorInterface)
def get_weights(self, policies=None):
if policies is None:
policies = self.policy_map.keys()
return {
pid: policy.get_weights()
for pid, policy in self.policy_map.items() if pid in policies
}
@override(EvaluatorInterface)
def set_weights(self, weights):
for pid, w in weights.items():
self.policy_map[pid].set_weights(w)
@override(EvaluatorInterface)
def compute_gradients(self, samples):
if log_once("compute_gradients"):
logger.info("Compute gradients on:\n\n{}\n".format(
summarize(samples)))
if isinstance(samples, MultiAgentBatch):
grad_out, info_out = {}, {}
if self.tf_sess is not None:
builder = TFRunBuilder(self.tf_sess, "compute_gradients")
for pid, batch in samples.policy_batches.items():
if pid not in self.policies_to_train:
continue
grad_out[pid], info_out[pid] = (
self.policy_map[pid]._build_compute_gradients(
builder, batch))
grad_out = {k: builder.get(v) for k, v in grad_out.items()}
info_out = {k: builder.get(v) for k, v in info_out.items()}
else:
for pid, batch in samples.policy_batches.items():
if pid not in self.policies_to_train:
continue
grad_out[pid], info_out[pid] = (
self.policy_map[pid].compute_gradients(batch))
else:
grad_out, info_out = (
self.policy_map[DEFAULT_POLICY_ID].compute_gradients(samples))
info_out["batch_count"] = samples.count
if log_once("grad_out"):
logger.info("Compute grad info:\n\n{}\n".format(
summarize(info_out)))
return grad_out, info_out
@override(EvaluatorInterface)
def apply_gradients(self, grads):
if log_once("apply_gradients"):
logger.info("Apply gradients:\n\n{}\n".format(summarize(grads)))
if isinstance(grads, dict):
if self.tf_sess is not None:
builder = TFRunBuilder(self.tf_sess, "apply_gradients")
outputs = {
pid: self.policy_map[pid]._build_apply_gradients(
builder, grad)
for pid, grad in grads.items()
}
return {k: builder.get(v) for k, v in outputs.items()}
else:
return {
pid: self.policy_map[pid].apply_gradients(g)
for pid, g in grads.items()
}
else:
return self.policy_map[DEFAULT_POLICY_ID].apply_gradients(grads)
@override(EvaluatorInterface)
def learn_on_batch(self, samples):
if log_once("learn_on_batch"):
logger.info(
"Training on concatenated sample batches:\n\n{}\n".format(
summarize(samples)))
if isinstance(samples, MultiAgentBatch):
info_out = {}
to_fetch = {}
if self.tf_sess is not None:
builder = TFRunBuilder(self.tf_sess, "learn_on_batch")
else:
builder = None
for pid, batch in samples.policy_batches.items():
if pid not in self.policies_to_train:
continue
policy = self.policy_map[pid]
if builder and hasattr(policy, "_build_learn_on_batch"):
to_fetch[pid] = policy._build_learn_on_batch(
builder, batch)
else:
info_out[pid] = policy.learn_on_batch(batch)
info_out.update({k: builder.get(v) for k, v in to_fetch.items()})
else:
info_out = self.policy_map[DEFAULT_POLICY_ID].learn_on_batch(
samples)
if log_once("learn_out"):
logger.info("Training output:\n\n{}\n".format(summarize(info_out)))
return info_out
@DeveloperAPI
def get_metrics(self):
"""Returns a list of new RolloutMetric objects from evaluation."""
out = self.sampler.get_metrics()
for m in self.reward_estimators:
out.extend(m.get_metrics())
return out
@DeveloperAPI
def foreach_env(self, func):
"""Apply the given function to each underlying env instance."""
envs = self.async_env.get_unwrapped()
if not envs:
return [func(self.async_env)]
else:
return [func(e) for e in envs]
@DeveloperAPI
def get_policy(self, policy_id=DEFAULT_POLICY_ID):
"""Return policy for the specified id, or None.
Arguments:
policy_id (str): id of policy to return.
"""
return self.policy_map.get(policy_id)
@DeveloperAPI
def for_policy(self, func, policy_id=DEFAULT_POLICY_ID):
"""Apply the given function to the specified policy."""
return func(self.policy_map[policy_id])
@DeveloperAPI
def foreach_policy(self, func):
"""Apply the given function to each (policy, policy_id) tuple."""
return [func(policy, pid) for pid, policy in self.policy_map.items()]
@DeveloperAPI
def foreach_trainable_policy(self, func):
"""Apply the given function to each (policy, policy_id) tuple.
This only applies func to policies in `self.policies_to_train`."""
return [
func(policy, pid) for pid, policy in self.policy_map.items()
if pid in self.policies_to_train
]
@DeveloperAPI
def sync_filters(self, new_filters):
"""Changes self's filter to given and rebases any accumulated delta.
Args:
new_filters (dict): Filters with new state to update local copy.
"""
assert all(k in new_filters for k in self.filters)
for k in self.filters:
self.filters[k].sync(new_filters[k])
@DeveloperAPI
def get_filters(self, flush_after=False):
"""Returns a snapshot of filters.
Args:
flush_after (bool): Clears the filter buffer state.
Returns:
return_filters (dict): Dict for serializable filters
"""
return_filters = {}
for k, f in self.filters.items():
return_filters[k] = f.as_serializable()
if flush_after:
f.clear_buffer()
return return_filters
@DeveloperAPI
def save(self):
filters = self.get_filters(flush_after=True)
state = {
pid: self.policy_map[pid].get_state()
for pid in self.policy_map
}
return pickle.dumps({"filters": filters, "state": state})
@DeveloperAPI
def restore(self, objs):
objs = pickle.loads(objs)
self.sync_filters(objs["filters"])
for pid, state in objs["state"].items():
self.policy_map[pid].set_state(state)
@DeveloperAPI
def set_global_vars(self, global_vars):
self.foreach_policy(lambda p, _: p.on_global_var_update(global_vars))
@DeveloperAPI
def export_policy_model(self, export_dir, policy_id=DEFAULT_POLICY_ID):
self.policy_map[policy_id].export_model(export_dir)
@DeveloperAPI
def export_policy_checkpoint(self,
export_dir,
filename_prefix="model",
policy_id=DEFAULT_POLICY_ID):
self.policy_map[policy_id].export_checkpoint(export_dir,
filename_prefix)
@DeveloperAPI
def stop(self):
self.async_env.stop()
def _build_policy_map(self, policy_dict, policy_config):
policy_map = {}
preprocessors = {}
for name, (cls, obs_space, act_space,
conf) in sorted(policy_dict.items()):
logger.debug("Creating policy for {}".format(name))
merged_conf = merge_dicts(policy_config, conf)
if self.preprocessing_enabled:
preprocessor = ModelCatalog.get_preprocessor_for_space(
obs_space, merged_conf.get("model"))
preprocessors[name] = preprocessor
obs_space = preprocessor.observation_space
else:
preprocessors[name] = NoPreprocessor(obs_space)
if isinstance(obs_space, gym.spaces.Dict) or \
isinstance(obs_space, gym.spaces.Tuple):
raise ValueError(
"Found raw Tuple|Dict space as input to policy. "
"Please preprocess these observations with a "
"Tuple|DictFlatteningPreprocessor.")
if tf:
with tf.variable_scope(name):
policy_map[name] = cls(obs_space, act_space, merged_conf)
else:
policy_map[name] = cls(obs_space, act_space, merged_conf)
if self.worker_index == 0:
logger.info("Built policy map: {}".format(policy_map))
logger.info("Built preprocessor map: {}".format(preprocessors))
return policy_map, preprocessors
def __del__(self):
if hasattr(self, "sampler") and isinstance(self.sampler, AsyncSampler):
self.sampler.shutdown = True
def _validate_and_canonicalize(policy, env):
if isinstance(policy, dict):
_validate_multiagent_config(policy)
return policy
elif not issubclass(policy, Policy):
raise ValueError("policy must be a rllib.Policy class")
else:
if (isinstance(env, MultiAgentEnv)
and not hasattr(env, "observation_space")):
raise ValueError(
"MultiAgentEnv must have observation_space defined if run "
"in a single-agent configuration.")
return {
DEFAULT_POLICY_ID: (policy, env.observation_space,
env.action_space, {})
}
def _validate_multiagent_config(policy, allow_none_graph=False):
for k, v in policy.items():
if not isinstance(k, str):
raise ValueError("policy keys must be strs, got {}".format(
type(k)))
if not isinstance(v, tuple) or len(v) != 4:
raise ValueError(
"policy values must be tuples of "
"(cls, obs_space, action_space, config), got {}".format(v))
if allow_none_graph and v[0] is None:
pass
elif not issubclass(v[0], Policy):
raise ValueError("policy tuple value 0 must be a rllib.Policy "
"class or None, got {}".format(v[0]))
if not isinstance(v[1], gym.Space):
raise ValueError(
"policy tuple value 1 (observation_space) must be a "
"gym.Space, got {}".format(type(v[1])))
if not isinstance(v[2], gym.Space):
raise ValueError("policy tuple value 2 (action_space) must be a "
"gym.Space, got {}".format(type(v[2])))
if not isinstance(v[3], dict):
raise ValueError("policy tuple value 3 (config) must be a dict, "
"got {}".format(type(v[3])))
def _validate_env(env):
# allow this as a special case (assumed gym.Env)
if hasattr(env, "observation_space") and hasattr(env, "action_space"):
return env
allowed_types = [gym.Env, MultiAgentEnv, ExternalEnv, VectorEnv, BaseEnv]
if not any(isinstance(env, tpe) for tpe in allowed_types):
raise ValueError(
"Returned env should be an instance of gym.Env, MultiAgentEnv, "
"ExternalEnv, VectorEnv, or BaseEnv. The provided env creator "
"function returned {} ({}).".format(env, type(env)))
return env
def _monitor(env, path):
return gym.wrappers.Monitor(env, path, resume=True)
def _has_tensorflow_graph(policy_dict):
for policy, _, _, _ in policy_dict.values():
if issubclass(policy, TFPolicy):
return True
return False
PolicyEvaluator = renamed_class(
RolloutWorker, old_name="rllib.evaluation.PolicyEvaluator")
@@ -0,0 +1,794 @@
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import gym
import logging
import pickle
import ray
from ray.rllib.env.atari_wrappers import wrap_deepmind, is_atari
from ray.rllib.env.base_env import BaseEnv
from ray.rllib.env.env_context import EnvContext
from ray.rllib.env.external_env import ExternalEnv
from ray.rllib.env.multi_agent_env import MultiAgentEnv
from ray.rllib.env.external_multi_agent_env import ExternalMultiAgentEnv
from ray.rllib.env.vector_env import VectorEnv
from ray.rllib.evaluation.interface import EvaluatorInterface
from ray.rllib.evaluation.sampler import AsyncSampler, SyncSampler
from ray.rllib.policy.sample_batch import MultiAgentBatch, DEFAULT_POLICY_ID
from ray.rllib.policy.policy import Policy
from ray.rllib.policy.tf_policy import TFPolicy
from ray.rllib.offline import NoopOutput, IOContext, OutputWriter, InputReader
from ray.rllib.offline.is_estimator import ImportanceSamplingEstimator
from ray.rllib.offline.wis_estimator import WeightedImportanceSamplingEstimator
from ray.rllib.models import ModelCatalog
from ray.rllib.models.preprocessors import NoPreprocessor
from ray.rllib.utils import merge_dicts
from ray.rllib.utils.annotations import override, DeveloperAPI
from ray.rllib.utils.debug import disable_log_once_globally, log_once, \
summarize, enable_periodic_logging
from ray.rllib.utils.filter import get_filter
from ray.rllib.utils.tf_run_builder import TFRunBuilder
from ray.rllib.utils import try_import_tf
tf = try_import_tf()
logger = logging.getLogger(__name__)
# Handle to the current rollout worker, which will be set to the most recently
# created RolloutWorker in this process. This can be helpful to access in
# custom env or policy classes for debugging or advanced use cases.
_global_worker = None
@DeveloperAPI
def get_global_worker():
"""Returns a handle to the active rollout worker in this process."""
global _global_worker
return _global_worker
@DeveloperAPI
class RolloutWorker(EvaluatorInterface):
"""Common experience collection class.
This class wraps a policy instance and an environment class to
collect experiences from the environment. You can create many replicas of
this class as Ray actors to scale RL training.
This class supports vectorized and multi-agent policy evaluation (e.g.,
VectorEnv, MultiAgentEnv, etc.)
Examples:
>>> # Create a rollout worker and using it to collect experiences.
>>> worker = RolloutWorker(
... env_creator=lambda _: gym.make("CartPole-v0"),
... policy=PGTFPolicy)
>>> print(worker.sample())
SampleBatch({
"obs": [[...]], "actions": [[...]], "rewards": [[...]],
"dones": [[...]], "new_obs": [[...]]})
>>> # Creating a multi-agent rollout worker
>>> worker = RolloutWorker(
... env_creator=lambda _: MultiAgentTrafficGrid(num_cars=25),
... policies={
... # Use an ensemble of two policies for car agents
... "car_policy1":
... (PGTFPolicy, Box(...), Discrete(...), {"gamma": 0.99}),
... "car_policy2":
... (PGTFPolicy, Box(...), Discrete(...), {"gamma": 0.95}),
... # Use a single shared policy for all traffic lights
... "traffic_light_policy":
... (PGTFPolicy, Box(...), Discrete(...), {}),
... },
... policy_mapping_fn=lambda agent_id:
... random.choice(["car_policy1", "car_policy2"])
... if agent_id.startswith("car_") else "traffic_light_policy")
>>> print(worker.sample())
MultiAgentBatch({
"car_policy1": SampleBatch(...),
"car_policy2": SampleBatch(...),
"traffic_light_policy": SampleBatch(...)})
"""
@DeveloperAPI
@classmethod
def as_remote(cls, num_cpus=None, num_gpus=None, resources=None):
return ray.remote(
num_cpus=num_cpus, num_gpus=num_gpus, resources=resources)(cls)
@DeveloperAPI
def __init__(self,
env_creator,
policy,
policy_mapping_fn=None,
policies_to_train=None,
tf_session_creator=None,
batch_steps=100,
batch_mode="truncate_episodes",
episode_horizon=None,
preprocessor_pref="deepmind",
sample_async=False,
compress_observations=False,
num_envs=1,
observation_filter="NoFilter",
clip_rewards=None,
clip_actions=True,
env_config=None,
model_config=None,
policy_config=None,
worker_index=0,
monitor_path=None,
log_dir=None,
log_level=None,
callbacks=None,
input_creator=lambda ioctx: ioctx.default_sampler_input(),
input_evaluation=frozenset([]),
output_creator=lambda ioctx: NoopOutput(),
remote_worker_envs=False,
remote_env_batch_wait_ms=0,
soft_horizon=False,
_fake_sampler=False):
"""Initialize a rollout worker.
Arguments:
env_creator (func): Function that returns a gym.Env given an
EnvContext wrapped configuration.
policy (class|dict): Either a class implementing
Policy, or a dictionary of policy id strings to
(Policy, obs_space, action_space, config) tuples. If a
dict is specified, then we are in multi-agent mode and a
policy_mapping_fn should also be set.
policy_mapping_fn (func): A function that maps agent ids to
policy ids in multi-agent mode. This function will be called
each time a new agent appears in an episode, to bind that agent
to a policy for the duration of the episode.
policies_to_train (list): Optional whitelist of policies to train,
or None for all policies.
tf_session_creator (func): A function that returns a TF session.
This is optional and only useful with TFPolicy.
batch_steps (int): The target number of env transitions to include
in each sample batch returned from this worker.
batch_mode (str): One of the following batch modes:
"truncate_episodes": Each call to sample() will return a batch
of at most `batch_steps * num_envs` in size. The batch will
be exactly `batch_steps * num_envs` in size if
postprocessing does not change batch sizes. Episodes may be
truncated in order to meet this size requirement.
"complete_episodes": Each call to sample() will return a batch
of at least `batch_steps * num_envs` in size. Episodes will
not be truncated, but multiple episodes may be packed
within one batch to meet the batch size. Note that when
`num_envs > 1`, episode steps will be buffered until the
episode completes, and hence batches may contain
significant amounts of off-policy data.
episode_horizon (int): Whether to stop episodes at this horizon.
preprocessor_pref (str): Whether to prefer RLlib preprocessors
("rllib") or deepmind ("deepmind") when applicable.
sample_async (bool): Whether to compute samples asynchronously in
the background, which improves throughput but can cause samples
to be slightly off-policy.
compress_observations (bool): If true, compress the observations.
They can be decompressed with rllib/utils/compression.
num_envs (int): If more than one, will create multiple envs
and vectorize the computation of actions. This has no effect if
if the env already implements VectorEnv.
observation_filter (str): Name of observation filter to use.
clip_rewards (bool): Whether to clip rewards to [-1, 1] prior to
experience postprocessing. Setting to None means clip for Atari
only.
clip_actions (bool): Whether to clip action values to the range
specified by the policy action space.
env_config (dict): Config to pass to the env creator.
model_config (dict): Config to use when creating the policy model.
policy_config (dict): Config to pass to the policy. In the
multi-agent case, this config will be merged with the
per-policy configs specified by `policy`.
worker_index (int): For remote workers, this should be set to a
non-zero and unique value. This index is passed to created envs
through EnvContext so that envs can be configured per worker.
monitor_path (str): Write out episode stats and videos to this
directory if specified.
log_dir (str): Directory where logs can be placed.
log_level (str): Set the root log level on creation.
callbacks (dict): Dict of custom debug callbacks.
input_creator (func): Function that returns an InputReader object
for loading previous generated experiences.
input_evaluation (list): How to evaluate the policy performance.
This only makes sense to set when the input is reading offline
data. The possible values include:
- "is": the step-wise importance sampling estimator.
- "wis": the weighted step-wise is estimator.
- "simulation": run the environment in the background, but
use this data for evaluation only and never for learning.
output_creator (func): Function that returns an OutputWriter object
for saving generated experiences.
remote_worker_envs (bool): If using num_envs > 1, whether to create
those new envs in remote processes instead of in the current
process. This adds overheads, but can make sense if your envs
remote_env_batch_wait_ms (float): Timeout that remote workers
are waiting when polling environments. 0 (continue when at
least one env is ready) is a reasonable default, but optimal
value could be obtained by measuring your environment
step / reset and model inference perf.
soft_horizon (bool): Calculate rewards but don't reset the
environment when the horizon is hit.
_fake_sampler (bool): Use a fake (inf speed) sampler for testing.
"""
global _global_worker
_global_worker = self
if log_level:
logging.getLogger("ray.rllib").setLevel(log_level)
if worker_index > 1:
disable_log_once_globally() # only need 1 worker to log
elif log_level == "DEBUG":
enable_periodic_logging()
env_context = EnvContext(env_config or {}, worker_index)
policy_config = policy_config or {}
self.policy_config = policy_config
self.callbacks = callbacks or {}
self.worker_index = worker_index
model_config = model_config or {}
policy_mapping_fn = (policy_mapping_fn
or (lambda agent_id: DEFAULT_POLICY_ID))
if not callable(policy_mapping_fn):
raise ValueError(
"Policy mapping function not callable. If you're using Tune, "
"make sure to escape the function with tune.function() "
"to prevent it from being evaluated as an expression.")
self.env_creator = env_creator
self.sample_batch_size = batch_steps * num_envs
self.batch_mode = batch_mode
self.compress_observations = compress_observations
self.preprocessing_enabled = True
self.last_batch = None
self._fake_sampler = _fake_sampler
self.env = _validate_env(env_creator(env_context))
if isinstance(self.env, MultiAgentEnv) or \
isinstance(self.env, BaseEnv):
def wrap(env):
return env # we can't auto-wrap these env types
elif is_atari(self.env) and \
not model_config.get("custom_preprocessor") and \
preprocessor_pref == "deepmind":
# Deepmind wrappers already handle all preprocessing
self.preprocessing_enabled = False
if clip_rewards is None:
clip_rewards = True
def wrap(env):
env = wrap_deepmind(
env,
dim=model_config.get("dim"),
framestack=model_config.get("framestack"))
if monitor_path:
env = _monitor(env, monitor_path)
return env
else:
def wrap(env):
if monitor_path:
env = _monitor(env, monitor_path)
return env
self.env = wrap(self.env)
def make_env(vector_index):
return wrap(
env_creator(
env_context.copy_with_overrides(
vector_index=vector_index, remote=remote_worker_envs)))
self.tf_sess = None
policy_dict = _validate_and_canonicalize(policy, self.env)
self.policies_to_train = policies_to_train or list(policy_dict.keys())
if _has_tensorflow_graph(policy_dict):
if (ray.is_initialized()
and ray.worker._mode() != ray.worker.LOCAL_MODE
and not ray.get_gpu_ids()):
logger.info("Creating policy evaluation worker {}".format(
worker_index) +
" on CPU (please ignore any CUDA init errors)")
with tf.Graph().as_default():
if tf_session_creator:
self.tf_sess = tf_session_creator()
else:
self.tf_sess = tf.Session(
config=tf.ConfigProto(
gpu_options=tf.GPUOptions(allow_growth=True)))
with self.tf_sess.as_default():
self.policy_map, self.preprocessors = \
self._build_policy_map(policy_dict, policy_config)
else:
self.policy_map, self.preprocessors = self._build_policy_map(
policy_dict, policy_config)
self.multiagent = set(self.policy_map.keys()) != {DEFAULT_POLICY_ID}
if self.multiagent:
if not ((isinstance(self.env, MultiAgentEnv)
or isinstance(self.env, ExternalMultiAgentEnv))
or isinstance(self.env, BaseEnv)):
raise ValueError(
"Have multiple policies {}, but the env ".format(
self.policy_map) +
"{} is not a subclass of BaseEnv, MultiAgentEnv or "
"ExternalMultiAgentEnv?".format(self.env))
self.filters = {
policy_id: get_filter(observation_filter,
policy.observation_space.shape)
for (policy_id, policy) in self.policy_map.items()
}
if self.worker_index == 0:
logger.info("Built filter map: {}".format(self.filters))
# Always use vector env for consistency even if num_envs = 1
self.async_env = BaseEnv.to_base_env(
self.env,
make_env=make_env,
num_envs=num_envs,
remote_envs=remote_worker_envs,
remote_env_batch_wait_ms=remote_env_batch_wait_ms)
self.num_envs = num_envs
if self.batch_mode == "truncate_episodes":
unroll_length = batch_steps
pack_episodes = True
elif self.batch_mode == "complete_episodes":
unroll_length = float("inf") # never cut episodes
pack_episodes = False # sampler will return 1 episode per poll
else:
raise ValueError("Unsupported batch mode: {}".format(
self.batch_mode))
self.io_context = IOContext(log_dir, policy_config, worker_index, self)
self.reward_estimators = []
for method in input_evaluation:
if method == "simulation":
logger.warning(
"Requested 'simulation' input evaluation method: "
"will discard all sampler outputs and keep only metrics.")
sample_async = True
elif method == "is":
ise = ImportanceSamplingEstimator.create(self.io_context)
self.reward_estimators.append(ise)
elif method == "wis":
wise = WeightedImportanceSamplingEstimator.create(
self.io_context)
self.reward_estimators.append(wise)
else:
raise ValueError(
"Unknown evaluation method: {}".format(method))
if sample_async:
self.sampler = AsyncSampler(
self.async_env,
self.policy_map,
policy_mapping_fn,
self.preprocessors,
self.filters,
clip_rewards,
unroll_length,
self.callbacks,
horizon=episode_horizon,
pack=pack_episodes,
tf_sess=self.tf_sess,
clip_actions=clip_actions,
blackhole_outputs="simulation" in input_evaluation,
soft_horizon=soft_horizon)
self.sampler.start()
else:
self.sampler = SyncSampler(
self.async_env,
self.policy_map,
policy_mapping_fn,
self.preprocessors,
self.filters,
clip_rewards,
unroll_length,
self.callbacks,
horizon=episode_horizon,
pack=pack_episodes,
tf_sess=self.tf_sess,
clip_actions=clip_actions,
soft_horizon=soft_horizon)
self.input_reader = input_creator(self.io_context)
assert isinstance(self.input_reader, InputReader), self.input_reader
self.output_writer = output_creator(self.io_context)
assert isinstance(self.output_writer, OutputWriter), self.output_writer
logger.debug(
"Created rollout worker with env {} ({}), policies {}".format(
self.async_env, self.env, self.policy_map))
@override(EvaluatorInterface)
def sample(self):
"""Evaluate the current policies and return a batch of experiences.
Return:
SampleBatch|MultiAgentBatch from evaluating the current policies.
"""
if self._fake_sampler and self.last_batch is not None:
return self.last_batch
if log_once("sample_start"):
logger.info("Generating sample batch of size {}".format(
self.sample_batch_size))
batches = [self.input_reader.next()]
steps_so_far = batches[0].count
# In truncate_episodes mode, never pull more than 1 batch per env.
# This avoids over-running the target batch size.
if self.batch_mode == "truncate_episodes":
max_batches = self.num_envs
else:
max_batches = float("inf")
while steps_so_far < self.sample_batch_size and len(
batches) < max_batches:
batch = self.input_reader.next()
steps_so_far += batch.count
batches.append(batch)
batch = batches[0].concat_samples(batches)
if self.callbacks.get("on_sample_end"):
self.callbacks["on_sample_end"]({"worker": self, "samples": batch})
# Always do writes prior to compression for consistency and to allow
# for better compression inside the writer.
self.output_writer.write(batch)
# Do off-policy estimation if needed
if self.reward_estimators:
for sub_batch in batch.split_by_episode():
for estimator in self.reward_estimators:
estimator.process(sub_batch)
if log_once("sample_end"):
logger.info("Completed sample batch:\n\n{}\n".format(
summarize(batch)))
if self.compress_observations == "bulk":
batch.compress(bulk=True)
elif self.compress_observations:
batch.compress()
if self._fake_sampler:
self.last_batch = batch
return batch
@DeveloperAPI
@ray.method(num_return_vals=2)
def sample_with_count(self):
"""Same as sample() but returns the count as a separate future."""
batch = self.sample()
return batch, batch.count
@override(EvaluatorInterface)
def get_weights(self, policies=None):
if policies is None:
policies = self.policy_map.keys()
return {
pid: policy.get_weights()
for pid, policy in self.policy_map.items() if pid in policies
}
@override(EvaluatorInterface)
def set_weights(self, weights):
for pid, w in weights.items():
self.policy_map[pid].set_weights(w)
@override(EvaluatorInterface)
def compute_gradients(self, samples):
if log_once("compute_gradients"):
logger.info("Compute gradients on:\n\n{}\n".format(
summarize(samples)))
if isinstance(samples, MultiAgentBatch):
grad_out, info_out = {}, {}
if self.tf_sess is not None:
builder = TFRunBuilder(self.tf_sess, "compute_gradients")
for pid, batch in samples.policy_batches.items():
if pid not in self.policies_to_train:
continue
grad_out[pid], info_out[pid] = (
self.policy_map[pid]._build_compute_gradients(
builder, batch))
grad_out = {k: builder.get(v) for k, v in grad_out.items()}
info_out = {k: builder.get(v) for k, v in info_out.items()}
else:
for pid, batch in samples.policy_batches.items():
if pid not in self.policies_to_train:
continue
grad_out[pid], info_out[pid] = (
self.policy_map[pid].compute_gradients(batch))
else:
grad_out, info_out = (
self.policy_map[DEFAULT_POLICY_ID].compute_gradients(samples))
info_out["batch_count"] = samples.count
if log_once("grad_out"):
logger.info("Compute grad info:\n\n{}\n".format(
summarize(info_out)))
return grad_out, info_out
@override(EvaluatorInterface)
def apply_gradients(self, grads):
if log_once("apply_gradients"):
logger.info("Apply gradients:\n\n{}\n".format(summarize(grads)))
if isinstance(grads, dict):
if self.tf_sess is not None:
builder = TFRunBuilder(self.tf_sess, "apply_gradients")
outputs = {
pid: self.policy_map[pid]._build_apply_gradients(
builder, grad)
for pid, grad in grads.items()
}
return {k: builder.get(v) for k, v in outputs.items()}
else:
return {
pid: self.policy_map[pid].apply_gradients(g)
for pid, g in grads.items()
}
else:
return self.policy_map[DEFAULT_POLICY_ID].apply_gradients(grads)
@override(EvaluatorInterface)
def learn_on_batch(self, samples):
if log_once("learn_on_batch"):
logger.info(
"Training on concatenated sample batches:\n\n{}\n".format(
summarize(samples)))
if isinstance(samples, MultiAgentBatch):
info_out = {}
to_fetch = {}
if self.tf_sess is not None:
builder = TFRunBuilder(self.tf_sess, "learn_on_batch")
else:
builder = None
for pid, batch in samples.policy_batches.items():
if pid not in self.policies_to_train:
continue
policy = self.policy_map[pid]
if builder and hasattr(policy, "_build_learn_on_batch"):
to_fetch[pid] = policy._build_learn_on_batch(
builder, batch)
else:
info_out[pid] = policy.learn_on_batch(batch)
info_out.update({k: builder.get(v) for k, v in to_fetch.items()})
else:
info_out = self.policy_map[DEFAULT_POLICY_ID].learn_on_batch(
samples)
if log_once("learn_out"):
logger.info("Training output:\n\n{}\n".format(summarize(info_out)))
return info_out
@DeveloperAPI
def get_metrics(self):
"""Returns a list of new RolloutMetric objects from evaluation."""
out = self.sampler.get_metrics()
for m in self.reward_estimators:
out.extend(m.get_metrics())
return out
@DeveloperAPI
def foreach_env(self, func):
"""Apply the given function to each underlying env instance."""
envs = self.async_env.get_unwrapped()
if not envs:
return [func(self.async_env)]
else:
return [func(e) for e in envs]
@DeveloperAPI
def get_policy(self, policy_id=DEFAULT_POLICY_ID):
"""Return policy for the specified id, or None.
Arguments:
policy_id (str): id of policy to return.
"""
return self.policy_map.get(policy_id)
@DeveloperAPI
def for_policy(self, func, policy_id=DEFAULT_POLICY_ID):
"""Apply the given function to the specified policy."""
return func(self.policy_map[policy_id])
@DeveloperAPI
def foreach_policy(self, func):
"""Apply the given function to each (policy, policy_id) tuple."""
return [func(policy, pid) for pid, policy in self.policy_map.items()]
@DeveloperAPI
def foreach_trainable_policy(self, func):
"""Apply the given function to each (policy, policy_id) tuple.
This only applies func to policies in `self.policies_to_train`."""
return [
func(policy, pid) for pid, policy in self.policy_map.items()
if pid in self.policies_to_train
]
@DeveloperAPI
def sync_filters(self, new_filters):
"""Changes self's filter to given and rebases any accumulated delta.
Args:
new_filters (dict): Filters with new state to update local copy.
"""
assert all(k in new_filters for k in self.filters)
for k in self.filters:
self.filters[k].sync(new_filters[k])
@DeveloperAPI
def get_filters(self, flush_after=False):
"""Returns a snapshot of filters.
Args:
flush_after (bool): Clears the filter buffer state.
Returns:
return_filters (dict): Dict for serializable filters
"""
return_filters = {}
for k, f in self.filters.items():
return_filters[k] = f.as_serializable()
if flush_after:
f.clear_buffer()
return return_filters
@DeveloperAPI
def save(self):
filters = self.get_filters(flush_after=True)
state = {
pid: self.policy_map[pid].get_state()
for pid in self.policy_map
}
return pickle.dumps({"filters": filters, "state": state})
@DeveloperAPI
def restore(self, objs):
objs = pickle.loads(objs)
self.sync_filters(objs["filters"])
for pid, state in objs["state"].items():
self.policy_map[pid].set_state(state)
@DeveloperAPI
def set_global_vars(self, global_vars):
self.foreach_policy(lambda p, _: p.on_global_var_update(global_vars))
@DeveloperAPI
def export_policy_model(self, export_dir, policy_id=DEFAULT_POLICY_ID):
self.policy_map[policy_id].export_model(export_dir)
@DeveloperAPI
def export_policy_checkpoint(self,
export_dir,
filename_prefix="model",
policy_id=DEFAULT_POLICY_ID):
self.policy_map[policy_id].export_checkpoint(export_dir,
filename_prefix)
@DeveloperAPI
def stop(self):
self.async_env.stop()
def _build_policy_map(self, policy_dict, policy_config):
policy_map = {}
preprocessors = {}
for name, (cls, obs_space, act_space,
conf) in sorted(policy_dict.items()):
logger.debug("Creating policy for {}".format(name))
merged_conf = merge_dicts(policy_config, conf)
if self.preprocessing_enabled:
preprocessor = ModelCatalog.get_preprocessor_for_space(
obs_space, merged_conf.get("model"))
preprocessors[name] = preprocessor
obs_space = preprocessor.observation_space
else:
preprocessors[name] = NoPreprocessor(obs_space)
if isinstance(obs_space, gym.spaces.Dict) or \
isinstance(obs_space, gym.spaces.Tuple):
raise ValueError(
"Found raw Tuple|Dict space as input to policy. "
"Please preprocess these observations with a "
"Tuple|DictFlatteningPreprocessor.")
if tf:
with tf.variable_scope(name):
policy_map[name] = cls(obs_space, act_space, merged_conf)
else:
policy_map[name] = cls(obs_space, act_space, merged_conf)
if self.worker_index == 0:
logger.info("Built policy map: {}".format(policy_map))
logger.info("Built preprocessor map: {}".format(preprocessors))
return policy_map, preprocessors
def __del__(self):
if hasattr(self, "sampler") and isinstance(self.sampler, AsyncSampler):
self.sampler.shutdown = True
def _validate_and_canonicalize(policy, env):
if isinstance(policy, dict):
_validate_multiagent_config(policy)
return policy
elif not issubclass(policy, Policy):
raise ValueError("policy must be a rllib.Policy class")
else:
if (isinstance(env, MultiAgentEnv)
and not hasattr(env, "observation_space")):
raise ValueError(
"MultiAgentEnv must have observation_space defined if run "
"in a single-agent configuration.")
return {
DEFAULT_POLICY_ID: (policy, env.observation_space,
env.action_space, {})
}
def _validate_multiagent_config(policy, allow_none_graph=False):
for k, v in policy.items():
if not isinstance(k, str):
raise ValueError("policy keys must be strs, got {}".format(
type(k)))
if not isinstance(v, tuple) or len(v) != 4:
raise ValueError(
"policy values must be tuples of "
"(cls, obs_space, action_space, config), got {}".format(v))
if allow_none_graph and v[0] is None:
pass
elif not issubclass(v[0], Policy):
raise ValueError("policy tuple value 0 must be a rllib.Policy "
"class or None, got {}".format(v[0]))
if not isinstance(v[1], gym.Space):
raise ValueError(
"policy tuple value 1 (observation_space) must be a "
"gym.Space, got {}".format(type(v[1])))
if not isinstance(v[2], gym.Space):
raise ValueError("policy tuple value 2 (action_space) must be a "
"gym.Space, got {}".format(type(v[2])))
if not isinstance(v[3], dict):
raise ValueError("policy tuple value 3 (config) must be a dict, "
"got {}".format(type(v[3])))
def _validate_env(env):
# allow this as a special case (assumed gym.Env)
if hasattr(env, "observation_space") and hasattr(env, "action_space"):
return env
allowed_types = [gym.Env, MultiAgentEnv, ExternalEnv, VectorEnv, BaseEnv]
if not any(isinstance(env, tpe) for tpe in allowed_types):
raise ValueError(
"Returned env should be an instance of gym.Env, MultiAgentEnv, "
"ExternalEnv, VectorEnv, or BaseEnv. The provided env creator "
"function returned {} ({}).".format(env, type(env)))
return env
def _monitor(env, path):
return gym.wrappers.Monitor(env, path, resume=True)
def _has_tensorflow_graph(policy_dict):
for policy, _, _, _ in policy_dict.values():
if issubclass(policy, TFPolicy):
return True
return False
+214
View File
@@ -0,0 +1,214 @@
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import logging
from types import FunctionType
from ray.rllib.utils.annotations import DeveloperAPI
from ray.rllib.evaluation.rollout_worker import RolloutWorker, \
_validate_multiagent_config
from ray.rllib.offline import NoopOutput, JsonReader, MixedInput, JsonWriter, \
ShuffledInput
from ray.rllib.utils import merge_dicts, try_import_tf
from ray.rllib.utils.memory import ray_get_and_free
tf = try_import_tf()
logger = logging.getLogger(__name__)
@DeveloperAPI
class WorkerSet(object):
"""Represents a set of RolloutWorkers.
There must be one local worker copy, and zero or more remote workers.
"""
def __init__(self,
env_creator,
policy,
trainer_config=None,
num_workers=0,
logdir=None,
_setup=True):
"""Create a new WorkerSet and initialize its workers.
Arguments:
env_creator (func): Function that returns env given env config.
policy (cls): rllib.policy.Policy class.
trainer_config (dict): Optional dict that extends the common
config of the Trainer class.
num_workers (int): Number of remote rollout workers to create.
logdir (str): Optional logging directory for workers.
_setup (bool): Whether to setup workers. This is only for testing.
"""
if not trainer_config:
from ray.rllib.agents.trainer import COMMON_CONFIG
trainer_config = COMMON_CONFIG
self._env_creator = env_creator
self._policy = policy
self._remote_config = trainer_config
self._num_workers = num_workers
self._logdir = logdir
if _setup:
self._local_config = merge_dicts(
trainer_config,
{"tf_session_args": trainer_config["local_tf_session_args"]})
# Always create a local worker
self._local_worker = self._make_worker(
RolloutWorker, env_creator, policy, 0, self._local_config)
# Create a number of remote workers
self._remote_workers = []
self.add_workers(num_workers)
def local_worker(self):
"""Return the local rollout worker."""
return self._local_worker
def remote_workers(self):
"""Return a list of remote rollout workers."""
return self._remote_workers
def add_workers(self, num_workers):
"""Create and add a number of remote workers to this worker set."""
remote_args = {
"num_cpus": self._remote_config["num_cpus_per_worker"],
"num_gpus": self._remote_config["num_gpus_per_worker"],
"resources": self._remote_config["custom_resources_per_worker"],
}
cls = RolloutWorker.as_remote(**remote_args).remote
self._remote_workers.extend([
self._make_worker(cls, self._env_creator, self._policy, i + 1,
self._remote_config) for i in range(num_workers)
])
def reset(self, new_remote_workers):
"""Called to change the set of remote workers."""
self._remote_workers = new_remote_workers
def stop(self):
"""Stop all rollout workers."""
self.local_worker().stop()
for w in self.remote_workers():
w.stop.remote()
w.__ray_terminate__.remote()
@DeveloperAPI
def foreach_worker(self, func):
"""Apply the given function to each worker instance."""
local_result = [func(self.local_worker())]
remote_results = ray_get_and_free(
[w.apply.remote(func) for w in self.remote_workers()])
return local_result + remote_results
@DeveloperAPI
def foreach_worker_with_index(self, func):
"""Apply the given function to each worker instance.
The index will be passed as the second arg to the given function.
"""
local_result = [func(self.local_worker(), 0)]
remote_results = ray_get_and_free([
w.apply.remote(func, i + 1)
for i, w in enumerate(self.remote_workers())
])
return local_result + remote_results
@staticmethod
def _from_existing(local_worker, remote_workers=None):
workers = WorkerSet(None, None, {}, _setup=False)
workers._local_worker = local_worker
workers._remote_workers = remote_workers or []
return workers
def _make_worker(self, cls, env_creator, policy, worker_index, config):
def session_creator():
logger.debug("Creating TF session {}".format(
config["tf_session_args"]))
return tf.Session(
config=tf.ConfigProto(**config["tf_session_args"]))
if isinstance(config["input"], FunctionType):
input_creator = config["input"]
elif config["input"] == "sampler":
input_creator = (lambda ioctx: ioctx.default_sampler_input())
elif isinstance(config["input"], dict):
input_creator = (lambda ioctx: ShuffledInput(
MixedInput(config["input"], ioctx), config[
"shuffle_buffer_size"]))
else:
input_creator = (lambda ioctx: ShuffledInput(
JsonReader(config["input"], ioctx), config[
"shuffle_buffer_size"]))
if isinstance(config["output"], FunctionType):
output_creator = config["output"]
elif config["output"] is None:
output_creator = (lambda ioctx: NoopOutput())
elif config["output"] == "logdir":
output_creator = (lambda ioctx: JsonWriter(
ioctx.log_dir,
ioctx,
max_file_size=config["output_max_file_size"],
compress_columns=config["output_compress_columns"]))
else:
output_creator = (lambda ioctx: JsonWriter(
config["output"],
ioctx,
max_file_size=config["output_max_file_size"],
compress_columns=config["output_compress_columns"]))
if config["input"] == "sampler":
input_evaluation = []
else:
input_evaluation = config["input_evaluation"]
# Fill in the default policy if 'None' is specified in multiagent
if config["multiagent"]["policies"]:
tmp = config["multiagent"]["policies"]
_validate_multiagent_config(tmp, allow_none_graph=True)
for k, v in tmp.items():
if v[0] is None:
tmp[k] = (policy, v[1], v[2], v[3])
policy = tmp
return cls(
env_creator,
policy,
policy_mapping_fn=config["multiagent"]["policy_mapping_fn"],
policies_to_train=config["multiagent"]["policies_to_train"],
tf_session_creator=(session_creator
if config["tf_session_args"] else None),
batch_steps=config["sample_batch_size"],
batch_mode=config["batch_mode"],
episode_horizon=config["horizon"],
preprocessor_pref=config["preprocessor_pref"],
sample_async=config["sample_async"],
compress_observations=config["compress_observations"],
num_envs=config["num_envs_per_worker"],
observation_filter=config["observation_filter"],
clip_rewards=config["clip_rewards"],
clip_actions=config["clip_actions"],
env_config=config["env_config"],
model_config=config["model"],
policy_config=config,
worker_index=worker_index,
monitor_path=self._logdir if config["monitor"] else None,
log_dir=self._logdir,
log_level=config["log_level"],
callbacks=config["callbacks"],
input_creator=input_creator,
input_evaluation=input_evaluation,
output_creator=output_creator,
remote_worker_envs=config["remote_worker_envs"],
remote_env_batch_wait_ms=config["remote_env_batch_wait_ms"],
soft_horizon=config["soft_horizon"],
_fake_sampler=config.get("_fake_sampler", False))
@@ -75,7 +75,7 @@ if __name__ == "__main__":
})
# disable DQN exploration when used by the PPO trainer
ppo_trainer.optimizer.foreach_evaluator(
ppo_trainer.workers.foreach_worker(
lambda ev: ev.for_policy(
lambda pi: pi.set_epsilon(0.0), policy_id="dqn_policy"))
@@ -1,4 +1,4 @@
"""Example of using policy evaluator classes directly to implement training.
"""Example of using rollout worker classes directly to implement training.
Instead of using the built-in Trainer classes provided by RLlib, here we define
a custom Policy class and manually coordinate distributed sample
@@ -15,7 +15,7 @@ import gym
import ray
from ray import tune
from ray.rllib.policy import Policy
from ray.rllib.evaluation import PolicyEvaluator, SampleBatch
from ray.rllib.evaluation import RolloutWorker, SampleBatch
from ray.rllib.evaluation.metrics import collect_metrics
parser = argparse.ArgumentParser()
@@ -67,8 +67,8 @@ def training_workflow(config, reporter):
env = gym.make("CartPole-v0")
policy = CustomPolicy(env.observation_space, env.action_space, {})
workers = [
PolicyEvaluator.as_remote().remote(lambda c: gym.make("CartPole-v0"),
CustomPolicy)
RolloutWorker.as_remote().remote(lambda c: gym.make("CartPole-v0"),
CustomPolicy)
for _ in range(config["num_workers"])
]
@@ -97,7 +97,7 @@ def training_workflow(config, reporter):
# Do some arbitrary updates based on the T2 batch
policy.update_some_value(sum(T2["rewards"]))
reporter(**collect_metrics(remote_evaluators=workers))
reporter(**collect_metrics(remote_workers=workers))
if __name__ == "__main__":
+4 -8
View File
@@ -18,20 +18,16 @@ class IOContext(object):
config (dict): Configuration of the agent.
worker_index (int): When there are multiple workers created, this
uniquely identifies the current worker.
evaluator (PolicyEvaluator): policy evaluator object reference.
worker (RolloutWorker): rollout worker object reference.
"""
@PublicAPI
def __init__(self,
log_dir=None,
config=None,
worker_index=0,
evaluator=None):
def __init__(self, log_dir=None, config=None, worker_index=0, worker=None):
self.log_dir = log_dir or os.getcwd()
self.config = config or {}
self.worker_index = worker_index
self.evaluator = evaluator
self.worker = worker
@PublicAPI
def default_sampler_input(self):
return self.evaluator.sampler
return self.worker.sampler
+1 -1
View File
@@ -88,7 +88,7 @@ class JsonReader(InputReader):
if isinstance(batch, SampleBatch):
out = []
for sub_batch in batch.split_by_episode():
out.append(self.ioctx.evaluator.policy_map[DEFAULT_POLICY_ID]
out.append(self.ioctx.worker.policy_map[DEFAULT_POLICY_ID]
.postprocess_trajectory(sub_batch))
return SampleBatch.concat_samples(out)
else:
@@ -33,14 +33,14 @@ class OffPolicyEstimator(object):
@classmethod
def create(cls, ioctx):
"""Create an off-policy estimator from a IOContext."""
gamma = ioctx.evaluator.policy_config["gamma"]
gamma = ioctx.worker.policy_config["gamma"]
# Grab a reference to the current model
keys = list(ioctx.evaluator.policy_map.keys())
keys = list(ioctx.worker.policy_map.keys())
if len(keys) > 1:
raise NotImplementedError(
"Off-policy estimation is not implemented for multi-agent. "
"You can set `input_evaluation: []` to resolve this.")
policy = ioctx.evaluator.get_policy(keys[0])
policy = ioctx.worker.get_policy(keys[0])
return cls(policy, gamma)
@DeveloperAPI
+16 -16
View File
@@ -14,7 +14,7 @@ from ray.rllib.utils.memory import ray_get_and_free
class Aggregator(object):
"""An aggregator collects and processes samples from evaluators.
"""An aggregator collects and processes samples from workers.
This class is used to abstract away the strategy for sample collection.
For example, you may want to use a tree of actors to collect samples. The
@@ -22,21 +22,21 @@ class Aggregator(object):
as concatenating and decompressing sample batches.
Attributes:
local_evaluator: local PolicyEvaluator copy
local_worker: local RolloutWorker copy
"""
def iter_train_batches(self):
"""Returns a generator over batches ready to learn on.
Iterating through this generator will also send out weight updates to
remote evaluators as needed.
remote workers as needed.
This call may block until results are available.
"""
raise NotImplementedError
def broadcast_new_weights(self):
"""Broadcast a new set of weights from the local evaluator."""
"""Broadcast a new set of weights from the local workers."""
raise NotImplementedError
def should_broadcast(self):
@@ -47,19 +47,19 @@ class Aggregator(object):
"""Returns runtime statistics for debugging."""
raise NotImplementedError
def reset(self, remote_evaluators):
"""Called to change the set of remote evaluators being used."""
def reset(self, remote_workers):
"""Called to change the set of remote workers being used."""
raise NotImplementedError
class AggregationWorkerBase(object):
"""Aggregators should extend from this class."""
def __init__(self, initial_weights_obj_id, remote_evaluators,
def __init__(self, initial_weights_obj_id, remote_workers,
max_sample_requests_in_flight_per_worker, replay_proportion,
replay_buffer_num_slots, train_batch_size, sample_batch_size):
self.broadcasted_weights = initial_weights_obj_id
self.remote_evaluators = remote_evaluators
self.remote_workers = remote_workers
self.sample_batch_size = sample_batch_size
self.train_batch_size = train_batch_size
@@ -73,7 +73,7 @@ class AggregationWorkerBase(object):
# Kick off async background sampling
self.sample_tasks = TaskPool()
for ev in self.remote_evaluators:
for ev in self.remote_workers:
ev.set_weights.remote(self.broadcasted_weights)
for _ in range(max_sample_requests_in_flight_per_worker):
self.sample_tasks.add(ev, ev.sample.remote())
@@ -138,8 +138,8 @@ class AggregationWorkerBase(object):
}
@override(Aggregator)
def reset(self, remote_evaluators):
self.sample_tasks.reset_evaluators(remote_evaluators)
def reset(self, remote_workers):
self.sample_tasks.reset_workers(remote_workers)
def _augment_with_replay(self, sample_futures):
def can_replay():
@@ -164,25 +164,25 @@ class SimpleAggregator(AggregationWorkerBase, Aggregator):
"""Simple single-threaded implementation of an Aggregator."""
def __init__(self,
local_evaluator,
remote_evaluators,
workers,
max_sample_requests_in_flight_per_worker=2,
replay_proportion=0.0,
replay_buffer_num_slots=0,
train_batch_size=500,
sample_batch_size=50,
broadcast_interval=5):
self.local_evaluator = local_evaluator
self.workers = workers
self.local_worker = workers.local_worker()
self.broadcast_interval = broadcast_interval
self.broadcast_new_weights()
AggregationWorkerBase.__init__(
self, self.broadcasted_weights, remote_evaluators,
self, self.broadcasted_weights, self.workers.remote_workers(),
max_sample_requests_in_flight_per_worker, replay_proportion,
replay_buffer_num_slots, train_batch_size, sample_batch_size)
@override(Aggregator)
def broadcast_new_weights(self):
self.broadcasted_weights = ray.put(self.local_evaluator.get_weights())
self.broadcasted_weights = ray.put(self.local_worker.get_weights())
self.num_sent_since_broadcast = 0
@override(Aggregator)
+3 -3
View File
@@ -25,11 +25,11 @@ class LearnerThread(threading.Thread):
improves overall throughput.
"""
def __init__(self, local_evaluator, minibatch_buffer_size, num_sgd_iter,
def __init__(self, local_worker, minibatch_buffer_size, num_sgd_iter,
learner_queue_size):
threading.Thread.__init__(self)
self.learner_queue_size = WindowStat("size", 50)
self.local_evaluator = local_evaluator
self.local_worker = local_worker
self.inqueue = queue.Queue(maxsize=learner_queue_size)
self.outqueue = queue.Queue()
self.minibatch_buffer = MinibatchBuffer(
@@ -52,7 +52,7 @@ class LearnerThread(threading.Thread):
batch, _ = self.minibatch_buffer.get()
with self.grad_timer:
fetches = self.local_evaluator.learn_on_batch(batch)
fetches = self.local_worker.learn_on_batch(batch)
self.weights_updated = True
self.stats = get_learner_stats(fetches)
@@ -31,7 +31,7 @@ class TFMultiGPULearner(LearnerThread):
"""
def __init__(self,
local_evaluator,
local_worker,
num_gpus=1,
lr=0.0005,
train_batch_size=500,
@@ -41,7 +41,7 @@ class TFMultiGPULearner(LearnerThread):
learner_queue_size=16,
num_data_load_threads=16,
_fake_gpus=False):
LearnerThread.__init__(self, local_evaluator, minibatch_buffer_size,
LearnerThread.__init__(self, local_worker, minibatch_buffer_size,
num_sgd_iter, learner_queue_size)
self.lr = lr
self.train_batch_size = train_batch_size
@@ -59,16 +59,16 @@ class TFMultiGPULearner(LearnerThread):
assert self.train_batch_size % len(self.devices) == 0
assert self.train_batch_size >= len(self.devices), "batch too small"
if set(self.local_evaluator.policy_map.keys()) != {DEFAULT_POLICY_ID}:
if set(self.local_worker.policy_map.keys()) != {DEFAULT_POLICY_ID}:
raise NotImplementedError("Multi-gpu mode for multi-agent")
self.policy = self.local_evaluator.policy_map[DEFAULT_POLICY_ID]
self.policy = self.local_worker.policy_map[DEFAULT_POLICY_ID]
# per-GPU graph copies created below must share vars with the policy
# reuse is set to AUTO_REUSE because Adam nodes are created after
# all of the device copies are created.
self.par_opt = []
with self.local_evaluator.tf_sess.graph.as_default():
with self.local_evaluator.tf_sess.as_default():
with self.local_worker.tf_sess.graph.as_default():
with self.local_worker.tf_sess.as_default():
with tf.variable_scope(DEFAULT_POLICY_ID, reuse=tf.AUTO_REUSE):
if self.policy._state_inputs:
rnn_inputs = self.policy._state_inputs + [
@@ -87,7 +87,7 @@ class TFMultiGPULearner(LearnerThread):
999999, # it will get rounded down
self.policy.copy))
self.sess = self.local_evaluator.tf_sess
self.sess = self.local_worker.tf_sess
self.sess.run(tf.global_variables_initializer())
self.idle_optimizers = queue.Queue()
@@ -22,15 +22,14 @@ logger = logging.getLogger(__name__)
class TreeAggregator(Aggregator):
"""A hierarchical experiences aggregator.
The given set of remote evaluators is divided into subsets and assigned to
The given set of remote workers is divided into subsets and assigned to
one of several aggregation workers. These aggregation workers collate
experiences into batches of size `train_batch_size` and we collect them
in this class when `iter_train_batches` is called.
"""
def __init__(self,
local_evaluator,
remote_evaluators,
workers,
num_aggregation_workers,
max_sample_requests_in_flight_per_worker=2,
replay_proportion=0.0,
@@ -38,8 +37,7 @@ class TreeAggregator(Aggregator):
train_batch_size=500,
sample_batch_size=50,
broadcast_interval=5):
self.local_evaluator = local_evaluator
self.remote_evaluators = remote_evaluators
self.workers = workers
self.num_aggregation_workers = num_aggregation_workers
self.max_sample_requests_in_flight_per_worker = \
max_sample_requests_in_flight_per_worker
@@ -48,7 +46,8 @@ class TreeAggregator(Aggregator):
self.sample_batch_size = sample_batch_size
self.train_batch_size = train_batch_size
self.broadcast_interval = broadcast_interval
self.broadcasted_weights = ray.put(local_evaluator.get_weights())
self.broadcasted_weights = ray.put(
workers.local_worker().get_weights())
self.num_batches_processed = 0
self.num_broadcasts = 0
self.num_sent_since_broadcast = 0
@@ -58,26 +57,27 @@ class TreeAggregator(Aggregator):
"""Deferred init so that we can pass in previously created workers."""
assert len(aggregators) == self.num_aggregation_workers, aggregators
if len(self.remote_evaluators) < self.num_aggregation_workers:
if len(self.workers.remote_workers()) < self.num_aggregation_workers:
raise ValueError(
"The number of aggregation workers should not exceed the "
"number of total evaluation workers ({} vs {})".format(
self.num_aggregation_workers, len(self.remote_evaluators)))
self.num_aggregation_workers,
len(self.workers.remote_workers())))
assigned_evaluators = collections.defaultdict(list)
for i, ev in enumerate(self.remote_evaluators):
assigned_evaluators[i % self.num_aggregation_workers].append(ev)
assigned_workers = collections.defaultdict(list)
for i, ev in enumerate(self.workers.remote_workers()):
assigned_workers[i % self.num_aggregation_workers].append(ev)
self.workers = aggregators
for i, worker in enumerate(self.workers):
worker.init.remote(
self.broadcasted_weights, assigned_evaluators[i],
self.max_sample_requests_in_flight_per_worker,
self.replay_proportion, self.replay_buffer_num_slots,
self.train_batch_size, self.sample_batch_size)
self.aggregators = aggregators
for i, agg in enumerate(self.aggregators):
agg.init.remote(self.broadcasted_weights, assigned_workers[i],
self.max_sample_requests_in_flight_per_worker,
self.replay_proportion,
self.replay_buffer_num_slots,
self.train_batch_size, self.sample_batch_size)
self.agg_tasks = TaskPool()
for agg in self.workers:
for agg in self.aggregators:
agg.set_weights.remote(self.broadcasted_weights)
self.agg_tasks.add(agg, agg.get_train_batches.remote())
@@ -96,7 +96,8 @@ class TreeAggregator(Aggregator):
@override(Aggregator)
def broadcast_new_weights(self):
self.broadcasted_weights = ray.put(self.local_evaluator.get_weights())
self.broadcasted_weights = ray.put(
self.workers.local_worker().get_weights())
self.num_sent_since_broadcast = 0
self.num_broadcasts += 1
@@ -112,8 +113,8 @@ class TreeAggregator(Aggregator):
}
@override(Aggregator)
def reset(self, remote_evaluators):
raise NotImplementedError("changing number of remote evaluators")
def reset(self, remote_workers):
raise NotImplementedError("changing number of remote workers")
@staticmethod
def precreate_aggregators(n):
@@ -125,16 +126,16 @@ class AggregationWorker(AggregationWorkerBase):
def __init__(self):
self.initialized = False
def init(self, initial_weights_obj_id, remote_evaluators,
def init(self, initial_weights_obj_id, remote_workers,
max_sample_requests_in_flight_per_worker, replay_proportion,
replay_buffer_num_slots, train_batch_size, sample_batch_size):
"""Deferred init that assigns sub-workers to this aggregator."""
logger.info("Assigned evaluators {} to aggregation worker {}".format(
remote_evaluators, self))
assert remote_evaluators
logger.info("Assigned workers {} to aggregation worker {}".format(
remote_workers, self))
assert remote_workers
AggregationWorkerBase.__init__(
self, initial_weights_obj_id, remote_evaluators,
self, initial_weights_obj_id, remote_workers,
max_sample_requests_in_flight_per_worker, replay_proportion,
replay_buffer_num_slots, train_batch_size, sample_batch_size)
self.initialized = True
@@ -14,30 +14,30 @@ class AsyncGradientsOptimizer(PolicyOptimizer):
"""An asynchronous RL optimizer, e.g. for implementing A3C.
This optimizer asynchronously pulls and applies gradients from remote
evaluators, sending updated weights back as needed. This pipelines the
workers, sending updated weights back as needed. This pipelines the
gradient computations on the remote workers.
"""
def __init__(self, local_evaluator, remote_evaluators, grads_per_step=100):
PolicyOptimizer.__init__(self, local_evaluator, remote_evaluators)
def __init__(self, workers, grads_per_step=100):
PolicyOptimizer.__init__(self, workers)
self.apply_timer = TimerStat()
self.wait_timer = TimerStat()
self.dispatch_timer = TimerStat()
self.grads_per_step = grads_per_step
self.learner_stats = {}
if not self.remote_evaluators:
if not self.workers.remote_workers():
raise ValueError(
"Async optimizer requires at least 1 remote evaluator")
"Async optimizer requires at least 1 remote workers")
@override(PolicyOptimizer)
def step(self):
weights = ray.put(self.local_evaluator.get_weights())
weights = ray.put(self.workers.local_worker().get_weights())
pending_gradients = {}
num_gradients = 0
# Kick off the first wave of async tasks
for e in self.remote_evaluators:
for e in self.workers.remote_workers():
e.set_weights.remote(weights)
future = e.compute_gradients.remote(e.sample.remote())
pending_gradients[future] = e
@@ -56,13 +56,14 @@ class AsyncGradientsOptimizer(PolicyOptimizer):
if gradient is not None:
with self.apply_timer:
self.local_evaluator.apply_gradients(gradient)
self.workers.local_worker().apply_gradients(gradient)
self.num_steps_sampled += info["batch_count"]
self.num_steps_trained += info["batch_count"]
if num_gradients < self.grads_per_step:
with self.dispatch_timer:
e.set_weights.remote(self.local_evaluator.get_weights())
e.set_weights.remote(
self.workers.local_worker().get_weights())
future = e.compute_gradients.remote(e.sample.remote())
pending_gradients[future] = e
@@ -36,20 +36,19 @@ class AsyncReplayOptimizer(PolicyOptimizer):
"""Main event loop of the Ape-X optimizer (async sampling with replay).
This class coordinates the data transfers between the learner thread,
remote evaluators (Ape-X actors), and replay buffer actors.
remote workers (Ape-X actors), and replay buffer actors.
This has two modes of operation:
- normal replay: replays independent samples.
- batch replay: simplified mode where entire sample batches are
replayed. This supports RNNs, but not prioritization.
This optimizer requires that policy evaluators return an additional
This optimizer requires that rollout workers return an additional
"td_error" array in the info return of compute_gradients(). This error
term will be used for sample prioritization."""
def __init__(self,
local_evaluator,
remote_evaluators,
workers,
learning_starts=1000,
buffer_size=10000,
prioritized_replay=True,
@@ -62,7 +61,7 @@ class AsyncReplayOptimizer(PolicyOptimizer):
max_weight_sync_delay=400,
debug=False,
batch_replay=False):
PolicyOptimizer.__init__(self, local_evaluator, remote_evaluators)
PolicyOptimizer.__init__(self, workers)
self.debug = debug
self.batch_replay = batch_replay
@@ -71,7 +70,7 @@ class AsyncReplayOptimizer(PolicyOptimizer):
self.prioritized_replay_eps = prioritized_replay_eps
self.max_weight_sync_delay = max_weight_sync_delay
self.learner = LearnerThread(self.local_evaluator)
self.learner = LearnerThread(self.workers.local_worker())
self.learner.start()
if self.batch_replay:
@@ -111,13 +110,13 @@ class AsyncReplayOptimizer(PolicyOptimizer):
# Kick off async background sampling
self.sample_tasks = TaskPool()
if self.remote_evaluators:
self._set_evaluators(self.remote_evaluators)
if self.workers.remote_workers():
self._set_workers(self.workers.remote_workers())
@override(PolicyOptimizer)
def step(self):
assert self.learner.is_alive()
assert len(self.remote_evaluators) > 0
assert len(self.workers.remote_workers()) > 0
start = time.time()
sample_timesteps, train_timesteps = self._step()
time_delta = time.time() - start
@@ -138,9 +137,9 @@ class AsyncReplayOptimizer(PolicyOptimizer):
self.learner.stopped = True
@override(PolicyOptimizer)
def reset(self, remote_evaluators):
self.remote_evaluators = remote_evaluators
self.sample_tasks.reset_evaluators(remote_evaluators)
def reset(self, remote_workers):
self.workers.reset(remote_workers)
self.sample_tasks.reset_workers(remote_workers)
@override(PolicyOptimizer)
def stats(self):
@@ -175,10 +174,10 @@ class AsyncReplayOptimizer(PolicyOptimizer):
return dict(PolicyOptimizer.stats(self), **stats)
# For https://github.com/ray-project/ray/issues/2541 only
def _set_evaluators(self, remote_evaluators):
self.remote_evaluators = remote_evaluators
weights = self.local_evaluator.get_weights()
for ev in self.remote_evaluators:
def _set_workers(self, remote_workers):
self.workers.reset(remote_workers)
weights = self.workers.local_worker().get_weights()
for ev in self.workers.remote_workers():
ev.set_weights.remote(weights)
self.steps_since_update[ev] = 0
for _ in range(SAMPLE_QUEUE_DEPTH):
@@ -207,7 +206,7 @@ class AsyncReplayOptimizer(PolicyOptimizer):
self.learner.weights_updated = False
with self.timers["put_weights"]:
weights = ray.put(
self.local_evaluator.get_weights())
self.workers.local_worker().get_weights())
ev.set_weights.remote(weights)
self.num_weight_syncs += 1
self.steps_since_update[ev] = 0
@@ -380,10 +379,10 @@ class LearnerThread(threading.Thread):
improves overall throughput.
"""
def __init__(self, local_evaluator):
def __init__(self, local_worker):
threading.Thread.__init__(self)
self.learner_queue_size = WindowStat("size", 50)
self.local_evaluator = local_evaluator
self.local_worker = local_worker
self.inqueue = queue.Queue(maxsize=LEARNER_QUEUE_MAX_SIZE)
self.outqueue = queue.Queue()
self.queue_timer = TimerStat()
@@ -403,7 +402,7 @@ class LearnerThread(threading.Thread):
if replay is not None:
prio_dict = {}
with self.grad_timer:
grad_out = self.local_evaluator.learn_on_batch(replay)
grad_out = self.local_worker.learn_on_batch(replay)
for pid, info in grad_out.items():
prio_dict[pid] = (
replay.policy_batches[pid].data.get("batch_indexes"),
@@ -24,12 +24,11 @@ class AsyncSamplesOptimizer(PolicyOptimizer):
"""Main event loop of the IMPALA architecture.
This class coordinates the data transfers between the learner thread
and remote evaluators (IMPALA actors).
and remote workers (IMPALA actors).
"""
def __init__(self,
local_evaluator,
remote_evaluators,
workers,
train_batch_size=500,
sample_batch_size=50,
num_envs_per_worker=1,
@@ -45,7 +44,7 @@ class AsyncSamplesOptimizer(PolicyOptimizer):
learner_queue_size=16,
num_aggregation_workers=0,
_fake_gpus=False):
PolicyOptimizer.__init__(self, local_evaluator, remote_evaluators)
PolicyOptimizer.__init__(self, workers)
self._stats_start_time = time.time()
self._last_stats_time = {}
@@ -62,7 +61,7 @@ class AsyncSamplesOptimizer(PolicyOptimizer):
"{} vs {}".format(num_data_loader_buffers,
minibatch_buffer_size))
self.learner = TFMultiGPULearner(
self.local_evaluator,
self.workers.local_worker(),
lr=lr,
num_gpus=num_gpus,
train_batch_size=train_batch_size,
@@ -72,7 +71,7 @@ class AsyncSamplesOptimizer(PolicyOptimizer):
learner_queue_size=learner_queue_size,
_fake_gpus=_fake_gpus)
else:
self.learner = LearnerThread(self.local_evaluator,
self.learner = LearnerThread(self.workers.local_worker(),
minibatch_buffer_size, num_sgd_iter,
learner_queue_size)
self.learner.start()
@@ -84,8 +83,7 @@ class AsyncSamplesOptimizer(PolicyOptimizer):
if num_aggregation_workers > 0:
self.aggregator = TreeAggregator(
self.local_evaluator,
self.remote_evaluators,
workers,
num_aggregation_workers,
replay_proportion=replay_proportion,
max_sample_requests_in_flight_per_worker=(
@@ -96,8 +94,7 @@ class AsyncSamplesOptimizer(PolicyOptimizer):
broadcast_interval=broadcast_interval)
else:
self.aggregator = SimpleAggregator(
self.local_evaluator,
self.remote_evaluators,
workers,
replay_proportion=replay_proportion,
max_sample_requests_in_flight_per_worker=(
max_sample_requests_in_flight_per_worker),
@@ -127,7 +124,7 @@ class AsyncSamplesOptimizer(PolicyOptimizer):
@override(PolicyOptimizer)
def step(self):
if len(self.remote_evaluators) == 0:
if len(self.workers.remote_workers()) == 0:
raise ValueError("Config num_workers=0 means training will hang!")
assert self.learner.is_alive()
with self._optimizer_step_timer:
@@ -146,9 +143,9 @@ class AsyncSamplesOptimizer(PolicyOptimizer):
self.learner.stopped = True
@override(PolicyOptimizer)
def reset(self, remote_evaluators):
self.remote_evaluators = remote_evaluators
self.aggregator.reset(remote_evaluators)
def reset(self, remote_workers):
self.workers.reset(remote_workers)
self.aggregator.reset(remote_workers)
@override(PolicyOptimizer)
def stats(self):
@@ -28,7 +28,7 @@ logger = logging.getLogger(__name__)
class LocalMultiGPUOptimizer(PolicyOptimizer):
"""A synchronous optimizer that uses multiple local GPUs.
Samples are pulled synchronously from multiple remote evaluators,
Samples are pulled synchronously from multiple remote workers,
concatenated, and then split across the memory of multiple local GPUs.
A number of SGD passes are then taken over the in-memory data. For more
details, see `multi_gpu_impl.LocalSyncParallelOptimizer`.
@@ -42,8 +42,7 @@ class LocalMultiGPUOptimizer(PolicyOptimizer):
"""
def __init__(self,
local_evaluator,
remote_evaluators,
workers,
sgd_batch_size=128,
num_sgd_iter=10,
sample_batch_size=200,
@@ -52,7 +51,7 @@ class LocalMultiGPUOptimizer(PolicyOptimizer):
num_gpus=0,
standardize_fields=[],
straggler_mitigation=False):
PolicyOptimizer.__init__(self, local_evaluator, remote_evaluators)
PolicyOptimizer.__init__(self, workers)
self.batch_size = sgd_batch_size
self.num_sgd_iter = num_sgd_iter
@@ -79,8 +78,8 @@ class LocalMultiGPUOptimizer(PolicyOptimizer):
logger.info("LocalMultiGPUOptimizer devices {}".format(self.devices))
self.policies = dict(
self.local_evaluator.foreach_trainable_policy(lambda p, i: (i, p)))
self.policies = dict(self.workers.local_worker()
.foreach_trainable_policy(lambda p, i: (i, p)))
logger.debug("Policies to train: {}".format(self.policies))
for policy_id, policy in self.policies.items():
if not isinstance(policy, TFPolicy):
@@ -92,8 +91,8 @@ class LocalMultiGPUOptimizer(PolicyOptimizer):
# reuse is set to AUTO_REUSE because Adam nodes are created after
# all of the device copies are created.
self.optimizers = {}
with self.local_evaluator.tf_sess.graph.as_default():
with self.local_evaluator.tf_sess.as_default():
with self.workers.local_worker().tf_sess.graph.as_default():
with self.workers.local_worker().tf_sess.as_default():
for policy_id, policy in self.policies.items():
with tf.variable_scope(policy_id, reuse=tf.AUTO_REUSE):
if policy._state_inputs:
@@ -109,25 +108,25 @@ class LocalMultiGPUOptimizer(PolicyOptimizer):
for _, v in policy._loss_inputs], rnn_inputs,
self.per_device_batch_size, policy.copy))
self.sess = self.local_evaluator.tf_sess
self.sess = self.workers.local_worker().tf_sess
self.sess.run(tf.global_variables_initializer())
@override(PolicyOptimizer)
def step(self):
with self.update_weights_timer:
if self.remote_evaluators:
weights = ray.put(self.local_evaluator.get_weights())
for e in self.remote_evaluators:
if self.workers.remote_workers():
weights = ray.put(self.workers.local_worker().get_weights())
for e in self.workers.remote_workers():
e.set_weights.remote(weights)
with self.sample_timer:
if self.remote_evaluators:
if self.workers.remote_workers():
if self.straggler_mitigation:
samples = collect_samples_straggler_mitigation(
self.remote_evaluators, self.train_batch_size)
self.workers.remote_workers(), self.train_batch_size)
else:
samples = collect_samples(
self.remote_evaluators, self.sample_batch_size,
self.workers.remote_workers(), self.sample_batch_size,
self.num_envs_per_worker, self.train_batch_size)
if samples.count > self.train_batch_size * 2:
logger.info(
@@ -139,7 +138,7 @@ class LocalMultiGPUOptimizer(PolicyOptimizer):
else:
samples = []
while sum(s.count for s in samples) < self.train_batch_size:
samples.append(self.local_evaluator.sample())
samples.append(self.workers.local_worker().sample())
samples = SampleBatch.concat_samples(samples)
# Handle everything as if multiagent
+29 -45
View File
@@ -6,7 +6,6 @@ import logging
from ray.rllib.utils.annotations import DeveloperAPI
from ray.rllib.evaluation.metrics import collect_episodes, summarize_episodes
from ray.rllib.utils.memory import ray_get_and_free
logger = logging.getLogger(__name__)
@@ -21,34 +20,21 @@ class PolicyOptimizer(object):
used for PPO. These optimizers are all pluggable, and it is possible
to mix and match as needed.
In order for an algorithm to use an RLlib optimizer, it must implement
the PolicyEvaluator interface and pass a PolicyEvaluator class or set of
PolicyEvaluators to its PolicyOptimizer of choice. The PolicyOptimizer
uses these Evaluators to sample from the environment and compute model
gradient updates.
Attributes:
config (dict): The JSON configuration passed to this optimizer.
local_evaluator (PolicyEvaluator): The embedded evaluator instance.
remote_evaluators (list): List of remote evaluator replicas, or [].
workers (WorkerSet): The set of rollout workers to use.
num_steps_trained (int): Number of timesteps trained on so far.
num_steps_sampled (int): Number of timesteps sampled so far.
evaluator_resources (dict): Optional resource requests to set for
evaluators created by this optimizer.
"""
@DeveloperAPI
def __init__(self, local_evaluator, remote_evaluators=None):
def __init__(self, workers):
"""Create an optimizer instance.
Args:
local_evaluator (Evaluator): Local evaluator instance, required.
remote_evaluators (list): A list of Ray actor handles to remote
evaluators instances. If empty, the optimizer should fall back
to using only the local evaluator.
workers (WorkerSet): The set of rollout workers to use.
"""
self.local_evaluator = local_evaluator
self.remote_evaluators = remote_evaluators or []
self.workers = workers
self.episode_history = []
# Counters that should be updated by sub-classes
@@ -100,23 +86,23 @@ class PolicyOptimizer(object):
def collect_metrics(self,
timeout_seconds,
min_history=100,
selected_evaluators=None):
"""Returns evaluator and optimizer stats.
selected_workers=None):
"""Returns worker and optimizer stats.
Arguments:
timeout_seconds (int): Max wait time for a evaluator before
dropping its results. This usually indicates a hung evaluator.
timeout_seconds (int): Max wait time for a worker before
dropping its results. This usually indicates a hung worker.
min_history (int): Min history length to smooth results over.
selected_evaluators (list): Override the list of remote evaluators
selected_workers (list): Override the list of remote workers
to collect metrics from.
Returns:
res (dict): A training result dict from evaluator metrics with
res (dict): A training result dict from worker metrics with
`info` replaced with stats from self.
"""
episodes, num_dropped = collect_episodes(
self.local_evaluator,
selected_evaluators or self.remote_evaluators,
self.workers.local_worker(),
selected_workers or self.workers.remote_workers(),
timeout_seconds=timeout_seconds)
orig_episodes = list(episodes)
missing = min_history - len(episodes)
@@ -130,30 +116,28 @@ class PolicyOptimizer(object):
return res
@DeveloperAPI
def reset(self, remote_evaluators):
"""Called to change the set of remote evaluators being used."""
self.remote_evaluators = remote_evaluators
def reset(self, remote_workers):
"""Called to change the set of remote workers being used."""
self.workers.reset(remote_workers)
@DeveloperAPI
def foreach_evaluator(self, func):
"""Apply the given function to each evaluator instance."""
local_result = [func(self.local_evaluator)]
remote_results = ray_get_and_free(
[ev.apply.remote(func) for ev in self.remote_evaluators])
return local_result + remote_results
def foreach_worker(self, func):
"""Apply the given function to each worker instance."""
return self.workers.foreach_worker(func)
@DeveloperAPI
def foreach_evaluator_with_index(self, func):
"""Apply the given function to each evaluator instance.
def foreach_worker_with_index(self, func):
"""Apply the given function to each worker instance.
The index will be passed as the second arg to the given function.
"""
return self.workers.foreach_worker_with_index(func)
local_result = [func(self.local_evaluator, 0)]
remote_results = ray_get_and_free([
ev.apply.remote(func, i + 1)
for i, ev in enumerate(self.remote_evaluators)
])
return local_result + remote_results
def foreach_evaluator(self, func):
raise DeprecationWarning(
"foreach_evaluator has been renamed to foreach_worker")
def foreach_evaluator_with_index(self, func):
raise DeprecationWarning(
"foreach_evaluator_with_index has been renamed to "
"foreach_worker_with_index")
@@ -20,12 +20,11 @@ class SyncBatchReplayOptimizer(PolicyOptimizer):
This enables RNN support. Does not currently support prioritization."""
def __init__(self,
local_evaluator,
remote_evaluators,
workers,
learning_starts=1000,
buffer_size=10000,
train_batch_size=32):
PolicyOptimizer.__init__(self, local_evaluator, remote_evaluators)
PolicyOptimizer.__init__(self, workers)
self.replay_starts = learning_starts
self.max_buffer_size = buffer_size
@@ -45,17 +44,17 @@ class SyncBatchReplayOptimizer(PolicyOptimizer):
@override(PolicyOptimizer)
def step(self):
with self.update_weights_timer:
if self.remote_evaluators:
weights = ray.put(self.local_evaluator.get_weights())
for e in self.remote_evaluators:
if self.workers.remote_workers():
weights = ray.put(self.workers.local_worker().get_weights())
for e in self.workers.remote_workers():
e.set_weights.remote(weights)
with self.sample_timer:
if self.remote_evaluators:
if self.workers.remote_workers():
batches = ray_get_and_free(
[e.sample.remote() for e in self.remote_evaluators])
[e.sample.remote() for e in self.workers.remote_workers()])
else:
batches = [self.local_evaluator.sample()]
batches = [self.workers.local_worker().sample()]
# Handle everything as if multiagent
tmp = []
@@ -105,7 +104,7 @@ class SyncBatchReplayOptimizer(PolicyOptimizer):
samples.append(random.choice(self.replay_buffer))
samples = SampleBatch.concat_samples(samples)
with self.grad_timer:
info_dict = self.local_evaluator.learn_on_batch(samples)
info_dict = self.workers.local_worker().learn_on_batch(samples)
for policy_id, info in info_dict.items():
self.learner_stats[policy_id] = get_learner_stats(info)
self.grad_timer.push_units_processed(samples.count)
@@ -25,13 +25,12 @@ logger = logging.getLogger(__name__)
class SyncReplayOptimizer(PolicyOptimizer):
"""Variant of the local sync optimizer that supports replay (for DQN).
This optimizer requires that policy evaluators return an additional
This optimizer requires that rollout workers return an additional
"td_error" array in the info return of compute_gradients(). This error
term will be used for sample prioritization."""
def __init__(self,
local_evaluator,
remote_evaluators,
workers,
learning_starts=1000,
buffer_size=10000,
prioritized_replay=True,
@@ -43,7 +42,7 @@ class SyncReplayOptimizer(PolicyOptimizer):
prioritized_replay_eps=1e-6,
train_batch_size=32,
sample_batch_size=4):
PolicyOptimizer.__init__(self, local_evaluator, remote_evaluators)
PolicyOptimizer.__init__(self, workers)
self.replay_starts = learning_starts
# linearly annealing beta used in Rainbow paper
@@ -82,18 +81,20 @@ class SyncReplayOptimizer(PolicyOptimizer):
@override(PolicyOptimizer)
def step(self):
with self.update_weights_timer:
if self.remote_evaluators:
weights = ray.put(self.local_evaluator.get_weights())
for e in self.remote_evaluators:
if self.workers.remote_workers():
weights = ray.put(self.workers.local_worker().get_weights())
for e in self.workers.remote_workers():
e.set_weights.remote(weights)
with self.sample_timer:
if self.remote_evaluators:
if self.workers.remote_workers():
batch = SampleBatch.concat_samples(
ray_get_and_free(
[e.sample.remote() for e in self.remote_evaluators]))
ray_get_and_free([
e.sample.remote()
for e in self.workers.remote_workers()
]))
else:
batch = self.local_evaluator.sample()
batch = self.workers.local_worker().sample()
# Handle everything as if multiagent
if isinstance(batch, SampleBatch):
@@ -135,7 +136,7 @@ class SyncReplayOptimizer(PolicyOptimizer):
samples = self._replay()
with self.grad_timer:
info_dict = self.local_evaluator.learn_on_batch(samples)
info_dict = self.workers.local_worker().learn_on_batch(samples)
for policy_id, info in info_dict.items():
self.learner_stats[policy_id] = get_learner_stats(info)
replay_buffer = self.replay_buffers[policy_id]
@@ -19,16 +19,12 @@ class SyncSamplesOptimizer(PolicyOptimizer):
"""A simple synchronous RL optimizer.
In each step, this optimizer pulls samples from a number of remote
evaluators, concatenates them, and then updates a local model. The updated
model weights are then broadcast to all remote evaluators.
workers, concatenates them, and then updates a local model. The updated
model weights are then broadcast to all remote workers.
"""
def __init__(self,
local_evaluator,
remote_evaluators,
num_sgd_iter=1,
train_batch_size=1):
PolicyOptimizer.__init__(self, local_evaluator, remote_evaluators)
def __init__(self, workers, num_sgd_iter=1, train_batch_size=1):
PolicyOptimizer.__init__(self, workers)
self.update_weights_timer = TimerStat()
self.sample_timer = TimerStat()
@@ -41,27 +37,28 @@ class SyncSamplesOptimizer(PolicyOptimizer):
@override(PolicyOptimizer)
def step(self):
with self.update_weights_timer:
if self.remote_evaluators:
weights = ray.put(self.local_evaluator.get_weights())
for e in self.remote_evaluators:
if self.workers.remote_workers():
weights = ray.put(self.workers.local_worker().get_weights())
for e in self.workers.remote_workers():
e.set_weights.remote(weights)
with self.sample_timer:
samples = []
while sum(s.count for s in samples) < self.train_batch_size:
if self.remote_evaluators:
if self.workers.remote_workers():
samples.extend(
ray_get_and_free([
e.sample.remote() for e in self.remote_evaluators
e.sample.remote()
for e in self.workers.remote_workers()
]))
else:
samples.append(self.local_evaluator.sample())
samples.append(self.workers.local_worker().sample())
samples = SampleBatch.concat_samples(samples)
self.sample_timer.push_units_processed(samples.count)
with self.grad_timer:
for i in range(self.num_sgd_iter):
fetches = self.local_evaluator.learn_on_batch(samples)
fetches = self.workers.local_worker().learn_on_batch(samples)
self.learner_stats = get_learner_stats(fetches)
if self.num_sgd_iter > 1:
logger.debug("{} {}".format(i, fetches))
+1 -1
View File
@@ -142,7 +142,7 @@ class DynamicTFPolicy(TFPolicy):
action_prob = self.action_dist.sampled_action_prob()
# Phase 1 init
sess = tf.get_default_session()
sess = tf.get_default_session() or tf.Session()
if get_batch_divisibility_req:
batch_divisibility_req = get_batch_divisibility_req(self)
else:
+1 -1
View File
@@ -36,7 +36,7 @@ class Policy(object):
"""Initialize the graph.
This is the standard constructor for policies. The policy
class you pass into PolicyEvaluator will be constructed with
class you pass into RolloutWorker will be constructed with
these arguments.
Args:
@@ -88,9 +88,7 @@ def build_tf_policy(name,
a DynamicTFPolicy instance that uses the specified args
"""
if not name.endswith("TFPolicy"):
raise ValueError("Name should match *TFPolicy", name)
original_kwargs = locals().copy()
base = DynamicTFPolicy
while mixins:
@@ -191,6 +189,11 @@ def build_tf_policy(name,
else:
return TFPolicy.extra_compute_grad_feed_dict(self)
@staticmethod
def with_updates(**overrides):
return build_tf_policy(**dict(original_kwargs, **overrides))
policy_cls.with_updates = with_updates
policy_cls.__name__ = name
policy_cls.__qualname__ = name
return policy_cls
@@ -24,7 +24,7 @@ def build_torch_policy(name,
"""Helper function for creating a torch policy at runtime.
Arguments:
name (str): name of the policy (e.g., "PPOTFPolicy")
name (str): name of the policy (e.g., "PPOTorchPolicy")
loss_fn (func): function that returns a loss tensor the policy,
and dict of experience tensor placeholders
get_default_config (func): optional function that returns the default
@@ -55,9 +55,7 @@ def build_torch_policy(name,
a TorchPolicy instance that uses the specified args
"""
if not name.endswith("TorchPolicy"):
raise ValueError("Name should match *TorchPolicy", name)
original_kwargs = locals().copy()
base = TorchPolicy
while mixins:
@@ -66,7 +64,7 @@ def build_torch_policy(name,
base = new_base
class graph_cls(base):
class policy_cls(base):
def __init__(self, obs_space, action_space, config):
if get_default_config:
config = dict(get_default_config(), **config)
@@ -130,6 +128,11 @@ def build_torch_policy(name,
else:
return TorchPolicy.extra_grad_info(self, batch_tensors)
graph_cls.__name__ = name
graph_cls.__qualname__ = name
return graph_cls
@staticmethod
def with_updates(**overrides):
return build_torch_policy(**dict(original_kwargs, **overrides))
policy_cls.with_updates = with_updates
policy_cls.__name__ = name
policy_cls.__qualname__ = name
return policy_cls
+4 -4
View File
@@ -120,14 +120,14 @@ def default_policy_agent_mapping(unused_agent_id):
def rollout(agent, env_name, num_steps, out=None, no_render=True):
policy_agent_mapping = default_policy_agent_mapping
if hasattr(agent, "local_evaluator"):
env = agent.local_evaluator.env
if hasattr(agent, "workers"):
env = agent.workers.local_worker().env
multiagent = isinstance(env, MultiAgentEnv)
if agent.local_evaluator.multiagent:
if agent.workers.local_worker().multiagent:
policy_agent_mapping = agent.config["multiagent"][
"policy_mapping_fn"]
policy_map = agent.local_evaluator.policy_map
policy_map = agent.workers.local_worker().policy_map
state_init = {p: m.get_initial_state() for p, m in policy_map.items()}
use_lstm = {p: len(s) > 0 for p, s in state_init.items()}
action_init = {
@@ -8,7 +8,7 @@ from ray.rllib.evaluation import SampleBatch
from ray.rllib.utils.filter import MeanStdFilter
class _MockEvaluator(object):
class _MockWorker(object):
def __init__(self, sample_count=10):
self._weights = np.array([-10, -10, -10, -10])
self._grad = np.array([1, 1, 1, 1])
+8 -8
View File
@@ -11,10 +11,10 @@ import uuid
import ray
from ray.rllib.agents.dqn import DQNTrainer
from ray.rllib.agents.pg import PGTrainer
from ray.rllib.evaluation.policy_evaluator import PolicyEvaluator
from ray.rllib.evaluation.rollout_worker import RolloutWorker
from ray.rllib.env.external_env import ExternalEnv
from ray.rllib.tests.test_policy_evaluator import (BadPolicy, MockPolicy,
MockEnv)
from ray.rllib.tests.test_rollout_worker import (BadPolicy, MockPolicy,
MockEnv)
from ray.tune.registry import register_env
@@ -119,7 +119,7 @@ class MultiServing(ExternalEnv):
class TestExternalEnv(unittest.TestCase):
def testExternalEnvCompleteEpisodes(self):
ev = PolicyEvaluator(
ev = RolloutWorker(
env_creator=lambda _: SimpleServing(MockEnv(25)),
policy=MockPolicy,
batch_steps=40,
@@ -129,7 +129,7 @@ class TestExternalEnv(unittest.TestCase):
self.assertEqual(batch.count, 50)
def testExternalEnvTruncateEpisodes(self):
ev = PolicyEvaluator(
ev = RolloutWorker(
env_creator=lambda _: SimpleServing(MockEnv(25)),
policy=MockPolicy,
batch_steps=40,
@@ -139,7 +139,7 @@ class TestExternalEnv(unittest.TestCase):
self.assertEqual(batch.count, 40)
def testExternalEnvOffPolicy(self):
ev = PolicyEvaluator(
ev = RolloutWorker(
env_creator=lambda _: SimpleOffPolicyServing(MockEnv(25), 42),
policy=MockPolicy,
batch_steps=40,
@@ -151,7 +151,7 @@ class TestExternalEnv(unittest.TestCase):
self.assertEqual(batch["actions"][-1], 42)
def testExternalEnvBadActions(self):
ev = PolicyEvaluator(
ev = RolloutWorker(
env_creator=lambda _: SimpleServing(MockEnv(25)),
policy=BadPolicy,
sample_async=True,
@@ -196,7 +196,7 @@ class TestExternalEnv(unittest.TestCase):
raise Exception("failed to improve reward")
def testExternalEnvHorizonNotSupported(self):
ev = PolicyEvaluator(
ev = RolloutWorker(
env_creator=lambda _: SimpleServing(MockEnv(25)),
policy=MockPolicy,
episode_horizon=20,
@@ -10,9 +10,10 @@ import unittest
import ray
from ray.rllib.agents.pg.pg_policy import PGTFPolicy
from ray.rllib.optimizers import SyncSamplesOptimizer
from ray.rllib.evaluation.policy_evaluator import PolicyEvaluator
from ray.rllib.evaluation.rollout_worker import RolloutWorker
from ray.rllib.evaluation.worker_set import WorkerSet
from ray.rllib.env.external_multi_agent_env import ExternalMultiAgentEnv
from ray.rllib.tests.test_policy_evaluator import MockPolicy
from ray.rllib.tests.test_rollout_worker import MockPolicy
from ray.rllib.tests.test_external_env import make_simple_serving
from ray.rllib.tests.test_multi_agent_env import BasicMultiAgent, MultiCartpole
from ray.rllib.evaluation.metrics import collect_metrics
@@ -23,7 +24,7 @@ SimpleMultiServing = make_simple_serving(True, ExternalMultiAgentEnv)
class TestExternalMultiAgentEnv(unittest.TestCase):
def testExternalMultiAgentEnvCompleteEpisodes(self):
agents = 4
ev = PolicyEvaluator(
ev = RolloutWorker(
env_creator=lambda _: SimpleMultiServing(BasicMultiAgent(agents)),
policy=MockPolicy,
batch_steps=40,
@@ -35,7 +36,7 @@ class TestExternalMultiAgentEnv(unittest.TestCase):
def testExternalMultiAgentEnvTruncateEpisodes(self):
agents = 4
ev = PolicyEvaluator(
ev = RolloutWorker(
env_creator=lambda _: SimpleMultiServing(BasicMultiAgent(agents)),
policy=MockPolicy,
batch_steps=40,
@@ -49,7 +50,7 @@ class TestExternalMultiAgentEnv(unittest.TestCase):
agents = 2
act_space = gym.spaces.Discrete(2)
obs_space = gym.spaces.Discrete(2)
ev = PolicyEvaluator(
ev = RolloutWorker(
env_creator=lambda _: SimpleMultiServing(BasicMultiAgent(agents)),
policy={
"p0": (MockPolicy, obs_space, act_space, {}),
@@ -70,12 +71,12 @@ class TestExternalMultiAgentEnv(unittest.TestCase):
policies["pg_{}".format(i)] = (PGTFPolicy, obs_space, act_space,
{})
policy_ids = list(policies.keys())
ev = PolicyEvaluator(
ev = RolloutWorker(
env_creator=lambda _: MultiCartpole(n),
policy=policies,
policy_mapping_fn=lambda agent_id: random.choice(policy_ids),
batch_steps=100)
optimizer = SyncSamplesOptimizer(ev, [])
optimizer = SyncSamplesOptimizer(WorkerSet._from_existing(ev))
for i in range(100):
optimizer.step()
result = collect_metrics(ev)
+3 -3
View File
@@ -8,7 +8,7 @@ import numpy as np
import ray
from ray.rllib.utils.filter import RunningStat, MeanStdFilter
from ray.rllib.utils import FilterManager
from ray.rllib.tests.mock_evaluator import _MockEvaluator
from ray.rllib.tests.mock_worker import _MockWorker
class RunningStatTest(unittest.TestCase):
@@ -89,8 +89,8 @@ class FilterManagerTest(unittest.TestCase):
filt1.clear_buffer()
self.assertEqual(filt1.buffer.n, 0)
RemoteEvaluator = ray.remote(_MockEvaluator)
remote_e = RemoteEvaluator.remote(sample_count=10)
RemoteWorker = ray.remote(_MockWorker)
remote_e = RemoteWorker.remote(sample_count=10)
remote_e.sample.remote()
FilterManager.synchronize({
+24 -22
View File
@@ -12,11 +12,11 @@ from ray.rllib.agents.pg.pg_policy import PGTFPolicy
from ray.rllib.agents.dqn.dqn_policy import DQNTFPolicy
from ray.rllib.optimizers import (SyncSamplesOptimizer, SyncReplayOptimizer,
AsyncGradientsOptimizer)
from ray.rllib.tests.test_policy_evaluator import (MockEnv, MockEnv2,
MockPolicy)
from ray.rllib.evaluation.policy_evaluator import PolicyEvaluator
from ray.rllib.tests.test_rollout_worker import (MockEnv, MockEnv2, MockPolicy)
from ray.rllib.evaluation.rollout_worker import RolloutWorker
from ray.rllib.policy.policy import Policy
from ray.rllib.evaluation.metrics import collect_metrics
from ray.rllib.evaluation.worker_set import WorkerSet
from ray.rllib.env.base_env import _MultiAgentEnvToBaseEnv
from ray.rllib.env.multi_agent_env import MultiAgentEnv
from ray.tune.registry import register_env
@@ -327,7 +327,7 @@ class TestMultiAgentEnv(unittest.TestCase):
def testMultiAgentSample(self):
act_space = gym.spaces.Discrete(2)
obs_space = gym.spaces.Discrete(2)
ev = PolicyEvaluator(
ev = RolloutWorker(
env_creator=lambda _: BasicMultiAgent(5),
policy={
"p0": (MockPolicy, obs_space, act_space, {}),
@@ -345,7 +345,7 @@ class TestMultiAgentEnv(unittest.TestCase):
def testMultiAgentSampleSyncRemote(self):
act_space = gym.spaces.Discrete(2)
obs_space = gym.spaces.Discrete(2)
ev = PolicyEvaluator(
ev = RolloutWorker(
env_creator=lambda _: BasicMultiAgent(5),
policy={
"p0": (MockPolicy, obs_space, act_space, {}),
@@ -362,7 +362,7 @@ class TestMultiAgentEnv(unittest.TestCase):
def testMultiAgentSampleAsyncRemote(self):
act_space = gym.spaces.Discrete(2)
obs_space = gym.spaces.Discrete(2)
ev = PolicyEvaluator(
ev = RolloutWorker(
env_creator=lambda _: BasicMultiAgent(5),
policy={
"p0": (MockPolicy, obs_space, act_space, {}),
@@ -378,7 +378,7 @@ class TestMultiAgentEnv(unittest.TestCase):
def testMultiAgentSampleWithHorizon(self):
act_space = gym.spaces.Discrete(2)
obs_space = gym.spaces.Discrete(2)
ev = PolicyEvaluator(
ev = RolloutWorker(
env_creator=lambda _: BasicMultiAgent(5),
policy={
"p0": (MockPolicy, obs_space, act_space, {}),
@@ -393,7 +393,7 @@ class TestMultiAgentEnv(unittest.TestCase):
def testSampleFromEarlyDoneEnv(self):
act_space = gym.spaces.Discrete(2)
obs_space = gym.spaces.Discrete(2)
ev = PolicyEvaluator(
ev = RolloutWorker(
env_creator=lambda _: EarlyDoneMultiAgent(),
policy={
"p0": (MockPolicy, obs_space, act_space, {}),
@@ -409,7 +409,7 @@ class TestMultiAgentEnv(unittest.TestCase):
def testMultiAgentSampleRoundRobin(self):
act_space = gym.spaces.Discrete(2)
obs_space = gym.spaces.Discrete(10)
ev = PolicyEvaluator(
ev = RolloutWorker(
env_creator=lambda _: RoundRobinMultiAgent(5, increment_obs=True),
policy={
"p0": (MockPolicy, obs_space, act_space, {}),
@@ -458,7 +458,7 @@ class TestMultiAgentEnv(unittest.TestCase):
def get_initial_state(self):
return [{}] # empty dict
ev = PolicyEvaluator(
ev = RolloutWorker(
env_creator=lambda _: gym.make("CartPole-v0"),
policy=StatefulPolicy,
batch_steps=5)
@@ -503,7 +503,7 @@ class TestMultiAgentEnv(unittest.TestCase):
single_env = gym.make("CartPole-v0")
obs_space = single_env.observation_space
act_space = single_env.action_space
ev = PolicyEvaluator(
ev = RolloutWorker(
env_creator=lambda _: MultiCartpole(2),
policy={
"p0": (ModelBasedPolicy, obs_space, act_space, {}),
@@ -587,7 +587,7 @@ class TestMultiAgentEnv(unittest.TestCase):
"p1": (PGTFPolicy, obs_space, act_space, {}),
"p2": (DQNTFPolicy, obs_space, act_space, dqn_config),
}
ev = PolicyEvaluator(
worker = RolloutWorker(
env_creator=lambda _: MultiCartpole(n),
policy=policies,
policy_mapping_fn=lambda agent_id: ["p1", "p2"][agent_id % 2],
@@ -597,29 +597,30 @@ class TestMultiAgentEnv(unittest.TestCase):
def policy_mapper(agent_id):
return ["p1", "p2"][agent_id % 2]
remote_evs = [
PolicyEvaluator.as_remote().remote(
remote_workers = [
RolloutWorker.as_remote().remote(
env_creator=lambda _: MultiCartpole(n),
policy=policies,
policy_mapping_fn=policy_mapper,
batch_steps=50)
]
else:
remote_evs = []
optimizer = optimizer_cls(ev, remote_evs)
remote_workers = []
workers = WorkerSet._from_existing(worker, remote_workers)
optimizer = optimizer_cls(workers)
for i in range(200):
ev.foreach_policy(lambda p, _: p.set_epsilon(
worker.foreach_policy(lambda p, _: p.set_epsilon(
max(0.02, 1 - i * .02))
if isinstance(p, DQNTFPolicy) else None)
optimizer.step()
result = collect_metrics(ev, remote_evs)
result = collect_metrics(worker, remote_workers)
if i % 20 == 0:
def do_update(p):
if isinstance(p, DQNTFPolicy):
p.update_target()
ev.foreach_policy(lambda p, _: do_update(p))
worker.foreach_policy(lambda p, _: do_update(p))
print("Iter {}, rew {}".format(i,
result["policy_reward_mean"]))
print("Total reward", result["episode_reward_mean"])
@@ -647,15 +648,16 @@ class TestMultiAgentEnv(unittest.TestCase):
policies["pg_{}".format(i)] = (PGTFPolicy, obs_space, act_space,
{})
policy_ids = list(policies.keys())
ev = PolicyEvaluator(
worker = RolloutWorker(
env_creator=lambda _: MultiCartpole(n),
policy=policies,
policy_mapping_fn=lambda agent_id: random.choice(policy_ids),
batch_steps=100)
optimizer = SyncSamplesOptimizer(ev, [])
workers = WorkerSet._from_existing(worker, [])
optimizer = SyncSamplesOptimizer(workers)
for i in range(100):
optimizer.step()
result = collect_metrics(ev)
result = collect_metrics(worker)
print("Iteration {}, rew {}".format(i,
result["policy_reward_mean"]))
print("Total reward", result["episode_reward_mean"])
+30 -33
View File
@@ -11,10 +11,11 @@ import ray
from ray.rllib.agents.ppo import PPOTrainer
from ray.rllib.agents.ppo.ppo_policy import PPOTFPolicy
from ray.rllib.evaluation import SampleBatch
from ray.rllib.evaluation.policy_evaluator import PolicyEvaluator
from ray.rllib.evaluation.rollout_worker import RolloutWorker
from ray.rllib.evaluation.worker_set import WorkerSet
from ray.rllib.optimizers import AsyncGradientsOptimizer, AsyncSamplesOptimizer
from ray.rllib.optimizers.aso_tree_aggregator import TreeAggregator
from ray.rllib.tests.mock_evaluator import _MockEvaluator
from ray.rllib.tests.mock_worker import _MockWorker
from ray.rllib.utils import try_import_tf
tf = try_import_tf()
@@ -26,11 +27,11 @@ class AsyncOptimizerTest(unittest.TestCase):
def testBasic(self):
ray.init(num_cpus=4)
local = _MockEvaluator()
remotes = ray.remote(_MockEvaluator)
remote_evaluators = [remotes.remote() for i in range(5)]
test_optimizer = AsyncGradientsOptimizer(
local, remote_evaluators, grads_per_step=10)
local = _MockWorker()
remotes = ray.remote(_MockWorker)
remote_workers = [remotes.remote() for i in range(5)]
workers = WorkerSet._from_existing(local, remote_workers)
test_optimizer = AsyncGradientsOptimizer(workers, grads_per_step=10)
test_optimizer.step()
self.assertTrue(all(local.get_weights() == 0))
@@ -117,30 +118,28 @@ class AsyncSamplesOptimizerTest(unittest.TestCase):
def testSimple(self):
local, remotes = self._make_evs()
optimizer = AsyncSamplesOptimizer(local, remotes)
workers = WorkerSet._from_existing(local, remotes)
optimizer = AsyncSamplesOptimizer(workers)
self._wait_for(optimizer, 1000, 1000)
def testMultiGPU(self):
local, remotes = self._make_evs()
optimizer = AsyncSamplesOptimizer(
local, remotes, num_gpus=2, _fake_gpus=True)
workers = WorkerSet._from_existing(local, remotes)
optimizer = AsyncSamplesOptimizer(workers, num_gpus=2, _fake_gpus=True)
self._wait_for(optimizer, 1000, 1000)
def testMultiGPUParallelLoad(self):
local, remotes = self._make_evs()
workers = WorkerSet._from_existing(local, remotes)
optimizer = AsyncSamplesOptimizer(
local,
remotes,
num_gpus=2,
num_data_loader_buffers=2,
_fake_gpus=True)
workers, num_gpus=2, num_data_loader_buffers=2, _fake_gpus=True)
self._wait_for(optimizer, 1000, 1000)
def testMultiplePasses(self):
local, remotes = self._make_evs()
workers = WorkerSet._from_existing(local, remotes)
optimizer = AsyncSamplesOptimizer(
local,
remotes,
workers,
minibatch_buffer_size=10,
num_sgd_iter=10,
sample_batch_size=10,
@@ -151,9 +150,9 @@ class AsyncSamplesOptimizerTest(unittest.TestCase):
def testReplay(self):
local, remotes = self._make_evs()
workers = WorkerSet._from_existing(local, remotes)
optimizer = AsyncSamplesOptimizer(
local,
remotes,
workers,
replay_buffer_num_slots=100,
replay_proportion=10,
sample_batch_size=10,
@@ -168,9 +167,9 @@ class AsyncSamplesOptimizerTest(unittest.TestCase):
def testReplayAndMultiplePasses(self):
local, remotes = self._make_evs()
workers = WorkerSet._from_existing(local, remotes)
optimizer = AsyncSamplesOptimizer(
local,
remotes,
workers,
minibatch_buffer_size=10,
num_sgd_iter=10,
replay_buffer_num_slots=100,
@@ -189,45 +188,43 @@ class AsyncSamplesOptimizerTest(unittest.TestCase):
def testMultiTierAggregationBadConf(self):
local, remotes = self._make_evs()
workers = WorkerSet._from_existing(local, remotes)
aggregators = TreeAggregator.precreate_aggregators(4)
optimizer = AsyncSamplesOptimizer(
local, remotes, num_aggregation_workers=4)
optimizer = AsyncSamplesOptimizer(workers, num_aggregation_workers=4)
self.assertRaises(ValueError,
lambda: optimizer.aggregator.init(aggregators))
def testMultiTierAggregation(self):
local, remotes = self._make_evs()
workers = WorkerSet._from_existing(local, remotes)
aggregators = TreeAggregator.precreate_aggregators(1)
optimizer = AsyncSamplesOptimizer(
local, remotes, num_aggregation_workers=1)
optimizer = AsyncSamplesOptimizer(workers, num_aggregation_workers=1)
optimizer.aggregator.init(aggregators)
self._wait_for(optimizer, 1000, 1000)
def testRejectBadConfigs(self):
local, remotes = self._make_evs()
workers = WorkerSet._from_existing(local, remotes)
self.assertRaises(
ValueError, lambda: AsyncSamplesOptimizer(
local, remotes,
num_data_loader_buffers=2, minibatch_buffer_size=4))
optimizer = AsyncSamplesOptimizer(
local,
remotes,
workers,
num_gpus=2,
train_batch_size=100,
sample_batch_size=50,
_fake_gpus=True)
self._wait_for(optimizer, 1000, 1000)
optimizer = AsyncSamplesOptimizer(
local,
remotes,
workers,
num_gpus=2,
train_batch_size=100,
sample_batch_size=25,
_fake_gpus=True)
self._wait_for(optimizer, 1000, 1000)
optimizer = AsyncSamplesOptimizer(
local,
remotes,
workers,
num_gpus=2,
train_batch_size=100,
sample_batch_size=74,
@@ -238,12 +235,12 @@ class AsyncSamplesOptimizerTest(unittest.TestCase):
def make_sess():
return tf.Session(config=tf.ConfigProto(device_count={"CPU": 2}))
local = PolicyEvaluator(
local = RolloutWorker(
env_creator=lambda _: gym.make("CartPole-v0"),
policy=PPOTFPolicy,
tf_session_creator=make_sess)
remotes = [
PolicyEvaluator.as_remote().remote(
RolloutWorker.as_remote().remote(
env_creator=lambda _: gym.make("CartPole-v0"),
policy=PPOTFPolicy,
tf_session_creator=make_sess)
+3 -3
View File
@@ -7,8 +7,8 @@ import time
import unittest
import ray
from ray.rllib.evaluation.policy_evaluator import PolicyEvaluator
from ray.rllib.tests.test_policy_evaluator import MockPolicy
from ray.rllib.evaluation.rollout_worker import RolloutWorker
from ray.rllib.tests.test_rollout_worker import MockPolicy
class TestPerf(unittest.TestCase):
@@ -17,7 +17,7 @@ class TestPerf(unittest.TestCase):
# 03/01/19: Samples per second 8610.164353268685
def testBaselinePerformance(self):
for _ in range(20):
ev = PolicyEvaluator(
ev = RolloutWorker(
env_creator=lambda _: gym.make("CartPole-v0"),
policy=MockPolicy,
batch_steps=100)
@@ -12,7 +12,7 @@ from collections import Counter
import ray
from ray.rllib.agents.pg import PGTrainer
from ray.rllib.agents.a3c import A2CTrainer
from ray.rllib.evaluation.policy_evaluator import PolicyEvaluator
from ray.rllib.evaluation.rollout_worker import RolloutWorker
from ray.rllib.evaluation.metrics import collect_metrics
from ray.rllib.policy.policy import Policy
from ray.rllib.evaluation.postprocessing import compute_advantages
@@ -129,9 +129,9 @@ class MockVectorEnv(VectorEnv):
return self.envs
class TestPolicyEvaluator(unittest.TestCase):
class TestRolloutWorker(unittest.TestCase):
def testBasic(self):
ev = PolicyEvaluator(
ev = RolloutWorker(
env_creator=lambda _: gym.make("CartPole-v0"), policy=MockPolicy)
batch = ev.sample()
for key in [
@@ -155,7 +155,7 @@ class TestPolicyEvaluator(unittest.TestCase):
self.assertGreater(batch["advantages"][0], 1)
def testBatchIds(self):
ev = PolicyEvaluator(
ev = RolloutWorker(
env_creator=lambda _: gym.make("CartPole-v0"), policy=MockPolicy)
batch1 = ev.sample()
batch2 = ev.sample()
@@ -213,11 +213,10 @@ class TestPolicyEvaluator(unittest.TestCase):
"sample_batch_size": 5,
"num_envs_per_worker": 2,
})
results = pg.optimizer.foreach_evaluator(
lambda ev: ev.sample_batch_size)
results2 = pg.optimizer.foreach_evaluator_with_index(
results = pg.workers.foreach_worker(lambda ev: ev.sample_batch_size)
results2 = pg.workers.foreach_worker_with_index(
lambda ev, i: (i, ev.sample_batch_size))
results3 = pg.optimizer.foreach_evaluator(
results3 = pg.workers.foreach_worker(
lambda ev: ev.foreach_env(lambda env: 1))
self.assertEqual(results, [10, 10, 10])
self.assertEqual(results2, [(0, 10), (1, 10), (2, 10)])
@@ -225,7 +224,7 @@ class TestPolicyEvaluator(unittest.TestCase):
def testRewardClipping(self):
# clipping on
ev = PolicyEvaluator(
ev = RolloutWorker(
env_creator=lambda _: MockEnv2(episode_length=10),
policy=MockPolicy,
clip_rewards=True,
@@ -235,7 +234,7 @@ class TestPolicyEvaluator(unittest.TestCase):
self.assertEqual(result["episode_reward_mean"], 1000)
# clipping off
ev2 = PolicyEvaluator(
ev2 = RolloutWorker(
env_creator=lambda _: MockEnv2(episode_length=10),
policy=MockPolicy,
clip_rewards=False,
@@ -245,7 +244,7 @@ class TestPolicyEvaluator(unittest.TestCase):
self.assertEqual(result2["episode_reward_mean"], 1000)
def testHardHorizon(self):
ev = PolicyEvaluator(
ev = RolloutWorker(
env_creator=lambda _: MockEnv(episode_length=10),
policy=MockPolicy,
batch_mode="complete_episodes",
@@ -259,7 +258,7 @@ class TestPolicyEvaluator(unittest.TestCase):
self.assertEqual(sum(samples["dones"]), 3)
def testSoftHorizon(self):
ev = PolicyEvaluator(
ev = RolloutWorker(
env_creator=lambda _: MockEnv(episode_length=10),
policy=MockPolicy,
batch_mode="complete_episodes",
@@ -273,11 +272,11 @@ class TestPolicyEvaluator(unittest.TestCase):
self.assertEqual(sum(samples["dones"]), 1)
def testMetrics(self):
ev = PolicyEvaluator(
ev = RolloutWorker(
env_creator=lambda _: MockEnv(episode_length=10),
policy=MockPolicy,
batch_mode="complete_episodes")
remote_ev = PolicyEvaluator.as_remote().remote(
remote_ev = RolloutWorker.as_remote().remote(
env_creator=lambda _: MockEnv(episode_length=10),
policy=MockPolicy,
batch_mode="complete_episodes")
@@ -288,7 +287,7 @@ class TestPolicyEvaluator(unittest.TestCase):
self.assertEqual(result["episode_reward_mean"], 10)
def testAsync(self):
ev = PolicyEvaluator(
ev = RolloutWorker(
env_creator=lambda _: gym.make("CartPole-v0"),
sample_async=True,
policy=MockPolicy)
@@ -298,7 +297,7 @@ class TestPolicyEvaluator(unittest.TestCase):
self.assertGreater(batch["advantages"][0], 1)
def testAutoVectorization(self):
ev = PolicyEvaluator(
ev = RolloutWorker(
env_creator=lambda cfg: MockEnv(episode_length=20, config=cfg),
policy=MockPolicy,
batch_mode="truncate_episodes",
@@ -321,7 +320,7 @@ class TestPolicyEvaluator(unittest.TestCase):
self.assertEqual(indices, [0, 1, 2, 3, 4, 5, 6, 7])
def testBatchesLargerWhenVectorized(self):
ev = PolicyEvaluator(
ev = RolloutWorker(
env_creator=lambda _: MockEnv(episode_length=8),
policy=MockPolicy,
batch_mode="truncate_episodes",
@@ -336,7 +335,7 @@ class TestPolicyEvaluator(unittest.TestCase):
self.assertEqual(result["episodes_this_iter"], 4)
def testVectorEnvSupport(self):
ev = PolicyEvaluator(
ev = RolloutWorker(
env_creator=lambda _: MockVectorEnv(episode_length=20, num_envs=8),
policy=MockPolicy,
batch_mode="truncate_episodes",
@@ -353,7 +352,7 @@ class TestPolicyEvaluator(unittest.TestCase):
self.assertEqual(result["episodes_this_iter"], 8)
def testTruncateEpisodes(self):
ev = PolicyEvaluator(
ev = RolloutWorker(
env_creator=lambda _: MockEnv(10),
policy=MockPolicy,
batch_steps=15,
@@ -362,7 +361,7 @@ class TestPolicyEvaluator(unittest.TestCase):
self.assertEqual(batch.count, 15)
def testCompleteEpisodes(self):
ev = PolicyEvaluator(
ev = RolloutWorker(
env_creator=lambda _: MockEnv(10),
policy=MockPolicy,
batch_steps=5,
@@ -371,7 +370,7 @@ class TestPolicyEvaluator(unittest.TestCase):
self.assertEqual(batch.count, 10)
def testCompleteEpisodesPacking(self):
ev = PolicyEvaluator(
ev = RolloutWorker(
env_creator=lambda _: MockEnv(10),
policy=MockPolicy,
batch_steps=15,
@@ -383,7 +382,7 @@ class TestPolicyEvaluator(unittest.TestCase):
[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9])
def testFilterSync(self):
ev = PolicyEvaluator(
ev = RolloutWorker(
env_creator=lambda _: gym.make("CartPole-v0"),
policy=MockPolicy,
sample_async=True,
@@ -396,7 +395,7 @@ class TestPolicyEvaluator(unittest.TestCase):
self.assertNotEqual(obs_f.buffer.n, 0)
def testGetFilters(self):
ev = PolicyEvaluator(
ev = RolloutWorker(
env_creator=lambda _: gym.make("CartPole-v0"),
policy=MockPolicy,
sample_async=True,
@@ -411,7 +410,7 @@ class TestPolicyEvaluator(unittest.TestCase):
self.assertGreaterEqual(obs_f2.buffer.n, obs_f.buffer.n)
def testSyncFilter(self):
ev = PolicyEvaluator(
ev = RolloutWorker(
env_creator=lambda _: gym.make("CartPole-v0"),
policy=MockPolicy,
sample_async=True,
+4 -4
View File
@@ -58,15 +58,15 @@ class TaskPool(object):
remaining.append((worker, obj_id))
self._fetching = remaining
def reset_evaluators(self, evaluators):
"""Notify that some evaluators may be removed."""
def reset_workers(self, workers):
"""Notify that some workers may be removed."""
for obj_id, ev in self._tasks.copy().items():
if ev not in evaluators:
if ev not in workers:
del self._tasks[obj_id]
del self._objects[obj_id]
ok = []
for ev, obj_id in self._fetching:
if ev in evaluators:
if ev in workers:
ok.append((ev, obj_id))
self._fetching = ok