[rllib] Remove need to pass around registry (#2250)

* remove registry

* fix

* too many _

* fix

* cloudpickle

* Update registry.py

* yapf

* fix test

* fix kv check
This commit is contained in:
Eric Liang
2018-06-19 22:47:00 -07:00
committed by GitHub
parent 30684446a6
commit 30f7c08ca7
36 changed files with 202 additions and 208 deletions
+31
View File
@@ -0,0 +1,31 @@
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import ray
def _internal_kv_initialized():
worker = ray.worker.get_global_worker()
return hasattr(worker, "mode") and worker.mode is not None
def _internal_kv_get(key):
"""Fetch the value of a binary key."""
worker = ray.worker.get_global_worker()
return worker.redis_client.hget(key, "value")
def _internal_kv_put(key, value):
"""Globally associates a value with a given binary key.
This only has an effect if the key does not already have a value.
Returns
already_exists (bool): whether the value already exists.
"""
worker = ray.worker.get_global_worker()
updated = worker.redis_client.hsetnx(key, "value", value)
return updated == 0 # already exists
+6 -7
View File
@@ -4,6 +4,7 @@ from __future__ import print_function
import ray
import ray.cloudpickle as pickle
from ray.experimental.internal_kv import _internal_kv_get, _internal_kv_put
def _calculate_key(name):
@@ -29,9 +30,8 @@ def get_actor(name):
Returns:
The ActorHandle object corresponding to the name.
"""
worker = ray.worker.get_global_worker()
actor_hash = _calculate_key(name)
pickled_state = worker.redis_client.hget(actor_hash, name)
actor_name = _calculate_key(name)
pickled_state = _internal_kv_get(actor_name)
if pickled_state is None:
raise ValueError("The actor with name={} doesn't exist".format(name))
handle = pickle.loads(pickled_state)
@@ -45,17 +45,16 @@ def register_actor(name, actor_handle):
name: The name of the named actor.
actor_handle: The actor object to be associated with this name
"""
worker = ray.worker.get_global_worker()
if not isinstance(name, str):
raise TypeError("The name argument must be a string.")
if not isinstance(actor_handle, ray.actor.ActorHandle):
raise TypeError("The actor_handle argument must be an ActorHandle "
"object.")
actor_hash = _calculate_key(name)
actor_name = _calculate_key(name)
pickled_state = pickle.dumps(actor_handle)
# Add the actor to Redis if it does not already exist.
updated = worker.redis_client.hsetnx(actor_hash, name, pickled_state)
if updated == 0:
already_exists = _internal_kv_put(actor_name, pickled_state)
if already_exists:
raise ValueError(
"Error: the actor with name={} already exists".format(name))
+2 -2
View File
@@ -102,7 +102,7 @@ class A3CAgent(Agent):
batch_steps=self.config["batch_size"],
batch_mode="truncate_episodes",
tf_session_creator=session_creator,
registry=self.registry, env_config=self.config["env_config"],
env_config=self.config["env_config"],
model_config=self.config["model"], policy_config=self.config,
num_envs=self.config["num_envs"])
self.remote_evaluators = [
@@ -111,7 +111,7 @@ class A3CAgent(Agent):
batch_steps=self.config["batch_size"],
batch_mode="truncate_episodes", sample_async=True,
tf_session_creator=session_creator,
registry=self.registry, env_config=self.config["env_config"],
env_config=self.config["env_config"],
model_config=self.config["model"], policy_config=self.config,
num_envs=self.config["num_envs"])
for i in range(self.config["num_workers"])]
+1 -2
View File
@@ -13,8 +13,7 @@ from ray.rllib.utils.tf_policy_graph import TFPolicyGraph
class A3CTFPolicyGraph(TFPolicyGraph):
"""The TF policy base class."""
def __init__(self, ob_space, action_space, registry, config):
self.registry = registry
def __init__(self, ob_space, action_space, config):
self.local_steps = 0
self.config = config
self.summarize = config.get("summarize")
+2 -4
View File
@@ -17,8 +17,7 @@ from ray.rllib.utils.policy_graph import PolicyGraph
class SharedTorchPolicy(PolicyGraph):
"""A simple, non-recurrent PyTorch policy example."""
def __init__(self, obs_space, action_space, registry, config):
self.registry = registry
def __init__(self, obs_space, action_space, config):
self.local_steps = 0
self.config = config
self.summarize = config.get("summarize")
@@ -30,8 +29,7 @@ class SharedTorchPolicy(PolicyGraph):
_, self.logit_dim = ModelCatalog.get_action_dist(
action_space, self.config["model"])
self._model = ModelCatalog.get_torch_model(
self.registry, obs_space.shape, self.logit_dim,
self.config["model"])
obs_space.shape, self.logit_dim, self.config["model"])
self.optimizer = torch.optim.Adam(
self._model.parameters(), lr=self.config["lr"])
+3 -3
View File
@@ -10,16 +10,16 @@ from ray.rllib.models.catalog import ModelCatalog
class SharedModel(A3CTFPolicyGraph):
def __init__(self, ob_space, ac_space, registry, config, **kwargs):
def __init__(self, ob_space, ac_space, config, **kwargs):
super(SharedModel, self).__init__(
ob_space, ac_space, registry, config, **kwargs)
ob_space, ac_space, config, **kwargs)
def _setup_graph(self, ob_space, ac_space):
self.x = tf.placeholder(tf.float32, [None] + list(ob_space.shape))
dist_class, self.logit_dim = ModelCatalog.get_action_dist(
ac_space, self.config["model"])
self._model = ModelCatalog.get_model(
self.registry, self.x, self.logit_dim, self.config["model"])
self.x, self.logit_dim, self.config["model"])
self.logits = self._model.outputs
self.action_dist = dist_class(self.logits)
self.vf = tf.reshape(linear(self._model.last_layer, 1, "value",
+2 -2
View File
@@ -11,9 +11,9 @@ from ray.rllib.models.lstm import LSTM
class SharedModelLSTM(A3CTFPolicyGraph):
def __init__(self, ob_space, ac_space, registry, config, **kwargs):
def __init__(self, ob_space, ac_space, config, **kwargs):
super(SharedModelLSTM, self).__init__(
ob_space, ac_space, registry, config, **kwargs)
ob_space, ac_space, config, **kwargs)
def _setup_graph(self, ob_space, ac_space):
self.x = tf.placeholder(tf.float32, [None] + list(ob_space.shape))
+5 -10
View File
@@ -9,7 +9,7 @@ import os
import pickle
import tensorflow as tf
from ray.tune.registry import ENV_CREATOR
from ray.tune.registry import ENV_CREATOR, _global_registry
from ray.tune.result import TrainingResult
from ray.tune.trainable import Trainable
@@ -56,8 +56,6 @@ class Agent(Trainable):
env_creator (func): Function that creates a new training env.
config (obj): Algorithm-specific configuration data.
logdir (str): Directory in which training outputs should be placed.
registry (obj): Tune object registry which holds user-registered
classes and objects by name.
"""
_allow_unknown_configs = False
@@ -72,16 +70,13 @@ class Agent(Trainable):
"The config of this agent is: " + json.dumps(config))
def __init__(
self, config=None, env=None, registry=None,
logger_creator=None):
self, config=None, env=None, logger_creator=None):
"""Initialize an RLLib agent.
Args:
config (dict): Algorithm-specific configuration data.
env (str): Name of the environment to use. Note that this can also
be specified as the `env` key in config.
registry (obj): Object registry for user-defined envs, models, etc.
If unspecified, the default registry will be used.
logger_creator (func): Function that creates a ray.tune.Logger
object. If unspecified, a default logger is created.
"""
@@ -90,14 +85,14 @@ class Agent(Trainable):
# Agents allow env ids to be passed directly to the constructor.
self._env_id = env or config.get("env")
Trainable.__init__(self, config, registry, logger_creator)
Trainable.__init__(self, config, logger_creator)
def _setup(self):
env = self._env_id
if env:
self.config["env"] = env
if self.registry and self.registry.contains(ENV_CREATOR, env):
self.env_creator = self.registry.get(ENV_CREATOR, env)
if _global_registry.contains(ENV_CREATOR, env):
self.env_creator = _global_registry.get(ENV_CREATOR, env)
else:
import gym # soft dependency
self.env_creator = lambda env_config: gym.make(env)
+2 -3
View File
@@ -63,14 +63,13 @@ class BCAgent(Agent):
def _init(self):
self.local_evaluator = BCEvaluator(
self.registry, self.env_creator, self.config, self.logdir)
self.env_creator, self.config, self.logdir)
if self.config["use_gpu_for_workers"]:
remote_cls = GPURemoteBCEvaluator
else:
remote_cls = RemoteBCEvaluator
self.remote_evaluators = [
remote_cls.remote(
self.registry, self.env_creator, self.config, self.logdir)
remote_cls.remote(self.env_creator, self.config, self.logdir)
for _ in range(self.config["num_workers"])]
self.optimizer = AsyncOptimizer(
self.config["optimizer"], self.local_evaluator,
+3 -4
View File
@@ -13,12 +13,11 @@ from ray.rllib.optimizers import PolicyEvaluator
class BCEvaluator(PolicyEvaluator):
def __init__(self, registry, env_creator, config, logdir):
env = ModelCatalog.get_preprocessor_as_wrapper(registry, env_creator(
def __init__(self, env_creator, config, logdir):
env = ModelCatalog.get_preprocessor_as_wrapper(env_creator(
config["env_config"]), config["model"])
self.dataset = ExperienceDataset(config["dataset_path"])
self.policy = BCPolicy(registry, env.observation_space,
env.action_space, config)
self.policy = BCPolicy(env.observation_space, env.action_space, config)
self.config = config
self.logdir = logdir
self.metrics_queue = queue.Queue()
+2 -3
View File
@@ -10,8 +10,7 @@ from ray.rllib.models.catalog import ModelCatalog
class BCPolicy(object):
def __init__(self, registry, obs_space, action_space, config):
self.registry = registry
def __init__(self, obs_space, action_space, config):
self.local_steps = 0
self.config = config
self.summarize = config.get("summarize")
@@ -25,7 +24,7 @@ class BCPolicy(object):
dist_class, self.logit_dim = ModelCatalog.get_action_dist(
ac_space, self.config["model"])
self._model = ModelCatalog.get_model(
self.registry, self.x, self.logit_dim, self.config["model"])
self.x, self.logit_dim, self.config["model"])
self.logits = self._model.outputs
self.curr_dist = dist_class(self.logits)
self.sample = self.curr_dist.sample()
+11 -15
View File
@@ -22,12 +22,12 @@ Q_SCOPE = "q_func"
Q_TARGET_SCOPE = "target_q_func"
def _build_p_network(registry, inputs, dim_actions, config):
def _build_p_network(inputs, dim_actions, config):
"""
map an observation (i.e., state) to an action where
each entry takes value from (0, 1) due to the sigmoid function
"""
frontend = ModelCatalog.get_model(registry, inputs, 1, config["model"])
frontend = ModelCatalog.get_model(inputs, 1, config["model"])
hiddens = config["actor_hiddens"]
action_out = frontend.last_layer
@@ -66,8 +66,8 @@ def _build_action_network(p_values, low_action, high_action, stochastic, eps,
lambda: deterministic_actions)
def _build_q_network(registry, inputs, action_inputs, config):
frontend = ModelCatalog.get_model(registry, inputs, 1, config["model"])
def _build_q_network(inputs, action_inputs, config):
frontend = ModelCatalog.get_model(inputs, 1, config["model"])
hiddens = config["critic_hiddens"]
@@ -81,7 +81,7 @@ def _build_q_network(registry, inputs, action_inputs, config):
class DDPGPolicyGraph(TFPolicyGraph):
def __init__(self, observation_space, action_space, registry, config):
def __init__(self, observation_space, action_space, config):
if not isinstance(action_space, Box):
raise UnsupportedSpaceException(
"Action space {} is not supported for DDPG.".format(
@@ -105,7 +105,7 @@ class DDPGPolicyGraph(TFPolicyGraph):
# Actor: P (policy) network
with tf.variable_scope(P_SCOPE) as scope:
p_values = _build_p_network(registry, self.cur_observations,
p_values = _build_p_network(self.cur_observations,
dim_actions, config)
self.p_func_vars = _scope_vars(scope.name)
@@ -136,13 +136,11 @@ class DDPGPolicyGraph(TFPolicyGraph):
# p network evaluation
with tf.variable_scope(P_SCOPE, reuse=True) as scope:
self.p_t = _build_p_network(
registry, self.obs_t, dim_actions, config)
self.p_t = _build_p_network(self.obs_t, dim_actions, config)
# target p network evaluation
with tf.variable_scope(P_TARGET_SCOPE) as scope:
p_tp1 = _build_p_network(
registry, self.obs_tp1, dim_actions, config)
p_tp1 = _build_p_network(self.obs_tp1, dim_actions, config)
target_p_func_vars = _scope_vars(scope.name)
# Action outputs
@@ -161,17 +159,15 @@ class DDPGPolicyGraph(TFPolicyGraph):
# q network evaluation
with tf.variable_scope(Q_SCOPE) as scope:
q_t = _build_q_network(
registry, self.obs_t, self.act_t, config)
q_t = _build_q_network(self.obs_t, self.act_t, config)
self.q_func_vars = _scope_vars(scope.name)
with tf.variable_scope(Q_SCOPE, reuse=True):
q_tp0 = _build_q_network(
registry, self.obs_t, output_actions, config)
q_tp0 = _build_q_network(self.obs_t, output_actions, config)
# target q network evalution
with tf.variable_scope(Q_TARGET_SCOPE) as scope:
q_tp1 = _build_q_network(
registry, self.obs_tp1, output_actions_estimated, config)
self.obs_tp1, output_actions_estimated, config)
target_q_func_vars = _scope_vars(scope.name)
q_t_selected = tf.squeeze(q_t, axis=len(q_t.shape) - 1)
+2 -2
View File
@@ -6,7 +6,7 @@ from ray.rllib.models import ModelCatalog
from ray.rllib.utils.atari_wrappers import wrap_deepmind
def wrap_dqn(registry, env, options, random_starts):
def wrap_dqn(env, options, random_starts):
"""Apply a common set of wrappers for DQN."""
is_atari = hasattr(env.unwrapped, "ale")
@@ -17,4 +17,4 @@ def wrap_dqn(registry, env, options, random_starts):
return wrap_deepmind(
env, random_starts=random_starts, dim=options.get("dim", 80))
return ModelCatalog.get_preprocessor_as_wrapper(registry, env, options)
return ModelCatalog.get_preprocessor_as_wrapper(env, options)
+2 -2
View File
@@ -129,7 +129,7 @@ class DQNAgent(Agent):
batch_steps=adjusted_batch_size,
batch_mode="truncate_episodes", preprocessor_pref="deepmind",
compress_observations=True,
registry=self.registry, env_config=self.config["env_config"],
env_config=self.config["env_config"],
model_config=self.config["model"], policy_config=self.config,
num_envs=self.config["num_envs"])
remote_cls = CommonPolicyEvaluator.as_remote(
@@ -141,7 +141,7 @@ class DQNAgent(Agent):
batch_steps=adjusted_batch_size,
batch_mode="truncate_episodes", preprocessor_pref="deepmind",
compress_observations=True,
registry=self.registry, env_config=self.config["env_config"],
env_config=self.config["env_config"],
model_config=self.config["model"], policy_config=self.config,
num_envs=self.config["num_envs"])
for _ in range(self.config["num_workers"])]
+7 -9
View File
@@ -46,7 +46,7 @@ def adjust_nstep(n_step, gamma, obs, actions, rewards, new_obs, dones):
class DQNPolicyGraph(TFPolicyGraph):
def __init__(self, observation_space, action_space, registry, config):
def __init__(self, observation_space, action_space, config):
if not isinstance(action_space, Discrete):
raise UnsupportedSpaceException(
"Action space {} is not supported for DQN.".format(
@@ -65,7 +65,7 @@ class DQNPolicyGraph(TFPolicyGraph):
# Action Q network
with tf.variable_scope(Q_SCOPE) as scope:
q_values = _build_q_network(
registry, self.cur_observations, num_actions, config)
self.cur_observations, num_actions, config)
self.q_func_vars = _scope_vars(scope.name)
# Action outputs
@@ -89,13 +89,11 @@ class DQNPolicyGraph(TFPolicyGraph):
# q network evaluation
with tf.variable_scope(Q_SCOPE, reuse=True):
q_t = _build_q_network(
registry, self.obs_t, num_actions, config)
q_t = _build_q_network(self.obs_t, num_actions, config)
# target q network evalution
with tf.variable_scope(Q_TARGET_SCOPE) as scope:
q_tp1 = _build_q_network(
registry, self.obs_tp1, num_actions, config)
q_tp1 = _build_q_network(self.obs_tp1, num_actions, config)
self.target_q_func_vars = _scope_vars(scope.name)
# q scores for actions which we know were selected in the given state.
@@ -106,7 +104,7 @@ class DQNPolicyGraph(TFPolicyGraph):
if config["double_q"]:
with tf.variable_scope(Q_SCOPE, reuse=True):
q_tp1_using_online_net = _build_q_network(
registry, self.obs_tp1, num_actions, config)
self.obs_tp1, num_actions, config)
q_tp1_best_using_online_net = tf.argmax(q_tp1_using_online_net, 1)
q_tp1_best = tf.reduce_sum(
q_tp1 * tf.one_hot(
@@ -236,10 +234,10 @@ def _postprocess_dqn(policy_graph, sample_batch):
return batch
def _build_q_network(registry, inputs, num_actions, config):
def _build_q_network(inputs, num_actions, config):
dueling = config["dueling"]
hiddens = config["hiddens"]
frontend = ModelCatalog.get_model(registry, inputs, 1, config["model"])
frontend = ModelCatalog.get_model(inputs, 1, config["model"])
frontend_out = frontend.last_layer
with tf.variable_scope("action_value"):
+6 -9
View File
@@ -64,7 +64,7 @@ class SharedNoiseTable(object):
@ray.remote
class Worker(object):
def __init__(self, registry, config, policy_params, env_creator, noise,
def __init__(self, config, policy_params, env_creator, noise,
min_task_runtime=0.2):
self.min_task_runtime = min_task_runtime
self.config = config
@@ -73,12 +73,11 @@ class Worker(object):
self.env = env_creator(config["env_config"])
from ray.rllib import models
self.preprocessor = models.ModelCatalog.get_preprocessor(
registry, self.env)
self.preprocessor = models.ModelCatalog.get_preprocessor(self.env)
self.sess = utils.make_session(single_threaded=True)
self.policy = policies.GenericPolicy(
registry, self.sess, self.env.action_space, self.preprocessor,
self.sess, self.env.action_space, self.preprocessor,
config["observation_filter"], **policy_params)
def rollout(self, timestep_limit, add_noise=True):
@@ -152,12 +151,11 @@ class ESAgent(agent.Agent):
env = self.env_creator(self.config["env_config"])
from ray.rllib import models
preprocessor = models.ModelCatalog.get_preprocessor(
self.registry, env)
preprocessor = models.ModelCatalog.get_preprocessor(env)
self.sess = utils.make_session(single_threaded=False)
self.policy = policies.GenericPolicy(
self.registry, self.sess, env.action_space, preprocessor,
self.sess, env.action_space, preprocessor,
self.config["observation_filter"], **policy_params)
self.optimizer = optimizers.Adam(self.policy, self.config["stepsize"])
@@ -170,8 +168,7 @@ class ESAgent(agent.Agent):
print("Creating actors.")
self.workers = [
Worker.remote(
self.registry, self.config, policy_params, self.env_creator,
noise_id)
self.config, policy_params, self.env_creator, noise_id)
for _ in range(self.config["num_workers"])]
self.episodes_so_far = 0
+2 -2
View File
@@ -38,7 +38,7 @@ def rollout(policy, env, timestep_limit=None, add_noise=False):
class GenericPolicy(object):
def __init__(self, registry, sess, action_space, preprocessor,
def __init__(self, sess, action_space, preprocessor,
observation_filter, action_noise_std):
self.sess = sess
self.action_space = action_space
@@ -52,7 +52,7 @@ class GenericPolicy(object):
# Policy network.
dist_class, dist_dim = ModelCatalog.get_action_dist(
self.action_space, dist_type="deterministic")
model = ModelCatalog.get_model(registry, self.inputs, dist_dim)
model = ModelCatalog.get_model(self.inputs, dist_dim)
dist = dist_class(model.outputs)
self.sampler = dist.sample()
@@ -8,7 +8,7 @@ from gym.envs.registration import register
import ray
import ray.rllib.ppo as ppo
from ray.tune.registry import get_registry, register_env
from ray.tune.registry import register_env
env_name = "MultiAgentMountainCarEnv"
@@ -51,6 +51,6 @@ if __name__ == '__main__':
"multiagent_shared_model": False,
"multiagent_fcnet_hiddens": [[32, 32]] * 2}
config["model"].update({"custom_options": options})
alg = ppo.PPOAgent(env=env_name, registry=get_registry(), config=config)
alg = ppo.PPOAgent(env=env_name, config=config)
for i in range(1):
alg.train()
@@ -8,7 +8,7 @@ from gym.envs.registration import register
import ray
import ray.rllib.ppo as ppo
from ray.tune.registry import get_registry, register_env
from ray.tune.registry import register_env
env_name = "MultiAgentPendulumEnv"
@@ -51,6 +51,6 @@ if __name__ == '__main__':
"multiagent_shared_model": True,
"multiagent_fcnet_hiddens": [[32, 32]] * 2}
config["model"].update({"custom_options": options})
alg = ppo.PPOAgent(env=env_name, registry=get_registry(), config=config)
alg = ppo.PPOAgent(env=env_name, config=config)
for i in range(1):
alg.train()
+12 -16
View File
@@ -8,7 +8,7 @@ import tensorflow as tf
from functools import partial
from ray.tune.registry import RLLIB_MODEL, RLLIB_PREPROCESSOR, \
_default_registry
_global_registry
from ray.rllib.models.action_dist import (
Categorical, Deterministic, DiagGaussian, MultiActionDistribution,
@@ -47,7 +47,7 @@ class ModelCatalog(object):
>>> observation = prep.transform(raw_observation)
>>> dist_cls, dist_dim = ModelCatalog.get_action_dist(env.action_space)
>>> model = ModelCatalog.get_model(registry, inputs, dist_dim)
>>> model = ModelCatalog.get_model(inputs, dist_dim)
>>> dist = dist_cls(model.outputs)
>>> action = dist.sample()
"""
@@ -130,11 +130,10 @@ class ModelCatalog(object):
" not supported".format(action_space))
@staticmethod
def get_model(registry, inputs, num_outputs, options={}):
def get_model(inputs, num_outputs, options={}):
"""Returns a suitable model conforming to given input and output specs.
Args:
registry (obj): Registry of named objects (ray.tune.registry).
inputs (Tensor): The input tensor to the model.
num_outputs (int): The size of the output vector of the model.
options (dict): Optional args to pass to the model constructor.
@@ -146,7 +145,7 @@ class ModelCatalog(object):
if "custom_model" in options:
model = options["custom_model"]
print("Using custom model {}".format(model))
return registry.get(RLLIB_MODEL, model)(
return _global_registry.get(RLLIB_MODEL, model)(
inputs, num_outputs, options)
obs_rank = len(inputs.shape) - 1
@@ -163,12 +162,11 @@ class ModelCatalog(object):
return FullyConnectedNetwork(inputs, num_outputs, options)
@staticmethod
def get_torch_model(registry, input_shape, num_outputs, options={}):
def get_torch_model(input_shape, num_outputs, options={}):
"""Returns a PyTorch suitable model. This is currently only supported
in A3C.
Args:
registry (obj): Registry of named objects (ray.tune.registry).
input_shape (tuple): The input shape to the model.
num_outputs (int): The size of the output vector of the model.
options (dict): Optional args to pass to the model constructor.
@@ -184,7 +182,7 @@ class ModelCatalog(object):
if "custom_model" in options:
model = options["custom_model"]
print("Using custom torch model {}".format(model))
return registry.get(RLLIB_MODEL, model)(
return _global_registry.get(RLLIB_MODEL, model)(
input_shape, num_outputs, options)
# TODO(alok): fix to handle Discrete(n) state spaces
@@ -198,11 +196,10 @@ class ModelCatalog(object):
return PyTorchFCNet(input_shape[0], num_outputs, options)
@staticmethod
def get_preprocessor(registry, env, options={}):
def get_preprocessor(env, options={}):
"""Returns a suitable processor for the given environment.
Args:
registry (obj): Registry of named objects (ray.tune.registry).
env (gym.Env): The gym environment to preprocess.
options (dict): Options to pass to the preprocessor.
@@ -218,18 +215,17 @@ class ModelCatalog(object):
if "custom_preprocessor" in options:
preprocessor = options["custom_preprocessor"]
print("Using custom preprocessor {}".format(preprocessor))
return registry.get(RLLIB_PREPROCESSOR, preprocessor)(
return _global_registry.get(RLLIB_PREPROCESSOR, preprocessor)(
env.observation_space, options)
preprocessor = get_preprocessor(env.observation_space)
return preprocessor(env.observation_space, options)
@staticmethod
def get_preprocessor_as_wrapper(registry, env, options={}):
def get_preprocessor_as_wrapper(env, options={}):
"""Returns a preprocessor as a gym observation wrapper.
Args:
registry (obj): Registry of named objects (ray.tune.registry).
env (gym.Env): The gym environment to wrap.
options (dict): Options to pass to the preprocessor.
@@ -237,7 +233,7 @@ class ModelCatalog(object):
wrapper (gym.ObservationWrapper): Preprocessor in wrapper form.
"""
preprocessor = ModelCatalog.get_preprocessor(registry, env, options)
preprocessor = ModelCatalog.get_preprocessor(env, options)
return _RLlibPreprocessorWrapper(env, preprocessor)
@staticmethod
@@ -251,7 +247,7 @@ class ModelCatalog(object):
preprocessor_name (str): Name to register the preprocessor under.
preprocessor_class (type): Python class of the preprocessor.
"""
_default_registry.register(
_global_registry.register(
RLLIB_PREPROCESSOR, preprocessor_name, preprocessor_class)
@staticmethod
@@ -265,7 +261,7 @@ class ModelCatalog(object):
model_name (str): Name to register the model under.
model_class (type): Python class of the model.
"""
_default_registry.register(RLLIB_MODEL, model_name, model_class)
_global_registry.register(RLLIB_MODEL, model_name, model_class)
class _RLlibPreprocessorWrapper(gym.ObservationWrapper):
-1
View File
@@ -55,7 +55,6 @@ class PGAgent(Agent):
"policy_graph": PGPolicyGraph,
"batch_steps": self.config["batch_size"],
"batch_mode": "truncate_episodes",
"registry": self.registry,
"model_config": self.config["model"],
"env_config": self.config["env_config"],
"policy_config": self.config,
+2 -2
View File
@@ -11,7 +11,7 @@ from ray.rllib.utils.tf_policy_graph import TFPolicyGraph
class PGPolicyGraph(TFPolicyGraph):
def __init__(self, obs_space, action_space, registry, config):
def __init__(self, obs_space, action_space, config):
self.config = config
# setup policy
@@ -19,7 +19,7 @@ class PGPolicyGraph(TFPolicyGraph):
dist_class, self.logit_dim = ModelCatalog.get_action_dist(
action_space, self.config["model"])
self.model = ModelCatalog.get_model(
registry, self.x, self.logit_dim, options=self.config["model"])
self.x, self.logit_dim, options=self.config["model"])
self.dist = dist_class(self.model.outputs) # logit for each action
# setup policy loss
+3 -3
View File
@@ -16,14 +16,14 @@ class ProximalPolicyGraph(object):
self, observation_space, action_space,
observations, value_targets, advantages, actions,
prev_logits, prev_vf_preds, logit_dim,
kl_coeff, distribution_class, config, sess, registry):
kl_coeff, distribution_class, config, sess):
self.prev_dist = distribution_class(prev_logits)
# Saved so that we can compute actions given different observations
self.observations = observations
self.curr_logits = ModelCatalog.get_model(
registry, observations, logit_dim, config["model"]).outputs
observations, logit_dim, config["model"]).outputs
self.curr_dist = distribution_class(self.curr_logits)
self.sampler = self.curr_dist.sample()
@@ -35,7 +35,7 @@ class ProximalPolicyGraph(object):
vf_config["free_log_std"] = False
with tf.variable_scope("value_function"):
self.value_function = ModelCatalog.get_model(
registry, observations, 1, vf_config).outputs
observations, 1, vf_config).outputs
self.value_function = tf.reshape(self.value_function, [-1])
# Make loss functions.
+2 -3
View File
@@ -103,14 +103,13 @@ class PPOAgent(Agent):
def _init(self):
self.global_step = 0
self.local_evaluator = PPOEvaluator(
self.registry, self.env_creator, self.config, self.logdir, False)
self.env_creator, self.config, self.logdir, False)
RemotePPOEvaluator = ray.remote(
num_cpus=self.config["num_cpus_per_worker"],
num_gpus=self.config["num_gpus_per_worker"])(PPOEvaluator)
self.remote_evaluators = [
RemotePPOEvaluator.remote(
self.registry, self.env_creator, self.config, self.logdir,
True)
self.env_creator, self.config, self.logdir, True)
for _ in range(self.config["num_workers"])]
self.optimizer = LocalMultiGPUOptimizer(
+3 -4
View File
@@ -24,12 +24,11 @@ class PPOEvaluator(TFMultiGPUSupport):
network weights. When run as a remote agent, only this graph is used.
"""
def __init__(self, registry, env_creator, config, logdir, is_remote):
self.registry = registry
def __init__(self, env_creator, config, logdir, is_remote):
self.config = config
self.logdir = logdir
self.env = ModelCatalog.get_preprocessor_as_wrapper(
registry, env_creator(config["env_config"]), config["model"])
env_creator(config["env_config"]), config["model"])
if is_remote:
config_proto = tf.ConfigProto()
else:
@@ -92,7 +91,7 @@ class PPOEvaluator(TFMultiGPUSupport):
self.env.observation_space, self.env.action_space,
obs, vtargets, advs, acts, plog, pvf_preds, self.logit_dim,
self.kl_coeff, self.distribution_class, self.config,
self.sess, self.registry)
self.sess)
def init_extra_ops(self, device_losses):
self.extra_ops = OrderedDict()
+2 -4
View File
@@ -14,7 +14,6 @@ import ray
from ray.rllib.agent import get_agent_class
from ray.rllib.dqn.common.wrappers import wrap_dqn
from ray.rllib.models import ModelCatalog
from ray.tune.registry import get_registry
EXAMPLE_USAGE = """
example usage:
@@ -74,10 +73,9 @@ if __name__ == "__main__":
if args.run == "DQN":
env = gym.make(args.env)
env = wrap_dqn(get_registry(), env, args.config.get("model", {}))
env = wrap_dqn(env, args.config.get("model", {}))
else:
env = ModelCatalog.get_preprocessor_as_wrapper(get_registry(),
gym.make(args.env))
env = ModelCatalog.get_preprocessor_as_wrapper(gym.make(args.env))
if args.out is not None:
rollouts = []
steps = 0
+9 -17
View File
@@ -5,7 +5,6 @@ import unittest
from gym.spaces import Box, Discrete, Tuple
import ray
from ray.tune.registry import get_registry
from ray.rllib.models import ModelCatalog
from ray.rllib.models.model import Model
@@ -33,12 +32,10 @@ class ModelCatalogTest(unittest.TestCase):
ray.worker.cleanup()
def testGymPreprocessors(self):
p1 = ModelCatalog.get_preprocessor(
get_registry(), gym.make("CartPole-v0"))
p1 = ModelCatalog.get_preprocessor(gym.make("CartPole-v0"))
self.assertEqual(type(p1), NoPreprocessor)
p2 = ModelCatalog.get_preprocessor(
get_registry(), gym.make("FrozenLake-v0"))
p2 = ModelCatalog.get_preprocessor(gym.make("FrozenLake-v0"))
self.assertEqual(type(p2), OneHotPreprocessor)
def testTuplePreprocessor(self):
@@ -48,8 +45,7 @@ class ModelCatalogTest(unittest.TestCase):
def __init__(self):
self.observation_space = Tuple(
[Discrete(5), Box(0, 1, shape=(3,), dtype=np.float32)])
p1 = ModelCatalog.get_preprocessor(
get_registry(), TupleEnv())
p1 = ModelCatalog.get_preprocessor(TupleEnv())
self.assertEqual(p1.shape, (8,))
self.assertEqual(
list(p1.transform((0, [1, 2, 3]))),
@@ -60,33 +56,29 @@ class ModelCatalogTest(unittest.TestCase):
ModelCatalog.register_custom_preprocessor("foo", CustomPreprocessor)
ModelCatalog.register_custom_preprocessor("bar", CustomPreprocessor2)
env = gym.make("CartPole-v0")
p1 = ModelCatalog.get_preprocessor(
get_registry(), env, {"custom_preprocessor": "foo"})
p1 = ModelCatalog.get_preprocessor(env, {"custom_preprocessor": "foo"})
self.assertEqual(str(type(p1)), str(CustomPreprocessor))
p2 = ModelCatalog.get_preprocessor(
get_registry(), env, {"custom_preprocessor": "bar"})
p2 = ModelCatalog.get_preprocessor(env, {"custom_preprocessor": "bar"})
self.assertEqual(str(type(p2)), str(CustomPreprocessor2))
p3 = ModelCatalog.get_preprocessor(get_registry(), env)
p3 = ModelCatalog.get_preprocessor(env)
self.assertEqual(type(p3), NoPreprocessor)
def testDefaultModels(self):
ray.init()
with tf.variable_scope("test1"):
p1 = ModelCatalog.get_model(
get_registry(), np.zeros((10, 3), dtype=np.float32), 5)
p1 = ModelCatalog.get_model(np.zeros((10, 3), dtype=np.float32), 5)
self.assertEqual(type(p1), FullyConnectedNetwork)
with tf.variable_scope("test2"):
p2 = ModelCatalog.get_model(
get_registry(), np.zeros((10, 80, 80, 3), dtype=np.float32), 5)
np.zeros((10, 80, 80, 3), dtype=np.float32), 5)
self.assertEqual(type(p2), VisionNetwork)
def testCustomModel(self):
ray.init()
ModelCatalog.register_custom_model("foo", CustomModel)
p1 = ModelCatalog.get_model(
get_registry(), 1, 5, {"custom_model": "foo"})
p1 = ModelCatalog.get_model(1, 5, {"custom_model": "foo"})
self.assertEqual(str(type(p1)), str(CustomModel))
@@ -59,8 +59,7 @@ def check_support(alg, config, stats):
for o_name, obs_space in OBSERVATION_SPACES_TO_TEST.items():
print("=== Testing", alg, action_space, obs_space, "===")
stub_env = make_stub_env(action_space, obs_space)
register_env(
"stub_env", lambda c: stub_env())
register_env("stub_env", lambda c: stub_env())
stat = "ok"
a = None
try:
@@ -18,7 +18,6 @@ from ray.rllib.utils.sampler import AsyncSampler, SyncSampler
from ray.rllib.utils.serving_env import ServingEnv, _ServingEnvToAsync
from ray.rllib.utils.tf_policy_graph import TFPolicyGraph
from ray.rllib.utils.vector_env import VectorEnv
from ray.tune.registry import get_registry
from ray.tune.result import TrainingResult
@@ -97,7 +96,6 @@ class CommonPolicyEvaluator(PolicyEvaluator):
compress_observations=False,
num_envs=1,
observation_filter="NoFilter",
registry=None,
env_config=None,
model_config=None,
policy_config=None):
@@ -137,15 +135,11 @@ class CommonPolicyEvaluator(PolicyEvaluator):
and vectorize the computation of actions. This has no effect if
if the env already implements VectorEnv.
observation_filter (str): Name of observation filter to use.
registry (tune.Registry): User-registered objects. Pass in the
value from tune.registry.get_registry() if you're having
trouble resolving things like custom envs.
env_config (dict): Config to pass to the env creator.
model_config (dict): Config to use when creating the policy model.
policy_config (dict): Config to pass to the policy.
"""
registry = registry or get_registry()
env_config = env_config or {}
policy_config = policy_config or {}
model_config = model_config or {}
@@ -169,7 +163,7 @@ class CommonPolicyEvaluator(PolicyEvaluator):
else:
def wrap(env):
return ModelCatalog.get_preprocessor_as_wrapper(
registry, env, model_config)
env, model_config)
self.env = wrap(self.env)
def make_env():
@@ -187,11 +181,11 @@ class CommonPolicyEvaluator(PolicyEvaluator):
with self.sess.as_default():
policy = policy_graph(
self.env.observation_space, self.env.action_space,
registry, policy_config)
policy_config)
else:
policy = policy_graph(
self.env.observation_space, self.env.action_space,
registry, policy_config)
policy_config)
self.policy_map = {
"default": policy
}
+1 -2
View File
@@ -17,11 +17,10 @@ class PolicyGraph(object):
graphs and multi-GPU support.
"""
def __init__(self, registry, observation_space, action_space, config):
def __init__(self, observation_space, action_space, config):
"""Initialize the graph.
Args:
registry (obj): Object registry for user-defined envs, models, etc.
observation_space (gym.Space): Observation space of the env.
action_space (gym.Space): Action space of the env.
config (dict): Policy-specific configuration data.
+38 -41
View File
@@ -4,10 +4,10 @@ from __future__ import print_function
from types import FunctionType
import numpy as np
import ray
from ray.local_scheduler import ObjectID
import ray.cloudpickle as pickle
from ray.experimental.internal_kv import _internal_kv_initialized, \
_internal_kv_get, _internal_kv_put
TRAINABLE_CLASS = "trainable_class"
ENV_CREATOR = "env_creator"
@@ -35,7 +35,7 @@ def register_trainable(name, trainable):
if not issubclass(trainable, Trainable):
raise TypeError("Second argument must be convertable to Trainable",
trainable)
_default_registry.register(TRAINABLE_CLASS, name, trainable)
_global_registry.register(TRAINABLE_CLASS, name, trainable)
def register_env(name, env_creator):
@@ -48,62 +48,59 @@ def register_env(name, env_creator):
if not isinstance(env_creator, FunctionType):
raise TypeError("Second argument must be a function.", env_creator)
_default_registry.register(ENV_CREATOR, name, env_creator)
_global_registry.register(ENV_CREATOR, name, env_creator)
def get_registry():
"""Use this to access the registry. This requires ray to be initialized."""
def _make_key(category, key):
"""Generate a binary key for the given category and key.
_default_registry.flush_values_to_object_store()
Args:
category (str): The category of the item
key (str): The unique identifier for the item
# returns a registry copy that doesn't include the hard refs
return _Registry(_default_registry._all_objects)
def _to_pinnable(obj):
"""Converts obj to a form that can be pinned in object store memory.
Currently only numpy arrays are pinned in memory, if you have a strong
reference to the array value.
Returns:
The key to use for storing a the value.
"""
return (obj, np.zeros(1))
def _from_pinnable(obj):
"""Retrieve from _to_pinnable format."""
return obj[0]
return (b"TuneRegistry:" + category.encode("ascii") + b"/" +
key.encode("ascii"))
class _Registry(object):
def __init__(self, objs=None):
self._all_objects = {} if objs is None else objs.copy()
self._refs = [] # hard refs that prevent eviction of objects
def __init__(self):
self._to_flush = {}
def register(self, category, key, value):
if category not in KNOWN_CATEGORIES:
from ray.tune import TuneError
raise TuneError("Unknown category {} not among {}".format(
category, KNOWN_CATEGORIES))
self._all_objects[(category, key)] = value
self._to_flush[(category, key)] = pickle.dumps(value)
if _internal_kv_initialized():
self.flush_values()
def contains(self, category, key):
return (category, key) in self._all_objects
if _internal_kv_initialized():
value = _internal_kv_get(_make_key(category, key))
return value is not None
else:
return (category, key) in self._to_flush
def get(self, category, key):
value = self._all_objects[(category, key)]
if type(value) == ObjectID:
return _from_pinnable(ray.get(value))
if _internal_kv_initialized():
value = _internal_kv_get(_make_key(category, key))
if value is None:
raise ValueError(
"Registry value for {}/{} doesn't exist.".format(
category, key))
return pickle.loads(value)
else:
return value
return pickle.loads(self._to_flush[(category, key)])
def flush_values_to_object_store(self):
for k, v in self._all_objects.items():
if type(v) != ObjectID:
obj = ray.put(_to_pinnable(v))
self._all_objects[k] = obj
self._refs.append(ray.get(obj))
def flush_values(self):
for (category, key), value in self._to_flush.items():
_internal_kv_put(_make_key(category, key), value)
self._to_flush.clear()
_default_registry = _Registry()
_global_registry = _Registry()
ray.worker._post_init_hooks.append(_global_registry.flush_values)
+3 -3
View File
@@ -11,7 +11,7 @@ from ray.rllib import _register_all
from ray.tune import Trainable, TuneError
from ray.tune import register_env, register_trainable, run_experiments
from ray.tune.registry import _default_registry, TRAINABLE_CLASS
from ray.tune.registry import _global_registry, TRAINABLE_CLASS
from ray.tune.result import DEFAULT_RESULTS_DIR, TrainingResult
from ray.tune.util import pin_in_object_store, get_pinned_object
from ray.tune.experiment import Experiment
@@ -595,7 +595,7 @@ class TrialRunnerTest(unittest.TestCase):
def testTrialErrorOnStart(self):
ray.init()
_default_registry.register(TRAINABLE_CLASS, "asdf", None)
_global_registry.register(TRAINABLE_CLASS, "asdf", None)
trial = Trial("asdf", resources=Resources(1, 0))
try:
trial.start()
@@ -690,7 +690,7 @@ class TrialRunnerTest(unittest.TestCase):
},
"resources": Resources(cpu=1, gpu=1),
}
_default_registry.register(TRAINABLE_CLASS, "asdf", None)
_global_registry.register(TRAINABLE_CLASS, "asdf", None)
trials = [Trial("asdf", **kwargs), Trial("__fake", **kwargs)]
for t in trials:
runner.add_trial(t)
+1 -10
View File
@@ -46,11 +46,9 @@ class Trainable(object):
Attributes:
config (obj): The hyperparam configuration for this trial.
logdir (str): Directory in which training outputs should be placed.
registry (obj): Tune object registry which holds user-registered
classes and objects by name.
"""
def __init__(self, config=None, registry=None, logger_creator=None):
def __init__(self, config=None, logger_creator=None):
"""Initialize an Trainable.
Subclasses should prefer defining ``_setup()`` instead of overriding
@@ -58,20 +56,13 @@ class Trainable(object):
Args:
config (dict): Trainable-specific configuration data.
registry (obj): Object registry for user-defined envs, models, etc.
If unspecified, the default registry will be used.
logger_creator (func): Function that creates a ray.tune.Logger
object. If unspecified, a default logger is created.
"""
if registry is None:
from ray.tune.registry import get_registry
registry = get_registry()
self._initialize_ok = False
self._experiment_id = uuid.uuid4().hex
self.config = config or {}
self.registry = registry
if logger_creator:
self._result_logger = logger_creator(self.config)
+3 -5
View File
@@ -57,7 +57,7 @@ class Resources(
def has_trainable(trainable_name):
return ray.tune.registry._default_registry.contains(
return ray.tune.registry._global_registry.contains(
ray.tune.registry.TRAINABLE_CLASS, trainable_name)
@@ -377,12 +377,10 @@ class Trial(object):
# Logging for trials is handled centrally by TrialRunner, so
# configure the remote runner to use a noop-logger.
self.runner = cls.remote(
config=self.config,
registry=ray.tune.registry.get_registry(),
logger_creator=logger_creator)
config=self.config, logger_creator=logger_creator)
def _get_trainable_cls(self):
return ray.tune.registry.get_registry().get(
return ray.tune.registry._global_registry.get(
ray.tune.registry.TRAINABLE_CLASS, self.trainable_name)
def set_verbose(self, verbose):
+18 -2
View File
@@ -2,12 +2,12 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import base64
from six.moves import queue
import base64
import numpy as np
import threading
import ray
from ray.tune.registry import _to_pinnable, _from_pinnable
_pinned_objects = []
_fetch_requests = queue.Queue()
@@ -63,6 +63,22 @@ def _serve_get_pin_requests():
pass
def _to_pinnable(obj):
"""Converts obj to a form that can be pinned in object store memory.
Currently only numpy arrays are pinned in memory, if you have a strong
reference to the array value.
"""
return (obj, np.zeros(1))
def _from_pinnable(obj):
"""Retrieve from _to_pinnable format."""
return obj[0]
if __name__ == '__main__':
ray.init()
X = pin_in_object_store("hello")
+8 -1
View File
@@ -1741,7 +1741,7 @@ def init(redis_address=None,
redis_address = services.address_to_ip(redis_address)
info = {"node_ip_address": node_ip_address, "redis_address": redis_address}
return _init(
ret = _init(
address_info=info,
start_ray_local=(redis_address is None),
num_workers=num_workers,
@@ -1758,6 +1758,13 @@ def init(redis_address=None,
include_webui=include_webui,
object_store_memory=object_store_memory,
use_raylet=use_raylet)
for hook in _post_init_hooks:
hook()
return ret
# Functions to run as callback after a successful ray init
_post_init_hooks = []
def cleanup(worker=global_worker):