Files
ray/python/ray/rllib/agents/agent.py
T

747 lines
30 KiB
Python

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from datetime import datetime
import copy
import logging
import os
import pickle
import six
import tempfile
import tensorflow as tf
from types import FunctionType
import ray
from ray.exceptions import RayError
from ray.rllib.offline import NoopOutput, JsonReader, MixedInput, JsonWriter, \
ShuffledInput
from ray.rllib.models import MODEL_DEFAULTS
from ray.rllib.evaluation.policy_evaluator import PolicyEvaluator
from ray.rllib.evaluation.sample_batch import DEFAULT_POLICY_ID
from ray.rllib.optimizers.policy_optimizer import PolicyOptimizer
from ray.rllib.utils.annotations import override, PublicAPI, DeveloperAPI
from ray.rllib.utils import FilterManager, deep_update, merge_dicts
from ray.tune.registry import ENV_CREATOR, register_env, _global_registry
from ray.tune.trainable import Trainable
from ray.tune.trial import Resources, ExportFormat
from ray.tune.logger import UnifiedLogger
from ray.tune.result import DEFAULT_RESULTS_DIR
logger = logging.getLogger(__name__)
# Max number of times to retry a worker failure. We shouldn't try too many
# times in a row since that would indicate a persistent cluster issue.
MAX_WORKER_FAILURE_RETRIES = 3
# yapf: disable
# __sphinx_doc_begin__
COMMON_CONFIG = {
# === Debugging ===
# Whether to write episode stats and videos to the agent log dir
"monitor": False,
# Set the ray.rllib.* log level for the agent process and its evaluators
"log_level": "INFO",
# Callbacks that will be run during various phases of training. These all
# take a single "info" dict as an argument. For episode callbacks, custom
# metrics can be attached to the episode by updating the episode object's
# custom metrics dict (see examples/custom_metrics_and_callbacks.py).
"callbacks": {
"on_episode_start": None, # arg: {"env": .., "episode": ...}
"on_episode_step": None, # arg: {"env": .., "episode": ...}
"on_episode_end": None, # arg: {"env": .., "episode": ...}
"on_sample_end": None, # arg: {"samples": .., "evaluator": ...}
"on_train_result": None, # arg: {"agent": ..., "result": ...}
},
# Whether to attempt to continue training if a worker crashes.
"ignore_worker_failures": False,
# === Policy ===
# Arguments to pass to model. See models/catalog.py for a full list of the
# available model options.
"model": MODEL_DEFAULTS,
# Arguments to pass to the policy optimizer. These vary by optimizer.
"optimizer": {},
# === Environment ===
# Discount factor of the MDP
"gamma": 0.99,
# Number of steps after which the episode is forced to terminate
"horizon": None,
# Arguments to pass to the env creator
"env_config": {},
# Environment name can also be passed via config
"env": None,
# Whether to clip rewards prior to experience postprocessing. Setting to
# None means clip for Atari only.
"clip_rewards": None,
# Whether to np.clip() actions to the action space low/high range spec.
"clip_actions": True,
# Whether to use rllib or deepmind preprocessors by default
"preprocessor_pref": "deepmind",
# === Resources ===
# Number of actors used for parallelism
"num_workers": 2,
# Number of GPUs to allocate to the driver. Note that not all algorithms
# can take advantage of driver GPUs. This can be fraction (e.g., 0.3 GPUs).
"num_gpus": 0,
# Number of CPUs to allocate per worker.
"num_cpus_per_worker": 1,
# Number of GPUs to allocate per worker. This can be fractional.
"num_gpus_per_worker": 0,
# Any custom resources to allocate per worker.
"custom_resources_per_worker": {},
# Number of CPUs to allocate for the driver. Note: this only takes effect
# when running in Tune.
"num_cpus_for_driver": 1,
# === Execution ===
# Number of environments to evaluate vectorwise per worker.
"num_envs_per_worker": 1,
# Default sample batch size
"sample_batch_size": 200,
# Training batch size, if applicable. Should be >= sample_batch_size.
# Samples batches will be concatenated together to this size for training.
"train_batch_size": 200,
# Whether to rollout "complete_episodes" or "truncate_episodes"
"batch_mode": "truncate_episodes",
# (Deprecated) Use a background thread for sampling (slightly off-policy)
"sample_async": False,
# Element-wise observation filter, either "NoFilter" or "MeanStdFilter"
"observation_filter": "NoFilter",
# Whether to synchronize the statistics of remote filters.
"synchronize_filters": True,
# Configure TF for single-process operation by default
"tf_session_args": {
# note: overriden by `local_evaluator_tf_session_args`
"intra_op_parallelism_threads": 2,
"inter_op_parallelism_threads": 2,
"gpu_options": {
"allow_growth": True,
},
"log_device_placement": False,
"device_count": {
"CPU": 1
},
"allow_soft_placement": True, # required by PPO multi-gpu
},
# Override the following tf session args on the local evaluator
"local_evaluator_tf_session_args": {
# Allow a higher level of parallelism by default, but not unlimited
# since that can cause crashes with many concurrent drivers.
"intra_op_parallelism_threads": 8,
"inter_op_parallelism_threads": 8,
},
# Whether to LZ4 compress individual observations
"compress_observations": False,
# Drop metric batches from unresponsive workers after this many seconds
"collect_metrics_timeout": 180,
# If using num_envs_per_worker > 1, whether to create those new envs in
# remote processes instead of in the same worker. This adds overheads, but
# can make sense if your envs are very CPU intensive (e.g., for StarCraft).
"remote_worker_envs": False,
# Similar to remote_worker_envs, but runs the envs asynchronously in the
# background for greater efficiency. Conflicts with remote_worker_envs.
"async_remote_worker_envs": False,
# === Offline Datasets ===
# __sphinx_doc_input_begin__
# Specify how to generate experiences:
# - "sampler": generate experiences via online simulation (default)
# - a local directory or file glob expression (e.g., "/tmp/*.json")
# - a list of individual file paths/URIs (e.g., ["/tmp/1.json",
# "s3://bucket/2.json"])
# - a dict with string keys and sampling probabilities as values (e.g.,
# {"sampler": 0.4, "/tmp/*.json": 0.4, "s3://bucket/expert.json": 0.2}).
# - a function that returns a rllib.offline.InputReader
"input": "sampler",
# Specify how to evaluate the current policy. This only has an effect when
# reading offline experiences. Available options:
# - "wis": the weighted step-wise importance sampling estimator.
# - "is": the step-wise importance sampling estimator.
# - "simulation": run the environment in the background, but use
# this data for evaluation only and not for learning.
"input_evaluation": ["is", "wis"],
# Whether to run postprocess_trajectory() on the trajectory fragments from
# offline inputs. Note that postprocessing will be done using the *current*
# policy, not the *behaviour* policy, which is typically undesirable for
# on-policy algorithms.
"postprocess_inputs": False,
# If positive, input batches will be shuffled via a sliding window buffer
# of this number of batches. Use this if the input data is not in random
# enough order. Input is delayed until the shuffle buffer is filled.
"shuffle_buffer_size": 0,
# __sphinx_doc_input_end__
# __sphinx_doc_output_begin__
# Specify where experiences should be saved:
# - None: don't save any experiences
# - "logdir" to save to the agent log dir
# - a path/URI to save to a custom output directory (e.g., "s3://bucket/")
# - a function that returns a rllib.offline.OutputWriter
"output": None,
# What sample batch columns to LZ4 compress in the output data.
"output_compress_columns": ["obs", "new_obs"],
# Max output file size before rolling over to a new file.
"output_max_file_size": 64 * 1024 * 1024,
# __sphinx_doc_output_end__
# === Multiagent ===
"multiagent": {
# Map from policy ids to tuples of (policy_graph_cls, obs_space,
# act_space, config). See policy_evaluator.py for more info.
"policy_graphs": {},
# Function mapping agent ids to policy ids.
"policy_mapping_fn": None,
# Optional whitelist of policies to train, or None for all policies.
"policies_to_train": None,
},
}
# __sphinx_doc_end__
# yapf: enable
@DeveloperAPI
def with_common_config(extra_config):
"""Returns the given config dict merged with common agent confs."""
return with_base_config(COMMON_CONFIG, extra_config)
def with_base_config(base_config, extra_config):
"""Returns the given config dict merged with a base agent conf."""
config = copy.deepcopy(base_config)
config.update(extra_config)
return config
@PublicAPI
class Agent(Trainable):
"""All RLlib agents extend this base class.
Agent objects retain internal model state between calls to train(), so
you should create a new agent instance for each training session.
Attributes:
env_creator (func): Function that creates a new training env.
config (obj): Algorithm-specific configuration data.
logdir (str): Directory in which training outputs should be placed.
"""
_allow_unknown_configs = False
_allow_unknown_subkeys = [
"tf_session_args", "env_config", "model", "optimizer", "multiagent",
"custom_resources_per_worker"
]
@PublicAPI
def __init__(self, config=None, env=None, logger_creator=None):
"""Initialize an RLLib agent.
Args:
config (dict): Algorithm-specific configuration data.
env (str): Name of the environment to use. Note that this can also
be specified as the `env` key in config.
logger_creator (func): Function that creates a ray.tune.Logger
object. If unspecified, a default logger is created.
"""
config = config or {}
# Vars to synchronize to evaluators on each train call
self.global_vars = {"timestep": 0}
# Agents allow env ids to be passed directly to the constructor.
self._env_id = self._register_if_needed(env or config.get("env"))
# Create a default logger creator if no logger_creator is specified
if logger_creator is None:
timestr = datetime.today().strftime("%Y-%m-%d_%H-%M-%S")
logdir_prefix = "{}_{}_{}".format(self._agent_name, self._env_id,
timestr)
def default_logger_creator(config):
"""Creates a Unified logger with a default logdir prefix
containing the agent name and the env id
"""
if not os.path.exists(DEFAULT_RESULTS_DIR):
os.makedirs(DEFAULT_RESULTS_DIR)
logdir = tempfile.mkdtemp(
prefix=logdir_prefix, dir=DEFAULT_RESULTS_DIR)
return UnifiedLogger(config, logdir, None)
logger_creator = default_logger_creator
Trainable.__init__(self, config, logger_creator)
@classmethod
@override(Trainable)
def default_resource_request(cls, config):
cf = dict(cls._default_config, **config)
Agent._validate_config(cf)
# TODO(ekl): add custom resources here once tune supports them
return Resources(
cpu=cf["num_cpus_for_driver"],
gpu=cf["num_gpus"],
extra_cpu=cf["num_cpus_per_worker"] * cf["num_workers"],
extra_gpu=cf["num_gpus_per_worker"] * cf["num_workers"])
@override(Trainable)
@PublicAPI
def train(self):
"""Overrides super.train to synchronize global vars."""
if self._has_policy_optimizer():
self.global_vars["timestep"] = self.optimizer.num_steps_sampled
self.optimizer.local_evaluator.set_global_vars(self.global_vars)
for ev in self.optimizer.remote_evaluators:
ev.set_global_vars.remote(self.global_vars)
logger.debug("updated global vars: {}".format(self.global_vars))
result = None
for _ in range(1 + MAX_WORKER_FAILURE_RETRIES):
try:
result = Trainable.train(self)
except RayError as e:
if self.config["ignore_worker_failures"]:
logger.exception(
"Error in train call, attempting to recover")
self._try_recover()
else:
logger.info(
"Worker crashed during call to train(). To attempt to "
"continue training without the failed worker, set "
"`'ignore_worker_failures': True`.")
raise e
else:
break
if result is None:
raise RuntimeError("Failed to recover from worker crash")
if (self.config.get("observation_filter", "NoFilter") != "NoFilter"
and hasattr(self, "local_evaluator")):
FilterManager.synchronize(
self.local_evaluator.filters,
self.remote_evaluators,
update_remote=self.config["synchronize_filters"])
logger.debug("synchronized filters: {}".format(
self.local_evaluator.filters))
if self._has_policy_optimizer():
result["num_healthy_workers"] = len(
self.optimizer.remote_evaluators)
return result
@override(Trainable)
def _log_result(self, result):
if self.config["callbacks"].get("on_train_result"):
self.config["callbacks"]["on_train_result"]({
"agent": self,
"result": result,
})
# log after the callback is invoked, so that the user has a chance
# to mutate the result
Trainable._log_result(self, result)
@override(Trainable)
def _setup(self, config):
env = self._env_id
if env:
config["env"] = env
if _global_registry.contains(ENV_CREATOR, env):
self.env_creator = _global_registry.get(ENV_CREATOR, env)
else:
import gym # soft dependency
self.env_creator = lambda env_config: gym.make(env)
else:
self.env_creator = lambda env_config: None
# Merge the supplied config with the class default
merged_config = copy.deepcopy(self._default_config)
merged_config = deep_update(merged_config, config,
self._allow_unknown_configs,
self._allow_unknown_subkeys)
self.raw_user_config = config
self.config = merged_config
Agent._validate_config(self.config)
if self.config.get("log_level"):
logging.getLogger("ray.rllib").setLevel(self.config["log_level"])
# TODO(ekl) setting the graph is unnecessary for PyTorch agents
with tf.Graph().as_default():
self._init()
@override(Trainable)
def _stop(self):
# workaround for https://github.com/ray-project/ray/issues/1516
if hasattr(self, "remote_evaluators"):
for ev in self.remote_evaluators:
ev.__ray_terminate__.remote()
if hasattr(self, "optimizer"):
self.optimizer.stop()
@override(Trainable)
def _save(self, checkpoint_dir):
checkpoint_path = os.path.join(checkpoint_dir,
"checkpoint-{}".format(self.iteration))
pickle.dump(self.__getstate__(), open(checkpoint_path, "wb"))
return checkpoint_path
@override(Trainable)
def _restore(self, checkpoint_path):
extra_data = pickle.load(open(checkpoint_path, "rb"))
self.__setstate__(extra_data)
@DeveloperAPI
def _init(self):
"""Subclasses should override this for custom initialization."""
raise NotImplementedError
@PublicAPI
def compute_action(self,
observation,
state=None,
prev_action=None,
prev_reward=None,
info=None,
policy_id="default"):
"""Computes an action for the specified policy.
Note that you can also access the policy object through
self.get_policy(policy_id) and call compute_actions() on it directly.
Arguments:
observation (obj): observation from the environment.
state (list): RNN hidden state, if any. If state is not None,
then all of compute_single_action(...) is returned
(computed action, rnn state, logits dictionary).
Otherwise compute_single_action(...)[0] is
returned (computed action).
prev_action (obj): previous action value, if any
prev_reward (int): previous reward, if any
info (dict): info object, if any
policy_id (str): policy to query (only applies to multi-agent).
"""
if state is None:
state = []
preprocessed = self.local_evaluator.preprocessors[policy_id].transform(
observation)
filtered_obs = self.local_evaluator.filters[policy_id](
preprocessed, update=False)
if state:
return self.get_policy(policy_id).compute_single_action(
filtered_obs, state, prev_action, prev_reward, info)
return self.get_policy(policy_id).compute_single_action(
filtered_obs, state, prev_action, prev_reward, info)[0]
@property
def iteration(self):
"""Current training iter, auto-incremented with each train() call."""
return self._iteration
@property
def _agent_name(self):
"""Subclasses should override this to declare their name."""
raise NotImplementedError
@property
def _default_config(self):
"""Subclasses should override this to declare their default config."""
raise NotImplementedError
@PublicAPI
def get_policy(self, policy_id=DEFAULT_POLICY_ID):
"""Return policy graph for the specified id, or None.
Arguments:
policy_id (str): id of policy graph to return.
"""
return self.local_evaluator.get_policy(policy_id)
@PublicAPI
def get_weights(self, policies=None):
"""Return a dictionary of policy ids to weights.
Arguments:
policies (list): Optional list of policies to return weights for,
or None for all policies.
"""
return self.local_evaluator.get_weights(policies)
@PublicAPI
def set_weights(self, weights):
"""Set policy weights by policy id.
Arguments:
weights (dict): Map of policy ids to weights to set.
"""
self.local_evaluator.set_weights(weights)
@DeveloperAPI
def make_local_evaluator(self,
env_creator,
policy_graph,
extra_config=None):
"""Convenience method to return configured local evaluator."""
return self._make_evaluator(
PolicyEvaluator,
env_creator,
policy_graph,
0,
merge_dicts(
# important: allow local tf to use more CPUs for optimization
merge_dicts(
self.config, {
"tf_session_args": self.
config["local_evaluator_tf_session_args"]
}),
extra_config or {}))
@DeveloperAPI
def make_remote_evaluators(self, env_creator, policy_graph, count):
"""Convenience method to return a number of remote evaluators."""
remote_args = {
"num_cpus": self.config["num_cpus_per_worker"],
"num_gpus": self.config["num_gpus_per_worker"],
"resources": self.config["custom_resources_per_worker"],
}
cls = PolicyEvaluator.as_remote(**remote_args).remote
return [
self._make_evaluator(cls, env_creator, policy_graph, i + 1,
self.config) for i in range(count)
]
@DeveloperAPI
def export_policy_model(self, export_dir, policy_id=DEFAULT_POLICY_ID):
"""Export policy model with given policy_id to local directory.
Arguments:
export_dir (string): Writable local directory.
policy_id (string): Optional policy id to export.
Example:
>>> agent = MyAgent()
>>> for _ in range(10):
>>> agent.train()
>>> agent.export_policy_model("/tmp/export_dir")
"""
self.local_evaluator.export_policy_model(export_dir, policy_id)
@DeveloperAPI
def export_policy_checkpoint(self,
export_dir,
filename_prefix="model",
policy_id=DEFAULT_POLICY_ID):
"""Export tensorflow policy model checkpoint to local directory.
Arguments:
export_dir (string): Writable local directory.
filename_prefix (string): file name prefix of checkpoint files.
policy_id (string): Optional policy id to export.
Example:
>>> agent = MyAgent()
>>> for _ in range(10):
>>> agent.train()
>>> agent.export_policy_checkpoint("/tmp/export_dir")
"""
self.local_evaluator.export_policy_checkpoint(
export_dir, filename_prefix, policy_id)
@classmethod
def resource_help(cls, config):
return ("\n\nYou can adjust the resource requests of RLlib agents by "
"setting `num_workers`, `num_gpus`, and other configs. See "
"the DEFAULT_CONFIG defined by each agent for more info.\n\n"
"The config of this agent is: {}".format(config))
@staticmethod
def _validate_config(config):
if "gpu" in config:
raise ValueError(
"The `gpu` config is deprecated, please use `num_gpus=0|1` "
"instead.")
if "gpu_fraction" in config:
raise ValueError(
"The `gpu_fraction` config is deprecated, please use "
"`num_gpus=<fraction>` instead.")
if "use_gpu_for_workers" in config:
raise ValueError(
"The `use_gpu_for_workers` config is deprecated, please use "
"`num_gpus_per_worker=1` instead.")
if type(config["input_evaluation"]) != list:
raise ValueError(
"`input_evaluation` must be a list of strings, got {}".format(
config["input_evaluation"]))
def _try_recover(self):
"""Try to identify and blacklist any unhealthy workers.
This method is called after an unexpected remote error is encountered
from a worker. It issues check requests to all current workers and
blacklists any that respond with error. If no healthy workers remain,
an error is raised.
"""
if not self._has_policy_optimizer():
raise NotImplementedError(
"Recovery is not supported for this algorithm")
logger.info("Health checking all workers...")
checks = []
for ev in self.optimizer.remote_evaluators:
_, obj_id = ev.sample_with_count.remote()
checks.append(obj_id)
healthy_evaluators = []
for i, obj_id in enumerate(checks):
ev = self.optimizer.remote_evaluators[i]
try:
ray.get(obj_id)
healthy_evaluators.append(ev)
logger.info("Worker {} looks healthy".format(i + 1))
except RayError:
logger.exception("Blacklisting worker {}".format(i + 1))
try:
ev.__ray_terminate__.remote()
except Exception:
logger.exception("Error terminating unhealthy worker")
if len(healthy_evaluators) < 1:
raise RuntimeError(
"Not enough healthy workers remain to continue.")
self.optimizer.reset(healthy_evaluators)
def _has_policy_optimizer(self):
return hasattr(self, "optimizer") and isinstance(
self.optimizer, PolicyOptimizer)
def _make_evaluator(self, cls, env_creator, policy_graph, worker_index,
config):
def session_creator():
logger.debug("Creating TF session {}".format(
config["tf_session_args"]))
return tf.Session(
config=tf.ConfigProto(**config["tf_session_args"]))
if isinstance(config["input"], FunctionType):
input_creator = config["input"]
elif config["input"] == "sampler":
input_creator = (lambda ioctx: ioctx.default_sampler_input())
elif isinstance(config["input"], dict):
input_creator = (lambda ioctx: ShuffledInput(
MixedInput(config["input"], ioctx),
config["shuffle_buffer_size"]))
else:
input_creator = (lambda ioctx: ShuffledInput(
JsonReader(config["input"], ioctx),
config["shuffle_buffer_size"]))
if isinstance(config["output"], FunctionType):
output_creator = config["output"]
elif config["output"] is None:
output_creator = (lambda ioctx: NoopOutput())
elif config["output"] == "logdir":
output_creator = (lambda ioctx: JsonWriter(
ioctx.log_dir,
ioctx,
max_file_size=config["output_max_file_size"],
compress_columns=config["output_compress_columns"]))
else:
output_creator = (lambda ioctx: JsonWriter(
config["output"],
ioctx,
max_file_size=config["output_max_file_size"],
compress_columns=config["output_compress_columns"]))
if config["input"] == "sampler":
input_evaluation = []
else:
input_evaluation = config["input_evaluation"]
return cls(
env_creator,
self.config["multiagent"]["policy_graphs"] or policy_graph,
policy_mapping_fn=self.config["multiagent"]["policy_mapping_fn"],
policies_to_train=self.config["multiagent"]["policies_to_train"],
tf_session_creator=(session_creator
if config["tf_session_args"] else None),
batch_steps=config["sample_batch_size"],
batch_mode=config["batch_mode"],
episode_horizon=config["horizon"],
preprocessor_pref=config["preprocessor_pref"],
sample_async=config["sample_async"],
compress_observations=config["compress_observations"],
num_envs=config["num_envs_per_worker"],
observation_filter=config["observation_filter"],
clip_rewards=config["clip_rewards"],
clip_actions=config["clip_actions"],
env_config=config["env_config"],
model_config=config["model"],
policy_config=config,
worker_index=worker_index,
monitor_path=self.logdir if config["monitor"] else None,
log_dir=self.logdir,
log_level=config["log_level"],
callbacks=config["callbacks"],
input_creator=input_creator,
input_evaluation=input_evaluation,
output_creator=output_creator,
remote_worker_envs=config["remote_worker_envs"],
async_remote_worker_envs=config["async_remote_worker_envs"])
@override(Trainable)
def _export_model(self, export_formats, export_dir):
ExportFormat.validate(export_formats)
exported = {}
if ExportFormat.CHECKPOINT in export_formats:
path = os.path.join(export_dir, ExportFormat.CHECKPOINT)
self.export_policy_checkpoint(path)
exported[ExportFormat.CHECKPOINT] = path
if ExportFormat.MODEL in export_formats:
path = os.path.join(export_dir, ExportFormat.MODEL)
self.export_policy_model(path)
exported[ExportFormat.MODEL] = path
return exported
def __getstate__(self):
state = {}
if hasattr(self, "local_evaluator"):
state["evaluator"] = self.local_evaluator.save()
if hasattr(self, "optimizer") and hasattr(self.optimizer, "save"):
state["optimizer"] = self.optimizer.save()
return state
def __setstate__(self, state):
if "evaluator" in state:
self.local_evaluator.restore(state["evaluator"])
remote_state = ray.put(state["evaluator"])
for r in self.remote_evaluators:
r.restore.remote(remote_state)
if "optimizer" in state:
self.optimizer.restore(state["optimizer"])
def _register_if_needed(self, env_object):
if isinstance(env_object, six.string_types):
return env_object
elif isinstance(env_object, type):
name = env_object.__name__
register_env(name, lambda config: env_object(config))
return name
raise ValueError(
"{} is an invalid env specification. ".format(env_object) +
"You can specify a custom env as either a class "
"(e.g., YourEnvCls) or a registered env id (e.g., \"your_env\").")