mirror of
https://github.com/wassname/ray.git
synced 2026-07-04 15:05:49 +08:00
[rllib] Rename PolicyEvaluator => RolloutWorker (#4820)
This commit is contained in:
@@ -11,7 +11,7 @@ from ray.tune.registry import register_trainable
|
||||
|
||||
from ray.rllib.evaluation.policy_graph import PolicyGraph
|
||||
from ray.rllib.evaluation.tf_policy_graph import TFPolicyGraph
|
||||
from ray.rllib.evaluation.policy_evaluator import PolicyEvaluator
|
||||
from ray.rllib.evaluation.rollout_worker import RolloutWorker
|
||||
from ray.rllib.env.base_env import BaseEnv
|
||||
from ray.rllib.env.multi_agent_env import MultiAgentEnv
|
||||
from ray.rllib.env.vector_env import VectorEnv
|
||||
@@ -55,6 +55,7 @@ __all__ = [
|
||||
"PolicyGraph",
|
||||
"TFPolicy",
|
||||
"TFPolicyGraph",
|
||||
"RolloutWorker",
|
||||
"PolicyEvaluator",
|
||||
"SampleBatch",
|
||||
"BaseEnv",
|
||||
|
||||
@@ -2,9 +2,10 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from ray.rllib.agents.a3c.a3c import A3CTrainer, DEFAULT_CONFIG as A3C_CONFIG
|
||||
from ray.rllib.optimizers import SyncSamplesOptimizer
|
||||
from ray.rllib.utils.annotations import override
|
||||
from ray.rllib.agents.a3c.a3c import DEFAULT_CONFIG as A3C_CONFIG, \
|
||||
validate_config, get_policy_class
|
||||
from ray.rllib.agents.a3c.a3c_tf_policy import A3CTFPolicy
|
||||
from ray.rllib.agents.trainer_template import build_trainer
|
||||
from ray.rllib.utils import merge_dicts
|
||||
|
||||
A2C_DEFAULT_CONFIG = merge_dicts(
|
||||
@@ -16,16 +17,9 @@ A2C_DEFAULT_CONFIG = merge_dicts(
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
class A2CTrainer(A3CTrainer):
|
||||
"""Synchronous variant of the A3CTrainer."""
|
||||
|
||||
_name = "A2C"
|
||||
_default_config = A2C_DEFAULT_CONFIG
|
||||
|
||||
@override(A3CTrainer)
|
||||
def _make_optimizer(self):
|
||||
return SyncSamplesOptimizer(
|
||||
self.local_evaluator,
|
||||
self.remote_evaluators,
|
||||
train_batch_size=self.config["train_batch_size"])
|
||||
A2CTrainer = build_trainer(
|
||||
name="A2C",
|
||||
default_config=A2C_DEFAULT_CONFIG,
|
||||
default_policy=A3CTFPolicy,
|
||||
get_policy_class=get_policy_class,
|
||||
validate_config=validate_config)
|
||||
|
||||
@@ -2,12 +2,10 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import time
|
||||
|
||||
from ray.rllib.agents.a3c.a3c_tf_policy import A3CTFPolicy
|
||||
from ray.rllib.agents.trainer import Trainer, with_common_config
|
||||
from ray.rllib.agents.trainer import with_common_config
|
||||
from ray.rllib.agents.trainer_template import build_trainer
|
||||
from ray.rllib.optimizers import AsyncGradientsOptimizer
|
||||
from ray.rllib.utils.annotations import override
|
||||
|
||||
# yapf: disable
|
||||
# __sphinx_doc_begin__
|
||||
@@ -38,43 +36,28 @@ DEFAULT_CONFIG = with_common_config({
|
||||
# yapf: enable
|
||||
|
||||
|
||||
class A3CTrainer(Trainer):
|
||||
"""A3C implementations in TensorFlow and PyTorch."""
|
||||
def get_policy_class(config):
|
||||
if config["use_pytorch"]:
|
||||
from ray.rllib.agents.a3c.a3c_torch_policy import \
|
||||
A3CTorchPolicy
|
||||
return A3CTorchPolicy
|
||||
else:
|
||||
return A3CTFPolicy
|
||||
|
||||
_name = "A3C"
|
||||
_default_config = DEFAULT_CONFIG
|
||||
_policy = A3CTFPolicy
|
||||
|
||||
@override(Trainer)
|
||||
def _init(self, config, env_creator):
|
||||
if config["use_pytorch"]:
|
||||
from ray.rllib.agents.a3c.a3c_torch_policy import \
|
||||
A3CTorchPolicy
|
||||
policy_cls = A3CTorchPolicy
|
||||
else:
|
||||
policy_cls = self._policy
|
||||
def validate_config(config):
|
||||
if config["entropy_coeff"] < 0:
|
||||
raise DeprecationWarning("entropy_coeff must be >= 0")
|
||||
|
||||
if config["entropy_coeff"] < 0:
|
||||
raise DeprecationWarning("entropy_coeff must be >= 0")
|
||||
|
||||
self.local_evaluator = self.make_local_evaluator(
|
||||
env_creator, policy_cls)
|
||||
self.remote_evaluators = self.make_remote_evaluators(
|
||||
env_creator, policy_cls, config["num_workers"])
|
||||
self.optimizer = self._make_optimizer()
|
||||
def make_async_optimizer(workers, config):
|
||||
return AsyncGradientsOptimizer(workers, **config["optimizer"])
|
||||
|
||||
@override(Trainer)
|
||||
def _train(self):
|
||||
prev_steps = self.optimizer.num_steps_sampled
|
||||
start = time.time()
|
||||
while time.time() - start < self.config["min_iter_time_s"]:
|
||||
self.optimizer.step()
|
||||
result = self.collect_metrics()
|
||||
result.update(timesteps_this_iter=self.optimizer.num_steps_sampled -
|
||||
prev_steps)
|
||||
return result
|
||||
|
||||
def _make_optimizer(self):
|
||||
return AsyncGradientsOptimizer(self.local_evaluator,
|
||||
self.remote_evaluators,
|
||||
**self.config["optimizer"])
|
||||
A3CTrainer = build_trainer(
|
||||
name="A3C",
|
||||
default_config=DEFAULT_CONFIG,
|
||||
default_policy=A3CTFPolicy,
|
||||
get_policy_class=get_policy_class,
|
||||
validate_config=validate_config,
|
||||
make_policy_optimizer=make_async_optimizer)
|
||||
|
||||
@@ -48,7 +48,7 @@ class ApexDDPGTrainer(DDPGTrainer):
|
||||
# Ape-X updates based on num steps trained, not sampled
|
||||
if self.optimizer.num_steps_trained - self.last_target_update_ts > \
|
||||
self.config["target_network_update_freq"]:
|
||||
self.local_evaluator.foreach_trainable_policy(
|
||||
self.workers.local_worker().foreach_trainable_policy(
|
||||
lambda p, _: p.update_target())
|
||||
self.last_target_update_ts = self.optimizer.num_steps_trained
|
||||
self.num_target_updates += 1
|
||||
|
||||
@@ -171,9 +171,9 @@ class DDPGTrainer(DQNTrainer):
|
||||
if pure_expl_steps:
|
||||
# tell workers whether they should do pure exploration
|
||||
only_explore = self.global_timestep < pure_expl_steps
|
||||
self.local_evaluator.foreach_trainable_policy(
|
||||
self.workers.local_worker().foreach_trainable_policy(
|
||||
lambda p, _: p.set_pure_exploration_phase(only_explore))
|
||||
for e in self.remote_evaluators:
|
||||
for e in self.workers.remote_workers():
|
||||
e.foreach_trainable_policy.remote(
|
||||
lambda p, _: p.set_pure_exploration_phase(only_explore))
|
||||
return super(DDPGTrainer, self)._train()
|
||||
|
||||
@@ -515,7 +515,7 @@ class DDPGTFPolicy(DDPGPostprocessing, TFPolicy):
|
||||
|
||||
stochastic_actions = tf.cond(
|
||||
# need to condition on noise_scale > 0 because zeroing
|
||||
# noise_scale is how evaluator signals no noise should be used
|
||||
# noise_scale is how a worker signals no noise should be used
|
||||
# (this is ugly and should be fixed by adding an "eval_mode"
|
||||
# config flag or something)
|
||||
tf.logical_and(enable_pure_exploration, noise_scale > 0),
|
||||
|
||||
@@ -51,7 +51,7 @@ class ApexTrainer(DQNTrainer):
|
||||
# Ape-X updates based on num steps trained, not sampled
|
||||
if self.optimizer.num_steps_trained - self.last_target_update_ts > \
|
||||
self.config["target_network_update_freq"]:
|
||||
self.local_evaluator.foreach_trainable_policy(
|
||||
self.workers.local_worker().foreach_trainable_policy(
|
||||
lambda p, _: p.update_target())
|
||||
self.last_target_update_ts = self.optimizer.num_steps_trained
|
||||
self.num_target_updates += 1
|
||||
|
||||
@@ -196,26 +196,26 @@ class DQNTrainer(Trainer):
|
||||
config["callbacks"]["on_episode_end"] = tune.function(
|
||||
on_episode_end)
|
||||
|
||||
self.local_evaluator = self.make_local_evaluator(
|
||||
env_creator, self._policy)
|
||||
|
||||
def create_remote_evaluators():
|
||||
return self.make_remote_evaluators(env_creator, self._policy,
|
||||
config["num_workers"])
|
||||
|
||||
if config["optimizer_class"] != "AsyncReplayOptimizer":
|
||||
self.remote_evaluators = create_remote_evaluators()
|
||||
self.workers = self._make_workers(
|
||||
env_creator,
|
||||
self._policy,
|
||||
config,
|
||||
num_workers=self.config["num_workers"])
|
||||
workers_needed = 0
|
||||
else:
|
||||
# Hack to workaround https://github.com/ray-project/ray/issues/2541
|
||||
self.remote_evaluators = None
|
||||
self.workers = self._make_workers(
|
||||
env_creator, self._policy, config, num_workers=0)
|
||||
workers_needed = self.config["num_workers"]
|
||||
|
||||
self.optimizer = getattr(optimizers, config["optimizer_class"])(
|
||||
self.local_evaluator, self.remote_evaluators,
|
||||
**config["optimizer"])
|
||||
# Create the remote evaluators *after* the replay actors
|
||||
if self.remote_evaluators is None:
|
||||
self.remote_evaluators = create_remote_evaluators()
|
||||
self.optimizer._set_evaluators(self.remote_evaluators)
|
||||
self.workers, **config["optimizer"])
|
||||
|
||||
# Create the remote workers *after* the replay actors
|
||||
if workers_needed > 0:
|
||||
self.workers.add_workers(workers_needed)
|
||||
self.optimizer._set_workers(self.workers.remote_workers())
|
||||
|
||||
self.last_target_update_ts = 0
|
||||
self.num_target_updates = 0
|
||||
@@ -226,9 +226,9 @@ class DQNTrainer(Trainer):
|
||||
|
||||
# Update worker explorations
|
||||
exp_vals = [self.exploration0.value(self.global_timestep)]
|
||||
self.local_evaluator.foreach_trainable_policy(
|
||||
self.workers.local_worker().foreach_trainable_policy(
|
||||
lambda p, _: p.set_epsilon(exp_vals[0]))
|
||||
for i, e in enumerate(self.remote_evaluators):
|
||||
for i, e in enumerate(self.workers.remote_workers()):
|
||||
exp_val = self.explorations[i].value(self.global_timestep)
|
||||
e.foreach_trainable_policy.remote(
|
||||
lambda p, _: p.set_epsilon(exp_val))
|
||||
@@ -245,8 +245,8 @@ class DQNTrainer(Trainer):
|
||||
if self.config["per_worker_exploration"]:
|
||||
# Only collect metrics from the third of workers with lowest eps
|
||||
result = self.collect_metrics(
|
||||
selected_evaluators=self.remote_evaluators[
|
||||
-len(self.remote_evaluators) // 3:])
|
||||
selected_workers=self.workers.remote_workers()[
|
||||
-len(self.workers.remote_workers()) // 3:])
|
||||
else:
|
||||
result = self.collect_metrics()
|
||||
|
||||
@@ -263,7 +263,7 @@ class DQNTrainer(Trainer):
|
||||
def update_target_if_needed(self):
|
||||
if self.global_timestep - self.last_target_update_ts > \
|
||||
self.config["target_network_update_freq"]:
|
||||
self.local_evaluator.foreach_trainable_policy(
|
||||
self.workers.local_worker().foreach_trainable_policy(
|
||||
lambda p, _: p.update_target())
|
||||
self.last_target_update_ts = self.global_timestep
|
||||
self.num_target_updates += 1
|
||||
@@ -275,11 +275,13 @@ class DQNTrainer(Trainer):
|
||||
def _evaluate(self):
|
||||
logger.info("Evaluating current policy for {} episodes".format(
|
||||
self.config["evaluation_num_episodes"]))
|
||||
self.evaluation_ev.restore(self.local_evaluator.save())
|
||||
self.evaluation_ev.foreach_policy(lambda p, _: p.set_epsilon(0))
|
||||
self.evaluation_workers.local_worker().restore(
|
||||
self.workers.local_worker().save())
|
||||
self.evaluation_workers.local_worker().foreach_policy(
|
||||
lambda p, _: p.set_epsilon(0))
|
||||
for _ in range(self.config["evaluation_num_episodes"]):
|
||||
self.evaluation_ev.sample()
|
||||
metrics = collect_metrics(self.evaluation_ev)
|
||||
self.evaluation_workers.local_worker().sample()
|
||||
metrics = collect_metrics(self.evaluation_workers.local_worker())
|
||||
return {"evaluation": metrics}
|
||||
|
||||
def _make_exploration_schedule(self, worker_index):
|
||||
|
||||
@@ -192,7 +192,7 @@ class ESTrainer(Trainer):
|
||||
|
||||
# Create the actors.
|
||||
logger.info("Creating actors.")
|
||||
self.workers = [
|
||||
self._workers = [
|
||||
Worker.remote(config, policy_params, env_creator, noise_id)
|
||||
for _ in range(config["num_workers"])
|
||||
]
|
||||
@@ -270,7 +270,7 @@ class ESTrainer(Trainer):
|
||||
# Now sync the filters
|
||||
FilterManager.synchronize({
|
||||
DEFAULT_POLICY_ID: self.policy.get_filter()
|
||||
}, self.workers)
|
||||
}, self._workers)
|
||||
|
||||
info = {
|
||||
"weights_norm": np.square(theta).sum(),
|
||||
@@ -296,7 +296,7 @@ class ESTrainer(Trainer):
|
||||
@override(Trainer)
|
||||
def _stop(self):
|
||||
# workaround for https://github.com/ray-project/ray/issues/1516
|
||||
for w in self.workers:
|
||||
for w in self._workers:
|
||||
w.__ray_terminate__.remote()
|
||||
|
||||
def _collect_results(self, theta_id, min_episodes, min_timesteps):
|
||||
@@ -307,7 +307,7 @@ class ESTrainer(Trainer):
|
||||
"Collected {} episodes {} timesteps so far this iter".format(
|
||||
num_episodes, num_timesteps))
|
||||
rollout_ids = [
|
||||
worker.do_rollouts.remote(theta_id) for worker in self.workers
|
||||
worker.do_rollouts.remote(theta_id) for worker in self._workers
|
||||
]
|
||||
# Get the results of the rollouts.
|
||||
for result in ray_get_and_free(rollout_ids):
|
||||
@@ -334,4 +334,4 @@ class ESTrainer(Trainer):
|
||||
self.policy.set_filter(state["filter"])
|
||||
FilterManager.synchronize({
|
||||
DEFAULT_POLICY_ID: self.policy.get_filter()
|
||||
}, self.workers)
|
||||
}, self._workers)
|
||||
|
||||
@@ -113,18 +113,16 @@ class ImpalaTrainer(Trainer):
|
||||
if k not in config["optimizer"]:
|
||||
config["optimizer"][k] = config[k]
|
||||
policy_cls = self._get_policy()
|
||||
self.local_evaluator = self.make_local_evaluator(
|
||||
self.env_creator, policy_cls)
|
||||
self.workers = self._make_workers(
|
||||
self.env_creator, policy_cls, self.config, num_workers=0)
|
||||
|
||||
if self.config["num_aggregation_workers"] > 0:
|
||||
# Create co-located aggregator actors first for placement pref
|
||||
aggregators = TreeAggregator.precreate_aggregators(
|
||||
self.config["num_aggregation_workers"])
|
||||
|
||||
self.remote_evaluators = self.make_remote_evaluators(
|
||||
env_creator, policy_cls, config["num_workers"])
|
||||
self.optimizer = AsyncSamplesOptimizer(self.local_evaluator,
|
||||
self.remote_evaluators,
|
||||
self.workers.add_workers(config["num_workers"])
|
||||
self.optimizer = AsyncSamplesOptimizer(self.workers,
|
||||
**config["optimizer"])
|
||||
if config["entropy_coeff"] < 0:
|
||||
raise DeprecationWarning("entropy_coeff must be >= 0")
|
||||
|
||||
@@ -48,13 +48,10 @@ class MARWILTrainer(Trainer):
|
||||
|
||||
@override(Trainer)
|
||||
def _init(self, config, env_creator):
|
||||
self.local_evaluator = self.make_local_evaluator(
|
||||
env_creator, self._policy)
|
||||
self.remote_evaluators = self.make_remote_evaluators(
|
||||
env_creator, self._policy, config["num_workers"])
|
||||
self.workers = self._make_workers(env_creator, self._policy, config,
|
||||
config["num_workers"])
|
||||
self.optimizer = SyncBatchReplayOptimizer(
|
||||
self.local_evaluator,
|
||||
self.remote_evaluators,
|
||||
self.workers,
|
||||
learning_starts=config["learning_starts"],
|
||||
buffer_size=config["replay_buffer_size"],
|
||||
train_batch_size=config["train_batch_size"],
|
||||
|
||||
@@ -29,7 +29,7 @@ def get_policy_class(config):
|
||||
|
||||
|
||||
PGTrainer = build_trainer(
|
||||
name="PGTrainer",
|
||||
name="PG",
|
||||
default_config=DEFAULT_CONFIG,
|
||||
default_policy=PGTFPolicy,
|
||||
get_policy_class=get_policy_class)
|
||||
|
||||
@@ -63,17 +63,15 @@ DEFAULT_CONFIG = with_common_config({
|
||||
# yapf: enable
|
||||
|
||||
|
||||
def choose_policy_optimizer(local_evaluator, remote_evaluators, config):
|
||||
def choose_policy_optimizer(workers, config):
|
||||
if config["simple_optimizer"]:
|
||||
return SyncSamplesOptimizer(
|
||||
local_evaluator,
|
||||
remote_evaluators,
|
||||
workers,
|
||||
num_sgd_iter=config["num_sgd_iter"],
|
||||
train_batch_size=config["train_batch_size"])
|
||||
|
||||
return LocalMultiGPUOptimizer(
|
||||
local_evaluator,
|
||||
remote_evaluators,
|
||||
workers,
|
||||
sgd_batch_size=config["sgd_minibatch_size"],
|
||||
num_sgd_iter=config["num_sgd_iter"],
|
||||
num_gpus=config["num_gpus"],
|
||||
@@ -87,7 +85,7 @@ def choose_policy_optimizer(local_evaluator, remote_evaluators, config):
|
||||
def update_kl(trainer, fetches):
|
||||
if "kl" in fetches:
|
||||
# single-agent
|
||||
trainer.local_evaluator.for_policy(
|
||||
trainer.workers.local_worker().for_policy(
|
||||
lambda pi: pi.update_kl(fetches["kl"]))
|
||||
else:
|
||||
|
||||
@@ -98,7 +96,7 @@ def update_kl(trainer, fetches):
|
||||
logger.debug("No data for {}, not updating kl".format(pi_id))
|
||||
|
||||
# multi-agent
|
||||
trainer.local_evaluator.foreach_trainable_policy(update)
|
||||
trainer.workers.local_worker().foreach_trainable_policy(update)
|
||||
|
||||
|
||||
def warn_about_obs_filter(trainer):
|
||||
@@ -155,7 +153,7 @@ def validate_config(config):
|
||||
|
||||
|
||||
PPOTrainer = build_trainer(
|
||||
name="PPOTrainer",
|
||||
name="PPO",
|
||||
default_config=DEFAULT_CONFIG,
|
||||
default_policy=PPOTFPolicy,
|
||||
make_policy_optimizer=choose_policy_optimizer,
|
||||
|
||||
@@ -50,7 +50,7 @@ class ApexQMixTrainer(QMixTrainer):
|
||||
# Ape-X updates based on num steps trained, not sampled
|
||||
if self.optimizer.num_steps_trained - self.last_target_update_ts > \
|
||||
self.config["target_network_update_freq"]:
|
||||
self.local_evaluator.foreach_trainable_policy(
|
||||
self.workers.local_worker().foreach_trainable_policy(
|
||||
lambda p, _: p.update_target())
|
||||
self.last_target_update_ts = self.optimizer.num_steps_trained
|
||||
self.num_target_updates += 1
|
||||
|
||||
@@ -10,18 +10,14 @@ import pickle
|
||||
import six
|
||||
import time
|
||||
import tempfile
|
||||
from types import FunctionType
|
||||
|
||||
import ray
|
||||
from ray.exceptions import RayError
|
||||
from ray.rllib.offline import NoopOutput, JsonReader, MixedInput, JsonWriter, \
|
||||
ShuffledInput
|
||||
from ray.rllib.models import MODEL_DEFAULTS
|
||||
from ray.rllib.evaluation.policy_evaluator import PolicyEvaluator, \
|
||||
_validate_multiagent_config
|
||||
from ray.rllib.policy.sample_batch import DEFAULT_POLICY_ID
|
||||
from ray.rllib.evaluation.metrics import collect_metrics
|
||||
from ray.rllib.optimizers.policy_optimizer import PolicyOptimizer
|
||||
from ray.rllib.evaluation.worker_set import WorkerSet
|
||||
from ray.rllib.utils.annotations import override, PublicAPI, DeveloperAPI
|
||||
from ray.rllib.utils import FilterManager, deep_update, merge_dicts
|
||||
from ray.rllib.utils.memory import ray_get_and_free
|
||||
@@ -46,7 +42,7 @@ COMMON_CONFIG = {
|
||||
# === Debugging ===
|
||||
# Whether to write episode stats and videos to the agent log dir
|
||||
"monitor": False,
|
||||
# Set the ray.rllib.* log level for the agent process and its evaluators.
|
||||
# Set the ray.rllib.* log level for the agent process and its workers.
|
||||
# Should be one of DEBUG, INFO, WARN, or ERROR. The DEBUG level will also
|
||||
# periodically print out summaries of relevant internal dataflow (this is
|
||||
# also printed out once at startup at the INFO level).
|
||||
@@ -60,7 +56,7 @@ COMMON_CONFIG = {
|
||||
"on_episode_start": None, # arg: {"env": .., "episode": ...}
|
||||
"on_episode_step": None, # arg: {"env": .., "episode": ...}
|
||||
"on_episode_end": None, # arg: {"env": .., "episode": ...}
|
||||
"on_sample_end": None, # arg: {"samples": .., "evaluator": ...}
|
||||
"on_sample_end": None, # arg: {"samples": .., "worker": ...}
|
||||
"on_train_result": None, # arg: {"trainer": ..., "result": ...}
|
||||
"on_postprocess_traj": None, # arg: {
|
||||
# "agent_id": ..., "episode": ...,
|
||||
@@ -153,7 +149,7 @@ COMMON_CONFIG = {
|
||||
"synchronize_filters": True,
|
||||
# Configure TF for single-process operation by default
|
||||
"tf_session_args": {
|
||||
# note: overriden by `local_evaluator_tf_session_args`
|
||||
# note: overriden by `local_tf_session_args`
|
||||
"intra_op_parallelism_threads": 2,
|
||||
"inter_op_parallelism_threads": 2,
|
||||
"gpu_options": {
|
||||
@@ -165,8 +161,8 @@ COMMON_CONFIG = {
|
||||
},
|
||||
"allow_soft_placement": True, # required by PPO multi-gpu
|
||||
},
|
||||
# Override the following tf session args on the local evaluator
|
||||
"local_evaluator_tf_session_args": {
|
||||
# Override the following tf session args on the local worker
|
||||
"local_tf_session_args": {
|
||||
# Allow a higher level of parallelism by default, but not unlimited
|
||||
# since that can cause crashes with many concurrent drivers.
|
||||
"intra_op_parallelism_threads": 8,
|
||||
@@ -188,6 +184,8 @@ COMMON_CONFIG = {
|
||||
# but optimal value could be obtained by measuring your environment
|
||||
# step / reset and model inference perf.
|
||||
"remote_env_batch_wait_ms": 0,
|
||||
# Minimum time per iteration
|
||||
"min_iter_time_s": 0,
|
||||
|
||||
# === Offline Datasets ===
|
||||
# Specify how to generate experiences:
|
||||
@@ -229,7 +227,7 @@ COMMON_CONFIG = {
|
||||
# === Multiagent ===
|
||||
"multiagent": {
|
||||
# Map from policy ids to tuples of (policy_cls, obs_space,
|
||||
# act_space, config). See policy_evaluator.py for more info.
|
||||
# act_space, config). See rollout_worker.py for more info.
|
||||
"policies": {},
|
||||
# Function mapping agent ids to policy ids.
|
||||
"policy_mapping_fn": None,
|
||||
@@ -292,7 +290,7 @@ class Trainer(Trainable):
|
||||
|
||||
config = config or {}
|
||||
|
||||
# Vars to synchronize to evaluators on each train call
|
||||
# Vars to synchronize to workers on each train call
|
||||
self.global_vars = {"timestep": 0}
|
||||
|
||||
# Trainers allow env ids to be passed directly to the constructor.
|
||||
@@ -337,9 +335,10 @@ class Trainer(Trainable):
|
||||
|
||||
if self._has_policy_optimizer():
|
||||
self.global_vars["timestep"] = self.optimizer.num_steps_sampled
|
||||
self.optimizer.local_evaluator.set_global_vars(self.global_vars)
|
||||
for ev in self.optimizer.remote_evaluators:
|
||||
ev.set_global_vars.remote(self.global_vars)
|
||||
self.optimizer.workers.local_worker().set_global_vars(
|
||||
self.global_vars)
|
||||
for w in self.optimizer.workers.remote_workers():
|
||||
w.set_global_vars.remote(self.global_vars)
|
||||
logger.debug("updated global vars: {}".format(self.global_vars))
|
||||
|
||||
result = None
|
||||
@@ -366,17 +365,18 @@ class Trainer(Trainable):
|
||||
raise RuntimeError("Failed to recover from worker crash")
|
||||
|
||||
if (self.config.get("observation_filter", "NoFilter") != "NoFilter"
|
||||
and hasattr(self, "local_evaluator")):
|
||||
and hasattr(self, "workers")
|
||||
and isinstance(self.workers, WorkerSet)):
|
||||
FilterManager.synchronize(
|
||||
self.local_evaluator.filters,
|
||||
self.remote_evaluators,
|
||||
self.workers.local_worker().filters,
|
||||
self.workers.remote_workers(),
|
||||
update_remote=self.config["synchronize_filters"])
|
||||
logger.debug("synchronized filters: {}".format(
|
||||
self.local_evaluator.filters))
|
||||
self.workers.local_worker().filters))
|
||||
|
||||
if self._has_policy_optimizer():
|
||||
result["num_healthy_workers"] = len(
|
||||
self.optimizer.remote_evaluators)
|
||||
self.optimizer.workers.remote_workers())
|
||||
|
||||
if self.config["evaluation_interval"]:
|
||||
if self._iteration % self.config["evaluation_interval"] == 0:
|
||||
@@ -441,25 +441,17 @@ class Trainer(Trainable):
|
||||
})
|
||||
logger.debug(
|
||||
"using evaluation_config: {}".format(extra_config))
|
||||
# Make local evaluation evaluators
|
||||
self.evaluation_ev = self.make_local_evaluator(
|
||||
self.env_creator, self._policy, extra_config=extra_config)
|
||||
self.evaluation_workers = self._make_workers(
|
||||
self.env_creator,
|
||||
self._policy,
|
||||
merge_dicts(self.config, extra_config),
|
||||
num_workers=0)
|
||||
self.evaluation_metrics = self._evaluate()
|
||||
|
||||
@override(Trainable)
|
||||
def _stop(self):
|
||||
# Call stop on all evaluators to release resources
|
||||
if hasattr(self, "local_evaluator"):
|
||||
self.local_evaluator.stop()
|
||||
if hasattr(self, "remote_evaluators"):
|
||||
for ev in self.remote_evaluators:
|
||||
ev.stop.remote()
|
||||
|
||||
# workaround for https://github.com/ray-project/ray/issues/1516
|
||||
if hasattr(self, "remote_evaluators"):
|
||||
for ev in self.remote_evaluators:
|
||||
ev.__ray_terminate__.remote()
|
||||
|
||||
if hasattr(self, "workers"):
|
||||
self.workers.stop()
|
||||
if hasattr(self, "optimizer"):
|
||||
self.optimizer.stop()
|
||||
|
||||
@@ -475,6 +467,15 @@ class Trainer(Trainable):
|
||||
extra_data = pickle.load(open(checkpoint_path, "rb"))
|
||||
self.__setstate__(extra_data)
|
||||
|
||||
@DeveloperAPI
|
||||
def _make_workers(self, env_creator, policy, config, num_workers):
|
||||
return WorkerSet(
|
||||
env_creator,
|
||||
policy,
|
||||
config,
|
||||
num_workers=num_workers,
|
||||
logdir=self.logdir)
|
||||
|
||||
@DeveloperAPI
|
||||
def _init(self, config, env_creator):
|
||||
"""Subclasses should override this for custom initialization."""
|
||||
@@ -498,11 +499,12 @@ class Trainer(Trainable):
|
||||
|
||||
logger.info("Evaluating current policy for {} episodes".format(
|
||||
self.config["evaluation_num_episodes"]))
|
||||
self.evaluation_ev.restore(self.local_evaluator.save())
|
||||
self.evaluation_workers.local_worker().restore(
|
||||
self.workers.local_worker().save())
|
||||
for _ in range(self.config["evaluation_num_episodes"]):
|
||||
self.evaluation_ev.sample()
|
||||
self.evaluation_workers.local_worker().sample()
|
||||
|
||||
metrics = collect_metrics(self.evaluation_ev)
|
||||
metrics = collect_metrics(self.evaluation_workers.local_worker())
|
||||
return {"evaluation": metrics}
|
||||
|
||||
@PublicAPI
|
||||
@@ -540,9 +542,9 @@ class Trainer(Trainable):
|
||||
|
||||
if state is None:
|
||||
state = []
|
||||
preprocessed = self.local_evaluator.preprocessors[policy_id].transform(
|
||||
observation)
|
||||
filtered_obs = self.local_evaluator.filters[policy_id](
|
||||
preprocessed = self.workers.local_worker().preprocessors[
|
||||
policy_id].transform(observation)
|
||||
filtered_obs = self.workers.local_worker().filters[policy_id](
|
||||
preprocessed, update=False)
|
||||
if state:
|
||||
return self.get_policy(policy_id).compute_single_action(
|
||||
@@ -590,7 +592,7 @@ class Trainer(Trainable):
|
||||
policy_id (str): id of policy to return.
|
||||
"""
|
||||
|
||||
return self.local_evaluator.get_policy(policy_id)
|
||||
return self.workers.local_worker().get_policy(policy_id)
|
||||
|
||||
@PublicAPI
|
||||
def get_weights(self, policies=None):
|
||||
@@ -600,7 +602,7 @@ class Trainer(Trainable):
|
||||
policies (list): Optional list of policies to return weights for,
|
||||
or None for all policies.
|
||||
"""
|
||||
return self.local_evaluator.get_weights(policies)
|
||||
return self.workers.local_worker().get_weights(policies)
|
||||
|
||||
@PublicAPI
|
||||
def set_weights(self, weights):
|
||||
@@ -609,42 +611,7 @@ class Trainer(Trainable):
|
||||
Arguments:
|
||||
weights (dict): Map of policy ids to weights to set.
|
||||
"""
|
||||
self.local_evaluator.set_weights(weights)
|
||||
|
||||
@DeveloperAPI
|
||||
def make_local_evaluator(self, env_creator, policy, extra_config=None):
|
||||
"""Convenience method to return configured local evaluator."""
|
||||
|
||||
return self._make_evaluator(
|
||||
PolicyEvaluator,
|
||||
env_creator,
|
||||
policy,
|
||||
0,
|
||||
merge_dicts(
|
||||
# important: allow local tf to use more CPUs for optimization
|
||||
merge_dicts(
|
||||
self.config, {
|
||||
"tf_session_args": self.
|
||||
config["local_evaluator_tf_session_args"]
|
||||
}),
|
||||
extra_config or {}))
|
||||
|
||||
@DeveloperAPI
|
||||
def make_remote_evaluators(self, env_creator, policy, count):
|
||||
"""Convenience method to return a number of remote evaluators."""
|
||||
|
||||
remote_args = {
|
||||
"num_cpus": self.config["num_cpus_per_worker"],
|
||||
"num_gpus": self.config["num_gpus_per_worker"],
|
||||
"resources": self.config["custom_resources_per_worker"],
|
||||
}
|
||||
|
||||
cls = PolicyEvaluator.as_remote(**remote_args).remote
|
||||
|
||||
return [
|
||||
self._make_evaluator(cls, env_creator, policy, i + 1, self.config)
|
||||
for i in range(count)
|
||||
]
|
||||
self.workers.local_worker().set_weights(weights)
|
||||
|
||||
@DeveloperAPI
|
||||
def export_policy_model(self, export_dir, policy_id=DEFAULT_POLICY_ID):
|
||||
@@ -660,7 +627,7 @@ class Trainer(Trainable):
|
||||
>>> trainer.train()
|
||||
>>> trainer.export_policy_model("/tmp/export_dir")
|
||||
"""
|
||||
self.local_evaluator.export_policy_model(export_dir, policy_id)
|
||||
self.workers.local_worker().export_policy_model(export_dir, policy_id)
|
||||
|
||||
@DeveloperAPI
|
||||
def export_policy_checkpoint(self,
|
||||
@@ -680,19 +647,19 @@ class Trainer(Trainable):
|
||||
>>> trainer.train()
|
||||
>>> trainer.export_policy_checkpoint("/tmp/export_dir")
|
||||
"""
|
||||
self.local_evaluator.export_policy_checkpoint(
|
||||
self.workers.local_worker().export_policy_checkpoint(
|
||||
export_dir, filename_prefix, policy_id)
|
||||
|
||||
@DeveloperAPI
|
||||
def collect_metrics(self, selected_evaluators=None):
|
||||
"""Collects metrics from the remote evaluators of this agent.
|
||||
def collect_metrics(self, selected_workers=None):
|
||||
"""Collects metrics from the remote workers of this agent.
|
||||
|
||||
This is the same data as returned by a call to train().
|
||||
"""
|
||||
return self.optimizer.collect_metrics(
|
||||
self.config["collect_metrics_timeout"],
|
||||
min_history=self.config["metrics_smoothing_episodes"],
|
||||
selected_evaluators=selected_evaluators)
|
||||
selected_workers=selected_workers)
|
||||
|
||||
@classmethod
|
||||
def resource_help(cls, config):
|
||||
@@ -742,118 +709,34 @@ class Trainer(Trainable):
|
||||
|
||||
logger.info("Health checking all workers...")
|
||||
checks = []
|
||||
for ev in self.optimizer.remote_evaluators:
|
||||
for ev in self.optimizer.workers.remote_workers():
|
||||
_, obj_id = ev.sample_with_count.remote()
|
||||
checks.append(obj_id)
|
||||
|
||||
healthy_evaluators = []
|
||||
healthy_workers = []
|
||||
for i, obj_id in enumerate(checks):
|
||||
ev = self.optimizer.remote_evaluators[i]
|
||||
w = self.optimizer.workers.remote_workers()[i]
|
||||
try:
|
||||
ray_get_and_free(obj_id)
|
||||
healthy_evaluators.append(ev)
|
||||
healthy_workers.append(w)
|
||||
logger.info("Worker {} looks healthy".format(i + 1))
|
||||
except RayError:
|
||||
logger.exception("Blacklisting worker {}".format(i + 1))
|
||||
try:
|
||||
ev.__ray_terminate__.remote()
|
||||
w.__ray_terminate__.remote()
|
||||
except Exception:
|
||||
logger.exception("Error terminating unhealthy worker")
|
||||
|
||||
if len(healthy_evaluators) < 1:
|
||||
if len(healthy_workers) < 1:
|
||||
raise RuntimeError(
|
||||
"Not enough healthy workers remain to continue.")
|
||||
|
||||
self.optimizer.reset(healthy_evaluators)
|
||||
self.optimizer.reset(healthy_workers)
|
||||
|
||||
def _has_policy_optimizer(self):
|
||||
return hasattr(self, "optimizer") and isinstance(
|
||||
self.optimizer, PolicyOptimizer)
|
||||
|
||||
def _make_evaluator(self, cls, env_creator, policy, worker_index, config):
|
||||
def session_creator():
|
||||
logger.debug("Creating TF session {}".format(
|
||||
config["tf_session_args"]))
|
||||
return tf.Session(
|
||||
config=tf.ConfigProto(**config["tf_session_args"]))
|
||||
|
||||
if isinstance(config["input"], FunctionType):
|
||||
input_creator = config["input"]
|
||||
elif config["input"] == "sampler":
|
||||
input_creator = (lambda ioctx: ioctx.default_sampler_input())
|
||||
elif isinstance(config["input"], dict):
|
||||
input_creator = (lambda ioctx: ShuffledInput(
|
||||
MixedInput(config["input"], ioctx), config[
|
||||
"shuffle_buffer_size"]))
|
||||
else:
|
||||
input_creator = (lambda ioctx: ShuffledInput(
|
||||
JsonReader(config["input"], ioctx), config[
|
||||
"shuffle_buffer_size"]))
|
||||
|
||||
if isinstance(config["output"], FunctionType):
|
||||
output_creator = config["output"]
|
||||
elif config["output"] is None:
|
||||
output_creator = (lambda ioctx: NoopOutput())
|
||||
elif config["output"] == "logdir":
|
||||
output_creator = (lambda ioctx: JsonWriter(
|
||||
ioctx.log_dir,
|
||||
ioctx,
|
||||
max_file_size=config["output_max_file_size"],
|
||||
compress_columns=config["output_compress_columns"]))
|
||||
else:
|
||||
output_creator = (lambda ioctx: JsonWriter(
|
||||
config["output"],
|
||||
ioctx,
|
||||
max_file_size=config["output_max_file_size"],
|
||||
compress_columns=config["output_compress_columns"]))
|
||||
|
||||
if config["input"] == "sampler":
|
||||
input_evaluation = []
|
||||
else:
|
||||
input_evaluation = config["input_evaluation"]
|
||||
|
||||
# Fill in the default policy if 'None' is specified in multiagent
|
||||
if self.config["multiagent"]["policies"]:
|
||||
tmp = self.config["multiagent"]["policies"]
|
||||
_validate_multiagent_config(tmp, allow_none_graph=True)
|
||||
for k, v in tmp.items():
|
||||
if v[0] is None:
|
||||
tmp[k] = (policy, v[1], v[2], v[3])
|
||||
policy = tmp
|
||||
|
||||
return cls(
|
||||
env_creator,
|
||||
policy,
|
||||
policy_mapping_fn=self.config["multiagent"]["policy_mapping_fn"],
|
||||
policies_to_train=self.config["multiagent"]["policies_to_train"],
|
||||
tf_session_creator=(session_creator
|
||||
if config["tf_session_args"] else None),
|
||||
batch_steps=config["sample_batch_size"],
|
||||
batch_mode=config["batch_mode"],
|
||||
episode_horizon=config["horizon"],
|
||||
preprocessor_pref=config["preprocessor_pref"],
|
||||
sample_async=config["sample_async"],
|
||||
compress_observations=config["compress_observations"],
|
||||
num_envs=config["num_envs_per_worker"],
|
||||
observation_filter=config["observation_filter"],
|
||||
clip_rewards=config["clip_rewards"],
|
||||
clip_actions=config["clip_actions"],
|
||||
env_config=config["env_config"],
|
||||
model_config=config["model"],
|
||||
policy_config=config,
|
||||
worker_index=worker_index,
|
||||
monitor_path=self.logdir if config["monitor"] else None,
|
||||
log_dir=self.logdir,
|
||||
log_level=config["log_level"],
|
||||
callbacks=config["callbacks"],
|
||||
input_creator=input_creator,
|
||||
input_evaluation=input_evaluation,
|
||||
output_creator=output_creator,
|
||||
remote_worker_envs=config["remote_worker_envs"],
|
||||
remote_env_batch_wait_ms=config["remote_env_batch_wait_ms"],
|
||||
soft_horizon=config["soft_horizon"],
|
||||
_fake_sampler=config.get("_fake_sampler", False))
|
||||
|
||||
@override(Trainable)
|
||||
def _export_model(self, export_formats, export_dir):
|
||||
ExportFormat.validate(export_formats)
|
||||
@@ -870,17 +753,17 @@ class Trainer(Trainable):
|
||||
|
||||
def __getstate__(self):
|
||||
state = {}
|
||||
if hasattr(self, "local_evaluator"):
|
||||
state["evaluator"] = self.local_evaluator.save()
|
||||
if hasattr(self, "workers"):
|
||||
state["worker"] = self.workers.local_worker().save()
|
||||
if hasattr(self, "optimizer") and hasattr(self.optimizer, "save"):
|
||||
state["optimizer"] = self.optimizer.save()
|
||||
return state
|
||||
|
||||
def __setstate__(self, state):
|
||||
if "evaluator" in state:
|
||||
self.local_evaluator.restore(state["evaluator"])
|
||||
remote_state = ray.put(state["evaluator"])
|
||||
for r in self.remote_evaluators:
|
||||
if "worker" in state:
|
||||
self.workers.local_worker().restore(state["worker"])
|
||||
remote_state = ray.put(state["worker"])
|
||||
for r in self.workers.remote_workers():
|
||||
r.restore.remote(remote_state)
|
||||
if "optimizer" in state:
|
||||
self.optimizer.restore(state["optimizer"])
|
||||
|
||||
@@ -2,6 +2,8 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import time
|
||||
|
||||
from ray.rllib.agents.trainer import Trainer, COMMON_CONFIG
|
||||
from ray.rllib.optimizers import SyncSamplesOptimizer
|
||||
from ray.rllib.utils.annotations import override, DeveloperAPI
|
||||
@@ -25,8 +27,7 @@ def build_trainer(name,
|
||||
default_config (dict): the default config dict of the algorithm,
|
||||
otherwises uses the Trainer default config
|
||||
make_policy_optimizer (func): optional function that returns a
|
||||
PolicyOptimizer instance given
|
||||
(local_evaluator, remote_evaluators, config)
|
||||
PolicyOptimizer instance given (WorkerSet, config)
|
||||
validate_config (func): optional callback that checks a given config
|
||||
for correctness. It may mutate the config as needed.
|
||||
get_policy_class (func): optional callback that takes a config and
|
||||
@@ -44,8 +45,7 @@ def build_trainer(name,
|
||||
a Trainer instance that uses the specified args.
|
||||
"""
|
||||
|
||||
if not name.endswith("Trainer"):
|
||||
raise ValueError("Algorithm name should have *Trainer suffix", name)
|
||||
original_kwargs = locals().copy()
|
||||
|
||||
class trainer_cls(Trainer):
|
||||
_name = name
|
||||
@@ -59,19 +59,15 @@ def build_trainer(name,
|
||||
policy = default_policy
|
||||
else:
|
||||
policy = get_policy_class(config)
|
||||
self.local_evaluator = self.make_local_evaluator(
|
||||
env_creator, policy)
|
||||
self.remote_evaluators = self.make_remote_evaluators(
|
||||
env_creator, policy, config["num_workers"])
|
||||
self.workers = self._make_workers(env_creator, policy, config,
|
||||
self.config["num_workers"])
|
||||
if make_policy_optimizer:
|
||||
self.optimizer = make_policy_optimizer(
|
||||
self.local_evaluator, self.remote_evaluators, config)
|
||||
self.optimizer = make_policy_optimizer(self.workers, config)
|
||||
else:
|
||||
optimizer_config = dict(
|
||||
config["optimizer"],
|
||||
**{"train_batch_size": config["train_batch_size"]})
|
||||
self.optimizer = SyncSamplesOptimizer(self.local_evaluator,
|
||||
self.remote_evaluators,
|
||||
self.optimizer = SyncSamplesOptimizer(self.workers,
|
||||
**optimizer_config)
|
||||
|
||||
@override(Trainer)
|
||||
@@ -79,9 +75,15 @@ def build_trainer(name,
|
||||
if before_train_step:
|
||||
before_train_step(self)
|
||||
prev_steps = self.optimizer.num_steps_sampled
|
||||
fetches = self.optimizer.step()
|
||||
if after_optimizer_step:
|
||||
after_optimizer_step(self, fetches)
|
||||
|
||||
start = time.time()
|
||||
while True:
|
||||
fetches = self.optimizer.step()
|
||||
if after_optimizer_step:
|
||||
after_optimizer_step(self, fetches)
|
||||
if time.time() - start > self.config["min_iter_time_s"]:
|
||||
break
|
||||
|
||||
res = self.collect_metrics()
|
||||
res.update(
|
||||
timesteps_this_iter=self.optimizer.num_steps_sampled -
|
||||
@@ -91,6 +93,11 @@ def build_trainer(name,
|
||||
after_train_result(self, res)
|
||||
return res
|
||||
|
||||
@staticmethod
|
||||
def with_updates(**overrides):
|
||||
return build_trainer(**dict(original_kwargs, **overrides))
|
||||
|
||||
trainer_cls.with_updates = with_updates
|
||||
trainer_cls.__name__ = name
|
||||
trainer_cls.__qualname__ = name
|
||||
return trainer_cls
|
||||
|
||||
Vendored
+1
-1
@@ -21,7 +21,7 @@ class BaseEnv(object):
|
||||
can be sent back via send_actions().
|
||||
|
||||
All other env types can be adapted to BaseEnv. RLlib handles these
|
||||
conversions internally in PolicyEvaluator, for example:
|
||||
conversions internally in RolloutWorker, for example:
|
||||
|
||||
gym.Env => rllib.VectorEnv => rllib.BaseEnv
|
||||
rllib.MultiAgentEnv => rllib.BaseEnv
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
from ray.rllib.evaluation.episode import MultiAgentEpisode
|
||||
from ray.rllib.evaluation.rollout_worker import RolloutWorker
|
||||
from ray.rllib.evaluation.policy_evaluator import PolicyEvaluator
|
||||
from ray.rllib.evaluation.interface import EvaluatorInterface
|
||||
from ray.rllib.evaluation.policy_graph import PolicyGraph
|
||||
@@ -12,8 +13,19 @@ from ray.rllib.evaluation.postprocessing import compute_advantages
|
||||
from ray.rllib.evaluation.metrics import collect_metrics
|
||||
|
||||
__all__ = [
|
||||
"EvaluatorInterface", "PolicyEvaluator", "PolicyGraph", "TFPolicyGraph",
|
||||
"TorchPolicyGraph", "SampleBatch", "MultiAgentBatch", "SampleBatchBuilder",
|
||||
"MultiAgentSampleBatchBuilder", "SyncSampler", "AsyncSampler",
|
||||
"compute_advantages", "collect_metrics", "MultiAgentEpisode"
|
||||
"EvaluatorInterface",
|
||||
"RolloutWorker",
|
||||
"PolicyGraph",
|
||||
"TFPolicyGraph",
|
||||
"TorchPolicyGraph",
|
||||
"SampleBatch",
|
||||
"MultiAgentBatch",
|
||||
"SampleBatchBuilder",
|
||||
"MultiAgentSampleBatchBuilder",
|
||||
"SyncSampler",
|
||||
"AsyncSampler",
|
||||
"compute_advantages",
|
||||
"collect_metrics",
|
||||
"MultiAgentEpisode",
|
||||
"PolicyEvaluator",
|
||||
]
|
||||
|
||||
@@ -11,7 +11,7 @@ from ray.rllib.utils.annotations import DeveloperAPI
|
||||
class EvaluatorInterface(object):
|
||||
"""This is the interface between policy optimizers and policy evaluation.
|
||||
|
||||
See also: PolicyEvaluator
|
||||
See also: RolloutWorker
|
||||
"""
|
||||
|
||||
@DeveloperAPI
|
||||
|
||||
@@ -39,27 +39,23 @@ def get_learner_stats(grad_info):
|
||||
|
||||
|
||||
@DeveloperAPI
|
||||
def collect_metrics(local_evaluator=None,
|
||||
remote_evaluators=[],
|
||||
timeout_seconds=180):
|
||||
"""Gathers episode metrics from PolicyEvaluator instances."""
|
||||
def collect_metrics(local_worker=None, remote_workers=[], timeout_seconds=180):
|
||||
"""Gathers episode metrics from RolloutWorker instances."""
|
||||
|
||||
episodes, num_dropped = collect_episodes(
|
||||
local_evaluator, remote_evaluators, timeout_seconds=timeout_seconds)
|
||||
local_worker, remote_workers, timeout_seconds=timeout_seconds)
|
||||
metrics = summarize_episodes(episodes, episodes, num_dropped)
|
||||
return metrics
|
||||
|
||||
|
||||
@DeveloperAPI
|
||||
def collect_episodes(local_evaluator=None,
|
||||
remote_evaluators=[],
|
||||
def collect_episodes(local_worker=None, remote_workers=[],
|
||||
timeout_seconds=180):
|
||||
"""Gathers new episodes metrics tuples from the given evaluators."""
|
||||
|
||||
if remote_evaluators:
|
||||
if remote_workers:
|
||||
pending = [
|
||||
a.apply.remote(lambda ev: ev.get_metrics())
|
||||
for a in remote_evaluators
|
||||
a.apply.remote(lambda ev: ev.get_metrics()) for a in remote_workers
|
||||
]
|
||||
collected, _ = ray.wait(
|
||||
pending, num_returns=len(pending), timeout=timeout_seconds * 1.0)
|
||||
@@ -73,8 +69,8 @@ def collect_episodes(local_evaluator=None,
|
||||
metric_lists = []
|
||||
num_metric_batches_dropped = 0
|
||||
|
||||
if local_evaluator:
|
||||
metric_lists.append(local_evaluator.get_metrics())
|
||||
if local_worker:
|
||||
metric_lists.append(local_worker.get_metrics())
|
||||
episodes = []
|
||||
for metrics in metric_lists:
|
||||
episodes.extend(metrics)
|
||||
|
||||
@@ -2,805 +2,8 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import gym
|
||||
import logging
|
||||
import pickle
|
||||
from ray.rllib.utils import renamed_class
|
||||
from ray.rllib.evaluation import RolloutWorker
|
||||
|
||||
import ray
|
||||
from ray.rllib.env.atari_wrappers import wrap_deepmind, is_atari
|
||||
from ray.rllib.env.base_env import BaseEnv
|
||||
from ray.rllib.env.env_context import EnvContext
|
||||
from ray.rllib.env.external_env import ExternalEnv
|
||||
from ray.rllib.env.multi_agent_env import MultiAgentEnv
|
||||
from ray.rllib.env.external_multi_agent_env import ExternalMultiAgentEnv
|
||||
from ray.rllib.env.vector_env import VectorEnv
|
||||
from ray.rllib.evaluation.interface import EvaluatorInterface
|
||||
from ray.rllib.evaluation.sampler import AsyncSampler, SyncSampler
|
||||
from ray.rllib.policy.sample_batch import MultiAgentBatch, DEFAULT_POLICY_ID
|
||||
from ray.rllib.policy.policy import Policy
|
||||
from ray.rllib.policy.tf_policy import TFPolicy
|
||||
from ray.rllib.offline import NoopOutput, IOContext, OutputWriter, InputReader
|
||||
from ray.rllib.offline.is_estimator import ImportanceSamplingEstimator
|
||||
from ray.rllib.offline.wis_estimator import WeightedImportanceSamplingEstimator
|
||||
from ray.rllib.models import ModelCatalog
|
||||
from ray.rllib.models.preprocessors import NoPreprocessor
|
||||
from ray.rllib.utils import merge_dicts
|
||||
from ray.rllib.utils.annotations import override, DeveloperAPI
|
||||
from ray.rllib.utils.debug import disable_log_once_globally, log_once, \
|
||||
summarize, enable_periodic_logging
|
||||
from ray.rllib.utils.filter import get_filter
|
||||
from ray.rllib.utils.tf_run_builder import TFRunBuilder
|
||||
from ray.rllib.utils import try_import_tf
|
||||
|
||||
tf = try_import_tf()
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Handle to the current evaluator, which will be set to the most recently
|
||||
# created PolicyEvaluator in this process. This can be helpful to access in
|
||||
# custom env or policy classes for debugging or advanced use cases.
|
||||
_global_evaluator = None
|
||||
|
||||
|
||||
@DeveloperAPI
|
||||
def get_global_evaluator():
|
||||
"""Returns a handle to the active policy evaluator in this process."""
|
||||
|
||||
global _global_evaluator
|
||||
return _global_evaluator
|
||||
|
||||
|
||||
@DeveloperAPI
|
||||
class PolicyEvaluator(EvaluatorInterface):
|
||||
"""Common ``PolicyEvaluator`` implementation that wraps a ``Policy``.
|
||||
|
||||
This class wraps a policy instance and an environment class to
|
||||
collect experiences from the environment. You can create many replicas of
|
||||
this class as Ray actors to scale RL training.
|
||||
|
||||
This class supports vectorized and multi-agent policy evaluation (e.g.,
|
||||
VectorEnv, MultiAgentEnv, etc.)
|
||||
|
||||
Examples:
|
||||
>>> # Create a policy evaluator and using it to collect experiences.
|
||||
>>> evaluator = PolicyEvaluator(
|
||||
... env_creator=lambda _: gym.make("CartPole-v0"),
|
||||
... policy=PGTFPolicy)
|
||||
>>> print(evaluator.sample())
|
||||
SampleBatch({
|
||||
"obs": [[...]], "actions": [[...]], "rewards": [[...]],
|
||||
"dones": [[...]], "new_obs": [[...]]})
|
||||
|
||||
>>> # Creating policy evaluators using optimizer_cls.make().
|
||||
>>> optimizer = SyncSamplesOptimizer.make(
|
||||
... evaluator_cls=PolicyEvaluator,
|
||||
... evaluator_args={
|
||||
... "env_creator": lambda _: gym.make("CartPole-v0"),
|
||||
... "policy": PGTFPolicy,
|
||||
... },
|
||||
... num_workers=10)
|
||||
>>> for _ in range(10): optimizer.step()
|
||||
|
||||
>>> # Creating a multi-agent policy evaluator
|
||||
>>> evaluator = PolicyEvaluator(
|
||||
... env_creator=lambda _: MultiAgentTrafficGrid(num_cars=25),
|
||||
... policies={
|
||||
... # Use an ensemble of two policies for car agents
|
||||
... "car_policy1":
|
||||
... (PGTFPolicy, Box(...), Discrete(...), {"gamma": 0.99}),
|
||||
... "car_policy2":
|
||||
... (PGTFPolicy, Box(...), Discrete(...), {"gamma": 0.95}),
|
||||
... # Use a single shared policy for all traffic lights
|
||||
... "traffic_light_policy":
|
||||
... (PGTFPolicy, Box(...), Discrete(...), {}),
|
||||
... },
|
||||
... policy_mapping_fn=lambda agent_id:
|
||||
... random.choice(["car_policy1", "car_policy2"])
|
||||
... if agent_id.startswith("car_") else "traffic_light_policy")
|
||||
>>> print(evaluator.sample())
|
||||
MultiAgentBatch({
|
||||
"car_policy1": SampleBatch(...),
|
||||
"car_policy2": SampleBatch(...),
|
||||
"traffic_light_policy": SampleBatch(...)})
|
||||
"""
|
||||
|
||||
@DeveloperAPI
|
||||
@classmethod
|
||||
def as_remote(cls, num_cpus=None, num_gpus=None, resources=None):
|
||||
return ray.remote(
|
||||
num_cpus=num_cpus, num_gpus=num_gpus, resources=resources)(cls)
|
||||
|
||||
@DeveloperAPI
|
||||
def __init__(self,
|
||||
env_creator,
|
||||
policy,
|
||||
policy_mapping_fn=None,
|
||||
policies_to_train=None,
|
||||
tf_session_creator=None,
|
||||
batch_steps=100,
|
||||
batch_mode="truncate_episodes",
|
||||
episode_horizon=None,
|
||||
preprocessor_pref="deepmind",
|
||||
sample_async=False,
|
||||
compress_observations=False,
|
||||
num_envs=1,
|
||||
observation_filter="NoFilter",
|
||||
clip_rewards=None,
|
||||
clip_actions=True,
|
||||
env_config=None,
|
||||
model_config=None,
|
||||
policy_config=None,
|
||||
worker_index=0,
|
||||
monitor_path=None,
|
||||
log_dir=None,
|
||||
log_level=None,
|
||||
callbacks=None,
|
||||
input_creator=lambda ioctx: ioctx.default_sampler_input(),
|
||||
input_evaluation=frozenset([]),
|
||||
output_creator=lambda ioctx: NoopOutput(),
|
||||
remote_worker_envs=False,
|
||||
remote_env_batch_wait_ms=0,
|
||||
soft_horizon=False,
|
||||
_fake_sampler=False):
|
||||
"""Initialize a policy evaluator.
|
||||
|
||||
Arguments:
|
||||
env_creator (func): Function that returns a gym.Env given an
|
||||
EnvContext wrapped configuration.
|
||||
policy (class|dict): Either a class implementing
|
||||
Policy, or a dictionary of policy id strings to
|
||||
(Policy, obs_space, action_space, config) tuples. If a
|
||||
dict is specified, then we are in multi-agent mode and a
|
||||
policy_mapping_fn should also be set.
|
||||
policy_mapping_fn (func): A function that maps agent ids to
|
||||
policy ids in multi-agent mode. This function will be called
|
||||
each time a new agent appears in an episode, to bind that agent
|
||||
to a policy for the duration of the episode.
|
||||
policies_to_train (list): Optional whitelist of policies to train,
|
||||
or None for all policies.
|
||||
tf_session_creator (func): A function that returns a TF session.
|
||||
This is optional and only useful with TFPolicy.
|
||||
batch_steps (int): The target number of env transitions to include
|
||||
in each sample batch returned from this evaluator.
|
||||
batch_mode (str): One of the following batch modes:
|
||||
"truncate_episodes": Each call to sample() will return a batch
|
||||
of at most `batch_steps * num_envs` in size. The batch will
|
||||
be exactly `batch_steps * num_envs` in size if
|
||||
postprocessing does not change batch sizes. Episodes may be
|
||||
truncated in order to meet this size requirement.
|
||||
"complete_episodes": Each call to sample() will return a batch
|
||||
of at least `batch_steps * num_envs` in size. Episodes will
|
||||
not be truncated, but multiple episodes may be packed
|
||||
within one batch to meet the batch size. Note that when
|
||||
`num_envs > 1`, episode steps will be buffered until the
|
||||
episode completes, and hence batches may contain
|
||||
significant amounts of off-policy data.
|
||||
episode_horizon (int): Whether to stop episodes at this horizon.
|
||||
preprocessor_pref (str): Whether to prefer RLlib preprocessors
|
||||
("rllib") or deepmind ("deepmind") when applicable.
|
||||
sample_async (bool): Whether to compute samples asynchronously in
|
||||
the background, which improves throughput but can cause samples
|
||||
to be slightly off-policy.
|
||||
compress_observations (bool): If true, compress the observations.
|
||||
They can be decompressed with rllib/utils/compression.
|
||||
num_envs (int): If more than one, will create multiple envs
|
||||
and vectorize the computation of actions. This has no effect if
|
||||
if the env already implements VectorEnv.
|
||||
observation_filter (str): Name of observation filter to use.
|
||||
clip_rewards (bool): Whether to clip rewards to [-1, 1] prior to
|
||||
experience postprocessing. Setting to None means clip for Atari
|
||||
only.
|
||||
clip_actions (bool): Whether to clip action values to the range
|
||||
specified by the policy action space.
|
||||
env_config (dict): Config to pass to the env creator.
|
||||
model_config (dict): Config to use when creating the policy model.
|
||||
policy_config (dict): Config to pass to the policy. In the
|
||||
multi-agent case, this config will be merged with the
|
||||
per-policy configs specified by `policy`.
|
||||
worker_index (int): For remote evaluators, this should be set to a
|
||||
non-zero and unique value. This index is passed to created envs
|
||||
through EnvContext so that envs can be configured per worker.
|
||||
monitor_path (str): Write out episode stats and videos to this
|
||||
directory if specified.
|
||||
log_dir (str): Directory where logs can be placed.
|
||||
log_level (str): Set the root log level on creation.
|
||||
callbacks (dict): Dict of custom debug callbacks.
|
||||
input_creator (func): Function that returns an InputReader object
|
||||
for loading previous generated experiences.
|
||||
input_evaluation (list): How to evaluate the policy performance.
|
||||
This only makes sense to set when the input is reading offline
|
||||
data. The possible values include:
|
||||
- "is": the step-wise importance sampling estimator.
|
||||
- "wis": the weighted step-wise is estimator.
|
||||
- "simulation": run the environment in the background, but
|
||||
use this data for evaluation only and never for learning.
|
||||
output_creator (func): Function that returns an OutputWriter object
|
||||
for saving generated experiences.
|
||||
remote_worker_envs (bool): If using num_envs > 1, whether to create
|
||||
those new envs in remote processes instead of in the current
|
||||
process. This adds overheads, but can make sense if your envs
|
||||
remote_env_batch_wait_ms (float): Timeout that remote workers
|
||||
are waiting when polling environments. 0 (continue when at
|
||||
least one env is ready) is a reasonable default, but optimal
|
||||
value could be obtained by measuring your environment
|
||||
step / reset and model inference perf.
|
||||
soft_horizon (bool): Calculate rewards but don't reset the
|
||||
environment when the horizon is hit.
|
||||
_fake_sampler (bool): Use a fake (inf speed) sampler for testing.
|
||||
"""
|
||||
|
||||
global _global_evaluator
|
||||
_global_evaluator = self
|
||||
|
||||
if log_level:
|
||||
logging.getLogger("ray.rllib").setLevel(log_level)
|
||||
|
||||
if worker_index > 1:
|
||||
disable_log_once_globally() # only need 1 evaluator to log
|
||||
elif log_level == "DEBUG":
|
||||
enable_periodic_logging()
|
||||
|
||||
env_context = EnvContext(env_config or {}, worker_index)
|
||||
policy_config = policy_config or {}
|
||||
self.policy_config = policy_config
|
||||
self.callbacks = callbacks or {}
|
||||
self.worker_index = worker_index
|
||||
model_config = model_config or {}
|
||||
policy_mapping_fn = (policy_mapping_fn
|
||||
or (lambda agent_id: DEFAULT_POLICY_ID))
|
||||
if not callable(policy_mapping_fn):
|
||||
raise ValueError(
|
||||
"Policy mapping function not callable. If you're using Tune, "
|
||||
"make sure to escape the function with tune.function() "
|
||||
"to prevent it from being evaluated as an expression.")
|
||||
self.env_creator = env_creator
|
||||
self.sample_batch_size = batch_steps * num_envs
|
||||
self.batch_mode = batch_mode
|
||||
self.compress_observations = compress_observations
|
||||
self.preprocessing_enabled = True
|
||||
self.last_batch = None
|
||||
self._fake_sampler = _fake_sampler
|
||||
|
||||
self.env = _validate_env(env_creator(env_context))
|
||||
if isinstance(self.env, MultiAgentEnv) or \
|
||||
isinstance(self.env, BaseEnv):
|
||||
|
||||
def wrap(env):
|
||||
return env # we can't auto-wrap these env types
|
||||
elif is_atari(self.env) and \
|
||||
not model_config.get("custom_preprocessor") and \
|
||||
preprocessor_pref == "deepmind":
|
||||
|
||||
# Deepmind wrappers already handle all preprocessing
|
||||
self.preprocessing_enabled = False
|
||||
|
||||
if clip_rewards is None:
|
||||
clip_rewards = True
|
||||
|
||||
def wrap(env):
|
||||
env = wrap_deepmind(
|
||||
env,
|
||||
dim=model_config.get("dim"),
|
||||
framestack=model_config.get("framestack"))
|
||||
if monitor_path:
|
||||
env = _monitor(env, monitor_path)
|
||||
return env
|
||||
else:
|
||||
|
||||
def wrap(env):
|
||||
if monitor_path:
|
||||
env = _monitor(env, monitor_path)
|
||||
return env
|
||||
|
||||
self.env = wrap(self.env)
|
||||
|
||||
def make_env(vector_index):
|
||||
return wrap(
|
||||
env_creator(
|
||||
env_context.copy_with_overrides(
|
||||
vector_index=vector_index, remote=remote_worker_envs)))
|
||||
|
||||
self.tf_sess = None
|
||||
policy_dict = _validate_and_canonicalize(policy, self.env)
|
||||
self.policies_to_train = policies_to_train or list(policy_dict.keys())
|
||||
if _has_tensorflow_graph(policy_dict):
|
||||
if (ray.is_initialized()
|
||||
and ray.worker._mode() != ray.worker.LOCAL_MODE
|
||||
and not ray.get_gpu_ids()):
|
||||
logger.info("Creating policy evaluation worker {}".format(
|
||||
worker_index) +
|
||||
" on CPU (please ignore any CUDA init errors)")
|
||||
with tf.Graph().as_default():
|
||||
if tf_session_creator:
|
||||
self.tf_sess = tf_session_creator()
|
||||
else:
|
||||
self.tf_sess = tf.Session(
|
||||
config=tf.ConfigProto(
|
||||
gpu_options=tf.GPUOptions(allow_growth=True)))
|
||||
with self.tf_sess.as_default():
|
||||
self.policy_map, self.preprocessors = \
|
||||
self._build_policy_map(policy_dict, policy_config)
|
||||
else:
|
||||
self.policy_map, self.preprocessors = self._build_policy_map(
|
||||
policy_dict, policy_config)
|
||||
|
||||
self.multiagent = set(self.policy_map.keys()) != {DEFAULT_POLICY_ID}
|
||||
if self.multiagent:
|
||||
if not ((isinstance(self.env, MultiAgentEnv)
|
||||
or isinstance(self.env, ExternalMultiAgentEnv))
|
||||
or isinstance(self.env, BaseEnv)):
|
||||
raise ValueError(
|
||||
"Have multiple policies {}, but the env ".format(
|
||||
self.policy_map) +
|
||||
"{} is not a subclass of BaseEnv, MultiAgentEnv or "
|
||||
"ExternalMultiAgentEnv?".format(self.env))
|
||||
|
||||
self.filters = {
|
||||
policy_id: get_filter(observation_filter,
|
||||
policy.observation_space.shape)
|
||||
for (policy_id, policy) in self.policy_map.items()
|
||||
}
|
||||
if self.worker_index == 0:
|
||||
logger.info("Built filter map: {}".format(self.filters))
|
||||
|
||||
# Always use vector env for consistency even if num_envs = 1
|
||||
self.async_env = BaseEnv.to_base_env(
|
||||
self.env,
|
||||
make_env=make_env,
|
||||
num_envs=num_envs,
|
||||
remote_envs=remote_worker_envs,
|
||||
remote_env_batch_wait_ms=remote_env_batch_wait_ms)
|
||||
self.num_envs = num_envs
|
||||
|
||||
if self.batch_mode == "truncate_episodes":
|
||||
unroll_length = batch_steps
|
||||
pack_episodes = True
|
||||
elif self.batch_mode == "complete_episodes":
|
||||
unroll_length = float("inf") # never cut episodes
|
||||
pack_episodes = False # sampler will return 1 episode per poll
|
||||
else:
|
||||
raise ValueError("Unsupported batch mode: {}".format(
|
||||
self.batch_mode))
|
||||
|
||||
self.io_context = IOContext(log_dir, policy_config, worker_index, self)
|
||||
self.reward_estimators = []
|
||||
for method in input_evaluation:
|
||||
if method == "simulation":
|
||||
logger.warning(
|
||||
"Requested 'simulation' input evaluation method: "
|
||||
"will discard all sampler outputs and keep only metrics.")
|
||||
sample_async = True
|
||||
elif method == "is":
|
||||
ise = ImportanceSamplingEstimator.create(self.io_context)
|
||||
self.reward_estimators.append(ise)
|
||||
elif method == "wis":
|
||||
wise = WeightedImportanceSamplingEstimator.create(
|
||||
self.io_context)
|
||||
self.reward_estimators.append(wise)
|
||||
else:
|
||||
raise ValueError(
|
||||
"Unknown evaluation method: {}".format(method))
|
||||
|
||||
if sample_async:
|
||||
self.sampler = AsyncSampler(
|
||||
self.async_env,
|
||||
self.policy_map,
|
||||
policy_mapping_fn,
|
||||
self.preprocessors,
|
||||
self.filters,
|
||||
clip_rewards,
|
||||
unroll_length,
|
||||
self.callbacks,
|
||||
horizon=episode_horizon,
|
||||
pack=pack_episodes,
|
||||
tf_sess=self.tf_sess,
|
||||
clip_actions=clip_actions,
|
||||
blackhole_outputs="simulation" in input_evaluation,
|
||||
soft_horizon=soft_horizon)
|
||||
self.sampler.start()
|
||||
else:
|
||||
self.sampler = SyncSampler(
|
||||
self.async_env,
|
||||
self.policy_map,
|
||||
policy_mapping_fn,
|
||||
self.preprocessors,
|
||||
self.filters,
|
||||
clip_rewards,
|
||||
unroll_length,
|
||||
self.callbacks,
|
||||
horizon=episode_horizon,
|
||||
pack=pack_episodes,
|
||||
tf_sess=self.tf_sess,
|
||||
clip_actions=clip_actions,
|
||||
soft_horizon=soft_horizon)
|
||||
|
||||
self.input_reader = input_creator(self.io_context)
|
||||
assert isinstance(self.input_reader, InputReader), self.input_reader
|
||||
self.output_writer = output_creator(self.io_context)
|
||||
assert isinstance(self.output_writer, OutputWriter), self.output_writer
|
||||
|
||||
logger.debug("Created evaluator with env {} ({}), policies {}".format(
|
||||
self.async_env, self.env, self.policy_map))
|
||||
|
||||
@override(EvaluatorInterface)
|
||||
def sample(self):
|
||||
"""Evaluate the current policies and return a batch of experiences.
|
||||
|
||||
Return:
|
||||
SampleBatch|MultiAgentBatch from evaluating the current policies.
|
||||
"""
|
||||
|
||||
if self._fake_sampler and self.last_batch is not None:
|
||||
return self.last_batch
|
||||
|
||||
if log_once("sample_start"):
|
||||
logger.info("Generating sample batch of size {}".format(
|
||||
self.sample_batch_size))
|
||||
|
||||
batches = [self.input_reader.next()]
|
||||
steps_so_far = batches[0].count
|
||||
|
||||
# In truncate_episodes mode, never pull more than 1 batch per env.
|
||||
# This avoids over-running the target batch size.
|
||||
if self.batch_mode == "truncate_episodes":
|
||||
max_batches = self.num_envs
|
||||
else:
|
||||
max_batches = float("inf")
|
||||
|
||||
while steps_so_far < self.sample_batch_size and len(
|
||||
batches) < max_batches:
|
||||
batch = self.input_reader.next()
|
||||
steps_so_far += batch.count
|
||||
batches.append(batch)
|
||||
batch = batches[0].concat_samples(batches)
|
||||
|
||||
if self.callbacks.get("on_sample_end"):
|
||||
self.callbacks["on_sample_end"]({
|
||||
"evaluator": self,
|
||||
"samples": batch
|
||||
})
|
||||
|
||||
# Always do writes prior to compression for consistency and to allow
|
||||
# for better compression inside the writer.
|
||||
self.output_writer.write(batch)
|
||||
|
||||
# Do off-policy estimation if needed
|
||||
if self.reward_estimators:
|
||||
for sub_batch in batch.split_by_episode():
|
||||
for estimator in self.reward_estimators:
|
||||
estimator.process(sub_batch)
|
||||
|
||||
if log_once("sample_end"):
|
||||
logger.info("Completed sample batch:\n\n{}\n".format(
|
||||
summarize(batch)))
|
||||
|
||||
if self.compress_observations == "bulk":
|
||||
batch.compress(bulk=True)
|
||||
elif self.compress_observations:
|
||||
batch.compress()
|
||||
|
||||
if self._fake_sampler:
|
||||
self.last_batch = batch
|
||||
return batch
|
||||
|
||||
@DeveloperAPI
|
||||
@ray.method(num_return_vals=2)
|
||||
def sample_with_count(self):
|
||||
"""Same as sample() but returns the count as a separate future."""
|
||||
batch = self.sample()
|
||||
return batch, batch.count
|
||||
|
||||
@override(EvaluatorInterface)
|
||||
def get_weights(self, policies=None):
|
||||
if policies is None:
|
||||
policies = self.policy_map.keys()
|
||||
return {
|
||||
pid: policy.get_weights()
|
||||
for pid, policy in self.policy_map.items() if pid in policies
|
||||
}
|
||||
|
||||
@override(EvaluatorInterface)
|
||||
def set_weights(self, weights):
|
||||
for pid, w in weights.items():
|
||||
self.policy_map[pid].set_weights(w)
|
||||
|
||||
@override(EvaluatorInterface)
|
||||
def compute_gradients(self, samples):
|
||||
if log_once("compute_gradients"):
|
||||
logger.info("Compute gradients on:\n\n{}\n".format(
|
||||
summarize(samples)))
|
||||
if isinstance(samples, MultiAgentBatch):
|
||||
grad_out, info_out = {}, {}
|
||||
if self.tf_sess is not None:
|
||||
builder = TFRunBuilder(self.tf_sess, "compute_gradients")
|
||||
for pid, batch in samples.policy_batches.items():
|
||||
if pid not in self.policies_to_train:
|
||||
continue
|
||||
grad_out[pid], info_out[pid] = (
|
||||
self.policy_map[pid]._build_compute_gradients(
|
||||
builder, batch))
|
||||
grad_out = {k: builder.get(v) for k, v in grad_out.items()}
|
||||
info_out = {k: builder.get(v) for k, v in info_out.items()}
|
||||
else:
|
||||
for pid, batch in samples.policy_batches.items():
|
||||
if pid not in self.policies_to_train:
|
||||
continue
|
||||
grad_out[pid], info_out[pid] = (
|
||||
self.policy_map[pid].compute_gradients(batch))
|
||||
else:
|
||||
grad_out, info_out = (
|
||||
self.policy_map[DEFAULT_POLICY_ID].compute_gradients(samples))
|
||||
info_out["batch_count"] = samples.count
|
||||
if log_once("grad_out"):
|
||||
logger.info("Compute grad info:\n\n{}\n".format(
|
||||
summarize(info_out)))
|
||||
return grad_out, info_out
|
||||
|
||||
@override(EvaluatorInterface)
|
||||
def apply_gradients(self, grads):
|
||||
if log_once("apply_gradients"):
|
||||
logger.info("Apply gradients:\n\n{}\n".format(summarize(grads)))
|
||||
if isinstance(grads, dict):
|
||||
if self.tf_sess is not None:
|
||||
builder = TFRunBuilder(self.tf_sess, "apply_gradients")
|
||||
outputs = {
|
||||
pid: self.policy_map[pid]._build_apply_gradients(
|
||||
builder, grad)
|
||||
for pid, grad in grads.items()
|
||||
}
|
||||
return {k: builder.get(v) for k, v in outputs.items()}
|
||||
else:
|
||||
return {
|
||||
pid: self.policy_map[pid].apply_gradients(g)
|
||||
for pid, g in grads.items()
|
||||
}
|
||||
else:
|
||||
return self.policy_map[DEFAULT_POLICY_ID].apply_gradients(grads)
|
||||
|
||||
@override(EvaluatorInterface)
|
||||
def learn_on_batch(self, samples):
|
||||
if log_once("learn_on_batch"):
|
||||
logger.info(
|
||||
"Training on concatenated sample batches:\n\n{}\n".format(
|
||||
summarize(samples)))
|
||||
if isinstance(samples, MultiAgentBatch):
|
||||
info_out = {}
|
||||
to_fetch = {}
|
||||
if self.tf_sess is not None:
|
||||
builder = TFRunBuilder(self.tf_sess, "learn_on_batch")
|
||||
else:
|
||||
builder = None
|
||||
for pid, batch in samples.policy_batches.items():
|
||||
if pid not in self.policies_to_train:
|
||||
continue
|
||||
policy = self.policy_map[pid]
|
||||
if builder and hasattr(policy, "_build_learn_on_batch"):
|
||||
to_fetch[pid] = policy._build_learn_on_batch(
|
||||
builder, batch)
|
||||
else:
|
||||
info_out[pid] = policy.learn_on_batch(batch)
|
||||
info_out.update({k: builder.get(v) for k, v in to_fetch.items()})
|
||||
else:
|
||||
info_out = self.policy_map[DEFAULT_POLICY_ID].learn_on_batch(
|
||||
samples)
|
||||
if log_once("learn_out"):
|
||||
logger.info("Training output:\n\n{}\n".format(summarize(info_out)))
|
||||
return info_out
|
||||
|
||||
@DeveloperAPI
|
||||
def get_metrics(self):
|
||||
"""Returns a list of new RolloutMetric objects from evaluation."""
|
||||
|
||||
out = self.sampler.get_metrics()
|
||||
for m in self.reward_estimators:
|
||||
out.extend(m.get_metrics())
|
||||
return out
|
||||
|
||||
@DeveloperAPI
|
||||
def foreach_env(self, func):
|
||||
"""Apply the given function to each underlying env instance."""
|
||||
|
||||
envs = self.async_env.get_unwrapped()
|
||||
if not envs:
|
||||
return [func(self.async_env)]
|
||||
else:
|
||||
return [func(e) for e in envs]
|
||||
|
||||
@DeveloperAPI
|
||||
def get_policy(self, policy_id=DEFAULT_POLICY_ID):
|
||||
"""Return policy for the specified id, or None.
|
||||
|
||||
Arguments:
|
||||
policy_id (str): id of policy to return.
|
||||
"""
|
||||
|
||||
return self.policy_map.get(policy_id)
|
||||
|
||||
@DeveloperAPI
|
||||
def for_policy(self, func, policy_id=DEFAULT_POLICY_ID):
|
||||
"""Apply the given function to the specified policy."""
|
||||
|
||||
return func(self.policy_map[policy_id])
|
||||
|
||||
@DeveloperAPI
|
||||
def foreach_policy(self, func):
|
||||
"""Apply the given function to each (policy, policy_id) tuple."""
|
||||
|
||||
return [func(policy, pid) for pid, policy in self.policy_map.items()]
|
||||
|
||||
@DeveloperAPI
|
||||
def foreach_trainable_policy(self, func):
|
||||
"""Apply the given function to each (policy, policy_id) tuple.
|
||||
|
||||
This only applies func to policies in `self.policies_to_train`."""
|
||||
|
||||
return [
|
||||
func(policy, pid) for pid, policy in self.policy_map.items()
|
||||
if pid in self.policies_to_train
|
||||
]
|
||||
|
||||
@DeveloperAPI
|
||||
def sync_filters(self, new_filters):
|
||||
"""Changes self's filter to given and rebases any accumulated delta.
|
||||
|
||||
Args:
|
||||
new_filters (dict): Filters with new state to update local copy.
|
||||
"""
|
||||
assert all(k in new_filters for k in self.filters)
|
||||
for k in self.filters:
|
||||
self.filters[k].sync(new_filters[k])
|
||||
|
||||
@DeveloperAPI
|
||||
def get_filters(self, flush_after=False):
|
||||
"""Returns a snapshot of filters.
|
||||
|
||||
Args:
|
||||
flush_after (bool): Clears the filter buffer state.
|
||||
|
||||
Returns:
|
||||
return_filters (dict): Dict for serializable filters
|
||||
"""
|
||||
return_filters = {}
|
||||
for k, f in self.filters.items():
|
||||
return_filters[k] = f.as_serializable()
|
||||
if flush_after:
|
||||
f.clear_buffer()
|
||||
return return_filters
|
||||
|
||||
@DeveloperAPI
|
||||
def save(self):
|
||||
filters = self.get_filters(flush_after=True)
|
||||
state = {
|
||||
pid: self.policy_map[pid].get_state()
|
||||
for pid in self.policy_map
|
||||
}
|
||||
return pickle.dumps({"filters": filters, "state": state})
|
||||
|
||||
@DeveloperAPI
|
||||
def restore(self, objs):
|
||||
objs = pickle.loads(objs)
|
||||
self.sync_filters(objs["filters"])
|
||||
for pid, state in objs["state"].items():
|
||||
self.policy_map[pid].set_state(state)
|
||||
|
||||
@DeveloperAPI
|
||||
def set_global_vars(self, global_vars):
|
||||
self.foreach_policy(lambda p, _: p.on_global_var_update(global_vars))
|
||||
|
||||
@DeveloperAPI
|
||||
def export_policy_model(self, export_dir, policy_id=DEFAULT_POLICY_ID):
|
||||
self.policy_map[policy_id].export_model(export_dir)
|
||||
|
||||
@DeveloperAPI
|
||||
def export_policy_checkpoint(self,
|
||||
export_dir,
|
||||
filename_prefix="model",
|
||||
policy_id=DEFAULT_POLICY_ID):
|
||||
self.policy_map[policy_id].export_checkpoint(export_dir,
|
||||
filename_prefix)
|
||||
|
||||
@DeveloperAPI
|
||||
def stop(self):
|
||||
self.async_env.stop()
|
||||
|
||||
def _build_policy_map(self, policy_dict, policy_config):
|
||||
policy_map = {}
|
||||
preprocessors = {}
|
||||
for name, (cls, obs_space, act_space,
|
||||
conf) in sorted(policy_dict.items()):
|
||||
logger.debug("Creating policy for {}".format(name))
|
||||
merged_conf = merge_dicts(policy_config, conf)
|
||||
if self.preprocessing_enabled:
|
||||
preprocessor = ModelCatalog.get_preprocessor_for_space(
|
||||
obs_space, merged_conf.get("model"))
|
||||
preprocessors[name] = preprocessor
|
||||
obs_space = preprocessor.observation_space
|
||||
else:
|
||||
preprocessors[name] = NoPreprocessor(obs_space)
|
||||
if isinstance(obs_space, gym.spaces.Dict) or \
|
||||
isinstance(obs_space, gym.spaces.Tuple):
|
||||
raise ValueError(
|
||||
"Found raw Tuple|Dict space as input to policy. "
|
||||
"Please preprocess these observations with a "
|
||||
"Tuple|DictFlatteningPreprocessor.")
|
||||
if tf:
|
||||
with tf.variable_scope(name):
|
||||
policy_map[name] = cls(obs_space, act_space, merged_conf)
|
||||
else:
|
||||
policy_map[name] = cls(obs_space, act_space, merged_conf)
|
||||
if self.worker_index == 0:
|
||||
logger.info("Built policy map: {}".format(policy_map))
|
||||
logger.info("Built preprocessor map: {}".format(preprocessors))
|
||||
return policy_map, preprocessors
|
||||
|
||||
def __del__(self):
|
||||
if hasattr(self, "sampler") and isinstance(self.sampler, AsyncSampler):
|
||||
self.sampler.shutdown = True
|
||||
|
||||
|
||||
def _validate_and_canonicalize(policy, env):
|
||||
if isinstance(policy, dict):
|
||||
_validate_multiagent_config(policy)
|
||||
return policy
|
||||
elif not issubclass(policy, Policy):
|
||||
raise ValueError("policy must be a rllib.Policy class")
|
||||
else:
|
||||
if (isinstance(env, MultiAgentEnv)
|
||||
and not hasattr(env, "observation_space")):
|
||||
raise ValueError(
|
||||
"MultiAgentEnv must have observation_space defined if run "
|
||||
"in a single-agent configuration.")
|
||||
return {
|
||||
DEFAULT_POLICY_ID: (policy, env.observation_space,
|
||||
env.action_space, {})
|
||||
}
|
||||
|
||||
|
||||
def _validate_multiagent_config(policy, allow_none_graph=False):
|
||||
for k, v in policy.items():
|
||||
if not isinstance(k, str):
|
||||
raise ValueError("policy keys must be strs, got {}".format(
|
||||
type(k)))
|
||||
if not isinstance(v, tuple) or len(v) != 4:
|
||||
raise ValueError(
|
||||
"policy values must be tuples of "
|
||||
"(cls, obs_space, action_space, config), got {}".format(v))
|
||||
if allow_none_graph and v[0] is None:
|
||||
pass
|
||||
elif not issubclass(v[0], Policy):
|
||||
raise ValueError("policy tuple value 0 must be a rllib.Policy "
|
||||
"class or None, got {}".format(v[0]))
|
||||
if not isinstance(v[1], gym.Space):
|
||||
raise ValueError(
|
||||
"policy tuple value 1 (observation_space) must be a "
|
||||
"gym.Space, got {}".format(type(v[1])))
|
||||
if not isinstance(v[2], gym.Space):
|
||||
raise ValueError("policy tuple value 2 (action_space) must be a "
|
||||
"gym.Space, got {}".format(type(v[2])))
|
||||
if not isinstance(v[3], dict):
|
||||
raise ValueError("policy tuple value 3 (config) must be a dict, "
|
||||
"got {}".format(type(v[3])))
|
||||
|
||||
|
||||
def _validate_env(env):
|
||||
# allow this as a special case (assumed gym.Env)
|
||||
if hasattr(env, "observation_space") and hasattr(env, "action_space"):
|
||||
return env
|
||||
|
||||
allowed_types = [gym.Env, MultiAgentEnv, ExternalEnv, VectorEnv, BaseEnv]
|
||||
if not any(isinstance(env, tpe) for tpe in allowed_types):
|
||||
raise ValueError(
|
||||
"Returned env should be an instance of gym.Env, MultiAgentEnv, "
|
||||
"ExternalEnv, VectorEnv, or BaseEnv. The provided env creator "
|
||||
"function returned {} ({}).".format(env, type(env)))
|
||||
return env
|
||||
|
||||
|
||||
def _monitor(env, path):
|
||||
return gym.wrappers.Monitor(env, path, resume=True)
|
||||
|
||||
|
||||
def _has_tensorflow_graph(policy_dict):
|
||||
for policy, _, _, _ in policy_dict.values():
|
||||
if issubclass(policy, TFPolicy):
|
||||
return True
|
||||
return False
|
||||
PolicyEvaluator = renamed_class(
|
||||
RolloutWorker, old_name="rllib.evaluation.PolicyEvaluator")
|
||||
|
||||
@@ -0,0 +1,794 @@
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import gym
|
||||
import logging
|
||||
import pickle
|
||||
|
||||
import ray
|
||||
from ray.rllib.env.atari_wrappers import wrap_deepmind, is_atari
|
||||
from ray.rllib.env.base_env import BaseEnv
|
||||
from ray.rllib.env.env_context import EnvContext
|
||||
from ray.rllib.env.external_env import ExternalEnv
|
||||
from ray.rllib.env.multi_agent_env import MultiAgentEnv
|
||||
from ray.rllib.env.external_multi_agent_env import ExternalMultiAgentEnv
|
||||
from ray.rllib.env.vector_env import VectorEnv
|
||||
from ray.rllib.evaluation.interface import EvaluatorInterface
|
||||
from ray.rllib.evaluation.sampler import AsyncSampler, SyncSampler
|
||||
from ray.rllib.policy.sample_batch import MultiAgentBatch, DEFAULT_POLICY_ID
|
||||
from ray.rllib.policy.policy import Policy
|
||||
from ray.rllib.policy.tf_policy import TFPolicy
|
||||
from ray.rllib.offline import NoopOutput, IOContext, OutputWriter, InputReader
|
||||
from ray.rllib.offline.is_estimator import ImportanceSamplingEstimator
|
||||
from ray.rllib.offline.wis_estimator import WeightedImportanceSamplingEstimator
|
||||
from ray.rllib.models import ModelCatalog
|
||||
from ray.rllib.models.preprocessors import NoPreprocessor
|
||||
from ray.rllib.utils import merge_dicts
|
||||
from ray.rllib.utils.annotations import override, DeveloperAPI
|
||||
from ray.rllib.utils.debug import disable_log_once_globally, log_once, \
|
||||
summarize, enable_periodic_logging
|
||||
from ray.rllib.utils.filter import get_filter
|
||||
from ray.rllib.utils.tf_run_builder import TFRunBuilder
|
||||
from ray.rllib.utils import try_import_tf
|
||||
|
||||
tf = try_import_tf()
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Handle to the current rollout worker, which will be set to the most recently
|
||||
# created RolloutWorker in this process. This can be helpful to access in
|
||||
# custom env or policy classes for debugging or advanced use cases.
|
||||
_global_worker = None
|
||||
|
||||
|
||||
@DeveloperAPI
|
||||
def get_global_worker():
|
||||
"""Returns a handle to the active rollout worker in this process."""
|
||||
|
||||
global _global_worker
|
||||
return _global_worker
|
||||
|
||||
|
||||
@DeveloperAPI
|
||||
class RolloutWorker(EvaluatorInterface):
|
||||
"""Common experience collection class.
|
||||
|
||||
This class wraps a policy instance and an environment class to
|
||||
collect experiences from the environment. You can create many replicas of
|
||||
this class as Ray actors to scale RL training.
|
||||
|
||||
This class supports vectorized and multi-agent policy evaluation (e.g.,
|
||||
VectorEnv, MultiAgentEnv, etc.)
|
||||
|
||||
Examples:
|
||||
>>> # Create a rollout worker and using it to collect experiences.
|
||||
>>> worker = RolloutWorker(
|
||||
... env_creator=lambda _: gym.make("CartPole-v0"),
|
||||
... policy=PGTFPolicy)
|
||||
>>> print(worker.sample())
|
||||
SampleBatch({
|
||||
"obs": [[...]], "actions": [[...]], "rewards": [[...]],
|
||||
"dones": [[...]], "new_obs": [[...]]})
|
||||
|
||||
>>> # Creating a multi-agent rollout worker
|
||||
>>> worker = RolloutWorker(
|
||||
... env_creator=lambda _: MultiAgentTrafficGrid(num_cars=25),
|
||||
... policies={
|
||||
... # Use an ensemble of two policies for car agents
|
||||
... "car_policy1":
|
||||
... (PGTFPolicy, Box(...), Discrete(...), {"gamma": 0.99}),
|
||||
... "car_policy2":
|
||||
... (PGTFPolicy, Box(...), Discrete(...), {"gamma": 0.95}),
|
||||
... # Use a single shared policy for all traffic lights
|
||||
... "traffic_light_policy":
|
||||
... (PGTFPolicy, Box(...), Discrete(...), {}),
|
||||
... },
|
||||
... policy_mapping_fn=lambda agent_id:
|
||||
... random.choice(["car_policy1", "car_policy2"])
|
||||
... if agent_id.startswith("car_") else "traffic_light_policy")
|
||||
>>> print(worker.sample())
|
||||
MultiAgentBatch({
|
||||
"car_policy1": SampleBatch(...),
|
||||
"car_policy2": SampleBatch(...),
|
||||
"traffic_light_policy": SampleBatch(...)})
|
||||
"""
|
||||
|
||||
@DeveloperAPI
|
||||
@classmethod
|
||||
def as_remote(cls, num_cpus=None, num_gpus=None, resources=None):
|
||||
return ray.remote(
|
||||
num_cpus=num_cpus, num_gpus=num_gpus, resources=resources)(cls)
|
||||
|
||||
@DeveloperAPI
|
||||
def __init__(self,
|
||||
env_creator,
|
||||
policy,
|
||||
policy_mapping_fn=None,
|
||||
policies_to_train=None,
|
||||
tf_session_creator=None,
|
||||
batch_steps=100,
|
||||
batch_mode="truncate_episodes",
|
||||
episode_horizon=None,
|
||||
preprocessor_pref="deepmind",
|
||||
sample_async=False,
|
||||
compress_observations=False,
|
||||
num_envs=1,
|
||||
observation_filter="NoFilter",
|
||||
clip_rewards=None,
|
||||
clip_actions=True,
|
||||
env_config=None,
|
||||
model_config=None,
|
||||
policy_config=None,
|
||||
worker_index=0,
|
||||
monitor_path=None,
|
||||
log_dir=None,
|
||||
log_level=None,
|
||||
callbacks=None,
|
||||
input_creator=lambda ioctx: ioctx.default_sampler_input(),
|
||||
input_evaluation=frozenset([]),
|
||||
output_creator=lambda ioctx: NoopOutput(),
|
||||
remote_worker_envs=False,
|
||||
remote_env_batch_wait_ms=0,
|
||||
soft_horizon=False,
|
||||
_fake_sampler=False):
|
||||
"""Initialize a rollout worker.
|
||||
|
||||
Arguments:
|
||||
env_creator (func): Function that returns a gym.Env given an
|
||||
EnvContext wrapped configuration.
|
||||
policy (class|dict): Either a class implementing
|
||||
Policy, or a dictionary of policy id strings to
|
||||
(Policy, obs_space, action_space, config) tuples. If a
|
||||
dict is specified, then we are in multi-agent mode and a
|
||||
policy_mapping_fn should also be set.
|
||||
policy_mapping_fn (func): A function that maps agent ids to
|
||||
policy ids in multi-agent mode. This function will be called
|
||||
each time a new agent appears in an episode, to bind that agent
|
||||
to a policy for the duration of the episode.
|
||||
policies_to_train (list): Optional whitelist of policies to train,
|
||||
or None for all policies.
|
||||
tf_session_creator (func): A function that returns a TF session.
|
||||
This is optional and only useful with TFPolicy.
|
||||
batch_steps (int): The target number of env transitions to include
|
||||
in each sample batch returned from this worker.
|
||||
batch_mode (str): One of the following batch modes:
|
||||
"truncate_episodes": Each call to sample() will return a batch
|
||||
of at most `batch_steps * num_envs` in size. The batch will
|
||||
be exactly `batch_steps * num_envs` in size if
|
||||
postprocessing does not change batch sizes. Episodes may be
|
||||
truncated in order to meet this size requirement.
|
||||
"complete_episodes": Each call to sample() will return a batch
|
||||
of at least `batch_steps * num_envs` in size. Episodes will
|
||||
not be truncated, but multiple episodes may be packed
|
||||
within one batch to meet the batch size. Note that when
|
||||
`num_envs > 1`, episode steps will be buffered until the
|
||||
episode completes, and hence batches may contain
|
||||
significant amounts of off-policy data.
|
||||
episode_horizon (int): Whether to stop episodes at this horizon.
|
||||
preprocessor_pref (str): Whether to prefer RLlib preprocessors
|
||||
("rllib") or deepmind ("deepmind") when applicable.
|
||||
sample_async (bool): Whether to compute samples asynchronously in
|
||||
the background, which improves throughput but can cause samples
|
||||
to be slightly off-policy.
|
||||
compress_observations (bool): If true, compress the observations.
|
||||
They can be decompressed with rllib/utils/compression.
|
||||
num_envs (int): If more than one, will create multiple envs
|
||||
and vectorize the computation of actions. This has no effect if
|
||||
if the env already implements VectorEnv.
|
||||
observation_filter (str): Name of observation filter to use.
|
||||
clip_rewards (bool): Whether to clip rewards to [-1, 1] prior to
|
||||
experience postprocessing. Setting to None means clip for Atari
|
||||
only.
|
||||
clip_actions (bool): Whether to clip action values to the range
|
||||
specified by the policy action space.
|
||||
env_config (dict): Config to pass to the env creator.
|
||||
model_config (dict): Config to use when creating the policy model.
|
||||
policy_config (dict): Config to pass to the policy. In the
|
||||
multi-agent case, this config will be merged with the
|
||||
per-policy configs specified by `policy`.
|
||||
worker_index (int): For remote workers, this should be set to a
|
||||
non-zero and unique value. This index is passed to created envs
|
||||
through EnvContext so that envs can be configured per worker.
|
||||
monitor_path (str): Write out episode stats and videos to this
|
||||
directory if specified.
|
||||
log_dir (str): Directory where logs can be placed.
|
||||
log_level (str): Set the root log level on creation.
|
||||
callbacks (dict): Dict of custom debug callbacks.
|
||||
input_creator (func): Function that returns an InputReader object
|
||||
for loading previous generated experiences.
|
||||
input_evaluation (list): How to evaluate the policy performance.
|
||||
This only makes sense to set when the input is reading offline
|
||||
data. The possible values include:
|
||||
- "is": the step-wise importance sampling estimator.
|
||||
- "wis": the weighted step-wise is estimator.
|
||||
- "simulation": run the environment in the background, but
|
||||
use this data for evaluation only and never for learning.
|
||||
output_creator (func): Function that returns an OutputWriter object
|
||||
for saving generated experiences.
|
||||
remote_worker_envs (bool): If using num_envs > 1, whether to create
|
||||
those new envs in remote processes instead of in the current
|
||||
process. This adds overheads, but can make sense if your envs
|
||||
remote_env_batch_wait_ms (float): Timeout that remote workers
|
||||
are waiting when polling environments. 0 (continue when at
|
||||
least one env is ready) is a reasonable default, but optimal
|
||||
value could be obtained by measuring your environment
|
||||
step / reset and model inference perf.
|
||||
soft_horizon (bool): Calculate rewards but don't reset the
|
||||
environment when the horizon is hit.
|
||||
_fake_sampler (bool): Use a fake (inf speed) sampler for testing.
|
||||
"""
|
||||
|
||||
global _global_worker
|
||||
_global_worker = self
|
||||
|
||||
if log_level:
|
||||
logging.getLogger("ray.rllib").setLevel(log_level)
|
||||
|
||||
if worker_index > 1:
|
||||
disable_log_once_globally() # only need 1 worker to log
|
||||
elif log_level == "DEBUG":
|
||||
enable_periodic_logging()
|
||||
|
||||
env_context = EnvContext(env_config or {}, worker_index)
|
||||
policy_config = policy_config or {}
|
||||
self.policy_config = policy_config
|
||||
self.callbacks = callbacks or {}
|
||||
self.worker_index = worker_index
|
||||
model_config = model_config or {}
|
||||
policy_mapping_fn = (policy_mapping_fn
|
||||
or (lambda agent_id: DEFAULT_POLICY_ID))
|
||||
if not callable(policy_mapping_fn):
|
||||
raise ValueError(
|
||||
"Policy mapping function not callable. If you're using Tune, "
|
||||
"make sure to escape the function with tune.function() "
|
||||
"to prevent it from being evaluated as an expression.")
|
||||
self.env_creator = env_creator
|
||||
self.sample_batch_size = batch_steps * num_envs
|
||||
self.batch_mode = batch_mode
|
||||
self.compress_observations = compress_observations
|
||||
self.preprocessing_enabled = True
|
||||
self.last_batch = None
|
||||
self._fake_sampler = _fake_sampler
|
||||
|
||||
self.env = _validate_env(env_creator(env_context))
|
||||
if isinstance(self.env, MultiAgentEnv) or \
|
||||
isinstance(self.env, BaseEnv):
|
||||
|
||||
def wrap(env):
|
||||
return env # we can't auto-wrap these env types
|
||||
elif is_atari(self.env) and \
|
||||
not model_config.get("custom_preprocessor") and \
|
||||
preprocessor_pref == "deepmind":
|
||||
|
||||
# Deepmind wrappers already handle all preprocessing
|
||||
self.preprocessing_enabled = False
|
||||
|
||||
if clip_rewards is None:
|
||||
clip_rewards = True
|
||||
|
||||
def wrap(env):
|
||||
env = wrap_deepmind(
|
||||
env,
|
||||
dim=model_config.get("dim"),
|
||||
framestack=model_config.get("framestack"))
|
||||
if monitor_path:
|
||||
env = _monitor(env, monitor_path)
|
||||
return env
|
||||
else:
|
||||
|
||||
def wrap(env):
|
||||
if monitor_path:
|
||||
env = _monitor(env, monitor_path)
|
||||
return env
|
||||
|
||||
self.env = wrap(self.env)
|
||||
|
||||
def make_env(vector_index):
|
||||
return wrap(
|
||||
env_creator(
|
||||
env_context.copy_with_overrides(
|
||||
vector_index=vector_index, remote=remote_worker_envs)))
|
||||
|
||||
self.tf_sess = None
|
||||
policy_dict = _validate_and_canonicalize(policy, self.env)
|
||||
self.policies_to_train = policies_to_train or list(policy_dict.keys())
|
||||
if _has_tensorflow_graph(policy_dict):
|
||||
if (ray.is_initialized()
|
||||
and ray.worker._mode() != ray.worker.LOCAL_MODE
|
||||
and not ray.get_gpu_ids()):
|
||||
logger.info("Creating policy evaluation worker {}".format(
|
||||
worker_index) +
|
||||
" on CPU (please ignore any CUDA init errors)")
|
||||
with tf.Graph().as_default():
|
||||
if tf_session_creator:
|
||||
self.tf_sess = tf_session_creator()
|
||||
else:
|
||||
self.tf_sess = tf.Session(
|
||||
config=tf.ConfigProto(
|
||||
gpu_options=tf.GPUOptions(allow_growth=True)))
|
||||
with self.tf_sess.as_default():
|
||||
self.policy_map, self.preprocessors = \
|
||||
self._build_policy_map(policy_dict, policy_config)
|
||||
else:
|
||||
self.policy_map, self.preprocessors = self._build_policy_map(
|
||||
policy_dict, policy_config)
|
||||
|
||||
self.multiagent = set(self.policy_map.keys()) != {DEFAULT_POLICY_ID}
|
||||
if self.multiagent:
|
||||
if not ((isinstance(self.env, MultiAgentEnv)
|
||||
or isinstance(self.env, ExternalMultiAgentEnv))
|
||||
or isinstance(self.env, BaseEnv)):
|
||||
raise ValueError(
|
||||
"Have multiple policies {}, but the env ".format(
|
||||
self.policy_map) +
|
||||
"{} is not a subclass of BaseEnv, MultiAgentEnv or "
|
||||
"ExternalMultiAgentEnv?".format(self.env))
|
||||
|
||||
self.filters = {
|
||||
policy_id: get_filter(observation_filter,
|
||||
policy.observation_space.shape)
|
||||
for (policy_id, policy) in self.policy_map.items()
|
||||
}
|
||||
if self.worker_index == 0:
|
||||
logger.info("Built filter map: {}".format(self.filters))
|
||||
|
||||
# Always use vector env for consistency even if num_envs = 1
|
||||
self.async_env = BaseEnv.to_base_env(
|
||||
self.env,
|
||||
make_env=make_env,
|
||||
num_envs=num_envs,
|
||||
remote_envs=remote_worker_envs,
|
||||
remote_env_batch_wait_ms=remote_env_batch_wait_ms)
|
||||
self.num_envs = num_envs
|
||||
|
||||
if self.batch_mode == "truncate_episodes":
|
||||
unroll_length = batch_steps
|
||||
pack_episodes = True
|
||||
elif self.batch_mode == "complete_episodes":
|
||||
unroll_length = float("inf") # never cut episodes
|
||||
pack_episodes = False # sampler will return 1 episode per poll
|
||||
else:
|
||||
raise ValueError("Unsupported batch mode: {}".format(
|
||||
self.batch_mode))
|
||||
|
||||
self.io_context = IOContext(log_dir, policy_config, worker_index, self)
|
||||
self.reward_estimators = []
|
||||
for method in input_evaluation:
|
||||
if method == "simulation":
|
||||
logger.warning(
|
||||
"Requested 'simulation' input evaluation method: "
|
||||
"will discard all sampler outputs and keep only metrics.")
|
||||
sample_async = True
|
||||
elif method == "is":
|
||||
ise = ImportanceSamplingEstimator.create(self.io_context)
|
||||
self.reward_estimators.append(ise)
|
||||
elif method == "wis":
|
||||
wise = WeightedImportanceSamplingEstimator.create(
|
||||
self.io_context)
|
||||
self.reward_estimators.append(wise)
|
||||
else:
|
||||
raise ValueError(
|
||||
"Unknown evaluation method: {}".format(method))
|
||||
|
||||
if sample_async:
|
||||
self.sampler = AsyncSampler(
|
||||
self.async_env,
|
||||
self.policy_map,
|
||||
policy_mapping_fn,
|
||||
self.preprocessors,
|
||||
self.filters,
|
||||
clip_rewards,
|
||||
unroll_length,
|
||||
self.callbacks,
|
||||
horizon=episode_horizon,
|
||||
pack=pack_episodes,
|
||||
tf_sess=self.tf_sess,
|
||||
clip_actions=clip_actions,
|
||||
blackhole_outputs="simulation" in input_evaluation,
|
||||
soft_horizon=soft_horizon)
|
||||
self.sampler.start()
|
||||
else:
|
||||
self.sampler = SyncSampler(
|
||||
self.async_env,
|
||||
self.policy_map,
|
||||
policy_mapping_fn,
|
||||
self.preprocessors,
|
||||
self.filters,
|
||||
clip_rewards,
|
||||
unroll_length,
|
||||
self.callbacks,
|
||||
horizon=episode_horizon,
|
||||
pack=pack_episodes,
|
||||
tf_sess=self.tf_sess,
|
||||
clip_actions=clip_actions,
|
||||
soft_horizon=soft_horizon)
|
||||
|
||||
self.input_reader = input_creator(self.io_context)
|
||||
assert isinstance(self.input_reader, InputReader), self.input_reader
|
||||
self.output_writer = output_creator(self.io_context)
|
||||
assert isinstance(self.output_writer, OutputWriter), self.output_writer
|
||||
|
||||
logger.debug(
|
||||
"Created rollout worker with env {} ({}), policies {}".format(
|
||||
self.async_env, self.env, self.policy_map))
|
||||
|
||||
@override(EvaluatorInterface)
|
||||
def sample(self):
|
||||
"""Evaluate the current policies and return a batch of experiences.
|
||||
|
||||
Return:
|
||||
SampleBatch|MultiAgentBatch from evaluating the current policies.
|
||||
"""
|
||||
|
||||
if self._fake_sampler and self.last_batch is not None:
|
||||
return self.last_batch
|
||||
|
||||
if log_once("sample_start"):
|
||||
logger.info("Generating sample batch of size {}".format(
|
||||
self.sample_batch_size))
|
||||
|
||||
batches = [self.input_reader.next()]
|
||||
steps_so_far = batches[0].count
|
||||
|
||||
# In truncate_episodes mode, never pull more than 1 batch per env.
|
||||
# This avoids over-running the target batch size.
|
||||
if self.batch_mode == "truncate_episodes":
|
||||
max_batches = self.num_envs
|
||||
else:
|
||||
max_batches = float("inf")
|
||||
|
||||
while steps_so_far < self.sample_batch_size and len(
|
||||
batches) < max_batches:
|
||||
batch = self.input_reader.next()
|
||||
steps_so_far += batch.count
|
||||
batches.append(batch)
|
||||
batch = batches[0].concat_samples(batches)
|
||||
|
||||
if self.callbacks.get("on_sample_end"):
|
||||
self.callbacks["on_sample_end"]({"worker": self, "samples": batch})
|
||||
|
||||
# Always do writes prior to compression for consistency and to allow
|
||||
# for better compression inside the writer.
|
||||
self.output_writer.write(batch)
|
||||
|
||||
# Do off-policy estimation if needed
|
||||
if self.reward_estimators:
|
||||
for sub_batch in batch.split_by_episode():
|
||||
for estimator in self.reward_estimators:
|
||||
estimator.process(sub_batch)
|
||||
|
||||
if log_once("sample_end"):
|
||||
logger.info("Completed sample batch:\n\n{}\n".format(
|
||||
summarize(batch)))
|
||||
|
||||
if self.compress_observations == "bulk":
|
||||
batch.compress(bulk=True)
|
||||
elif self.compress_observations:
|
||||
batch.compress()
|
||||
|
||||
if self._fake_sampler:
|
||||
self.last_batch = batch
|
||||
return batch
|
||||
|
||||
@DeveloperAPI
|
||||
@ray.method(num_return_vals=2)
|
||||
def sample_with_count(self):
|
||||
"""Same as sample() but returns the count as a separate future."""
|
||||
batch = self.sample()
|
||||
return batch, batch.count
|
||||
|
||||
@override(EvaluatorInterface)
|
||||
def get_weights(self, policies=None):
|
||||
if policies is None:
|
||||
policies = self.policy_map.keys()
|
||||
return {
|
||||
pid: policy.get_weights()
|
||||
for pid, policy in self.policy_map.items() if pid in policies
|
||||
}
|
||||
|
||||
@override(EvaluatorInterface)
|
||||
def set_weights(self, weights):
|
||||
for pid, w in weights.items():
|
||||
self.policy_map[pid].set_weights(w)
|
||||
|
||||
@override(EvaluatorInterface)
|
||||
def compute_gradients(self, samples):
|
||||
if log_once("compute_gradients"):
|
||||
logger.info("Compute gradients on:\n\n{}\n".format(
|
||||
summarize(samples)))
|
||||
if isinstance(samples, MultiAgentBatch):
|
||||
grad_out, info_out = {}, {}
|
||||
if self.tf_sess is not None:
|
||||
builder = TFRunBuilder(self.tf_sess, "compute_gradients")
|
||||
for pid, batch in samples.policy_batches.items():
|
||||
if pid not in self.policies_to_train:
|
||||
continue
|
||||
grad_out[pid], info_out[pid] = (
|
||||
self.policy_map[pid]._build_compute_gradients(
|
||||
builder, batch))
|
||||
grad_out = {k: builder.get(v) for k, v in grad_out.items()}
|
||||
info_out = {k: builder.get(v) for k, v in info_out.items()}
|
||||
else:
|
||||
for pid, batch in samples.policy_batches.items():
|
||||
if pid not in self.policies_to_train:
|
||||
continue
|
||||
grad_out[pid], info_out[pid] = (
|
||||
self.policy_map[pid].compute_gradients(batch))
|
||||
else:
|
||||
grad_out, info_out = (
|
||||
self.policy_map[DEFAULT_POLICY_ID].compute_gradients(samples))
|
||||
info_out["batch_count"] = samples.count
|
||||
if log_once("grad_out"):
|
||||
logger.info("Compute grad info:\n\n{}\n".format(
|
||||
summarize(info_out)))
|
||||
return grad_out, info_out
|
||||
|
||||
@override(EvaluatorInterface)
|
||||
def apply_gradients(self, grads):
|
||||
if log_once("apply_gradients"):
|
||||
logger.info("Apply gradients:\n\n{}\n".format(summarize(grads)))
|
||||
if isinstance(grads, dict):
|
||||
if self.tf_sess is not None:
|
||||
builder = TFRunBuilder(self.tf_sess, "apply_gradients")
|
||||
outputs = {
|
||||
pid: self.policy_map[pid]._build_apply_gradients(
|
||||
builder, grad)
|
||||
for pid, grad in grads.items()
|
||||
}
|
||||
return {k: builder.get(v) for k, v in outputs.items()}
|
||||
else:
|
||||
return {
|
||||
pid: self.policy_map[pid].apply_gradients(g)
|
||||
for pid, g in grads.items()
|
||||
}
|
||||
else:
|
||||
return self.policy_map[DEFAULT_POLICY_ID].apply_gradients(grads)
|
||||
|
||||
@override(EvaluatorInterface)
|
||||
def learn_on_batch(self, samples):
|
||||
if log_once("learn_on_batch"):
|
||||
logger.info(
|
||||
"Training on concatenated sample batches:\n\n{}\n".format(
|
||||
summarize(samples)))
|
||||
if isinstance(samples, MultiAgentBatch):
|
||||
info_out = {}
|
||||
to_fetch = {}
|
||||
if self.tf_sess is not None:
|
||||
builder = TFRunBuilder(self.tf_sess, "learn_on_batch")
|
||||
else:
|
||||
builder = None
|
||||
for pid, batch in samples.policy_batches.items():
|
||||
if pid not in self.policies_to_train:
|
||||
continue
|
||||
policy = self.policy_map[pid]
|
||||
if builder and hasattr(policy, "_build_learn_on_batch"):
|
||||
to_fetch[pid] = policy._build_learn_on_batch(
|
||||
builder, batch)
|
||||
else:
|
||||
info_out[pid] = policy.learn_on_batch(batch)
|
||||
info_out.update({k: builder.get(v) for k, v in to_fetch.items()})
|
||||
else:
|
||||
info_out = self.policy_map[DEFAULT_POLICY_ID].learn_on_batch(
|
||||
samples)
|
||||
if log_once("learn_out"):
|
||||
logger.info("Training output:\n\n{}\n".format(summarize(info_out)))
|
||||
return info_out
|
||||
|
||||
@DeveloperAPI
|
||||
def get_metrics(self):
|
||||
"""Returns a list of new RolloutMetric objects from evaluation."""
|
||||
|
||||
out = self.sampler.get_metrics()
|
||||
for m in self.reward_estimators:
|
||||
out.extend(m.get_metrics())
|
||||
return out
|
||||
|
||||
@DeveloperAPI
|
||||
def foreach_env(self, func):
|
||||
"""Apply the given function to each underlying env instance."""
|
||||
|
||||
envs = self.async_env.get_unwrapped()
|
||||
if not envs:
|
||||
return [func(self.async_env)]
|
||||
else:
|
||||
return [func(e) for e in envs]
|
||||
|
||||
@DeveloperAPI
|
||||
def get_policy(self, policy_id=DEFAULT_POLICY_ID):
|
||||
"""Return policy for the specified id, or None.
|
||||
|
||||
Arguments:
|
||||
policy_id (str): id of policy to return.
|
||||
"""
|
||||
|
||||
return self.policy_map.get(policy_id)
|
||||
|
||||
@DeveloperAPI
|
||||
def for_policy(self, func, policy_id=DEFAULT_POLICY_ID):
|
||||
"""Apply the given function to the specified policy."""
|
||||
|
||||
return func(self.policy_map[policy_id])
|
||||
|
||||
@DeveloperAPI
|
||||
def foreach_policy(self, func):
|
||||
"""Apply the given function to each (policy, policy_id) tuple."""
|
||||
|
||||
return [func(policy, pid) for pid, policy in self.policy_map.items()]
|
||||
|
||||
@DeveloperAPI
|
||||
def foreach_trainable_policy(self, func):
|
||||
"""Apply the given function to each (policy, policy_id) tuple.
|
||||
|
||||
This only applies func to policies in `self.policies_to_train`."""
|
||||
|
||||
return [
|
||||
func(policy, pid) for pid, policy in self.policy_map.items()
|
||||
if pid in self.policies_to_train
|
||||
]
|
||||
|
||||
@DeveloperAPI
|
||||
def sync_filters(self, new_filters):
|
||||
"""Changes self's filter to given and rebases any accumulated delta.
|
||||
|
||||
Args:
|
||||
new_filters (dict): Filters with new state to update local copy.
|
||||
"""
|
||||
assert all(k in new_filters for k in self.filters)
|
||||
for k in self.filters:
|
||||
self.filters[k].sync(new_filters[k])
|
||||
|
||||
@DeveloperAPI
|
||||
def get_filters(self, flush_after=False):
|
||||
"""Returns a snapshot of filters.
|
||||
|
||||
Args:
|
||||
flush_after (bool): Clears the filter buffer state.
|
||||
|
||||
Returns:
|
||||
return_filters (dict): Dict for serializable filters
|
||||
"""
|
||||
return_filters = {}
|
||||
for k, f in self.filters.items():
|
||||
return_filters[k] = f.as_serializable()
|
||||
if flush_after:
|
||||
f.clear_buffer()
|
||||
return return_filters
|
||||
|
||||
@DeveloperAPI
|
||||
def save(self):
|
||||
filters = self.get_filters(flush_after=True)
|
||||
state = {
|
||||
pid: self.policy_map[pid].get_state()
|
||||
for pid in self.policy_map
|
||||
}
|
||||
return pickle.dumps({"filters": filters, "state": state})
|
||||
|
||||
@DeveloperAPI
|
||||
def restore(self, objs):
|
||||
objs = pickle.loads(objs)
|
||||
self.sync_filters(objs["filters"])
|
||||
for pid, state in objs["state"].items():
|
||||
self.policy_map[pid].set_state(state)
|
||||
|
||||
@DeveloperAPI
|
||||
def set_global_vars(self, global_vars):
|
||||
self.foreach_policy(lambda p, _: p.on_global_var_update(global_vars))
|
||||
|
||||
@DeveloperAPI
|
||||
def export_policy_model(self, export_dir, policy_id=DEFAULT_POLICY_ID):
|
||||
self.policy_map[policy_id].export_model(export_dir)
|
||||
|
||||
@DeveloperAPI
|
||||
def export_policy_checkpoint(self,
|
||||
export_dir,
|
||||
filename_prefix="model",
|
||||
policy_id=DEFAULT_POLICY_ID):
|
||||
self.policy_map[policy_id].export_checkpoint(export_dir,
|
||||
filename_prefix)
|
||||
|
||||
@DeveloperAPI
|
||||
def stop(self):
|
||||
self.async_env.stop()
|
||||
|
||||
def _build_policy_map(self, policy_dict, policy_config):
|
||||
policy_map = {}
|
||||
preprocessors = {}
|
||||
for name, (cls, obs_space, act_space,
|
||||
conf) in sorted(policy_dict.items()):
|
||||
logger.debug("Creating policy for {}".format(name))
|
||||
merged_conf = merge_dicts(policy_config, conf)
|
||||
if self.preprocessing_enabled:
|
||||
preprocessor = ModelCatalog.get_preprocessor_for_space(
|
||||
obs_space, merged_conf.get("model"))
|
||||
preprocessors[name] = preprocessor
|
||||
obs_space = preprocessor.observation_space
|
||||
else:
|
||||
preprocessors[name] = NoPreprocessor(obs_space)
|
||||
if isinstance(obs_space, gym.spaces.Dict) or \
|
||||
isinstance(obs_space, gym.spaces.Tuple):
|
||||
raise ValueError(
|
||||
"Found raw Tuple|Dict space as input to policy. "
|
||||
"Please preprocess these observations with a "
|
||||
"Tuple|DictFlatteningPreprocessor.")
|
||||
if tf:
|
||||
with tf.variable_scope(name):
|
||||
policy_map[name] = cls(obs_space, act_space, merged_conf)
|
||||
else:
|
||||
policy_map[name] = cls(obs_space, act_space, merged_conf)
|
||||
if self.worker_index == 0:
|
||||
logger.info("Built policy map: {}".format(policy_map))
|
||||
logger.info("Built preprocessor map: {}".format(preprocessors))
|
||||
return policy_map, preprocessors
|
||||
|
||||
def __del__(self):
|
||||
if hasattr(self, "sampler") and isinstance(self.sampler, AsyncSampler):
|
||||
self.sampler.shutdown = True
|
||||
|
||||
|
||||
def _validate_and_canonicalize(policy, env):
|
||||
if isinstance(policy, dict):
|
||||
_validate_multiagent_config(policy)
|
||||
return policy
|
||||
elif not issubclass(policy, Policy):
|
||||
raise ValueError("policy must be a rllib.Policy class")
|
||||
else:
|
||||
if (isinstance(env, MultiAgentEnv)
|
||||
and not hasattr(env, "observation_space")):
|
||||
raise ValueError(
|
||||
"MultiAgentEnv must have observation_space defined if run "
|
||||
"in a single-agent configuration.")
|
||||
return {
|
||||
DEFAULT_POLICY_ID: (policy, env.observation_space,
|
||||
env.action_space, {})
|
||||
}
|
||||
|
||||
|
||||
def _validate_multiagent_config(policy, allow_none_graph=False):
|
||||
for k, v in policy.items():
|
||||
if not isinstance(k, str):
|
||||
raise ValueError("policy keys must be strs, got {}".format(
|
||||
type(k)))
|
||||
if not isinstance(v, tuple) or len(v) != 4:
|
||||
raise ValueError(
|
||||
"policy values must be tuples of "
|
||||
"(cls, obs_space, action_space, config), got {}".format(v))
|
||||
if allow_none_graph and v[0] is None:
|
||||
pass
|
||||
elif not issubclass(v[0], Policy):
|
||||
raise ValueError("policy tuple value 0 must be a rllib.Policy "
|
||||
"class or None, got {}".format(v[0]))
|
||||
if not isinstance(v[1], gym.Space):
|
||||
raise ValueError(
|
||||
"policy tuple value 1 (observation_space) must be a "
|
||||
"gym.Space, got {}".format(type(v[1])))
|
||||
if not isinstance(v[2], gym.Space):
|
||||
raise ValueError("policy tuple value 2 (action_space) must be a "
|
||||
"gym.Space, got {}".format(type(v[2])))
|
||||
if not isinstance(v[3], dict):
|
||||
raise ValueError("policy tuple value 3 (config) must be a dict, "
|
||||
"got {}".format(type(v[3])))
|
||||
|
||||
|
||||
def _validate_env(env):
|
||||
# allow this as a special case (assumed gym.Env)
|
||||
if hasattr(env, "observation_space") and hasattr(env, "action_space"):
|
||||
return env
|
||||
|
||||
allowed_types = [gym.Env, MultiAgentEnv, ExternalEnv, VectorEnv, BaseEnv]
|
||||
if not any(isinstance(env, tpe) for tpe in allowed_types):
|
||||
raise ValueError(
|
||||
"Returned env should be an instance of gym.Env, MultiAgentEnv, "
|
||||
"ExternalEnv, VectorEnv, or BaseEnv. The provided env creator "
|
||||
"function returned {} ({}).".format(env, type(env)))
|
||||
return env
|
||||
|
||||
|
||||
def _monitor(env, path):
|
||||
return gym.wrappers.Monitor(env, path, resume=True)
|
||||
|
||||
|
||||
def _has_tensorflow_graph(policy_dict):
|
||||
for policy, _, _, _ in policy_dict.values():
|
||||
if issubclass(policy, TFPolicy):
|
||||
return True
|
||||
return False
|
||||
@@ -0,0 +1,214 @@
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import logging
|
||||
from types import FunctionType
|
||||
|
||||
from ray.rllib.utils.annotations import DeveloperAPI
|
||||
from ray.rllib.evaluation.rollout_worker import RolloutWorker, \
|
||||
_validate_multiagent_config
|
||||
from ray.rllib.offline import NoopOutput, JsonReader, MixedInput, JsonWriter, \
|
||||
ShuffledInput
|
||||
from ray.rllib.utils import merge_dicts, try_import_tf
|
||||
from ray.rllib.utils.memory import ray_get_and_free
|
||||
|
||||
tf = try_import_tf()
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@DeveloperAPI
|
||||
class WorkerSet(object):
|
||||
"""Represents a set of RolloutWorkers.
|
||||
|
||||
There must be one local worker copy, and zero or more remote workers.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
env_creator,
|
||||
policy,
|
||||
trainer_config=None,
|
||||
num_workers=0,
|
||||
logdir=None,
|
||||
_setup=True):
|
||||
"""Create a new WorkerSet and initialize its workers.
|
||||
|
||||
Arguments:
|
||||
env_creator (func): Function that returns env given env config.
|
||||
policy (cls): rllib.policy.Policy class.
|
||||
trainer_config (dict): Optional dict that extends the common
|
||||
config of the Trainer class.
|
||||
num_workers (int): Number of remote rollout workers to create.
|
||||
logdir (str): Optional logging directory for workers.
|
||||
_setup (bool): Whether to setup workers. This is only for testing.
|
||||
"""
|
||||
|
||||
if not trainer_config:
|
||||
from ray.rllib.agents.trainer import COMMON_CONFIG
|
||||
trainer_config = COMMON_CONFIG
|
||||
|
||||
self._env_creator = env_creator
|
||||
self._policy = policy
|
||||
self._remote_config = trainer_config
|
||||
self._num_workers = num_workers
|
||||
self._logdir = logdir
|
||||
|
||||
if _setup:
|
||||
self._local_config = merge_dicts(
|
||||
trainer_config,
|
||||
{"tf_session_args": trainer_config["local_tf_session_args"]})
|
||||
|
||||
# Always create a local worker
|
||||
self._local_worker = self._make_worker(
|
||||
RolloutWorker, env_creator, policy, 0, self._local_config)
|
||||
|
||||
# Create a number of remote workers
|
||||
self._remote_workers = []
|
||||
self.add_workers(num_workers)
|
||||
|
||||
def local_worker(self):
|
||||
"""Return the local rollout worker."""
|
||||
return self._local_worker
|
||||
|
||||
def remote_workers(self):
|
||||
"""Return a list of remote rollout workers."""
|
||||
return self._remote_workers
|
||||
|
||||
def add_workers(self, num_workers):
|
||||
"""Create and add a number of remote workers to this worker set."""
|
||||
remote_args = {
|
||||
"num_cpus": self._remote_config["num_cpus_per_worker"],
|
||||
"num_gpus": self._remote_config["num_gpus_per_worker"],
|
||||
"resources": self._remote_config["custom_resources_per_worker"],
|
||||
}
|
||||
cls = RolloutWorker.as_remote(**remote_args).remote
|
||||
self._remote_workers.extend([
|
||||
self._make_worker(cls, self._env_creator, self._policy, i + 1,
|
||||
self._remote_config) for i in range(num_workers)
|
||||
])
|
||||
|
||||
def reset(self, new_remote_workers):
|
||||
"""Called to change the set of remote workers."""
|
||||
self._remote_workers = new_remote_workers
|
||||
|
||||
def stop(self):
|
||||
"""Stop all rollout workers."""
|
||||
self.local_worker().stop()
|
||||
for w in self.remote_workers():
|
||||
w.stop.remote()
|
||||
w.__ray_terminate__.remote()
|
||||
|
||||
@DeveloperAPI
|
||||
def foreach_worker(self, func):
|
||||
"""Apply the given function to each worker instance."""
|
||||
|
||||
local_result = [func(self.local_worker())]
|
||||
remote_results = ray_get_and_free(
|
||||
[w.apply.remote(func) for w in self.remote_workers()])
|
||||
return local_result + remote_results
|
||||
|
||||
@DeveloperAPI
|
||||
def foreach_worker_with_index(self, func):
|
||||
"""Apply the given function to each worker instance.
|
||||
|
||||
The index will be passed as the second arg to the given function.
|
||||
"""
|
||||
|
||||
local_result = [func(self.local_worker(), 0)]
|
||||
remote_results = ray_get_and_free([
|
||||
w.apply.remote(func, i + 1)
|
||||
for i, w in enumerate(self.remote_workers())
|
||||
])
|
||||
return local_result + remote_results
|
||||
|
||||
@staticmethod
|
||||
def _from_existing(local_worker, remote_workers=None):
|
||||
workers = WorkerSet(None, None, {}, _setup=False)
|
||||
workers._local_worker = local_worker
|
||||
workers._remote_workers = remote_workers or []
|
||||
return workers
|
||||
|
||||
def _make_worker(self, cls, env_creator, policy, worker_index, config):
|
||||
def session_creator():
|
||||
logger.debug("Creating TF session {}".format(
|
||||
config["tf_session_args"]))
|
||||
return tf.Session(
|
||||
config=tf.ConfigProto(**config["tf_session_args"]))
|
||||
|
||||
if isinstance(config["input"], FunctionType):
|
||||
input_creator = config["input"]
|
||||
elif config["input"] == "sampler":
|
||||
input_creator = (lambda ioctx: ioctx.default_sampler_input())
|
||||
elif isinstance(config["input"], dict):
|
||||
input_creator = (lambda ioctx: ShuffledInput(
|
||||
MixedInput(config["input"], ioctx), config[
|
||||
"shuffle_buffer_size"]))
|
||||
else:
|
||||
input_creator = (lambda ioctx: ShuffledInput(
|
||||
JsonReader(config["input"], ioctx), config[
|
||||
"shuffle_buffer_size"]))
|
||||
|
||||
if isinstance(config["output"], FunctionType):
|
||||
output_creator = config["output"]
|
||||
elif config["output"] is None:
|
||||
output_creator = (lambda ioctx: NoopOutput())
|
||||
elif config["output"] == "logdir":
|
||||
output_creator = (lambda ioctx: JsonWriter(
|
||||
ioctx.log_dir,
|
||||
ioctx,
|
||||
max_file_size=config["output_max_file_size"],
|
||||
compress_columns=config["output_compress_columns"]))
|
||||
else:
|
||||
output_creator = (lambda ioctx: JsonWriter(
|
||||
config["output"],
|
||||
ioctx,
|
||||
max_file_size=config["output_max_file_size"],
|
||||
compress_columns=config["output_compress_columns"]))
|
||||
|
||||
if config["input"] == "sampler":
|
||||
input_evaluation = []
|
||||
else:
|
||||
input_evaluation = config["input_evaluation"]
|
||||
|
||||
# Fill in the default policy if 'None' is specified in multiagent
|
||||
if config["multiagent"]["policies"]:
|
||||
tmp = config["multiagent"]["policies"]
|
||||
_validate_multiagent_config(tmp, allow_none_graph=True)
|
||||
for k, v in tmp.items():
|
||||
if v[0] is None:
|
||||
tmp[k] = (policy, v[1], v[2], v[3])
|
||||
policy = tmp
|
||||
|
||||
return cls(
|
||||
env_creator,
|
||||
policy,
|
||||
policy_mapping_fn=config["multiagent"]["policy_mapping_fn"],
|
||||
policies_to_train=config["multiagent"]["policies_to_train"],
|
||||
tf_session_creator=(session_creator
|
||||
if config["tf_session_args"] else None),
|
||||
batch_steps=config["sample_batch_size"],
|
||||
batch_mode=config["batch_mode"],
|
||||
episode_horizon=config["horizon"],
|
||||
preprocessor_pref=config["preprocessor_pref"],
|
||||
sample_async=config["sample_async"],
|
||||
compress_observations=config["compress_observations"],
|
||||
num_envs=config["num_envs_per_worker"],
|
||||
observation_filter=config["observation_filter"],
|
||||
clip_rewards=config["clip_rewards"],
|
||||
clip_actions=config["clip_actions"],
|
||||
env_config=config["env_config"],
|
||||
model_config=config["model"],
|
||||
policy_config=config,
|
||||
worker_index=worker_index,
|
||||
monitor_path=self._logdir if config["monitor"] else None,
|
||||
log_dir=self._logdir,
|
||||
log_level=config["log_level"],
|
||||
callbacks=config["callbacks"],
|
||||
input_creator=input_creator,
|
||||
input_evaluation=input_evaluation,
|
||||
output_creator=output_creator,
|
||||
remote_worker_envs=config["remote_worker_envs"],
|
||||
remote_env_batch_wait_ms=config["remote_env_batch_wait_ms"],
|
||||
soft_horizon=config["soft_horizon"],
|
||||
_fake_sampler=config.get("_fake_sampler", False))
|
||||
@@ -75,7 +75,7 @@ if __name__ == "__main__":
|
||||
})
|
||||
|
||||
# disable DQN exploration when used by the PPO trainer
|
||||
ppo_trainer.optimizer.foreach_evaluator(
|
||||
ppo_trainer.workers.foreach_worker(
|
||||
lambda ev: ev.for_policy(
|
||||
lambda pi: pi.set_epsilon(0.0), policy_id="dqn_policy"))
|
||||
|
||||
|
||||
+5
-5
@@ -1,4 +1,4 @@
|
||||
"""Example of using policy evaluator classes directly to implement training.
|
||||
"""Example of using rollout worker classes directly to implement training.
|
||||
|
||||
Instead of using the built-in Trainer classes provided by RLlib, here we define
|
||||
a custom Policy class and manually coordinate distributed sample
|
||||
@@ -15,7 +15,7 @@ import gym
|
||||
import ray
|
||||
from ray import tune
|
||||
from ray.rllib.policy import Policy
|
||||
from ray.rllib.evaluation import PolicyEvaluator, SampleBatch
|
||||
from ray.rllib.evaluation import RolloutWorker, SampleBatch
|
||||
from ray.rllib.evaluation.metrics import collect_metrics
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
@@ -67,8 +67,8 @@ def training_workflow(config, reporter):
|
||||
env = gym.make("CartPole-v0")
|
||||
policy = CustomPolicy(env.observation_space, env.action_space, {})
|
||||
workers = [
|
||||
PolicyEvaluator.as_remote().remote(lambda c: gym.make("CartPole-v0"),
|
||||
CustomPolicy)
|
||||
RolloutWorker.as_remote().remote(lambda c: gym.make("CartPole-v0"),
|
||||
CustomPolicy)
|
||||
for _ in range(config["num_workers"])
|
||||
]
|
||||
|
||||
@@ -97,7 +97,7 @@ def training_workflow(config, reporter):
|
||||
# Do some arbitrary updates based on the T2 batch
|
||||
policy.update_some_value(sum(T2["rewards"]))
|
||||
|
||||
reporter(**collect_metrics(remote_evaluators=workers))
|
||||
reporter(**collect_metrics(remote_workers=workers))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
@@ -18,20 +18,16 @@ class IOContext(object):
|
||||
config (dict): Configuration of the agent.
|
||||
worker_index (int): When there are multiple workers created, this
|
||||
uniquely identifies the current worker.
|
||||
evaluator (PolicyEvaluator): policy evaluator object reference.
|
||||
worker (RolloutWorker): rollout worker object reference.
|
||||
"""
|
||||
|
||||
@PublicAPI
|
||||
def __init__(self,
|
||||
log_dir=None,
|
||||
config=None,
|
||||
worker_index=0,
|
||||
evaluator=None):
|
||||
def __init__(self, log_dir=None, config=None, worker_index=0, worker=None):
|
||||
self.log_dir = log_dir or os.getcwd()
|
||||
self.config = config or {}
|
||||
self.worker_index = worker_index
|
||||
self.evaluator = evaluator
|
||||
self.worker = worker
|
||||
|
||||
@PublicAPI
|
||||
def default_sampler_input(self):
|
||||
return self.evaluator.sampler
|
||||
return self.worker.sampler
|
||||
|
||||
@@ -88,7 +88,7 @@ class JsonReader(InputReader):
|
||||
if isinstance(batch, SampleBatch):
|
||||
out = []
|
||||
for sub_batch in batch.split_by_episode():
|
||||
out.append(self.ioctx.evaluator.policy_map[DEFAULT_POLICY_ID]
|
||||
out.append(self.ioctx.worker.policy_map[DEFAULT_POLICY_ID]
|
||||
.postprocess_trajectory(sub_batch))
|
||||
return SampleBatch.concat_samples(out)
|
||||
else:
|
||||
|
||||
@@ -33,14 +33,14 @@ class OffPolicyEstimator(object):
|
||||
@classmethod
|
||||
def create(cls, ioctx):
|
||||
"""Create an off-policy estimator from a IOContext."""
|
||||
gamma = ioctx.evaluator.policy_config["gamma"]
|
||||
gamma = ioctx.worker.policy_config["gamma"]
|
||||
# Grab a reference to the current model
|
||||
keys = list(ioctx.evaluator.policy_map.keys())
|
||||
keys = list(ioctx.worker.policy_map.keys())
|
||||
if len(keys) > 1:
|
||||
raise NotImplementedError(
|
||||
"Off-policy estimation is not implemented for multi-agent. "
|
||||
"You can set `input_evaluation: []` to resolve this.")
|
||||
policy = ioctx.evaluator.get_policy(keys[0])
|
||||
policy = ioctx.worker.get_policy(keys[0])
|
||||
return cls(policy, gamma)
|
||||
|
||||
@DeveloperAPI
|
||||
|
||||
@@ -14,7 +14,7 @@ from ray.rllib.utils.memory import ray_get_and_free
|
||||
|
||||
|
||||
class Aggregator(object):
|
||||
"""An aggregator collects and processes samples from evaluators.
|
||||
"""An aggregator collects and processes samples from workers.
|
||||
|
||||
This class is used to abstract away the strategy for sample collection.
|
||||
For example, you may want to use a tree of actors to collect samples. The
|
||||
@@ -22,21 +22,21 @@ class Aggregator(object):
|
||||
as concatenating and decompressing sample batches.
|
||||
|
||||
Attributes:
|
||||
local_evaluator: local PolicyEvaluator copy
|
||||
local_worker: local RolloutWorker copy
|
||||
"""
|
||||
|
||||
def iter_train_batches(self):
|
||||
"""Returns a generator over batches ready to learn on.
|
||||
|
||||
Iterating through this generator will also send out weight updates to
|
||||
remote evaluators as needed.
|
||||
remote workers as needed.
|
||||
|
||||
This call may block until results are available.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def broadcast_new_weights(self):
|
||||
"""Broadcast a new set of weights from the local evaluator."""
|
||||
"""Broadcast a new set of weights from the local workers."""
|
||||
raise NotImplementedError
|
||||
|
||||
def should_broadcast(self):
|
||||
@@ -47,19 +47,19 @@ class Aggregator(object):
|
||||
"""Returns runtime statistics for debugging."""
|
||||
raise NotImplementedError
|
||||
|
||||
def reset(self, remote_evaluators):
|
||||
"""Called to change the set of remote evaluators being used."""
|
||||
def reset(self, remote_workers):
|
||||
"""Called to change the set of remote workers being used."""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class AggregationWorkerBase(object):
|
||||
"""Aggregators should extend from this class."""
|
||||
|
||||
def __init__(self, initial_weights_obj_id, remote_evaluators,
|
||||
def __init__(self, initial_weights_obj_id, remote_workers,
|
||||
max_sample_requests_in_flight_per_worker, replay_proportion,
|
||||
replay_buffer_num_slots, train_batch_size, sample_batch_size):
|
||||
self.broadcasted_weights = initial_weights_obj_id
|
||||
self.remote_evaluators = remote_evaluators
|
||||
self.remote_workers = remote_workers
|
||||
self.sample_batch_size = sample_batch_size
|
||||
self.train_batch_size = train_batch_size
|
||||
|
||||
@@ -73,7 +73,7 @@ class AggregationWorkerBase(object):
|
||||
|
||||
# Kick off async background sampling
|
||||
self.sample_tasks = TaskPool()
|
||||
for ev in self.remote_evaluators:
|
||||
for ev in self.remote_workers:
|
||||
ev.set_weights.remote(self.broadcasted_weights)
|
||||
for _ in range(max_sample_requests_in_flight_per_worker):
|
||||
self.sample_tasks.add(ev, ev.sample.remote())
|
||||
@@ -138,8 +138,8 @@ class AggregationWorkerBase(object):
|
||||
}
|
||||
|
||||
@override(Aggregator)
|
||||
def reset(self, remote_evaluators):
|
||||
self.sample_tasks.reset_evaluators(remote_evaluators)
|
||||
def reset(self, remote_workers):
|
||||
self.sample_tasks.reset_workers(remote_workers)
|
||||
|
||||
def _augment_with_replay(self, sample_futures):
|
||||
def can_replay():
|
||||
@@ -164,25 +164,25 @@ class SimpleAggregator(AggregationWorkerBase, Aggregator):
|
||||
"""Simple single-threaded implementation of an Aggregator."""
|
||||
|
||||
def __init__(self,
|
||||
local_evaluator,
|
||||
remote_evaluators,
|
||||
workers,
|
||||
max_sample_requests_in_flight_per_worker=2,
|
||||
replay_proportion=0.0,
|
||||
replay_buffer_num_slots=0,
|
||||
train_batch_size=500,
|
||||
sample_batch_size=50,
|
||||
broadcast_interval=5):
|
||||
self.local_evaluator = local_evaluator
|
||||
self.workers = workers
|
||||
self.local_worker = workers.local_worker()
|
||||
self.broadcast_interval = broadcast_interval
|
||||
self.broadcast_new_weights()
|
||||
AggregationWorkerBase.__init__(
|
||||
self, self.broadcasted_weights, remote_evaluators,
|
||||
self, self.broadcasted_weights, self.workers.remote_workers(),
|
||||
max_sample_requests_in_flight_per_worker, replay_proportion,
|
||||
replay_buffer_num_slots, train_batch_size, sample_batch_size)
|
||||
|
||||
@override(Aggregator)
|
||||
def broadcast_new_weights(self):
|
||||
self.broadcasted_weights = ray.put(self.local_evaluator.get_weights())
|
||||
self.broadcasted_weights = ray.put(self.local_worker.get_weights())
|
||||
self.num_sent_since_broadcast = 0
|
||||
|
||||
@override(Aggregator)
|
||||
|
||||
@@ -25,11 +25,11 @@ class LearnerThread(threading.Thread):
|
||||
improves overall throughput.
|
||||
"""
|
||||
|
||||
def __init__(self, local_evaluator, minibatch_buffer_size, num_sgd_iter,
|
||||
def __init__(self, local_worker, minibatch_buffer_size, num_sgd_iter,
|
||||
learner_queue_size):
|
||||
threading.Thread.__init__(self)
|
||||
self.learner_queue_size = WindowStat("size", 50)
|
||||
self.local_evaluator = local_evaluator
|
||||
self.local_worker = local_worker
|
||||
self.inqueue = queue.Queue(maxsize=learner_queue_size)
|
||||
self.outqueue = queue.Queue()
|
||||
self.minibatch_buffer = MinibatchBuffer(
|
||||
@@ -52,7 +52,7 @@ class LearnerThread(threading.Thread):
|
||||
batch, _ = self.minibatch_buffer.get()
|
||||
|
||||
with self.grad_timer:
|
||||
fetches = self.local_evaluator.learn_on_batch(batch)
|
||||
fetches = self.local_worker.learn_on_batch(batch)
|
||||
self.weights_updated = True
|
||||
self.stats = get_learner_stats(fetches)
|
||||
|
||||
|
||||
@@ -31,7 +31,7 @@ class TFMultiGPULearner(LearnerThread):
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
local_evaluator,
|
||||
local_worker,
|
||||
num_gpus=1,
|
||||
lr=0.0005,
|
||||
train_batch_size=500,
|
||||
@@ -41,7 +41,7 @@ class TFMultiGPULearner(LearnerThread):
|
||||
learner_queue_size=16,
|
||||
num_data_load_threads=16,
|
||||
_fake_gpus=False):
|
||||
LearnerThread.__init__(self, local_evaluator, minibatch_buffer_size,
|
||||
LearnerThread.__init__(self, local_worker, minibatch_buffer_size,
|
||||
num_sgd_iter, learner_queue_size)
|
||||
self.lr = lr
|
||||
self.train_batch_size = train_batch_size
|
||||
@@ -59,16 +59,16 @@ class TFMultiGPULearner(LearnerThread):
|
||||
assert self.train_batch_size % len(self.devices) == 0
|
||||
assert self.train_batch_size >= len(self.devices), "batch too small"
|
||||
|
||||
if set(self.local_evaluator.policy_map.keys()) != {DEFAULT_POLICY_ID}:
|
||||
if set(self.local_worker.policy_map.keys()) != {DEFAULT_POLICY_ID}:
|
||||
raise NotImplementedError("Multi-gpu mode for multi-agent")
|
||||
self.policy = self.local_evaluator.policy_map[DEFAULT_POLICY_ID]
|
||||
self.policy = self.local_worker.policy_map[DEFAULT_POLICY_ID]
|
||||
|
||||
# per-GPU graph copies created below must share vars with the policy
|
||||
# reuse is set to AUTO_REUSE because Adam nodes are created after
|
||||
# all of the device copies are created.
|
||||
self.par_opt = []
|
||||
with self.local_evaluator.tf_sess.graph.as_default():
|
||||
with self.local_evaluator.tf_sess.as_default():
|
||||
with self.local_worker.tf_sess.graph.as_default():
|
||||
with self.local_worker.tf_sess.as_default():
|
||||
with tf.variable_scope(DEFAULT_POLICY_ID, reuse=tf.AUTO_REUSE):
|
||||
if self.policy._state_inputs:
|
||||
rnn_inputs = self.policy._state_inputs + [
|
||||
@@ -87,7 +87,7 @@ class TFMultiGPULearner(LearnerThread):
|
||||
999999, # it will get rounded down
|
||||
self.policy.copy))
|
||||
|
||||
self.sess = self.local_evaluator.tf_sess
|
||||
self.sess = self.local_worker.tf_sess
|
||||
self.sess.run(tf.global_variables_initializer())
|
||||
|
||||
self.idle_optimizers = queue.Queue()
|
||||
|
||||
@@ -22,15 +22,14 @@ logger = logging.getLogger(__name__)
|
||||
class TreeAggregator(Aggregator):
|
||||
"""A hierarchical experiences aggregator.
|
||||
|
||||
The given set of remote evaluators is divided into subsets and assigned to
|
||||
The given set of remote workers is divided into subsets and assigned to
|
||||
one of several aggregation workers. These aggregation workers collate
|
||||
experiences into batches of size `train_batch_size` and we collect them
|
||||
in this class when `iter_train_batches` is called.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
local_evaluator,
|
||||
remote_evaluators,
|
||||
workers,
|
||||
num_aggregation_workers,
|
||||
max_sample_requests_in_flight_per_worker=2,
|
||||
replay_proportion=0.0,
|
||||
@@ -38,8 +37,7 @@ class TreeAggregator(Aggregator):
|
||||
train_batch_size=500,
|
||||
sample_batch_size=50,
|
||||
broadcast_interval=5):
|
||||
self.local_evaluator = local_evaluator
|
||||
self.remote_evaluators = remote_evaluators
|
||||
self.workers = workers
|
||||
self.num_aggregation_workers = num_aggregation_workers
|
||||
self.max_sample_requests_in_flight_per_worker = \
|
||||
max_sample_requests_in_flight_per_worker
|
||||
@@ -48,7 +46,8 @@ class TreeAggregator(Aggregator):
|
||||
self.sample_batch_size = sample_batch_size
|
||||
self.train_batch_size = train_batch_size
|
||||
self.broadcast_interval = broadcast_interval
|
||||
self.broadcasted_weights = ray.put(local_evaluator.get_weights())
|
||||
self.broadcasted_weights = ray.put(
|
||||
workers.local_worker().get_weights())
|
||||
self.num_batches_processed = 0
|
||||
self.num_broadcasts = 0
|
||||
self.num_sent_since_broadcast = 0
|
||||
@@ -58,26 +57,27 @@ class TreeAggregator(Aggregator):
|
||||
"""Deferred init so that we can pass in previously created workers."""
|
||||
|
||||
assert len(aggregators) == self.num_aggregation_workers, aggregators
|
||||
if len(self.remote_evaluators) < self.num_aggregation_workers:
|
||||
if len(self.workers.remote_workers()) < self.num_aggregation_workers:
|
||||
raise ValueError(
|
||||
"The number of aggregation workers should not exceed the "
|
||||
"number of total evaluation workers ({} vs {})".format(
|
||||
self.num_aggregation_workers, len(self.remote_evaluators)))
|
||||
self.num_aggregation_workers,
|
||||
len(self.workers.remote_workers())))
|
||||
|
||||
assigned_evaluators = collections.defaultdict(list)
|
||||
for i, ev in enumerate(self.remote_evaluators):
|
||||
assigned_evaluators[i % self.num_aggregation_workers].append(ev)
|
||||
assigned_workers = collections.defaultdict(list)
|
||||
for i, ev in enumerate(self.workers.remote_workers()):
|
||||
assigned_workers[i % self.num_aggregation_workers].append(ev)
|
||||
|
||||
self.workers = aggregators
|
||||
for i, worker in enumerate(self.workers):
|
||||
worker.init.remote(
|
||||
self.broadcasted_weights, assigned_evaluators[i],
|
||||
self.max_sample_requests_in_flight_per_worker,
|
||||
self.replay_proportion, self.replay_buffer_num_slots,
|
||||
self.train_batch_size, self.sample_batch_size)
|
||||
self.aggregators = aggregators
|
||||
for i, agg in enumerate(self.aggregators):
|
||||
agg.init.remote(self.broadcasted_weights, assigned_workers[i],
|
||||
self.max_sample_requests_in_flight_per_worker,
|
||||
self.replay_proportion,
|
||||
self.replay_buffer_num_slots,
|
||||
self.train_batch_size, self.sample_batch_size)
|
||||
|
||||
self.agg_tasks = TaskPool()
|
||||
for agg in self.workers:
|
||||
for agg in self.aggregators:
|
||||
agg.set_weights.remote(self.broadcasted_weights)
|
||||
self.agg_tasks.add(agg, agg.get_train_batches.remote())
|
||||
|
||||
@@ -96,7 +96,8 @@ class TreeAggregator(Aggregator):
|
||||
|
||||
@override(Aggregator)
|
||||
def broadcast_new_weights(self):
|
||||
self.broadcasted_weights = ray.put(self.local_evaluator.get_weights())
|
||||
self.broadcasted_weights = ray.put(
|
||||
self.workers.local_worker().get_weights())
|
||||
self.num_sent_since_broadcast = 0
|
||||
self.num_broadcasts += 1
|
||||
|
||||
@@ -112,8 +113,8 @@ class TreeAggregator(Aggregator):
|
||||
}
|
||||
|
||||
@override(Aggregator)
|
||||
def reset(self, remote_evaluators):
|
||||
raise NotImplementedError("changing number of remote evaluators")
|
||||
def reset(self, remote_workers):
|
||||
raise NotImplementedError("changing number of remote workers")
|
||||
|
||||
@staticmethod
|
||||
def precreate_aggregators(n):
|
||||
@@ -125,16 +126,16 @@ class AggregationWorker(AggregationWorkerBase):
|
||||
def __init__(self):
|
||||
self.initialized = False
|
||||
|
||||
def init(self, initial_weights_obj_id, remote_evaluators,
|
||||
def init(self, initial_weights_obj_id, remote_workers,
|
||||
max_sample_requests_in_flight_per_worker, replay_proportion,
|
||||
replay_buffer_num_slots, train_batch_size, sample_batch_size):
|
||||
"""Deferred init that assigns sub-workers to this aggregator."""
|
||||
|
||||
logger.info("Assigned evaluators {} to aggregation worker {}".format(
|
||||
remote_evaluators, self))
|
||||
assert remote_evaluators
|
||||
logger.info("Assigned workers {} to aggregation worker {}".format(
|
||||
remote_workers, self))
|
||||
assert remote_workers
|
||||
AggregationWorkerBase.__init__(
|
||||
self, initial_weights_obj_id, remote_evaluators,
|
||||
self, initial_weights_obj_id, remote_workers,
|
||||
max_sample_requests_in_flight_per_worker, replay_proportion,
|
||||
replay_buffer_num_slots, train_batch_size, sample_batch_size)
|
||||
self.initialized = True
|
||||
|
||||
@@ -14,30 +14,30 @@ class AsyncGradientsOptimizer(PolicyOptimizer):
|
||||
"""An asynchronous RL optimizer, e.g. for implementing A3C.
|
||||
|
||||
This optimizer asynchronously pulls and applies gradients from remote
|
||||
evaluators, sending updated weights back as needed. This pipelines the
|
||||
workers, sending updated weights back as needed. This pipelines the
|
||||
gradient computations on the remote workers.
|
||||
"""
|
||||
|
||||
def __init__(self, local_evaluator, remote_evaluators, grads_per_step=100):
|
||||
PolicyOptimizer.__init__(self, local_evaluator, remote_evaluators)
|
||||
def __init__(self, workers, grads_per_step=100):
|
||||
PolicyOptimizer.__init__(self, workers)
|
||||
|
||||
self.apply_timer = TimerStat()
|
||||
self.wait_timer = TimerStat()
|
||||
self.dispatch_timer = TimerStat()
|
||||
self.grads_per_step = grads_per_step
|
||||
self.learner_stats = {}
|
||||
if not self.remote_evaluators:
|
||||
if not self.workers.remote_workers():
|
||||
raise ValueError(
|
||||
"Async optimizer requires at least 1 remote evaluator")
|
||||
"Async optimizer requires at least 1 remote workers")
|
||||
|
||||
@override(PolicyOptimizer)
|
||||
def step(self):
|
||||
weights = ray.put(self.local_evaluator.get_weights())
|
||||
weights = ray.put(self.workers.local_worker().get_weights())
|
||||
pending_gradients = {}
|
||||
num_gradients = 0
|
||||
|
||||
# Kick off the first wave of async tasks
|
||||
for e in self.remote_evaluators:
|
||||
for e in self.workers.remote_workers():
|
||||
e.set_weights.remote(weights)
|
||||
future = e.compute_gradients.remote(e.sample.remote())
|
||||
pending_gradients[future] = e
|
||||
@@ -56,13 +56,14 @@ class AsyncGradientsOptimizer(PolicyOptimizer):
|
||||
|
||||
if gradient is not None:
|
||||
with self.apply_timer:
|
||||
self.local_evaluator.apply_gradients(gradient)
|
||||
self.workers.local_worker().apply_gradients(gradient)
|
||||
self.num_steps_sampled += info["batch_count"]
|
||||
self.num_steps_trained += info["batch_count"]
|
||||
|
||||
if num_gradients < self.grads_per_step:
|
||||
with self.dispatch_timer:
|
||||
e.set_weights.remote(self.local_evaluator.get_weights())
|
||||
e.set_weights.remote(
|
||||
self.workers.local_worker().get_weights())
|
||||
future = e.compute_gradients.remote(e.sample.remote())
|
||||
|
||||
pending_gradients[future] = e
|
||||
|
||||
@@ -36,20 +36,19 @@ class AsyncReplayOptimizer(PolicyOptimizer):
|
||||
"""Main event loop of the Ape-X optimizer (async sampling with replay).
|
||||
|
||||
This class coordinates the data transfers between the learner thread,
|
||||
remote evaluators (Ape-X actors), and replay buffer actors.
|
||||
remote workers (Ape-X actors), and replay buffer actors.
|
||||
|
||||
This has two modes of operation:
|
||||
- normal replay: replays independent samples.
|
||||
- batch replay: simplified mode where entire sample batches are
|
||||
replayed. This supports RNNs, but not prioritization.
|
||||
|
||||
This optimizer requires that policy evaluators return an additional
|
||||
This optimizer requires that rollout workers return an additional
|
||||
"td_error" array in the info return of compute_gradients(). This error
|
||||
term will be used for sample prioritization."""
|
||||
|
||||
def __init__(self,
|
||||
local_evaluator,
|
||||
remote_evaluators,
|
||||
workers,
|
||||
learning_starts=1000,
|
||||
buffer_size=10000,
|
||||
prioritized_replay=True,
|
||||
@@ -62,7 +61,7 @@ class AsyncReplayOptimizer(PolicyOptimizer):
|
||||
max_weight_sync_delay=400,
|
||||
debug=False,
|
||||
batch_replay=False):
|
||||
PolicyOptimizer.__init__(self, local_evaluator, remote_evaluators)
|
||||
PolicyOptimizer.__init__(self, workers)
|
||||
|
||||
self.debug = debug
|
||||
self.batch_replay = batch_replay
|
||||
@@ -71,7 +70,7 @@ class AsyncReplayOptimizer(PolicyOptimizer):
|
||||
self.prioritized_replay_eps = prioritized_replay_eps
|
||||
self.max_weight_sync_delay = max_weight_sync_delay
|
||||
|
||||
self.learner = LearnerThread(self.local_evaluator)
|
||||
self.learner = LearnerThread(self.workers.local_worker())
|
||||
self.learner.start()
|
||||
|
||||
if self.batch_replay:
|
||||
@@ -111,13 +110,13 @@ class AsyncReplayOptimizer(PolicyOptimizer):
|
||||
|
||||
# Kick off async background sampling
|
||||
self.sample_tasks = TaskPool()
|
||||
if self.remote_evaluators:
|
||||
self._set_evaluators(self.remote_evaluators)
|
||||
if self.workers.remote_workers():
|
||||
self._set_workers(self.workers.remote_workers())
|
||||
|
||||
@override(PolicyOptimizer)
|
||||
def step(self):
|
||||
assert self.learner.is_alive()
|
||||
assert len(self.remote_evaluators) > 0
|
||||
assert len(self.workers.remote_workers()) > 0
|
||||
start = time.time()
|
||||
sample_timesteps, train_timesteps = self._step()
|
||||
time_delta = time.time() - start
|
||||
@@ -138,9 +137,9 @@ class AsyncReplayOptimizer(PolicyOptimizer):
|
||||
self.learner.stopped = True
|
||||
|
||||
@override(PolicyOptimizer)
|
||||
def reset(self, remote_evaluators):
|
||||
self.remote_evaluators = remote_evaluators
|
||||
self.sample_tasks.reset_evaluators(remote_evaluators)
|
||||
def reset(self, remote_workers):
|
||||
self.workers.reset(remote_workers)
|
||||
self.sample_tasks.reset_workers(remote_workers)
|
||||
|
||||
@override(PolicyOptimizer)
|
||||
def stats(self):
|
||||
@@ -175,10 +174,10 @@ class AsyncReplayOptimizer(PolicyOptimizer):
|
||||
return dict(PolicyOptimizer.stats(self), **stats)
|
||||
|
||||
# For https://github.com/ray-project/ray/issues/2541 only
|
||||
def _set_evaluators(self, remote_evaluators):
|
||||
self.remote_evaluators = remote_evaluators
|
||||
weights = self.local_evaluator.get_weights()
|
||||
for ev in self.remote_evaluators:
|
||||
def _set_workers(self, remote_workers):
|
||||
self.workers.reset(remote_workers)
|
||||
weights = self.workers.local_worker().get_weights()
|
||||
for ev in self.workers.remote_workers():
|
||||
ev.set_weights.remote(weights)
|
||||
self.steps_since_update[ev] = 0
|
||||
for _ in range(SAMPLE_QUEUE_DEPTH):
|
||||
@@ -207,7 +206,7 @@ class AsyncReplayOptimizer(PolicyOptimizer):
|
||||
self.learner.weights_updated = False
|
||||
with self.timers["put_weights"]:
|
||||
weights = ray.put(
|
||||
self.local_evaluator.get_weights())
|
||||
self.workers.local_worker().get_weights())
|
||||
ev.set_weights.remote(weights)
|
||||
self.num_weight_syncs += 1
|
||||
self.steps_since_update[ev] = 0
|
||||
@@ -380,10 +379,10 @@ class LearnerThread(threading.Thread):
|
||||
improves overall throughput.
|
||||
"""
|
||||
|
||||
def __init__(self, local_evaluator):
|
||||
def __init__(self, local_worker):
|
||||
threading.Thread.__init__(self)
|
||||
self.learner_queue_size = WindowStat("size", 50)
|
||||
self.local_evaluator = local_evaluator
|
||||
self.local_worker = local_worker
|
||||
self.inqueue = queue.Queue(maxsize=LEARNER_QUEUE_MAX_SIZE)
|
||||
self.outqueue = queue.Queue()
|
||||
self.queue_timer = TimerStat()
|
||||
@@ -403,7 +402,7 @@ class LearnerThread(threading.Thread):
|
||||
if replay is not None:
|
||||
prio_dict = {}
|
||||
with self.grad_timer:
|
||||
grad_out = self.local_evaluator.learn_on_batch(replay)
|
||||
grad_out = self.local_worker.learn_on_batch(replay)
|
||||
for pid, info in grad_out.items():
|
||||
prio_dict[pid] = (
|
||||
replay.policy_batches[pid].data.get("batch_indexes"),
|
||||
|
||||
@@ -24,12 +24,11 @@ class AsyncSamplesOptimizer(PolicyOptimizer):
|
||||
"""Main event loop of the IMPALA architecture.
|
||||
|
||||
This class coordinates the data transfers between the learner thread
|
||||
and remote evaluators (IMPALA actors).
|
||||
and remote workers (IMPALA actors).
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
local_evaluator,
|
||||
remote_evaluators,
|
||||
workers,
|
||||
train_batch_size=500,
|
||||
sample_batch_size=50,
|
||||
num_envs_per_worker=1,
|
||||
@@ -45,7 +44,7 @@ class AsyncSamplesOptimizer(PolicyOptimizer):
|
||||
learner_queue_size=16,
|
||||
num_aggregation_workers=0,
|
||||
_fake_gpus=False):
|
||||
PolicyOptimizer.__init__(self, local_evaluator, remote_evaluators)
|
||||
PolicyOptimizer.__init__(self, workers)
|
||||
|
||||
self._stats_start_time = time.time()
|
||||
self._last_stats_time = {}
|
||||
@@ -62,7 +61,7 @@ class AsyncSamplesOptimizer(PolicyOptimizer):
|
||||
"{} vs {}".format(num_data_loader_buffers,
|
||||
minibatch_buffer_size))
|
||||
self.learner = TFMultiGPULearner(
|
||||
self.local_evaluator,
|
||||
self.workers.local_worker(),
|
||||
lr=lr,
|
||||
num_gpus=num_gpus,
|
||||
train_batch_size=train_batch_size,
|
||||
@@ -72,7 +71,7 @@ class AsyncSamplesOptimizer(PolicyOptimizer):
|
||||
learner_queue_size=learner_queue_size,
|
||||
_fake_gpus=_fake_gpus)
|
||||
else:
|
||||
self.learner = LearnerThread(self.local_evaluator,
|
||||
self.learner = LearnerThread(self.workers.local_worker(),
|
||||
minibatch_buffer_size, num_sgd_iter,
|
||||
learner_queue_size)
|
||||
self.learner.start()
|
||||
@@ -84,8 +83,7 @@ class AsyncSamplesOptimizer(PolicyOptimizer):
|
||||
|
||||
if num_aggregation_workers > 0:
|
||||
self.aggregator = TreeAggregator(
|
||||
self.local_evaluator,
|
||||
self.remote_evaluators,
|
||||
workers,
|
||||
num_aggregation_workers,
|
||||
replay_proportion=replay_proportion,
|
||||
max_sample_requests_in_flight_per_worker=(
|
||||
@@ -96,8 +94,7 @@ class AsyncSamplesOptimizer(PolicyOptimizer):
|
||||
broadcast_interval=broadcast_interval)
|
||||
else:
|
||||
self.aggregator = SimpleAggregator(
|
||||
self.local_evaluator,
|
||||
self.remote_evaluators,
|
||||
workers,
|
||||
replay_proportion=replay_proportion,
|
||||
max_sample_requests_in_flight_per_worker=(
|
||||
max_sample_requests_in_flight_per_worker),
|
||||
@@ -127,7 +124,7 @@ class AsyncSamplesOptimizer(PolicyOptimizer):
|
||||
|
||||
@override(PolicyOptimizer)
|
||||
def step(self):
|
||||
if len(self.remote_evaluators) == 0:
|
||||
if len(self.workers.remote_workers()) == 0:
|
||||
raise ValueError("Config num_workers=0 means training will hang!")
|
||||
assert self.learner.is_alive()
|
||||
with self._optimizer_step_timer:
|
||||
@@ -146,9 +143,9 @@ class AsyncSamplesOptimizer(PolicyOptimizer):
|
||||
self.learner.stopped = True
|
||||
|
||||
@override(PolicyOptimizer)
|
||||
def reset(self, remote_evaluators):
|
||||
self.remote_evaluators = remote_evaluators
|
||||
self.aggregator.reset(remote_evaluators)
|
||||
def reset(self, remote_workers):
|
||||
self.workers.reset(remote_workers)
|
||||
self.aggregator.reset(remote_workers)
|
||||
|
||||
@override(PolicyOptimizer)
|
||||
def stats(self):
|
||||
|
||||
@@ -28,7 +28,7 @@ logger = logging.getLogger(__name__)
|
||||
class LocalMultiGPUOptimizer(PolicyOptimizer):
|
||||
"""A synchronous optimizer that uses multiple local GPUs.
|
||||
|
||||
Samples are pulled synchronously from multiple remote evaluators,
|
||||
Samples are pulled synchronously from multiple remote workers,
|
||||
concatenated, and then split across the memory of multiple local GPUs.
|
||||
A number of SGD passes are then taken over the in-memory data. For more
|
||||
details, see `multi_gpu_impl.LocalSyncParallelOptimizer`.
|
||||
@@ -42,8 +42,7 @@ class LocalMultiGPUOptimizer(PolicyOptimizer):
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
local_evaluator,
|
||||
remote_evaluators,
|
||||
workers,
|
||||
sgd_batch_size=128,
|
||||
num_sgd_iter=10,
|
||||
sample_batch_size=200,
|
||||
@@ -52,7 +51,7 @@ class LocalMultiGPUOptimizer(PolicyOptimizer):
|
||||
num_gpus=0,
|
||||
standardize_fields=[],
|
||||
straggler_mitigation=False):
|
||||
PolicyOptimizer.__init__(self, local_evaluator, remote_evaluators)
|
||||
PolicyOptimizer.__init__(self, workers)
|
||||
|
||||
self.batch_size = sgd_batch_size
|
||||
self.num_sgd_iter = num_sgd_iter
|
||||
@@ -79,8 +78,8 @@ class LocalMultiGPUOptimizer(PolicyOptimizer):
|
||||
|
||||
logger.info("LocalMultiGPUOptimizer devices {}".format(self.devices))
|
||||
|
||||
self.policies = dict(
|
||||
self.local_evaluator.foreach_trainable_policy(lambda p, i: (i, p)))
|
||||
self.policies = dict(self.workers.local_worker()
|
||||
.foreach_trainable_policy(lambda p, i: (i, p)))
|
||||
logger.debug("Policies to train: {}".format(self.policies))
|
||||
for policy_id, policy in self.policies.items():
|
||||
if not isinstance(policy, TFPolicy):
|
||||
@@ -92,8 +91,8 @@ class LocalMultiGPUOptimizer(PolicyOptimizer):
|
||||
# reuse is set to AUTO_REUSE because Adam nodes are created after
|
||||
# all of the device copies are created.
|
||||
self.optimizers = {}
|
||||
with self.local_evaluator.tf_sess.graph.as_default():
|
||||
with self.local_evaluator.tf_sess.as_default():
|
||||
with self.workers.local_worker().tf_sess.graph.as_default():
|
||||
with self.workers.local_worker().tf_sess.as_default():
|
||||
for policy_id, policy in self.policies.items():
|
||||
with tf.variable_scope(policy_id, reuse=tf.AUTO_REUSE):
|
||||
if policy._state_inputs:
|
||||
@@ -109,25 +108,25 @@ class LocalMultiGPUOptimizer(PolicyOptimizer):
|
||||
for _, v in policy._loss_inputs], rnn_inputs,
|
||||
self.per_device_batch_size, policy.copy))
|
||||
|
||||
self.sess = self.local_evaluator.tf_sess
|
||||
self.sess = self.workers.local_worker().tf_sess
|
||||
self.sess.run(tf.global_variables_initializer())
|
||||
|
||||
@override(PolicyOptimizer)
|
||||
def step(self):
|
||||
with self.update_weights_timer:
|
||||
if self.remote_evaluators:
|
||||
weights = ray.put(self.local_evaluator.get_weights())
|
||||
for e in self.remote_evaluators:
|
||||
if self.workers.remote_workers():
|
||||
weights = ray.put(self.workers.local_worker().get_weights())
|
||||
for e in self.workers.remote_workers():
|
||||
e.set_weights.remote(weights)
|
||||
|
||||
with self.sample_timer:
|
||||
if self.remote_evaluators:
|
||||
if self.workers.remote_workers():
|
||||
if self.straggler_mitigation:
|
||||
samples = collect_samples_straggler_mitigation(
|
||||
self.remote_evaluators, self.train_batch_size)
|
||||
self.workers.remote_workers(), self.train_batch_size)
|
||||
else:
|
||||
samples = collect_samples(
|
||||
self.remote_evaluators, self.sample_batch_size,
|
||||
self.workers.remote_workers(), self.sample_batch_size,
|
||||
self.num_envs_per_worker, self.train_batch_size)
|
||||
if samples.count > self.train_batch_size * 2:
|
||||
logger.info(
|
||||
@@ -139,7 +138,7 @@ class LocalMultiGPUOptimizer(PolicyOptimizer):
|
||||
else:
|
||||
samples = []
|
||||
while sum(s.count for s in samples) < self.train_batch_size:
|
||||
samples.append(self.local_evaluator.sample())
|
||||
samples.append(self.workers.local_worker().sample())
|
||||
samples = SampleBatch.concat_samples(samples)
|
||||
|
||||
# Handle everything as if multiagent
|
||||
|
||||
@@ -6,7 +6,6 @@ import logging
|
||||
|
||||
from ray.rllib.utils.annotations import DeveloperAPI
|
||||
from ray.rllib.evaluation.metrics import collect_episodes, summarize_episodes
|
||||
from ray.rllib.utils.memory import ray_get_and_free
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -21,34 +20,21 @@ class PolicyOptimizer(object):
|
||||
used for PPO. These optimizers are all pluggable, and it is possible
|
||||
to mix and match as needed.
|
||||
|
||||
In order for an algorithm to use an RLlib optimizer, it must implement
|
||||
the PolicyEvaluator interface and pass a PolicyEvaluator class or set of
|
||||
PolicyEvaluators to its PolicyOptimizer of choice. The PolicyOptimizer
|
||||
uses these Evaluators to sample from the environment and compute model
|
||||
gradient updates.
|
||||
|
||||
Attributes:
|
||||
config (dict): The JSON configuration passed to this optimizer.
|
||||
local_evaluator (PolicyEvaluator): The embedded evaluator instance.
|
||||
remote_evaluators (list): List of remote evaluator replicas, or [].
|
||||
workers (WorkerSet): The set of rollout workers to use.
|
||||
num_steps_trained (int): Number of timesteps trained on so far.
|
||||
num_steps_sampled (int): Number of timesteps sampled so far.
|
||||
evaluator_resources (dict): Optional resource requests to set for
|
||||
evaluators created by this optimizer.
|
||||
"""
|
||||
|
||||
@DeveloperAPI
|
||||
def __init__(self, local_evaluator, remote_evaluators=None):
|
||||
def __init__(self, workers):
|
||||
"""Create an optimizer instance.
|
||||
|
||||
Args:
|
||||
local_evaluator (Evaluator): Local evaluator instance, required.
|
||||
remote_evaluators (list): A list of Ray actor handles to remote
|
||||
evaluators instances. If empty, the optimizer should fall back
|
||||
to using only the local evaluator.
|
||||
workers (WorkerSet): The set of rollout workers to use.
|
||||
"""
|
||||
self.local_evaluator = local_evaluator
|
||||
self.remote_evaluators = remote_evaluators or []
|
||||
self.workers = workers
|
||||
self.episode_history = []
|
||||
|
||||
# Counters that should be updated by sub-classes
|
||||
@@ -100,23 +86,23 @@ class PolicyOptimizer(object):
|
||||
def collect_metrics(self,
|
||||
timeout_seconds,
|
||||
min_history=100,
|
||||
selected_evaluators=None):
|
||||
"""Returns evaluator and optimizer stats.
|
||||
selected_workers=None):
|
||||
"""Returns worker and optimizer stats.
|
||||
|
||||
Arguments:
|
||||
timeout_seconds (int): Max wait time for a evaluator before
|
||||
dropping its results. This usually indicates a hung evaluator.
|
||||
timeout_seconds (int): Max wait time for a worker before
|
||||
dropping its results. This usually indicates a hung worker.
|
||||
min_history (int): Min history length to smooth results over.
|
||||
selected_evaluators (list): Override the list of remote evaluators
|
||||
selected_workers (list): Override the list of remote workers
|
||||
to collect metrics from.
|
||||
|
||||
Returns:
|
||||
res (dict): A training result dict from evaluator metrics with
|
||||
res (dict): A training result dict from worker metrics with
|
||||
`info` replaced with stats from self.
|
||||
"""
|
||||
episodes, num_dropped = collect_episodes(
|
||||
self.local_evaluator,
|
||||
selected_evaluators or self.remote_evaluators,
|
||||
self.workers.local_worker(),
|
||||
selected_workers or self.workers.remote_workers(),
|
||||
timeout_seconds=timeout_seconds)
|
||||
orig_episodes = list(episodes)
|
||||
missing = min_history - len(episodes)
|
||||
@@ -130,30 +116,28 @@ class PolicyOptimizer(object):
|
||||
return res
|
||||
|
||||
@DeveloperAPI
|
||||
def reset(self, remote_evaluators):
|
||||
"""Called to change the set of remote evaluators being used."""
|
||||
|
||||
self.remote_evaluators = remote_evaluators
|
||||
def reset(self, remote_workers):
|
||||
"""Called to change the set of remote workers being used."""
|
||||
self.workers.reset(remote_workers)
|
||||
|
||||
@DeveloperAPI
|
||||
def foreach_evaluator(self, func):
|
||||
"""Apply the given function to each evaluator instance."""
|
||||
|
||||
local_result = [func(self.local_evaluator)]
|
||||
remote_results = ray_get_and_free(
|
||||
[ev.apply.remote(func) for ev in self.remote_evaluators])
|
||||
return local_result + remote_results
|
||||
def foreach_worker(self, func):
|
||||
"""Apply the given function to each worker instance."""
|
||||
return self.workers.foreach_worker(func)
|
||||
|
||||
@DeveloperAPI
|
||||
def foreach_evaluator_with_index(self, func):
|
||||
"""Apply the given function to each evaluator instance.
|
||||
def foreach_worker_with_index(self, func):
|
||||
"""Apply the given function to each worker instance.
|
||||
|
||||
The index will be passed as the second arg to the given function.
|
||||
"""
|
||||
return self.workers.foreach_worker_with_index(func)
|
||||
|
||||
local_result = [func(self.local_evaluator, 0)]
|
||||
remote_results = ray_get_and_free([
|
||||
ev.apply.remote(func, i + 1)
|
||||
for i, ev in enumerate(self.remote_evaluators)
|
||||
])
|
||||
return local_result + remote_results
|
||||
def foreach_evaluator(self, func):
|
||||
raise DeprecationWarning(
|
||||
"foreach_evaluator has been renamed to foreach_worker")
|
||||
|
||||
def foreach_evaluator_with_index(self, func):
|
||||
raise DeprecationWarning(
|
||||
"foreach_evaluator_with_index has been renamed to "
|
||||
"foreach_worker_with_index")
|
||||
|
||||
@@ -20,12 +20,11 @@ class SyncBatchReplayOptimizer(PolicyOptimizer):
|
||||
This enables RNN support. Does not currently support prioritization."""
|
||||
|
||||
def __init__(self,
|
||||
local_evaluator,
|
||||
remote_evaluators,
|
||||
workers,
|
||||
learning_starts=1000,
|
||||
buffer_size=10000,
|
||||
train_batch_size=32):
|
||||
PolicyOptimizer.__init__(self, local_evaluator, remote_evaluators)
|
||||
PolicyOptimizer.__init__(self, workers)
|
||||
|
||||
self.replay_starts = learning_starts
|
||||
self.max_buffer_size = buffer_size
|
||||
@@ -45,17 +44,17 @@ class SyncBatchReplayOptimizer(PolicyOptimizer):
|
||||
@override(PolicyOptimizer)
|
||||
def step(self):
|
||||
with self.update_weights_timer:
|
||||
if self.remote_evaluators:
|
||||
weights = ray.put(self.local_evaluator.get_weights())
|
||||
for e in self.remote_evaluators:
|
||||
if self.workers.remote_workers():
|
||||
weights = ray.put(self.workers.local_worker().get_weights())
|
||||
for e in self.workers.remote_workers():
|
||||
e.set_weights.remote(weights)
|
||||
|
||||
with self.sample_timer:
|
||||
if self.remote_evaluators:
|
||||
if self.workers.remote_workers():
|
||||
batches = ray_get_and_free(
|
||||
[e.sample.remote() for e in self.remote_evaluators])
|
||||
[e.sample.remote() for e in self.workers.remote_workers()])
|
||||
else:
|
||||
batches = [self.local_evaluator.sample()]
|
||||
batches = [self.workers.local_worker().sample()]
|
||||
|
||||
# Handle everything as if multiagent
|
||||
tmp = []
|
||||
@@ -105,7 +104,7 @@ class SyncBatchReplayOptimizer(PolicyOptimizer):
|
||||
samples.append(random.choice(self.replay_buffer))
|
||||
samples = SampleBatch.concat_samples(samples)
|
||||
with self.grad_timer:
|
||||
info_dict = self.local_evaluator.learn_on_batch(samples)
|
||||
info_dict = self.workers.local_worker().learn_on_batch(samples)
|
||||
for policy_id, info in info_dict.items():
|
||||
self.learner_stats[policy_id] = get_learner_stats(info)
|
||||
self.grad_timer.push_units_processed(samples.count)
|
||||
|
||||
@@ -25,13 +25,12 @@ logger = logging.getLogger(__name__)
|
||||
class SyncReplayOptimizer(PolicyOptimizer):
|
||||
"""Variant of the local sync optimizer that supports replay (for DQN).
|
||||
|
||||
This optimizer requires that policy evaluators return an additional
|
||||
This optimizer requires that rollout workers return an additional
|
||||
"td_error" array in the info return of compute_gradients(). This error
|
||||
term will be used for sample prioritization."""
|
||||
|
||||
def __init__(self,
|
||||
local_evaluator,
|
||||
remote_evaluators,
|
||||
workers,
|
||||
learning_starts=1000,
|
||||
buffer_size=10000,
|
||||
prioritized_replay=True,
|
||||
@@ -43,7 +42,7 @@ class SyncReplayOptimizer(PolicyOptimizer):
|
||||
prioritized_replay_eps=1e-6,
|
||||
train_batch_size=32,
|
||||
sample_batch_size=4):
|
||||
PolicyOptimizer.__init__(self, local_evaluator, remote_evaluators)
|
||||
PolicyOptimizer.__init__(self, workers)
|
||||
|
||||
self.replay_starts = learning_starts
|
||||
# linearly annealing beta used in Rainbow paper
|
||||
@@ -82,18 +81,20 @@ class SyncReplayOptimizer(PolicyOptimizer):
|
||||
@override(PolicyOptimizer)
|
||||
def step(self):
|
||||
with self.update_weights_timer:
|
||||
if self.remote_evaluators:
|
||||
weights = ray.put(self.local_evaluator.get_weights())
|
||||
for e in self.remote_evaluators:
|
||||
if self.workers.remote_workers():
|
||||
weights = ray.put(self.workers.local_worker().get_weights())
|
||||
for e in self.workers.remote_workers():
|
||||
e.set_weights.remote(weights)
|
||||
|
||||
with self.sample_timer:
|
||||
if self.remote_evaluators:
|
||||
if self.workers.remote_workers():
|
||||
batch = SampleBatch.concat_samples(
|
||||
ray_get_and_free(
|
||||
[e.sample.remote() for e in self.remote_evaluators]))
|
||||
ray_get_and_free([
|
||||
e.sample.remote()
|
||||
for e in self.workers.remote_workers()
|
||||
]))
|
||||
else:
|
||||
batch = self.local_evaluator.sample()
|
||||
batch = self.workers.local_worker().sample()
|
||||
|
||||
# Handle everything as if multiagent
|
||||
if isinstance(batch, SampleBatch):
|
||||
@@ -135,7 +136,7 @@ class SyncReplayOptimizer(PolicyOptimizer):
|
||||
samples = self._replay()
|
||||
|
||||
with self.grad_timer:
|
||||
info_dict = self.local_evaluator.learn_on_batch(samples)
|
||||
info_dict = self.workers.local_worker().learn_on_batch(samples)
|
||||
for policy_id, info in info_dict.items():
|
||||
self.learner_stats[policy_id] = get_learner_stats(info)
|
||||
replay_buffer = self.replay_buffers[policy_id]
|
||||
|
||||
@@ -19,16 +19,12 @@ class SyncSamplesOptimizer(PolicyOptimizer):
|
||||
"""A simple synchronous RL optimizer.
|
||||
|
||||
In each step, this optimizer pulls samples from a number of remote
|
||||
evaluators, concatenates them, and then updates a local model. The updated
|
||||
model weights are then broadcast to all remote evaluators.
|
||||
workers, concatenates them, and then updates a local model. The updated
|
||||
model weights are then broadcast to all remote workers.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
local_evaluator,
|
||||
remote_evaluators,
|
||||
num_sgd_iter=1,
|
||||
train_batch_size=1):
|
||||
PolicyOptimizer.__init__(self, local_evaluator, remote_evaluators)
|
||||
def __init__(self, workers, num_sgd_iter=1, train_batch_size=1):
|
||||
PolicyOptimizer.__init__(self, workers)
|
||||
|
||||
self.update_weights_timer = TimerStat()
|
||||
self.sample_timer = TimerStat()
|
||||
@@ -41,27 +37,28 @@ class SyncSamplesOptimizer(PolicyOptimizer):
|
||||
@override(PolicyOptimizer)
|
||||
def step(self):
|
||||
with self.update_weights_timer:
|
||||
if self.remote_evaluators:
|
||||
weights = ray.put(self.local_evaluator.get_weights())
|
||||
for e in self.remote_evaluators:
|
||||
if self.workers.remote_workers():
|
||||
weights = ray.put(self.workers.local_worker().get_weights())
|
||||
for e in self.workers.remote_workers():
|
||||
e.set_weights.remote(weights)
|
||||
|
||||
with self.sample_timer:
|
||||
samples = []
|
||||
while sum(s.count for s in samples) < self.train_batch_size:
|
||||
if self.remote_evaluators:
|
||||
if self.workers.remote_workers():
|
||||
samples.extend(
|
||||
ray_get_and_free([
|
||||
e.sample.remote() for e in self.remote_evaluators
|
||||
e.sample.remote()
|
||||
for e in self.workers.remote_workers()
|
||||
]))
|
||||
else:
|
||||
samples.append(self.local_evaluator.sample())
|
||||
samples.append(self.workers.local_worker().sample())
|
||||
samples = SampleBatch.concat_samples(samples)
|
||||
self.sample_timer.push_units_processed(samples.count)
|
||||
|
||||
with self.grad_timer:
|
||||
for i in range(self.num_sgd_iter):
|
||||
fetches = self.local_evaluator.learn_on_batch(samples)
|
||||
fetches = self.workers.local_worker().learn_on_batch(samples)
|
||||
self.learner_stats = get_learner_stats(fetches)
|
||||
if self.num_sgd_iter > 1:
|
||||
logger.debug("{} {}".format(i, fetches))
|
||||
|
||||
@@ -142,7 +142,7 @@ class DynamicTFPolicy(TFPolicy):
|
||||
action_prob = self.action_dist.sampled_action_prob()
|
||||
|
||||
# Phase 1 init
|
||||
sess = tf.get_default_session()
|
||||
sess = tf.get_default_session() or tf.Session()
|
||||
if get_batch_divisibility_req:
|
||||
batch_divisibility_req = get_batch_divisibility_req(self)
|
||||
else:
|
||||
|
||||
@@ -36,7 +36,7 @@ class Policy(object):
|
||||
"""Initialize the graph.
|
||||
|
||||
This is the standard constructor for policies. The policy
|
||||
class you pass into PolicyEvaluator will be constructed with
|
||||
class you pass into RolloutWorker will be constructed with
|
||||
these arguments.
|
||||
|
||||
Args:
|
||||
|
||||
@@ -88,9 +88,7 @@ def build_tf_policy(name,
|
||||
a DynamicTFPolicy instance that uses the specified args
|
||||
"""
|
||||
|
||||
if not name.endswith("TFPolicy"):
|
||||
raise ValueError("Name should match *TFPolicy", name)
|
||||
|
||||
original_kwargs = locals().copy()
|
||||
base = DynamicTFPolicy
|
||||
while mixins:
|
||||
|
||||
@@ -191,6 +189,11 @@ def build_tf_policy(name,
|
||||
else:
|
||||
return TFPolicy.extra_compute_grad_feed_dict(self)
|
||||
|
||||
@staticmethod
|
||||
def with_updates(**overrides):
|
||||
return build_tf_policy(**dict(original_kwargs, **overrides))
|
||||
|
||||
policy_cls.with_updates = with_updates
|
||||
policy_cls.__name__ = name
|
||||
policy_cls.__qualname__ = name
|
||||
return policy_cls
|
||||
|
||||
@@ -24,7 +24,7 @@ def build_torch_policy(name,
|
||||
"""Helper function for creating a torch policy at runtime.
|
||||
|
||||
Arguments:
|
||||
name (str): name of the policy (e.g., "PPOTFPolicy")
|
||||
name (str): name of the policy (e.g., "PPOTorchPolicy")
|
||||
loss_fn (func): function that returns a loss tensor the policy,
|
||||
and dict of experience tensor placeholders
|
||||
get_default_config (func): optional function that returns the default
|
||||
@@ -55,9 +55,7 @@ def build_torch_policy(name,
|
||||
a TorchPolicy instance that uses the specified args
|
||||
"""
|
||||
|
||||
if not name.endswith("TorchPolicy"):
|
||||
raise ValueError("Name should match *TorchPolicy", name)
|
||||
|
||||
original_kwargs = locals().copy()
|
||||
base = TorchPolicy
|
||||
while mixins:
|
||||
|
||||
@@ -66,7 +64,7 @@ def build_torch_policy(name,
|
||||
|
||||
base = new_base
|
||||
|
||||
class graph_cls(base):
|
||||
class policy_cls(base):
|
||||
def __init__(self, obs_space, action_space, config):
|
||||
if get_default_config:
|
||||
config = dict(get_default_config(), **config)
|
||||
@@ -130,6 +128,11 @@ def build_torch_policy(name,
|
||||
else:
|
||||
return TorchPolicy.extra_grad_info(self, batch_tensors)
|
||||
|
||||
graph_cls.__name__ = name
|
||||
graph_cls.__qualname__ = name
|
||||
return graph_cls
|
||||
@staticmethod
|
||||
def with_updates(**overrides):
|
||||
return build_torch_policy(**dict(original_kwargs, **overrides))
|
||||
|
||||
policy_cls.with_updates = with_updates
|
||||
policy_cls.__name__ = name
|
||||
policy_cls.__qualname__ = name
|
||||
return policy_cls
|
||||
|
||||
@@ -120,14 +120,14 @@ def default_policy_agent_mapping(unused_agent_id):
|
||||
def rollout(agent, env_name, num_steps, out=None, no_render=True):
|
||||
policy_agent_mapping = default_policy_agent_mapping
|
||||
|
||||
if hasattr(agent, "local_evaluator"):
|
||||
env = agent.local_evaluator.env
|
||||
if hasattr(agent, "workers"):
|
||||
env = agent.workers.local_worker().env
|
||||
multiagent = isinstance(env, MultiAgentEnv)
|
||||
if agent.local_evaluator.multiagent:
|
||||
if agent.workers.local_worker().multiagent:
|
||||
policy_agent_mapping = agent.config["multiagent"][
|
||||
"policy_mapping_fn"]
|
||||
|
||||
policy_map = agent.local_evaluator.policy_map
|
||||
policy_map = agent.workers.local_worker().policy_map
|
||||
state_init = {p: m.get_initial_state() for p, m in policy_map.items()}
|
||||
use_lstm = {p: len(s) > 0 for p, s in state_init.items()}
|
||||
action_init = {
|
||||
|
||||
@@ -8,7 +8,7 @@ from ray.rllib.evaluation import SampleBatch
|
||||
from ray.rllib.utils.filter import MeanStdFilter
|
||||
|
||||
|
||||
class _MockEvaluator(object):
|
||||
class _MockWorker(object):
|
||||
def __init__(self, sample_count=10):
|
||||
self._weights = np.array([-10, -10, -10, -10])
|
||||
self._grad = np.array([1, 1, 1, 1])
|
||||
@@ -11,10 +11,10 @@ import uuid
|
||||
import ray
|
||||
from ray.rllib.agents.dqn import DQNTrainer
|
||||
from ray.rllib.agents.pg import PGTrainer
|
||||
from ray.rllib.evaluation.policy_evaluator import PolicyEvaluator
|
||||
from ray.rllib.evaluation.rollout_worker import RolloutWorker
|
||||
from ray.rllib.env.external_env import ExternalEnv
|
||||
from ray.rllib.tests.test_policy_evaluator import (BadPolicy, MockPolicy,
|
||||
MockEnv)
|
||||
from ray.rllib.tests.test_rollout_worker import (BadPolicy, MockPolicy,
|
||||
MockEnv)
|
||||
from ray.tune.registry import register_env
|
||||
|
||||
|
||||
@@ -119,7 +119,7 @@ class MultiServing(ExternalEnv):
|
||||
|
||||
class TestExternalEnv(unittest.TestCase):
|
||||
def testExternalEnvCompleteEpisodes(self):
|
||||
ev = PolicyEvaluator(
|
||||
ev = RolloutWorker(
|
||||
env_creator=lambda _: SimpleServing(MockEnv(25)),
|
||||
policy=MockPolicy,
|
||||
batch_steps=40,
|
||||
@@ -129,7 +129,7 @@ class TestExternalEnv(unittest.TestCase):
|
||||
self.assertEqual(batch.count, 50)
|
||||
|
||||
def testExternalEnvTruncateEpisodes(self):
|
||||
ev = PolicyEvaluator(
|
||||
ev = RolloutWorker(
|
||||
env_creator=lambda _: SimpleServing(MockEnv(25)),
|
||||
policy=MockPolicy,
|
||||
batch_steps=40,
|
||||
@@ -139,7 +139,7 @@ class TestExternalEnv(unittest.TestCase):
|
||||
self.assertEqual(batch.count, 40)
|
||||
|
||||
def testExternalEnvOffPolicy(self):
|
||||
ev = PolicyEvaluator(
|
||||
ev = RolloutWorker(
|
||||
env_creator=lambda _: SimpleOffPolicyServing(MockEnv(25), 42),
|
||||
policy=MockPolicy,
|
||||
batch_steps=40,
|
||||
@@ -151,7 +151,7 @@ class TestExternalEnv(unittest.TestCase):
|
||||
self.assertEqual(batch["actions"][-1], 42)
|
||||
|
||||
def testExternalEnvBadActions(self):
|
||||
ev = PolicyEvaluator(
|
||||
ev = RolloutWorker(
|
||||
env_creator=lambda _: SimpleServing(MockEnv(25)),
|
||||
policy=BadPolicy,
|
||||
sample_async=True,
|
||||
@@ -196,7 +196,7 @@ class TestExternalEnv(unittest.TestCase):
|
||||
raise Exception("failed to improve reward")
|
||||
|
||||
def testExternalEnvHorizonNotSupported(self):
|
||||
ev = PolicyEvaluator(
|
||||
ev = RolloutWorker(
|
||||
env_creator=lambda _: SimpleServing(MockEnv(25)),
|
||||
policy=MockPolicy,
|
||||
episode_horizon=20,
|
||||
|
||||
@@ -10,9 +10,10 @@ import unittest
|
||||
import ray
|
||||
from ray.rllib.agents.pg.pg_policy import PGTFPolicy
|
||||
from ray.rllib.optimizers import SyncSamplesOptimizer
|
||||
from ray.rllib.evaluation.policy_evaluator import PolicyEvaluator
|
||||
from ray.rllib.evaluation.rollout_worker import RolloutWorker
|
||||
from ray.rllib.evaluation.worker_set import WorkerSet
|
||||
from ray.rllib.env.external_multi_agent_env import ExternalMultiAgentEnv
|
||||
from ray.rllib.tests.test_policy_evaluator import MockPolicy
|
||||
from ray.rllib.tests.test_rollout_worker import MockPolicy
|
||||
from ray.rllib.tests.test_external_env import make_simple_serving
|
||||
from ray.rllib.tests.test_multi_agent_env import BasicMultiAgent, MultiCartpole
|
||||
from ray.rllib.evaluation.metrics import collect_metrics
|
||||
@@ -23,7 +24,7 @@ SimpleMultiServing = make_simple_serving(True, ExternalMultiAgentEnv)
|
||||
class TestExternalMultiAgentEnv(unittest.TestCase):
|
||||
def testExternalMultiAgentEnvCompleteEpisodes(self):
|
||||
agents = 4
|
||||
ev = PolicyEvaluator(
|
||||
ev = RolloutWorker(
|
||||
env_creator=lambda _: SimpleMultiServing(BasicMultiAgent(agents)),
|
||||
policy=MockPolicy,
|
||||
batch_steps=40,
|
||||
@@ -35,7 +36,7 @@ class TestExternalMultiAgentEnv(unittest.TestCase):
|
||||
|
||||
def testExternalMultiAgentEnvTruncateEpisodes(self):
|
||||
agents = 4
|
||||
ev = PolicyEvaluator(
|
||||
ev = RolloutWorker(
|
||||
env_creator=lambda _: SimpleMultiServing(BasicMultiAgent(agents)),
|
||||
policy=MockPolicy,
|
||||
batch_steps=40,
|
||||
@@ -49,7 +50,7 @@ class TestExternalMultiAgentEnv(unittest.TestCase):
|
||||
agents = 2
|
||||
act_space = gym.spaces.Discrete(2)
|
||||
obs_space = gym.spaces.Discrete(2)
|
||||
ev = PolicyEvaluator(
|
||||
ev = RolloutWorker(
|
||||
env_creator=lambda _: SimpleMultiServing(BasicMultiAgent(agents)),
|
||||
policy={
|
||||
"p0": (MockPolicy, obs_space, act_space, {}),
|
||||
@@ -70,12 +71,12 @@ class TestExternalMultiAgentEnv(unittest.TestCase):
|
||||
policies["pg_{}".format(i)] = (PGTFPolicy, obs_space, act_space,
|
||||
{})
|
||||
policy_ids = list(policies.keys())
|
||||
ev = PolicyEvaluator(
|
||||
ev = RolloutWorker(
|
||||
env_creator=lambda _: MultiCartpole(n),
|
||||
policy=policies,
|
||||
policy_mapping_fn=lambda agent_id: random.choice(policy_ids),
|
||||
batch_steps=100)
|
||||
optimizer = SyncSamplesOptimizer(ev, [])
|
||||
optimizer = SyncSamplesOptimizer(WorkerSet._from_existing(ev))
|
||||
for i in range(100):
|
||||
optimizer.step()
|
||||
result = collect_metrics(ev)
|
||||
|
||||
@@ -8,7 +8,7 @@ import numpy as np
|
||||
import ray
|
||||
from ray.rllib.utils.filter import RunningStat, MeanStdFilter
|
||||
from ray.rllib.utils import FilterManager
|
||||
from ray.rllib.tests.mock_evaluator import _MockEvaluator
|
||||
from ray.rllib.tests.mock_worker import _MockWorker
|
||||
|
||||
|
||||
class RunningStatTest(unittest.TestCase):
|
||||
@@ -89,8 +89,8 @@ class FilterManagerTest(unittest.TestCase):
|
||||
filt1.clear_buffer()
|
||||
self.assertEqual(filt1.buffer.n, 0)
|
||||
|
||||
RemoteEvaluator = ray.remote(_MockEvaluator)
|
||||
remote_e = RemoteEvaluator.remote(sample_count=10)
|
||||
RemoteWorker = ray.remote(_MockWorker)
|
||||
remote_e = RemoteWorker.remote(sample_count=10)
|
||||
remote_e.sample.remote()
|
||||
|
||||
FilterManager.synchronize({
|
||||
|
||||
@@ -12,11 +12,11 @@ from ray.rllib.agents.pg.pg_policy import PGTFPolicy
|
||||
from ray.rllib.agents.dqn.dqn_policy import DQNTFPolicy
|
||||
from ray.rllib.optimizers import (SyncSamplesOptimizer, SyncReplayOptimizer,
|
||||
AsyncGradientsOptimizer)
|
||||
from ray.rllib.tests.test_policy_evaluator import (MockEnv, MockEnv2,
|
||||
MockPolicy)
|
||||
from ray.rllib.evaluation.policy_evaluator import PolicyEvaluator
|
||||
from ray.rllib.tests.test_rollout_worker import (MockEnv, MockEnv2, MockPolicy)
|
||||
from ray.rllib.evaluation.rollout_worker import RolloutWorker
|
||||
from ray.rllib.policy.policy import Policy
|
||||
from ray.rllib.evaluation.metrics import collect_metrics
|
||||
from ray.rllib.evaluation.worker_set import WorkerSet
|
||||
from ray.rllib.env.base_env import _MultiAgentEnvToBaseEnv
|
||||
from ray.rllib.env.multi_agent_env import MultiAgentEnv
|
||||
from ray.tune.registry import register_env
|
||||
@@ -327,7 +327,7 @@ class TestMultiAgentEnv(unittest.TestCase):
|
||||
def testMultiAgentSample(self):
|
||||
act_space = gym.spaces.Discrete(2)
|
||||
obs_space = gym.spaces.Discrete(2)
|
||||
ev = PolicyEvaluator(
|
||||
ev = RolloutWorker(
|
||||
env_creator=lambda _: BasicMultiAgent(5),
|
||||
policy={
|
||||
"p0": (MockPolicy, obs_space, act_space, {}),
|
||||
@@ -345,7 +345,7 @@ class TestMultiAgentEnv(unittest.TestCase):
|
||||
def testMultiAgentSampleSyncRemote(self):
|
||||
act_space = gym.spaces.Discrete(2)
|
||||
obs_space = gym.spaces.Discrete(2)
|
||||
ev = PolicyEvaluator(
|
||||
ev = RolloutWorker(
|
||||
env_creator=lambda _: BasicMultiAgent(5),
|
||||
policy={
|
||||
"p0": (MockPolicy, obs_space, act_space, {}),
|
||||
@@ -362,7 +362,7 @@ class TestMultiAgentEnv(unittest.TestCase):
|
||||
def testMultiAgentSampleAsyncRemote(self):
|
||||
act_space = gym.spaces.Discrete(2)
|
||||
obs_space = gym.spaces.Discrete(2)
|
||||
ev = PolicyEvaluator(
|
||||
ev = RolloutWorker(
|
||||
env_creator=lambda _: BasicMultiAgent(5),
|
||||
policy={
|
||||
"p0": (MockPolicy, obs_space, act_space, {}),
|
||||
@@ -378,7 +378,7 @@ class TestMultiAgentEnv(unittest.TestCase):
|
||||
def testMultiAgentSampleWithHorizon(self):
|
||||
act_space = gym.spaces.Discrete(2)
|
||||
obs_space = gym.spaces.Discrete(2)
|
||||
ev = PolicyEvaluator(
|
||||
ev = RolloutWorker(
|
||||
env_creator=lambda _: BasicMultiAgent(5),
|
||||
policy={
|
||||
"p0": (MockPolicy, obs_space, act_space, {}),
|
||||
@@ -393,7 +393,7 @@ class TestMultiAgentEnv(unittest.TestCase):
|
||||
def testSampleFromEarlyDoneEnv(self):
|
||||
act_space = gym.spaces.Discrete(2)
|
||||
obs_space = gym.spaces.Discrete(2)
|
||||
ev = PolicyEvaluator(
|
||||
ev = RolloutWorker(
|
||||
env_creator=lambda _: EarlyDoneMultiAgent(),
|
||||
policy={
|
||||
"p0": (MockPolicy, obs_space, act_space, {}),
|
||||
@@ -409,7 +409,7 @@ class TestMultiAgentEnv(unittest.TestCase):
|
||||
def testMultiAgentSampleRoundRobin(self):
|
||||
act_space = gym.spaces.Discrete(2)
|
||||
obs_space = gym.spaces.Discrete(10)
|
||||
ev = PolicyEvaluator(
|
||||
ev = RolloutWorker(
|
||||
env_creator=lambda _: RoundRobinMultiAgent(5, increment_obs=True),
|
||||
policy={
|
||||
"p0": (MockPolicy, obs_space, act_space, {}),
|
||||
@@ -458,7 +458,7 @@ class TestMultiAgentEnv(unittest.TestCase):
|
||||
def get_initial_state(self):
|
||||
return [{}] # empty dict
|
||||
|
||||
ev = PolicyEvaluator(
|
||||
ev = RolloutWorker(
|
||||
env_creator=lambda _: gym.make("CartPole-v0"),
|
||||
policy=StatefulPolicy,
|
||||
batch_steps=5)
|
||||
@@ -503,7 +503,7 @@ class TestMultiAgentEnv(unittest.TestCase):
|
||||
single_env = gym.make("CartPole-v0")
|
||||
obs_space = single_env.observation_space
|
||||
act_space = single_env.action_space
|
||||
ev = PolicyEvaluator(
|
||||
ev = RolloutWorker(
|
||||
env_creator=lambda _: MultiCartpole(2),
|
||||
policy={
|
||||
"p0": (ModelBasedPolicy, obs_space, act_space, {}),
|
||||
@@ -587,7 +587,7 @@ class TestMultiAgentEnv(unittest.TestCase):
|
||||
"p1": (PGTFPolicy, obs_space, act_space, {}),
|
||||
"p2": (DQNTFPolicy, obs_space, act_space, dqn_config),
|
||||
}
|
||||
ev = PolicyEvaluator(
|
||||
worker = RolloutWorker(
|
||||
env_creator=lambda _: MultiCartpole(n),
|
||||
policy=policies,
|
||||
policy_mapping_fn=lambda agent_id: ["p1", "p2"][agent_id % 2],
|
||||
@@ -597,29 +597,30 @@ class TestMultiAgentEnv(unittest.TestCase):
|
||||
def policy_mapper(agent_id):
|
||||
return ["p1", "p2"][agent_id % 2]
|
||||
|
||||
remote_evs = [
|
||||
PolicyEvaluator.as_remote().remote(
|
||||
remote_workers = [
|
||||
RolloutWorker.as_remote().remote(
|
||||
env_creator=lambda _: MultiCartpole(n),
|
||||
policy=policies,
|
||||
policy_mapping_fn=policy_mapper,
|
||||
batch_steps=50)
|
||||
]
|
||||
else:
|
||||
remote_evs = []
|
||||
optimizer = optimizer_cls(ev, remote_evs)
|
||||
remote_workers = []
|
||||
workers = WorkerSet._from_existing(worker, remote_workers)
|
||||
optimizer = optimizer_cls(workers)
|
||||
for i in range(200):
|
||||
ev.foreach_policy(lambda p, _: p.set_epsilon(
|
||||
worker.foreach_policy(lambda p, _: p.set_epsilon(
|
||||
max(0.02, 1 - i * .02))
|
||||
if isinstance(p, DQNTFPolicy) else None)
|
||||
optimizer.step()
|
||||
result = collect_metrics(ev, remote_evs)
|
||||
result = collect_metrics(worker, remote_workers)
|
||||
if i % 20 == 0:
|
||||
|
||||
def do_update(p):
|
||||
if isinstance(p, DQNTFPolicy):
|
||||
p.update_target()
|
||||
|
||||
ev.foreach_policy(lambda p, _: do_update(p))
|
||||
worker.foreach_policy(lambda p, _: do_update(p))
|
||||
print("Iter {}, rew {}".format(i,
|
||||
result["policy_reward_mean"]))
|
||||
print("Total reward", result["episode_reward_mean"])
|
||||
@@ -647,15 +648,16 @@ class TestMultiAgentEnv(unittest.TestCase):
|
||||
policies["pg_{}".format(i)] = (PGTFPolicy, obs_space, act_space,
|
||||
{})
|
||||
policy_ids = list(policies.keys())
|
||||
ev = PolicyEvaluator(
|
||||
worker = RolloutWorker(
|
||||
env_creator=lambda _: MultiCartpole(n),
|
||||
policy=policies,
|
||||
policy_mapping_fn=lambda agent_id: random.choice(policy_ids),
|
||||
batch_steps=100)
|
||||
optimizer = SyncSamplesOptimizer(ev, [])
|
||||
workers = WorkerSet._from_existing(worker, [])
|
||||
optimizer = SyncSamplesOptimizer(workers)
|
||||
for i in range(100):
|
||||
optimizer.step()
|
||||
result = collect_metrics(ev)
|
||||
result = collect_metrics(worker)
|
||||
print("Iteration {}, rew {}".format(i,
|
||||
result["policy_reward_mean"]))
|
||||
print("Total reward", result["episode_reward_mean"])
|
||||
|
||||
@@ -11,10 +11,11 @@ import ray
|
||||
from ray.rllib.agents.ppo import PPOTrainer
|
||||
from ray.rllib.agents.ppo.ppo_policy import PPOTFPolicy
|
||||
from ray.rllib.evaluation import SampleBatch
|
||||
from ray.rllib.evaluation.policy_evaluator import PolicyEvaluator
|
||||
from ray.rllib.evaluation.rollout_worker import RolloutWorker
|
||||
from ray.rllib.evaluation.worker_set import WorkerSet
|
||||
from ray.rllib.optimizers import AsyncGradientsOptimizer, AsyncSamplesOptimizer
|
||||
from ray.rllib.optimizers.aso_tree_aggregator import TreeAggregator
|
||||
from ray.rllib.tests.mock_evaluator import _MockEvaluator
|
||||
from ray.rllib.tests.mock_worker import _MockWorker
|
||||
from ray.rllib.utils import try_import_tf
|
||||
|
||||
tf = try_import_tf()
|
||||
@@ -26,11 +27,11 @@ class AsyncOptimizerTest(unittest.TestCase):
|
||||
|
||||
def testBasic(self):
|
||||
ray.init(num_cpus=4)
|
||||
local = _MockEvaluator()
|
||||
remotes = ray.remote(_MockEvaluator)
|
||||
remote_evaluators = [remotes.remote() for i in range(5)]
|
||||
test_optimizer = AsyncGradientsOptimizer(
|
||||
local, remote_evaluators, grads_per_step=10)
|
||||
local = _MockWorker()
|
||||
remotes = ray.remote(_MockWorker)
|
||||
remote_workers = [remotes.remote() for i in range(5)]
|
||||
workers = WorkerSet._from_existing(local, remote_workers)
|
||||
test_optimizer = AsyncGradientsOptimizer(workers, grads_per_step=10)
|
||||
test_optimizer.step()
|
||||
self.assertTrue(all(local.get_weights() == 0))
|
||||
|
||||
@@ -117,30 +118,28 @@ class AsyncSamplesOptimizerTest(unittest.TestCase):
|
||||
|
||||
def testSimple(self):
|
||||
local, remotes = self._make_evs()
|
||||
optimizer = AsyncSamplesOptimizer(local, remotes)
|
||||
workers = WorkerSet._from_existing(local, remotes)
|
||||
optimizer = AsyncSamplesOptimizer(workers)
|
||||
self._wait_for(optimizer, 1000, 1000)
|
||||
|
||||
def testMultiGPU(self):
|
||||
local, remotes = self._make_evs()
|
||||
optimizer = AsyncSamplesOptimizer(
|
||||
local, remotes, num_gpus=2, _fake_gpus=True)
|
||||
workers = WorkerSet._from_existing(local, remotes)
|
||||
optimizer = AsyncSamplesOptimizer(workers, num_gpus=2, _fake_gpus=True)
|
||||
self._wait_for(optimizer, 1000, 1000)
|
||||
|
||||
def testMultiGPUParallelLoad(self):
|
||||
local, remotes = self._make_evs()
|
||||
workers = WorkerSet._from_existing(local, remotes)
|
||||
optimizer = AsyncSamplesOptimizer(
|
||||
local,
|
||||
remotes,
|
||||
num_gpus=2,
|
||||
num_data_loader_buffers=2,
|
||||
_fake_gpus=True)
|
||||
workers, num_gpus=2, num_data_loader_buffers=2, _fake_gpus=True)
|
||||
self._wait_for(optimizer, 1000, 1000)
|
||||
|
||||
def testMultiplePasses(self):
|
||||
local, remotes = self._make_evs()
|
||||
workers = WorkerSet._from_existing(local, remotes)
|
||||
optimizer = AsyncSamplesOptimizer(
|
||||
local,
|
||||
remotes,
|
||||
workers,
|
||||
minibatch_buffer_size=10,
|
||||
num_sgd_iter=10,
|
||||
sample_batch_size=10,
|
||||
@@ -151,9 +150,9 @@ class AsyncSamplesOptimizerTest(unittest.TestCase):
|
||||
|
||||
def testReplay(self):
|
||||
local, remotes = self._make_evs()
|
||||
workers = WorkerSet._from_existing(local, remotes)
|
||||
optimizer = AsyncSamplesOptimizer(
|
||||
local,
|
||||
remotes,
|
||||
workers,
|
||||
replay_buffer_num_slots=100,
|
||||
replay_proportion=10,
|
||||
sample_batch_size=10,
|
||||
@@ -168,9 +167,9 @@ class AsyncSamplesOptimizerTest(unittest.TestCase):
|
||||
|
||||
def testReplayAndMultiplePasses(self):
|
||||
local, remotes = self._make_evs()
|
||||
workers = WorkerSet._from_existing(local, remotes)
|
||||
optimizer = AsyncSamplesOptimizer(
|
||||
local,
|
||||
remotes,
|
||||
workers,
|
||||
minibatch_buffer_size=10,
|
||||
num_sgd_iter=10,
|
||||
replay_buffer_num_slots=100,
|
||||
@@ -189,45 +188,43 @@ class AsyncSamplesOptimizerTest(unittest.TestCase):
|
||||
|
||||
def testMultiTierAggregationBadConf(self):
|
||||
local, remotes = self._make_evs()
|
||||
workers = WorkerSet._from_existing(local, remotes)
|
||||
aggregators = TreeAggregator.precreate_aggregators(4)
|
||||
optimizer = AsyncSamplesOptimizer(
|
||||
local, remotes, num_aggregation_workers=4)
|
||||
optimizer = AsyncSamplesOptimizer(workers, num_aggregation_workers=4)
|
||||
self.assertRaises(ValueError,
|
||||
lambda: optimizer.aggregator.init(aggregators))
|
||||
|
||||
def testMultiTierAggregation(self):
|
||||
local, remotes = self._make_evs()
|
||||
workers = WorkerSet._from_existing(local, remotes)
|
||||
aggregators = TreeAggregator.precreate_aggregators(1)
|
||||
optimizer = AsyncSamplesOptimizer(
|
||||
local, remotes, num_aggregation_workers=1)
|
||||
optimizer = AsyncSamplesOptimizer(workers, num_aggregation_workers=1)
|
||||
optimizer.aggregator.init(aggregators)
|
||||
self._wait_for(optimizer, 1000, 1000)
|
||||
|
||||
def testRejectBadConfigs(self):
|
||||
local, remotes = self._make_evs()
|
||||
workers = WorkerSet._from_existing(local, remotes)
|
||||
self.assertRaises(
|
||||
ValueError, lambda: AsyncSamplesOptimizer(
|
||||
local, remotes,
|
||||
num_data_loader_buffers=2, minibatch_buffer_size=4))
|
||||
optimizer = AsyncSamplesOptimizer(
|
||||
local,
|
||||
remotes,
|
||||
workers,
|
||||
num_gpus=2,
|
||||
train_batch_size=100,
|
||||
sample_batch_size=50,
|
||||
_fake_gpus=True)
|
||||
self._wait_for(optimizer, 1000, 1000)
|
||||
optimizer = AsyncSamplesOptimizer(
|
||||
local,
|
||||
remotes,
|
||||
workers,
|
||||
num_gpus=2,
|
||||
train_batch_size=100,
|
||||
sample_batch_size=25,
|
||||
_fake_gpus=True)
|
||||
self._wait_for(optimizer, 1000, 1000)
|
||||
optimizer = AsyncSamplesOptimizer(
|
||||
local,
|
||||
remotes,
|
||||
workers,
|
||||
num_gpus=2,
|
||||
train_batch_size=100,
|
||||
sample_batch_size=74,
|
||||
@@ -238,12 +235,12 @@ class AsyncSamplesOptimizerTest(unittest.TestCase):
|
||||
def make_sess():
|
||||
return tf.Session(config=tf.ConfigProto(device_count={"CPU": 2}))
|
||||
|
||||
local = PolicyEvaluator(
|
||||
local = RolloutWorker(
|
||||
env_creator=lambda _: gym.make("CartPole-v0"),
|
||||
policy=PPOTFPolicy,
|
||||
tf_session_creator=make_sess)
|
||||
remotes = [
|
||||
PolicyEvaluator.as_remote().remote(
|
||||
RolloutWorker.as_remote().remote(
|
||||
env_creator=lambda _: gym.make("CartPole-v0"),
|
||||
policy=PPOTFPolicy,
|
||||
tf_session_creator=make_sess)
|
||||
|
||||
@@ -7,8 +7,8 @@ import time
|
||||
import unittest
|
||||
|
||||
import ray
|
||||
from ray.rllib.evaluation.policy_evaluator import PolicyEvaluator
|
||||
from ray.rllib.tests.test_policy_evaluator import MockPolicy
|
||||
from ray.rllib.evaluation.rollout_worker import RolloutWorker
|
||||
from ray.rllib.tests.test_rollout_worker import MockPolicy
|
||||
|
||||
|
||||
class TestPerf(unittest.TestCase):
|
||||
@@ -17,7 +17,7 @@ class TestPerf(unittest.TestCase):
|
||||
# 03/01/19: Samples per second 8610.164353268685
|
||||
def testBaselinePerformance(self):
|
||||
for _ in range(20):
|
||||
ev = PolicyEvaluator(
|
||||
ev = RolloutWorker(
|
||||
env_creator=lambda _: gym.make("CartPole-v0"),
|
||||
policy=MockPolicy,
|
||||
batch_steps=100)
|
||||
|
||||
+23
-24
@@ -12,7 +12,7 @@ from collections import Counter
|
||||
import ray
|
||||
from ray.rllib.agents.pg import PGTrainer
|
||||
from ray.rllib.agents.a3c import A2CTrainer
|
||||
from ray.rllib.evaluation.policy_evaluator import PolicyEvaluator
|
||||
from ray.rllib.evaluation.rollout_worker import RolloutWorker
|
||||
from ray.rllib.evaluation.metrics import collect_metrics
|
||||
from ray.rllib.policy.policy import Policy
|
||||
from ray.rllib.evaluation.postprocessing import compute_advantages
|
||||
@@ -129,9 +129,9 @@ class MockVectorEnv(VectorEnv):
|
||||
return self.envs
|
||||
|
||||
|
||||
class TestPolicyEvaluator(unittest.TestCase):
|
||||
class TestRolloutWorker(unittest.TestCase):
|
||||
def testBasic(self):
|
||||
ev = PolicyEvaluator(
|
||||
ev = RolloutWorker(
|
||||
env_creator=lambda _: gym.make("CartPole-v0"), policy=MockPolicy)
|
||||
batch = ev.sample()
|
||||
for key in [
|
||||
@@ -155,7 +155,7 @@ class TestPolicyEvaluator(unittest.TestCase):
|
||||
self.assertGreater(batch["advantages"][0], 1)
|
||||
|
||||
def testBatchIds(self):
|
||||
ev = PolicyEvaluator(
|
||||
ev = RolloutWorker(
|
||||
env_creator=lambda _: gym.make("CartPole-v0"), policy=MockPolicy)
|
||||
batch1 = ev.sample()
|
||||
batch2 = ev.sample()
|
||||
@@ -213,11 +213,10 @@ class TestPolicyEvaluator(unittest.TestCase):
|
||||
"sample_batch_size": 5,
|
||||
"num_envs_per_worker": 2,
|
||||
})
|
||||
results = pg.optimizer.foreach_evaluator(
|
||||
lambda ev: ev.sample_batch_size)
|
||||
results2 = pg.optimizer.foreach_evaluator_with_index(
|
||||
results = pg.workers.foreach_worker(lambda ev: ev.sample_batch_size)
|
||||
results2 = pg.workers.foreach_worker_with_index(
|
||||
lambda ev, i: (i, ev.sample_batch_size))
|
||||
results3 = pg.optimizer.foreach_evaluator(
|
||||
results3 = pg.workers.foreach_worker(
|
||||
lambda ev: ev.foreach_env(lambda env: 1))
|
||||
self.assertEqual(results, [10, 10, 10])
|
||||
self.assertEqual(results2, [(0, 10), (1, 10), (2, 10)])
|
||||
@@ -225,7 +224,7 @@ class TestPolicyEvaluator(unittest.TestCase):
|
||||
|
||||
def testRewardClipping(self):
|
||||
# clipping on
|
||||
ev = PolicyEvaluator(
|
||||
ev = RolloutWorker(
|
||||
env_creator=lambda _: MockEnv2(episode_length=10),
|
||||
policy=MockPolicy,
|
||||
clip_rewards=True,
|
||||
@@ -235,7 +234,7 @@ class TestPolicyEvaluator(unittest.TestCase):
|
||||
self.assertEqual(result["episode_reward_mean"], 1000)
|
||||
|
||||
# clipping off
|
||||
ev2 = PolicyEvaluator(
|
||||
ev2 = RolloutWorker(
|
||||
env_creator=lambda _: MockEnv2(episode_length=10),
|
||||
policy=MockPolicy,
|
||||
clip_rewards=False,
|
||||
@@ -245,7 +244,7 @@ class TestPolicyEvaluator(unittest.TestCase):
|
||||
self.assertEqual(result2["episode_reward_mean"], 1000)
|
||||
|
||||
def testHardHorizon(self):
|
||||
ev = PolicyEvaluator(
|
||||
ev = RolloutWorker(
|
||||
env_creator=lambda _: MockEnv(episode_length=10),
|
||||
policy=MockPolicy,
|
||||
batch_mode="complete_episodes",
|
||||
@@ -259,7 +258,7 @@ class TestPolicyEvaluator(unittest.TestCase):
|
||||
self.assertEqual(sum(samples["dones"]), 3)
|
||||
|
||||
def testSoftHorizon(self):
|
||||
ev = PolicyEvaluator(
|
||||
ev = RolloutWorker(
|
||||
env_creator=lambda _: MockEnv(episode_length=10),
|
||||
policy=MockPolicy,
|
||||
batch_mode="complete_episodes",
|
||||
@@ -273,11 +272,11 @@ class TestPolicyEvaluator(unittest.TestCase):
|
||||
self.assertEqual(sum(samples["dones"]), 1)
|
||||
|
||||
def testMetrics(self):
|
||||
ev = PolicyEvaluator(
|
||||
ev = RolloutWorker(
|
||||
env_creator=lambda _: MockEnv(episode_length=10),
|
||||
policy=MockPolicy,
|
||||
batch_mode="complete_episodes")
|
||||
remote_ev = PolicyEvaluator.as_remote().remote(
|
||||
remote_ev = RolloutWorker.as_remote().remote(
|
||||
env_creator=lambda _: MockEnv(episode_length=10),
|
||||
policy=MockPolicy,
|
||||
batch_mode="complete_episodes")
|
||||
@@ -288,7 +287,7 @@ class TestPolicyEvaluator(unittest.TestCase):
|
||||
self.assertEqual(result["episode_reward_mean"], 10)
|
||||
|
||||
def testAsync(self):
|
||||
ev = PolicyEvaluator(
|
||||
ev = RolloutWorker(
|
||||
env_creator=lambda _: gym.make("CartPole-v0"),
|
||||
sample_async=True,
|
||||
policy=MockPolicy)
|
||||
@@ -298,7 +297,7 @@ class TestPolicyEvaluator(unittest.TestCase):
|
||||
self.assertGreater(batch["advantages"][0], 1)
|
||||
|
||||
def testAutoVectorization(self):
|
||||
ev = PolicyEvaluator(
|
||||
ev = RolloutWorker(
|
||||
env_creator=lambda cfg: MockEnv(episode_length=20, config=cfg),
|
||||
policy=MockPolicy,
|
||||
batch_mode="truncate_episodes",
|
||||
@@ -321,7 +320,7 @@ class TestPolicyEvaluator(unittest.TestCase):
|
||||
self.assertEqual(indices, [0, 1, 2, 3, 4, 5, 6, 7])
|
||||
|
||||
def testBatchesLargerWhenVectorized(self):
|
||||
ev = PolicyEvaluator(
|
||||
ev = RolloutWorker(
|
||||
env_creator=lambda _: MockEnv(episode_length=8),
|
||||
policy=MockPolicy,
|
||||
batch_mode="truncate_episodes",
|
||||
@@ -336,7 +335,7 @@ class TestPolicyEvaluator(unittest.TestCase):
|
||||
self.assertEqual(result["episodes_this_iter"], 4)
|
||||
|
||||
def testVectorEnvSupport(self):
|
||||
ev = PolicyEvaluator(
|
||||
ev = RolloutWorker(
|
||||
env_creator=lambda _: MockVectorEnv(episode_length=20, num_envs=8),
|
||||
policy=MockPolicy,
|
||||
batch_mode="truncate_episodes",
|
||||
@@ -353,7 +352,7 @@ class TestPolicyEvaluator(unittest.TestCase):
|
||||
self.assertEqual(result["episodes_this_iter"], 8)
|
||||
|
||||
def testTruncateEpisodes(self):
|
||||
ev = PolicyEvaluator(
|
||||
ev = RolloutWorker(
|
||||
env_creator=lambda _: MockEnv(10),
|
||||
policy=MockPolicy,
|
||||
batch_steps=15,
|
||||
@@ -362,7 +361,7 @@ class TestPolicyEvaluator(unittest.TestCase):
|
||||
self.assertEqual(batch.count, 15)
|
||||
|
||||
def testCompleteEpisodes(self):
|
||||
ev = PolicyEvaluator(
|
||||
ev = RolloutWorker(
|
||||
env_creator=lambda _: MockEnv(10),
|
||||
policy=MockPolicy,
|
||||
batch_steps=5,
|
||||
@@ -371,7 +370,7 @@ class TestPolicyEvaluator(unittest.TestCase):
|
||||
self.assertEqual(batch.count, 10)
|
||||
|
||||
def testCompleteEpisodesPacking(self):
|
||||
ev = PolicyEvaluator(
|
||||
ev = RolloutWorker(
|
||||
env_creator=lambda _: MockEnv(10),
|
||||
policy=MockPolicy,
|
||||
batch_steps=15,
|
||||
@@ -383,7 +382,7 @@ class TestPolicyEvaluator(unittest.TestCase):
|
||||
[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9])
|
||||
|
||||
def testFilterSync(self):
|
||||
ev = PolicyEvaluator(
|
||||
ev = RolloutWorker(
|
||||
env_creator=lambda _: gym.make("CartPole-v0"),
|
||||
policy=MockPolicy,
|
||||
sample_async=True,
|
||||
@@ -396,7 +395,7 @@ class TestPolicyEvaluator(unittest.TestCase):
|
||||
self.assertNotEqual(obs_f.buffer.n, 0)
|
||||
|
||||
def testGetFilters(self):
|
||||
ev = PolicyEvaluator(
|
||||
ev = RolloutWorker(
|
||||
env_creator=lambda _: gym.make("CartPole-v0"),
|
||||
policy=MockPolicy,
|
||||
sample_async=True,
|
||||
@@ -411,7 +410,7 @@ class TestPolicyEvaluator(unittest.TestCase):
|
||||
self.assertGreaterEqual(obs_f2.buffer.n, obs_f.buffer.n)
|
||||
|
||||
def testSyncFilter(self):
|
||||
ev = PolicyEvaluator(
|
||||
ev = RolloutWorker(
|
||||
env_creator=lambda _: gym.make("CartPole-v0"),
|
||||
policy=MockPolicy,
|
||||
sample_async=True,
|
||||
@@ -58,15 +58,15 @@ class TaskPool(object):
|
||||
remaining.append((worker, obj_id))
|
||||
self._fetching = remaining
|
||||
|
||||
def reset_evaluators(self, evaluators):
|
||||
"""Notify that some evaluators may be removed."""
|
||||
def reset_workers(self, workers):
|
||||
"""Notify that some workers may be removed."""
|
||||
for obj_id, ev in self._tasks.copy().items():
|
||||
if ev not in evaluators:
|
||||
if ev not in workers:
|
||||
del self._tasks[obj_id]
|
||||
del self._objects[obj_id]
|
||||
ok = []
|
||||
for ev, obj_id in self._fetching:
|
||||
if ev in evaluators:
|
||||
if ev in workers:
|
||||
ok.append((ev, obj_id))
|
||||
self._fetching = ok
|
||||
|
||||
|
||||
Reference in New Issue
Block a user