mirror of
https://github.com/wassname/ray.git
synced 2026-06-29 14:39:44 +08:00
[rllib] Remove need to pass around registry (#2250)
* remove registry * fix * too many _ * fix * cloudpickle * Update registry.py * yapf * fix test * fix kv check
This commit is contained in:
@@ -0,0 +1,31 @@
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import ray
|
||||
|
||||
|
||||
def _internal_kv_initialized():
|
||||
worker = ray.worker.get_global_worker()
|
||||
return hasattr(worker, "mode") and worker.mode is not None
|
||||
|
||||
|
||||
def _internal_kv_get(key):
|
||||
"""Fetch the value of a binary key."""
|
||||
|
||||
worker = ray.worker.get_global_worker()
|
||||
return worker.redis_client.hget(key, "value")
|
||||
|
||||
|
||||
def _internal_kv_put(key, value):
|
||||
"""Globally associates a value with a given binary key.
|
||||
|
||||
This only has an effect if the key does not already have a value.
|
||||
|
||||
Returns
|
||||
already_exists (bool): whether the value already exists.
|
||||
"""
|
||||
|
||||
worker = ray.worker.get_global_worker()
|
||||
updated = worker.redis_client.hsetnx(key, "value", value)
|
||||
return updated == 0 # already exists
|
||||
@@ -4,6 +4,7 @@ from __future__ import print_function
|
||||
|
||||
import ray
|
||||
import ray.cloudpickle as pickle
|
||||
from ray.experimental.internal_kv import _internal_kv_get, _internal_kv_put
|
||||
|
||||
|
||||
def _calculate_key(name):
|
||||
@@ -29,9 +30,8 @@ def get_actor(name):
|
||||
Returns:
|
||||
The ActorHandle object corresponding to the name.
|
||||
"""
|
||||
worker = ray.worker.get_global_worker()
|
||||
actor_hash = _calculate_key(name)
|
||||
pickled_state = worker.redis_client.hget(actor_hash, name)
|
||||
actor_name = _calculate_key(name)
|
||||
pickled_state = _internal_kv_get(actor_name)
|
||||
if pickled_state is None:
|
||||
raise ValueError("The actor with name={} doesn't exist".format(name))
|
||||
handle = pickle.loads(pickled_state)
|
||||
@@ -45,17 +45,16 @@ def register_actor(name, actor_handle):
|
||||
name: The name of the named actor.
|
||||
actor_handle: The actor object to be associated with this name
|
||||
"""
|
||||
worker = ray.worker.get_global_worker()
|
||||
if not isinstance(name, str):
|
||||
raise TypeError("The name argument must be a string.")
|
||||
if not isinstance(actor_handle, ray.actor.ActorHandle):
|
||||
raise TypeError("The actor_handle argument must be an ActorHandle "
|
||||
"object.")
|
||||
actor_hash = _calculate_key(name)
|
||||
actor_name = _calculate_key(name)
|
||||
pickled_state = pickle.dumps(actor_handle)
|
||||
|
||||
# Add the actor to Redis if it does not already exist.
|
||||
updated = worker.redis_client.hsetnx(actor_hash, name, pickled_state)
|
||||
if updated == 0:
|
||||
already_exists = _internal_kv_put(actor_name, pickled_state)
|
||||
if already_exists:
|
||||
raise ValueError(
|
||||
"Error: the actor with name={} already exists".format(name))
|
||||
|
||||
@@ -102,7 +102,7 @@ class A3CAgent(Agent):
|
||||
batch_steps=self.config["batch_size"],
|
||||
batch_mode="truncate_episodes",
|
||||
tf_session_creator=session_creator,
|
||||
registry=self.registry, env_config=self.config["env_config"],
|
||||
env_config=self.config["env_config"],
|
||||
model_config=self.config["model"], policy_config=self.config,
|
||||
num_envs=self.config["num_envs"])
|
||||
self.remote_evaluators = [
|
||||
@@ -111,7 +111,7 @@ class A3CAgent(Agent):
|
||||
batch_steps=self.config["batch_size"],
|
||||
batch_mode="truncate_episodes", sample_async=True,
|
||||
tf_session_creator=session_creator,
|
||||
registry=self.registry, env_config=self.config["env_config"],
|
||||
env_config=self.config["env_config"],
|
||||
model_config=self.config["model"], policy_config=self.config,
|
||||
num_envs=self.config["num_envs"])
|
||||
for i in range(self.config["num_workers"])]
|
||||
|
||||
@@ -13,8 +13,7 @@ from ray.rllib.utils.tf_policy_graph import TFPolicyGraph
|
||||
class A3CTFPolicyGraph(TFPolicyGraph):
|
||||
"""The TF policy base class."""
|
||||
|
||||
def __init__(self, ob_space, action_space, registry, config):
|
||||
self.registry = registry
|
||||
def __init__(self, ob_space, action_space, config):
|
||||
self.local_steps = 0
|
||||
self.config = config
|
||||
self.summarize = config.get("summarize")
|
||||
|
||||
@@ -17,8 +17,7 @@ from ray.rllib.utils.policy_graph import PolicyGraph
|
||||
class SharedTorchPolicy(PolicyGraph):
|
||||
"""A simple, non-recurrent PyTorch policy example."""
|
||||
|
||||
def __init__(self, obs_space, action_space, registry, config):
|
||||
self.registry = registry
|
||||
def __init__(self, obs_space, action_space, config):
|
||||
self.local_steps = 0
|
||||
self.config = config
|
||||
self.summarize = config.get("summarize")
|
||||
@@ -30,8 +29,7 @@ class SharedTorchPolicy(PolicyGraph):
|
||||
_, self.logit_dim = ModelCatalog.get_action_dist(
|
||||
action_space, self.config["model"])
|
||||
self._model = ModelCatalog.get_torch_model(
|
||||
self.registry, obs_space.shape, self.logit_dim,
|
||||
self.config["model"])
|
||||
obs_space.shape, self.logit_dim, self.config["model"])
|
||||
self.optimizer = torch.optim.Adam(
|
||||
self._model.parameters(), lr=self.config["lr"])
|
||||
|
||||
|
||||
@@ -10,16 +10,16 @@ from ray.rllib.models.catalog import ModelCatalog
|
||||
|
||||
class SharedModel(A3CTFPolicyGraph):
|
||||
|
||||
def __init__(self, ob_space, ac_space, registry, config, **kwargs):
|
||||
def __init__(self, ob_space, ac_space, config, **kwargs):
|
||||
super(SharedModel, self).__init__(
|
||||
ob_space, ac_space, registry, config, **kwargs)
|
||||
ob_space, ac_space, config, **kwargs)
|
||||
|
||||
def _setup_graph(self, ob_space, ac_space):
|
||||
self.x = tf.placeholder(tf.float32, [None] + list(ob_space.shape))
|
||||
dist_class, self.logit_dim = ModelCatalog.get_action_dist(
|
||||
ac_space, self.config["model"])
|
||||
self._model = ModelCatalog.get_model(
|
||||
self.registry, self.x, self.logit_dim, self.config["model"])
|
||||
self.x, self.logit_dim, self.config["model"])
|
||||
self.logits = self._model.outputs
|
||||
self.action_dist = dist_class(self.logits)
|
||||
self.vf = tf.reshape(linear(self._model.last_layer, 1, "value",
|
||||
|
||||
@@ -11,9 +11,9 @@ from ray.rllib.models.lstm import LSTM
|
||||
|
||||
class SharedModelLSTM(A3CTFPolicyGraph):
|
||||
|
||||
def __init__(self, ob_space, ac_space, registry, config, **kwargs):
|
||||
def __init__(self, ob_space, ac_space, config, **kwargs):
|
||||
super(SharedModelLSTM, self).__init__(
|
||||
ob_space, ac_space, registry, config, **kwargs)
|
||||
ob_space, ac_space, config, **kwargs)
|
||||
|
||||
def _setup_graph(self, ob_space, ac_space):
|
||||
self.x = tf.placeholder(tf.float32, [None] + list(ob_space.shape))
|
||||
|
||||
@@ -9,7 +9,7 @@ import os
|
||||
import pickle
|
||||
|
||||
import tensorflow as tf
|
||||
from ray.tune.registry import ENV_CREATOR
|
||||
from ray.tune.registry import ENV_CREATOR, _global_registry
|
||||
from ray.tune.result import TrainingResult
|
||||
from ray.tune.trainable import Trainable
|
||||
|
||||
@@ -56,8 +56,6 @@ class Agent(Trainable):
|
||||
env_creator (func): Function that creates a new training env.
|
||||
config (obj): Algorithm-specific configuration data.
|
||||
logdir (str): Directory in which training outputs should be placed.
|
||||
registry (obj): Tune object registry which holds user-registered
|
||||
classes and objects by name.
|
||||
"""
|
||||
|
||||
_allow_unknown_configs = False
|
||||
@@ -72,16 +70,13 @@ class Agent(Trainable):
|
||||
"The config of this agent is: " + json.dumps(config))
|
||||
|
||||
def __init__(
|
||||
self, config=None, env=None, registry=None,
|
||||
logger_creator=None):
|
||||
self, config=None, env=None, logger_creator=None):
|
||||
"""Initialize an RLLib agent.
|
||||
|
||||
Args:
|
||||
config (dict): Algorithm-specific configuration data.
|
||||
env (str): Name of the environment to use. Note that this can also
|
||||
be specified as the `env` key in config.
|
||||
registry (obj): Object registry for user-defined envs, models, etc.
|
||||
If unspecified, the default registry will be used.
|
||||
logger_creator (func): Function that creates a ray.tune.Logger
|
||||
object. If unspecified, a default logger is created.
|
||||
"""
|
||||
@@ -90,14 +85,14 @@ class Agent(Trainable):
|
||||
|
||||
# Agents allow env ids to be passed directly to the constructor.
|
||||
self._env_id = env or config.get("env")
|
||||
Trainable.__init__(self, config, registry, logger_creator)
|
||||
Trainable.__init__(self, config, logger_creator)
|
||||
|
||||
def _setup(self):
|
||||
env = self._env_id
|
||||
if env:
|
||||
self.config["env"] = env
|
||||
if self.registry and self.registry.contains(ENV_CREATOR, env):
|
||||
self.env_creator = self.registry.get(ENV_CREATOR, env)
|
||||
if _global_registry.contains(ENV_CREATOR, env):
|
||||
self.env_creator = _global_registry.get(ENV_CREATOR, env)
|
||||
else:
|
||||
import gym # soft dependency
|
||||
self.env_creator = lambda env_config: gym.make(env)
|
||||
|
||||
@@ -63,14 +63,13 @@ class BCAgent(Agent):
|
||||
|
||||
def _init(self):
|
||||
self.local_evaluator = BCEvaluator(
|
||||
self.registry, self.env_creator, self.config, self.logdir)
|
||||
self.env_creator, self.config, self.logdir)
|
||||
if self.config["use_gpu_for_workers"]:
|
||||
remote_cls = GPURemoteBCEvaluator
|
||||
else:
|
||||
remote_cls = RemoteBCEvaluator
|
||||
self.remote_evaluators = [
|
||||
remote_cls.remote(
|
||||
self.registry, self.env_creator, self.config, self.logdir)
|
||||
remote_cls.remote(self.env_creator, self.config, self.logdir)
|
||||
for _ in range(self.config["num_workers"])]
|
||||
self.optimizer = AsyncOptimizer(
|
||||
self.config["optimizer"], self.local_evaluator,
|
||||
|
||||
@@ -13,12 +13,11 @@ from ray.rllib.optimizers import PolicyEvaluator
|
||||
|
||||
|
||||
class BCEvaluator(PolicyEvaluator):
|
||||
def __init__(self, registry, env_creator, config, logdir):
|
||||
env = ModelCatalog.get_preprocessor_as_wrapper(registry, env_creator(
|
||||
def __init__(self, env_creator, config, logdir):
|
||||
env = ModelCatalog.get_preprocessor_as_wrapper(env_creator(
|
||||
config["env_config"]), config["model"])
|
||||
self.dataset = ExperienceDataset(config["dataset_path"])
|
||||
self.policy = BCPolicy(registry, env.observation_space,
|
||||
env.action_space, config)
|
||||
self.policy = BCPolicy(env.observation_space, env.action_space, config)
|
||||
self.config = config
|
||||
self.logdir = logdir
|
||||
self.metrics_queue = queue.Queue()
|
||||
|
||||
@@ -10,8 +10,7 @@ from ray.rllib.models.catalog import ModelCatalog
|
||||
|
||||
|
||||
class BCPolicy(object):
|
||||
def __init__(self, registry, obs_space, action_space, config):
|
||||
self.registry = registry
|
||||
def __init__(self, obs_space, action_space, config):
|
||||
self.local_steps = 0
|
||||
self.config = config
|
||||
self.summarize = config.get("summarize")
|
||||
@@ -25,7 +24,7 @@ class BCPolicy(object):
|
||||
dist_class, self.logit_dim = ModelCatalog.get_action_dist(
|
||||
ac_space, self.config["model"])
|
||||
self._model = ModelCatalog.get_model(
|
||||
self.registry, self.x, self.logit_dim, self.config["model"])
|
||||
self.x, self.logit_dim, self.config["model"])
|
||||
self.logits = self._model.outputs
|
||||
self.curr_dist = dist_class(self.logits)
|
||||
self.sample = self.curr_dist.sample()
|
||||
|
||||
@@ -22,12 +22,12 @@ Q_SCOPE = "q_func"
|
||||
Q_TARGET_SCOPE = "target_q_func"
|
||||
|
||||
|
||||
def _build_p_network(registry, inputs, dim_actions, config):
|
||||
def _build_p_network(inputs, dim_actions, config):
|
||||
"""
|
||||
map an observation (i.e., state) to an action where
|
||||
each entry takes value from (0, 1) due to the sigmoid function
|
||||
"""
|
||||
frontend = ModelCatalog.get_model(registry, inputs, 1, config["model"])
|
||||
frontend = ModelCatalog.get_model(inputs, 1, config["model"])
|
||||
|
||||
hiddens = config["actor_hiddens"]
|
||||
action_out = frontend.last_layer
|
||||
@@ -66,8 +66,8 @@ def _build_action_network(p_values, low_action, high_action, stochastic, eps,
|
||||
lambda: deterministic_actions)
|
||||
|
||||
|
||||
def _build_q_network(registry, inputs, action_inputs, config):
|
||||
frontend = ModelCatalog.get_model(registry, inputs, 1, config["model"])
|
||||
def _build_q_network(inputs, action_inputs, config):
|
||||
frontend = ModelCatalog.get_model(inputs, 1, config["model"])
|
||||
|
||||
hiddens = config["critic_hiddens"]
|
||||
|
||||
@@ -81,7 +81,7 @@ def _build_q_network(registry, inputs, action_inputs, config):
|
||||
|
||||
|
||||
class DDPGPolicyGraph(TFPolicyGraph):
|
||||
def __init__(self, observation_space, action_space, registry, config):
|
||||
def __init__(self, observation_space, action_space, config):
|
||||
if not isinstance(action_space, Box):
|
||||
raise UnsupportedSpaceException(
|
||||
"Action space {} is not supported for DDPG.".format(
|
||||
@@ -105,7 +105,7 @@ class DDPGPolicyGraph(TFPolicyGraph):
|
||||
|
||||
# Actor: P (policy) network
|
||||
with tf.variable_scope(P_SCOPE) as scope:
|
||||
p_values = _build_p_network(registry, self.cur_observations,
|
||||
p_values = _build_p_network(self.cur_observations,
|
||||
dim_actions, config)
|
||||
self.p_func_vars = _scope_vars(scope.name)
|
||||
|
||||
@@ -136,13 +136,11 @@ class DDPGPolicyGraph(TFPolicyGraph):
|
||||
|
||||
# p network evaluation
|
||||
with tf.variable_scope(P_SCOPE, reuse=True) as scope:
|
||||
self.p_t = _build_p_network(
|
||||
registry, self.obs_t, dim_actions, config)
|
||||
self.p_t = _build_p_network(self.obs_t, dim_actions, config)
|
||||
|
||||
# target p network evaluation
|
||||
with tf.variable_scope(P_TARGET_SCOPE) as scope:
|
||||
p_tp1 = _build_p_network(
|
||||
registry, self.obs_tp1, dim_actions, config)
|
||||
p_tp1 = _build_p_network(self.obs_tp1, dim_actions, config)
|
||||
target_p_func_vars = _scope_vars(scope.name)
|
||||
|
||||
# Action outputs
|
||||
@@ -161,17 +159,15 @@ class DDPGPolicyGraph(TFPolicyGraph):
|
||||
|
||||
# q network evaluation
|
||||
with tf.variable_scope(Q_SCOPE) as scope:
|
||||
q_t = _build_q_network(
|
||||
registry, self.obs_t, self.act_t, config)
|
||||
q_t = _build_q_network(self.obs_t, self.act_t, config)
|
||||
self.q_func_vars = _scope_vars(scope.name)
|
||||
with tf.variable_scope(Q_SCOPE, reuse=True):
|
||||
q_tp0 = _build_q_network(
|
||||
registry, self.obs_t, output_actions, config)
|
||||
q_tp0 = _build_q_network(self.obs_t, output_actions, config)
|
||||
|
||||
# target q network evalution
|
||||
with tf.variable_scope(Q_TARGET_SCOPE) as scope:
|
||||
q_tp1 = _build_q_network(
|
||||
registry, self.obs_tp1, output_actions_estimated, config)
|
||||
self.obs_tp1, output_actions_estimated, config)
|
||||
target_q_func_vars = _scope_vars(scope.name)
|
||||
|
||||
q_t_selected = tf.squeeze(q_t, axis=len(q_t.shape) - 1)
|
||||
|
||||
@@ -6,7 +6,7 @@ from ray.rllib.models import ModelCatalog
|
||||
from ray.rllib.utils.atari_wrappers import wrap_deepmind
|
||||
|
||||
|
||||
def wrap_dqn(registry, env, options, random_starts):
|
||||
def wrap_dqn(env, options, random_starts):
|
||||
"""Apply a common set of wrappers for DQN."""
|
||||
|
||||
is_atari = hasattr(env.unwrapped, "ale")
|
||||
@@ -17,4 +17,4 @@ def wrap_dqn(registry, env, options, random_starts):
|
||||
return wrap_deepmind(
|
||||
env, random_starts=random_starts, dim=options.get("dim", 80))
|
||||
|
||||
return ModelCatalog.get_preprocessor_as_wrapper(registry, env, options)
|
||||
return ModelCatalog.get_preprocessor_as_wrapper(env, options)
|
||||
|
||||
@@ -129,7 +129,7 @@ class DQNAgent(Agent):
|
||||
batch_steps=adjusted_batch_size,
|
||||
batch_mode="truncate_episodes", preprocessor_pref="deepmind",
|
||||
compress_observations=True,
|
||||
registry=self.registry, env_config=self.config["env_config"],
|
||||
env_config=self.config["env_config"],
|
||||
model_config=self.config["model"], policy_config=self.config,
|
||||
num_envs=self.config["num_envs"])
|
||||
remote_cls = CommonPolicyEvaluator.as_remote(
|
||||
@@ -141,7 +141,7 @@ class DQNAgent(Agent):
|
||||
batch_steps=adjusted_batch_size,
|
||||
batch_mode="truncate_episodes", preprocessor_pref="deepmind",
|
||||
compress_observations=True,
|
||||
registry=self.registry, env_config=self.config["env_config"],
|
||||
env_config=self.config["env_config"],
|
||||
model_config=self.config["model"], policy_config=self.config,
|
||||
num_envs=self.config["num_envs"])
|
||||
for _ in range(self.config["num_workers"])]
|
||||
|
||||
@@ -46,7 +46,7 @@ def adjust_nstep(n_step, gamma, obs, actions, rewards, new_obs, dones):
|
||||
|
||||
|
||||
class DQNPolicyGraph(TFPolicyGraph):
|
||||
def __init__(self, observation_space, action_space, registry, config):
|
||||
def __init__(self, observation_space, action_space, config):
|
||||
if not isinstance(action_space, Discrete):
|
||||
raise UnsupportedSpaceException(
|
||||
"Action space {} is not supported for DQN.".format(
|
||||
@@ -65,7 +65,7 @@ class DQNPolicyGraph(TFPolicyGraph):
|
||||
# Action Q network
|
||||
with tf.variable_scope(Q_SCOPE) as scope:
|
||||
q_values = _build_q_network(
|
||||
registry, self.cur_observations, num_actions, config)
|
||||
self.cur_observations, num_actions, config)
|
||||
self.q_func_vars = _scope_vars(scope.name)
|
||||
|
||||
# Action outputs
|
||||
@@ -89,13 +89,11 @@ class DQNPolicyGraph(TFPolicyGraph):
|
||||
|
||||
# q network evaluation
|
||||
with tf.variable_scope(Q_SCOPE, reuse=True):
|
||||
q_t = _build_q_network(
|
||||
registry, self.obs_t, num_actions, config)
|
||||
q_t = _build_q_network(self.obs_t, num_actions, config)
|
||||
|
||||
# target q network evalution
|
||||
with tf.variable_scope(Q_TARGET_SCOPE) as scope:
|
||||
q_tp1 = _build_q_network(
|
||||
registry, self.obs_tp1, num_actions, config)
|
||||
q_tp1 = _build_q_network(self.obs_tp1, num_actions, config)
|
||||
self.target_q_func_vars = _scope_vars(scope.name)
|
||||
|
||||
# q scores for actions which we know were selected in the given state.
|
||||
@@ -106,7 +104,7 @@ class DQNPolicyGraph(TFPolicyGraph):
|
||||
if config["double_q"]:
|
||||
with tf.variable_scope(Q_SCOPE, reuse=True):
|
||||
q_tp1_using_online_net = _build_q_network(
|
||||
registry, self.obs_tp1, num_actions, config)
|
||||
self.obs_tp1, num_actions, config)
|
||||
q_tp1_best_using_online_net = tf.argmax(q_tp1_using_online_net, 1)
|
||||
q_tp1_best = tf.reduce_sum(
|
||||
q_tp1 * tf.one_hot(
|
||||
@@ -236,10 +234,10 @@ def _postprocess_dqn(policy_graph, sample_batch):
|
||||
return batch
|
||||
|
||||
|
||||
def _build_q_network(registry, inputs, num_actions, config):
|
||||
def _build_q_network(inputs, num_actions, config):
|
||||
dueling = config["dueling"]
|
||||
hiddens = config["hiddens"]
|
||||
frontend = ModelCatalog.get_model(registry, inputs, 1, config["model"])
|
||||
frontend = ModelCatalog.get_model(inputs, 1, config["model"])
|
||||
frontend_out = frontend.last_layer
|
||||
|
||||
with tf.variable_scope("action_value"):
|
||||
|
||||
@@ -64,7 +64,7 @@ class SharedNoiseTable(object):
|
||||
|
||||
@ray.remote
|
||||
class Worker(object):
|
||||
def __init__(self, registry, config, policy_params, env_creator, noise,
|
||||
def __init__(self, config, policy_params, env_creator, noise,
|
||||
min_task_runtime=0.2):
|
||||
self.min_task_runtime = min_task_runtime
|
||||
self.config = config
|
||||
@@ -73,12 +73,11 @@ class Worker(object):
|
||||
|
||||
self.env = env_creator(config["env_config"])
|
||||
from ray.rllib import models
|
||||
self.preprocessor = models.ModelCatalog.get_preprocessor(
|
||||
registry, self.env)
|
||||
self.preprocessor = models.ModelCatalog.get_preprocessor(self.env)
|
||||
|
||||
self.sess = utils.make_session(single_threaded=True)
|
||||
self.policy = policies.GenericPolicy(
|
||||
registry, self.sess, self.env.action_space, self.preprocessor,
|
||||
self.sess, self.env.action_space, self.preprocessor,
|
||||
config["observation_filter"], **policy_params)
|
||||
|
||||
def rollout(self, timestep_limit, add_noise=True):
|
||||
@@ -152,12 +151,11 @@ class ESAgent(agent.Agent):
|
||||
|
||||
env = self.env_creator(self.config["env_config"])
|
||||
from ray.rllib import models
|
||||
preprocessor = models.ModelCatalog.get_preprocessor(
|
||||
self.registry, env)
|
||||
preprocessor = models.ModelCatalog.get_preprocessor(env)
|
||||
|
||||
self.sess = utils.make_session(single_threaded=False)
|
||||
self.policy = policies.GenericPolicy(
|
||||
self.registry, self.sess, env.action_space, preprocessor,
|
||||
self.sess, env.action_space, preprocessor,
|
||||
self.config["observation_filter"], **policy_params)
|
||||
self.optimizer = optimizers.Adam(self.policy, self.config["stepsize"])
|
||||
|
||||
@@ -170,8 +168,7 @@ class ESAgent(agent.Agent):
|
||||
print("Creating actors.")
|
||||
self.workers = [
|
||||
Worker.remote(
|
||||
self.registry, self.config, policy_params, self.env_creator,
|
||||
noise_id)
|
||||
self.config, policy_params, self.env_creator, noise_id)
|
||||
for _ in range(self.config["num_workers"])]
|
||||
|
||||
self.episodes_so_far = 0
|
||||
|
||||
@@ -38,7 +38,7 @@ def rollout(policy, env, timestep_limit=None, add_noise=False):
|
||||
|
||||
|
||||
class GenericPolicy(object):
|
||||
def __init__(self, registry, sess, action_space, preprocessor,
|
||||
def __init__(self, sess, action_space, preprocessor,
|
||||
observation_filter, action_noise_std):
|
||||
self.sess = sess
|
||||
self.action_space = action_space
|
||||
@@ -52,7 +52,7 @@ class GenericPolicy(object):
|
||||
# Policy network.
|
||||
dist_class, dist_dim = ModelCatalog.get_action_dist(
|
||||
self.action_space, dist_type="deterministic")
|
||||
model = ModelCatalog.get_model(registry, self.inputs, dist_dim)
|
||||
model = ModelCatalog.get_model(self.inputs, dist_dim)
|
||||
dist = dist_class(model.outputs)
|
||||
self.sampler = dist.sample()
|
||||
|
||||
|
||||
@@ -8,7 +8,7 @@ from gym.envs.registration import register
|
||||
|
||||
import ray
|
||||
import ray.rllib.ppo as ppo
|
||||
from ray.tune.registry import get_registry, register_env
|
||||
from ray.tune.registry import register_env
|
||||
|
||||
env_name = "MultiAgentMountainCarEnv"
|
||||
|
||||
@@ -51,6 +51,6 @@ if __name__ == '__main__':
|
||||
"multiagent_shared_model": False,
|
||||
"multiagent_fcnet_hiddens": [[32, 32]] * 2}
|
||||
config["model"].update({"custom_options": options})
|
||||
alg = ppo.PPOAgent(env=env_name, registry=get_registry(), config=config)
|
||||
alg = ppo.PPOAgent(env=env_name, config=config)
|
||||
for i in range(1):
|
||||
alg.train()
|
||||
|
||||
@@ -8,7 +8,7 @@ from gym.envs.registration import register
|
||||
|
||||
import ray
|
||||
import ray.rllib.ppo as ppo
|
||||
from ray.tune.registry import get_registry, register_env
|
||||
from ray.tune.registry import register_env
|
||||
|
||||
env_name = "MultiAgentPendulumEnv"
|
||||
|
||||
@@ -51,6 +51,6 @@ if __name__ == '__main__':
|
||||
"multiagent_shared_model": True,
|
||||
"multiagent_fcnet_hiddens": [[32, 32]] * 2}
|
||||
config["model"].update({"custom_options": options})
|
||||
alg = ppo.PPOAgent(env=env_name, registry=get_registry(), config=config)
|
||||
alg = ppo.PPOAgent(env=env_name, config=config)
|
||||
for i in range(1):
|
||||
alg.train()
|
||||
|
||||
@@ -8,7 +8,7 @@ import tensorflow as tf
|
||||
from functools import partial
|
||||
|
||||
from ray.tune.registry import RLLIB_MODEL, RLLIB_PREPROCESSOR, \
|
||||
_default_registry
|
||||
_global_registry
|
||||
|
||||
from ray.rllib.models.action_dist import (
|
||||
Categorical, Deterministic, DiagGaussian, MultiActionDistribution,
|
||||
@@ -47,7 +47,7 @@ class ModelCatalog(object):
|
||||
>>> observation = prep.transform(raw_observation)
|
||||
|
||||
>>> dist_cls, dist_dim = ModelCatalog.get_action_dist(env.action_space)
|
||||
>>> model = ModelCatalog.get_model(registry, inputs, dist_dim)
|
||||
>>> model = ModelCatalog.get_model(inputs, dist_dim)
|
||||
>>> dist = dist_cls(model.outputs)
|
||||
>>> action = dist.sample()
|
||||
"""
|
||||
@@ -130,11 +130,10 @@ class ModelCatalog(object):
|
||||
" not supported".format(action_space))
|
||||
|
||||
@staticmethod
|
||||
def get_model(registry, inputs, num_outputs, options={}):
|
||||
def get_model(inputs, num_outputs, options={}):
|
||||
"""Returns a suitable model conforming to given input and output specs.
|
||||
|
||||
Args:
|
||||
registry (obj): Registry of named objects (ray.tune.registry).
|
||||
inputs (Tensor): The input tensor to the model.
|
||||
num_outputs (int): The size of the output vector of the model.
|
||||
options (dict): Optional args to pass to the model constructor.
|
||||
@@ -146,7 +145,7 @@ class ModelCatalog(object):
|
||||
if "custom_model" in options:
|
||||
model = options["custom_model"]
|
||||
print("Using custom model {}".format(model))
|
||||
return registry.get(RLLIB_MODEL, model)(
|
||||
return _global_registry.get(RLLIB_MODEL, model)(
|
||||
inputs, num_outputs, options)
|
||||
|
||||
obs_rank = len(inputs.shape) - 1
|
||||
@@ -163,12 +162,11 @@ class ModelCatalog(object):
|
||||
return FullyConnectedNetwork(inputs, num_outputs, options)
|
||||
|
||||
@staticmethod
|
||||
def get_torch_model(registry, input_shape, num_outputs, options={}):
|
||||
def get_torch_model(input_shape, num_outputs, options={}):
|
||||
"""Returns a PyTorch suitable model. This is currently only supported
|
||||
in A3C.
|
||||
|
||||
Args:
|
||||
registry (obj): Registry of named objects (ray.tune.registry).
|
||||
input_shape (tuple): The input shape to the model.
|
||||
num_outputs (int): The size of the output vector of the model.
|
||||
options (dict): Optional args to pass to the model constructor.
|
||||
@@ -184,7 +182,7 @@ class ModelCatalog(object):
|
||||
if "custom_model" in options:
|
||||
model = options["custom_model"]
|
||||
print("Using custom torch model {}".format(model))
|
||||
return registry.get(RLLIB_MODEL, model)(
|
||||
return _global_registry.get(RLLIB_MODEL, model)(
|
||||
input_shape, num_outputs, options)
|
||||
|
||||
# TODO(alok): fix to handle Discrete(n) state spaces
|
||||
@@ -198,11 +196,10 @@ class ModelCatalog(object):
|
||||
return PyTorchFCNet(input_shape[0], num_outputs, options)
|
||||
|
||||
@staticmethod
|
||||
def get_preprocessor(registry, env, options={}):
|
||||
def get_preprocessor(env, options={}):
|
||||
"""Returns a suitable processor for the given environment.
|
||||
|
||||
Args:
|
||||
registry (obj): Registry of named objects (ray.tune.registry).
|
||||
env (gym.Env): The gym environment to preprocess.
|
||||
options (dict): Options to pass to the preprocessor.
|
||||
|
||||
@@ -218,18 +215,17 @@ class ModelCatalog(object):
|
||||
if "custom_preprocessor" in options:
|
||||
preprocessor = options["custom_preprocessor"]
|
||||
print("Using custom preprocessor {}".format(preprocessor))
|
||||
return registry.get(RLLIB_PREPROCESSOR, preprocessor)(
|
||||
return _global_registry.get(RLLIB_PREPROCESSOR, preprocessor)(
|
||||
env.observation_space, options)
|
||||
|
||||
preprocessor = get_preprocessor(env.observation_space)
|
||||
return preprocessor(env.observation_space, options)
|
||||
|
||||
@staticmethod
|
||||
def get_preprocessor_as_wrapper(registry, env, options={}):
|
||||
def get_preprocessor_as_wrapper(env, options={}):
|
||||
"""Returns a preprocessor as a gym observation wrapper.
|
||||
|
||||
Args:
|
||||
registry (obj): Registry of named objects (ray.tune.registry).
|
||||
env (gym.Env): The gym environment to wrap.
|
||||
options (dict): Options to pass to the preprocessor.
|
||||
|
||||
@@ -237,7 +233,7 @@ class ModelCatalog(object):
|
||||
wrapper (gym.ObservationWrapper): Preprocessor in wrapper form.
|
||||
"""
|
||||
|
||||
preprocessor = ModelCatalog.get_preprocessor(registry, env, options)
|
||||
preprocessor = ModelCatalog.get_preprocessor(env, options)
|
||||
return _RLlibPreprocessorWrapper(env, preprocessor)
|
||||
|
||||
@staticmethod
|
||||
@@ -251,7 +247,7 @@ class ModelCatalog(object):
|
||||
preprocessor_name (str): Name to register the preprocessor under.
|
||||
preprocessor_class (type): Python class of the preprocessor.
|
||||
"""
|
||||
_default_registry.register(
|
||||
_global_registry.register(
|
||||
RLLIB_PREPROCESSOR, preprocessor_name, preprocessor_class)
|
||||
|
||||
@staticmethod
|
||||
@@ -265,7 +261,7 @@ class ModelCatalog(object):
|
||||
model_name (str): Name to register the model under.
|
||||
model_class (type): Python class of the model.
|
||||
"""
|
||||
_default_registry.register(RLLIB_MODEL, model_name, model_class)
|
||||
_global_registry.register(RLLIB_MODEL, model_name, model_class)
|
||||
|
||||
|
||||
class _RLlibPreprocessorWrapper(gym.ObservationWrapper):
|
||||
|
||||
@@ -55,7 +55,6 @@ class PGAgent(Agent):
|
||||
"policy_graph": PGPolicyGraph,
|
||||
"batch_steps": self.config["batch_size"],
|
||||
"batch_mode": "truncate_episodes",
|
||||
"registry": self.registry,
|
||||
"model_config": self.config["model"],
|
||||
"env_config": self.config["env_config"],
|
||||
"policy_config": self.config,
|
||||
|
||||
@@ -11,7 +11,7 @@ from ray.rllib.utils.tf_policy_graph import TFPolicyGraph
|
||||
|
||||
class PGPolicyGraph(TFPolicyGraph):
|
||||
|
||||
def __init__(self, obs_space, action_space, registry, config):
|
||||
def __init__(self, obs_space, action_space, config):
|
||||
self.config = config
|
||||
|
||||
# setup policy
|
||||
@@ -19,7 +19,7 @@ class PGPolicyGraph(TFPolicyGraph):
|
||||
dist_class, self.logit_dim = ModelCatalog.get_action_dist(
|
||||
action_space, self.config["model"])
|
||||
self.model = ModelCatalog.get_model(
|
||||
registry, self.x, self.logit_dim, options=self.config["model"])
|
||||
self.x, self.logit_dim, options=self.config["model"])
|
||||
self.dist = dist_class(self.model.outputs) # logit for each action
|
||||
|
||||
# setup policy loss
|
||||
|
||||
@@ -16,14 +16,14 @@ class ProximalPolicyGraph(object):
|
||||
self, observation_space, action_space,
|
||||
observations, value_targets, advantages, actions,
|
||||
prev_logits, prev_vf_preds, logit_dim,
|
||||
kl_coeff, distribution_class, config, sess, registry):
|
||||
kl_coeff, distribution_class, config, sess):
|
||||
self.prev_dist = distribution_class(prev_logits)
|
||||
|
||||
# Saved so that we can compute actions given different observations
|
||||
self.observations = observations
|
||||
|
||||
self.curr_logits = ModelCatalog.get_model(
|
||||
registry, observations, logit_dim, config["model"]).outputs
|
||||
observations, logit_dim, config["model"]).outputs
|
||||
self.curr_dist = distribution_class(self.curr_logits)
|
||||
self.sampler = self.curr_dist.sample()
|
||||
|
||||
@@ -35,7 +35,7 @@ class ProximalPolicyGraph(object):
|
||||
vf_config["free_log_std"] = False
|
||||
with tf.variable_scope("value_function"):
|
||||
self.value_function = ModelCatalog.get_model(
|
||||
registry, observations, 1, vf_config).outputs
|
||||
observations, 1, vf_config).outputs
|
||||
self.value_function = tf.reshape(self.value_function, [-1])
|
||||
|
||||
# Make loss functions.
|
||||
|
||||
@@ -103,14 +103,13 @@ class PPOAgent(Agent):
|
||||
def _init(self):
|
||||
self.global_step = 0
|
||||
self.local_evaluator = PPOEvaluator(
|
||||
self.registry, self.env_creator, self.config, self.logdir, False)
|
||||
self.env_creator, self.config, self.logdir, False)
|
||||
RemotePPOEvaluator = ray.remote(
|
||||
num_cpus=self.config["num_cpus_per_worker"],
|
||||
num_gpus=self.config["num_gpus_per_worker"])(PPOEvaluator)
|
||||
self.remote_evaluators = [
|
||||
RemotePPOEvaluator.remote(
|
||||
self.registry, self.env_creator, self.config, self.logdir,
|
||||
True)
|
||||
self.env_creator, self.config, self.logdir, True)
|
||||
for _ in range(self.config["num_workers"])]
|
||||
|
||||
self.optimizer = LocalMultiGPUOptimizer(
|
||||
|
||||
@@ -24,12 +24,11 @@ class PPOEvaluator(TFMultiGPUSupport):
|
||||
network weights. When run as a remote agent, only this graph is used.
|
||||
"""
|
||||
|
||||
def __init__(self, registry, env_creator, config, logdir, is_remote):
|
||||
self.registry = registry
|
||||
def __init__(self, env_creator, config, logdir, is_remote):
|
||||
self.config = config
|
||||
self.logdir = logdir
|
||||
self.env = ModelCatalog.get_preprocessor_as_wrapper(
|
||||
registry, env_creator(config["env_config"]), config["model"])
|
||||
env_creator(config["env_config"]), config["model"])
|
||||
if is_remote:
|
||||
config_proto = tf.ConfigProto()
|
||||
else:
|
||||
@@ -92,7 +91,7 @@ class PPOEvaluator(TFMultiGPUSupport):
|
||||
self.env.observation_space, self.env.action_space,
|
||||
obs, vtargets, advs, acts, plog, pvf_preds, self.logit_dim,
|
||||
self.kl_coeff, self.distribution_class, self.config,
|
||||
self.sess, self.registry)
|
||||
self.sess)
|
||||
|
||||
def init_extra_ops(self, device_losses):
|
||||
self.extra_ops = OrderedDict()
|
||||
|
||||
@@ -14,7 +14,6 @@ import ray
|
||||
from ray.rllib.agent import get_agent_class
|
||||
from ray.rllib.dqn.common.wrappers import wrap_dqn
|
||||
from ray.rllib.models import ModelCatalog
|
||||
from ray.tune.registry import get_registry
|
||||
|
||||
EXAMPLE_USAGE = """
|
||||
example usage:
|
||||
@@ -74,10 +73,9 @@ if __name__ == "__main__":
|
||||
|
||||
if args.run == "DQN":
|
||||
env = gym.make(args.env)
|
||||
env = wrap_dqn(get_registry(), env, args.config.get("model", {}))
|
||||
env = wrap_dqn(env, args.config.get("model", {}))
|
||||
else:
|
||||
env = ModelCatalog.get_preprocessor_as_wrapper(get_registry(),
|
||||
gym.make(args.env))
|
||||
env = ModelCatalog.get_preprocessor_as_wrapper(gym.make(args.env))
|
||||
if args.out is not None:
|
||||
rollouts = []
|
||||
steps = 0
|
||||
|
||||
@@ -5,7 +5,6 @@ import unittest
|
||||
from gym.spaces import Box, Discrete, Tuple
|
||||
|
||||
import ray
|
||||
from ray.tune.registry import get_registry
|
||||
|
||||
from ray.rllib.models import ModelCatalog
|
||||
from ray.rllib.models.model import Model
|
||||
@@ -33,12 +32,10 @@ class ModelCatalogTest(unittest.TestCase):
|
||||
ray.worker.cleanup()
|
||||
|
||||
def testGymPreprocessors(self):
|
||||
p1 = ModelCatalog.get_preprocessor(
|
||||
get_registry(), gym.make("CartPole-v0"))
|
||||
p1 = ModelCatalog.get_preprocessor(gym.make("CartPole-v0"))
|
||||
self.assertEqual(type(p1), NoPreprocessor)
|
||||
|
||||
p2 = ModelCatalog.get_preprocessor(
|
||||
get_registry(), gym.make("FrozenLake-v0"))
|
||||
p2 = ModelCatalog.get_preprocessor(gym.make("FrozenLake-v0"))
|
||||
self.assertEqual(type(p2), OneHotPreprocessor)
|
||||
|
||||
def testTuplePreprocessor(self):
|
||||
@@ -48,8 +45,7 @@ class ModelCatalogTest(unittest.TestCase):
|
||||
def __init__(self):
|
||||
self.observation_space = Tuple(
|
||||
[Discrete(5), Box(0, 1, shape=(3,), dtype=np.float32)])
|
||||
p1 = ModelCatalog.get_preprocessor(
|
||||
get_registry(), TupleEnv())
|
||||
p1 = ModelCatalog.get_preprocessor(TupleEnv())
|
||||
self.assertEqual(p1.shape, (8,))
|
||||
self.assertEqual(
|
||||
list(p1.transform((0, [1, 2, 3]))),
|
||||
@@ -60,33 +56,29 @@ class ModelCatalogTest(unittest.TestCase):
|
||||
ModelCatalog.register_custom_preprocessor("foo", CustomPreprocessor)
|
||||
ModelCatalog.register_custom_preprocessor("bar", CustomPreprocessor2)
|
||||
env = gym.make("CartPole-v0")
|
||||
p1 = ModelCatalog.get_preprocessor(
|
||||
get_registry(), env, {"custom_preprocessor": "foo"})
|
||||
p1 = ModelCatalog.get_preprocessor(env, {"custom_preprocessor": "foo"})
|
||||
self.assertEqual(str(type(p1)), str(CustomPreprocessor))
|
||||
p2 = ModelCatalog.get_preprocessor(
|
||||
get_registry(), env, {"custom_preprocessor": "bar"})
|
||||
p2 = ModelCatalog.get_preprocessor(env, {"custom_preprocessor": "bar"})
|
||||
self.assertEqual(str(type(p2)), str(CustomPreprocessor2))
|
||||
p3 = ModelCatalog.get_preprocessor(get_registry(), env)
|
||||
p3 = ModelCatalog.get_preprocessor(env)
|
||||
self.assertEqual(type(p3), NoPreprocessor)
|
||||
|
||||
def testDefaultModels(self):
|
||||
ray.init()
|
||||
|
||||
with tf.variable_scope("test1"):
|
||||
p1 = ModelCatalog.get_model(
|
||||
get_registry(), np.zeros((10, 3), dtype=np.float32), 5)
|
||||
p1 = ModelCatalog.get_model(np.zeros((10, 3), dtype=np.float32), 5)
|
||||
self.assertEqual(type(p1), FullyConnectedNetwork)
|
||||
|
||||
with tf.variable_scope("test2"):
|
||||
p2 = ModelCatalog.get_model(
|
||||
get_registry(), np.zeros((10, 80, 80, 3), dtype=np.float32), 5)
|
||||
np.zeros((10, 80, 80, 3), dtype=np.float32), 5)
|
||||
self.assertEqual(type(p2), VisionNetwork)
|
||||
|
||||
def testCustomModel(self):
|
||||
ray.init()
|
||||
ModelCatalog.register_custom_model("foo", CustomModel)
|
||||
p1 = ModelCatalog.get_model(
|
||||
get_registry(), 1, 5, {"custom_model": "foo"})
|
||||
p1 = ModelCatalog.get_model(1, 5, {"custom_model": "foo"})
|
||||
self.assertEqual(str(type(p1)), str(CustomModel))
|
||||
|
||||
|
||||
|
||||
@@ -59,8 +59,7 @@ def check_support(alg, config, stats):
|
||||
for o_name, obs_space in OBSERVATION_SPACES_TO_TEST.items():
|
||||
print("=== Testing", alg, action_space, obs_space, "===")
|
||||
stub_env = make_stub_env(action_space, obs_space)
|
||||
register_env(
|
||||
"stub_env", lambda c: stub_env())
|
||||
register_env("stub_env", lambda c: stub_env())
|
||||
stat = "ok"
|
||||
a = None
|
||||
try:
|
||||
|
||||
@@ -18,7 +18,6 @@ from ray.rllib.utils.sampler import AsyncSampler, SyncSampler
|
||||
from ray.rllib.utils.serving_env import ServingEnv, _ServingEnvToAsync
|
||||
from ray.rllib.utils.tf_policy_graph import TFPolicyGraph
|
||||
from ray.rllib.utils.vector_env import VectorEnv
|
||||
from ray.tune.registry import get_registry
|
||||
from ray.tune.result import TrainingResult
|
||||
|
||||
|
||||
@@ -97,7 +96,6 @@ class CommonPolicyEvaluator(PolicyEvaluator):
|
||||
compress_observations=False,
|
||||
num_envs=1,
|
||||
observation_filter="NoFilter",
|
||||
registry=None,
|
||||
env_config=None,
|
||||
model_config=None,
|
||||
policy_config=None):
|
||||
@@ -137,15 +135,11 @@ class CommonPolicyEvaluator(PolicyEvaluator):
|
||||
and vectorize the computation of actions. This has no effect if
|
||||
if the env already implements VectorEnv.
|
||||
observation_filter (str): Name of observation filter to use.
|
||||
registry (tune.Registry): User-registered objects. Pass in the
|
||||
value from tune.registry.get_registry() if you're having
|
||||
trouble resolving things like custom envs.
|
||||
env_config (dict): Config to pass to the env creator.
|
||||
model_config (dict): Config to use when creating the policy model.
|
||||
policy_config (dict): Config to pass to the policy.
|
||||
"""
|
||||
|
||||
registry = registry or get_registry()
|
||||
env_config = env_config or {}
|
||||
policy_config = policy_config or {}
|
||||
model_config = model_config or {}
|
||||
@@ -169,7 +163,7 @@ class CommonPolicyEvaluator(PolicyEvaluator):
|
||||
else:
|
||||
def wrap(env):
|
||||
return ModelCatalog.get_preprocessor_as_wrapper(
|
||||
registry, env, model_config)
|
||||
env, model_config)
|
||||
self.env = wrap(self.env)
|
||||
|
||||
def make_env():
|
||||
@@ -187,11 +181,11 @@ class CommonPolicyEvaluator(PolicyEvaluator):
|
||||
with self.sess.as_default():
|
||||
policy = policy_graph(
|
||||
self.env.observation_space, self.env.action_space,
|
||||
registry, policy_config)
|
||||
policy_config)
|
||||
else:
|
||||
policy = policy_graph(
|
||||
self.env.observation_space, self.env.action_space,
|
||||
registry, policy_config)
|
||||
policy_config)
|
||||
self.policy_map = {
|
||||
"default": policy
|
||||
}
|
||||
|
||||
@@ -17,11 +17,10 @@ class PolicyGraph(object):
|
||||
graphs and multi-GPU support.
|
||||
"""
|
||||
|
||||
def __init__(self, registry, observation_space, action_space, config):
|
||||
def __init__(self, observation_space, action_space, config):
|
||||
"""Initialize the graph.
|
||||
|
||||
Args:
|
||||
registry (obj): Object registry for user-defined envs, models, etc.
|
||||
observation_space (gym.Space): Observation space of the env.
|
||||
action_space (gym.Space): Action space of the env.
|
||||
config (dict): Policy-specific configuration data.
|
||||
|
||||
+38
-41
@@ -4,10 +4,10 @@ from __future__ import print_function
|
||||
|
||||
from types import FunctionType
|
||||
|
||||
import numpy as np
|
||||
|
||||
import ray
|
||||
from ray.local_scheduler import ObjectID
|
||||
import ray.cloudpickle as pickle
|
||||
from ray.experimental.internal_kv import _internal_kv_initialized, \
|
||||
_internal_kv_get, _internal_kv_put
|
||||
|
||||
TRAINABLE_CLASS = "trainable_class"
|
||||
ENV_CREATOR = "env_creator"
|
||||
@@ -35,7 +35,7 @@ def register_trainable(name, trainable):
|
||||
if not issubclass(trainable, Trainable):
|
||||
raise TypeError("Second argument must be convertable to Trainable",
|
||||
trainable)
|
||||
_default_registry.register(TRAINABLE_CLASS, name, trainable)
|
||||
_global_registry.register(TRAINABLE_CLASS, name, trainable)
|
||||
|
||||
|
||||
def register_env(name, env_creator):
|
||||
@@ -48,62 +48,59 @@ def register_env(name, env_creator):
|
||||
|
||||
if not isinstance(env_creator, FunctionType):
|
||||
raise TypeError("Second argument must be a function.", env_creator)
|
||||
_default_registry.register(ENV_CREATOR, name, env_creator)
|
||||
_global_registry.register(ENV_CREATOR, name, env_creator)
|
||||
|
||||
|
||||
def get_registry():
|
||||
"""Use this to access the registry. This requires ray to be initialized."""
|
||||
def _make_key(category, key):
|
||||
"""Generate a binary key for the given category and key.
|
||||
|
||||
_default_registry.flush_values_to_object_store()
|
||||
Args:
|
||||
category (str): The category of the item
|
||||
key (str): The unique identifier for the item
|
||||
|
||||
# returns a registry copy that doesn't include the hard refs
|
||||
return _Registry(_default_registry._all_objects)
|
||||
|
||||
|
||||
def _to_pinnable(obj):
|
||||
"""Converts obj to a form that can be pinned in object store memory.
|
||||
|
||||
Currently only numpy arrays are pinned in memory, if you have a strong
|
||||
reference to the array value.
|
||||
Returns:
|
||||
The key to use for storing a the value.
|
||||
"""
|
||||
|
||||
return (obj, np.zeros(1))
|
||||
|
||||
|
||||
def _from_pinnable(obj):
|
||||
"""Retrieve from _to_pinnable format."""
|
||||
|
||||
return obj[0]
|
||||
return (b"TuneRegistry:" + category.encode("ascii") + b"/" +
|
||||
key.encode("ascii"))
|
||||
|
||||
|
||||
class _Registry(object):
|
||||
def __init__(self, objs=None):
|
||||
self._all_objects = {} if objs is None else objs.copy()
|
||||
self._refs = [] # hard refs that prevent eviction of objects
|
||||
def __init__(self):
|
||||
self._to_flush = {}
|
||||
|
||||
def register(self, category, key, value):
|
||||
if category not in KNOWN_CATEGORIES:
|
||||
from ray.tune import TuneError
|
||||
raise TuneError("Unknown category {} not among {}".format(
|
||||
category, KNOWN_CATEGORIES))
|
||||
self._all_objects[(category, key)] = value
|
||||
self._to_flush[(category, key)] = pickle.dumps(value)
|
||||
if _internal_kv_initialized():
|
||||
self.flush_values()
|
||||
|
||||
def contains(self, category, key):
|
||||
return (category, key) in self._all_objects
|
||||
if _internal_kv_initialized():
|
||||
value = _internal_kv_get(_make_key(category, key))
|
||||
return value is not None
|
||||
else:
|
||||
return (category, key) in self._to_flush
|
||||
|
||||
def get(self, category, key):
|
||||
value = self._all_objects[(category, key)]
|
||||
if type(value) == ObjectID:
|
||||
return _from_pinnable(ray.get(value))
|
||||
if _internal_kv_initialized():
|
||||
value = _internal_kv_get(_make_key(category, key))
|
||||
if value is None:
|
||||
raise ValueError(
|
||||
"Registry value for {}/{} doesn't exist.".format(
|
||||
category, key))
|
||||
return pickle.loads(value)
|
||||
else:
|
||||
return value
|
||||
return pickle.loads(self._to_flush[(category, key)])
|
||||
|
||||
def flush_values_to_object_store(self):
|
||||
for k, v in self._all_objects.items():
|
||||
if type(v) != ObjectID:
|
||||
obj = ray.put(_to_pinnable(v))
|
||||
self._all_objects[k] = obj
|
||||
self._refs.append(ray.get(obj))
|
||||
def flush_values(self):
|
||||
for (category, key), value in self._to_flush.items():
|
||||
_internal_kv_put(_make_key(category, key), value)
|
||||
self._to_flush.clear()
|
||||
|
||||
|
||||
_default_registry = _Registry()
|
||||
_global_registry = _Registry()
|
||||
ray.worker._post_init_hooks.append(_global_registry.flush_values)
|
||||
|
||||
@@ -11,7 +11,7 @@ from ray.rllib import _register_all
|
||||
|
||||
from ray.tune import Trainable, TuneError
|
||||
from ray.tune import register_env, register_trainable, run_experiments
|
||||
from ray.tune.registry import _default_registry, TRAINABLE_CLASS
|
||||
from ray.tune.registry import _global_registry, TRAINABLE_CLASS
|
||||
from ray.tune.result import DEFAULT_RESULTS_DIR, TrainingResult
|
||||
from ray.tune.util import pin_in_object_store, get_pinned_object
|
||||
from ray.tune.experiment import Experiment
|
||||
@@ -595,7 +595,7 @@ class TrialRunnerTest(unittest.TestCase):
|
||||
|
||||
def testTrialErrorOnStart(self):
|
||||
ray.init()
|
||||
_default_registry.register(TRAINABLE_CLASS, "asdf", None)
|
||||
_global_registry.register(TRAINABLE_CLASS, "asdf", None)
|
||||
trial = Trial("asdf", resources=Resources(1, 0))
|
||||
try:
|
||||
trial.start()
|
||||
@@ -690,7 +690,7 @@ class TrialRunnerTest(unittest.TestCase):
|
||||
},
|
||||
"resources": Resources(cpu=1, gpu=1),
|
||||
}
|
||||
_default_registry.register(TRAINABLE_CLASS, "asdf", None)
|
||||
_global_registry.register(TRAINABLE_CLASS, "asdf", None)
|
||||
trials = [Trial("asdf", **kwargs), Trial("__fake", **kwargs)]
|
||||
for t in trials:
|
||||
runner.add_trial(t)
|
||||
|
||||
@@ -46,11 +46,9 @@ class Trainable(object):
|
||||
Attributes:
|
||||
config (obj): The hyperparam configuration for this trial.
|
||||
logdir (str): Directory in which training outputs should be placed.
|
||||
registry (obj): Tune object registry which holds user-registered
|
||||
classes and objects by name.
|
||||
"""
|
||||
|
||||
def __init__(self, config=None, registry=None, logger_creator=None):
|
||||
def __init__(self, config=None, logger_creator=None):
|
||||
"""Initialize an Trainable.
|
||||
|
||||
Subclasses should prefer defining ``_setup()`` instead of overriding
|
||||
@@ -58,20 +56,13 @@ class Trainable(object):
|
||||
|
||||
Args:
|
||||
config (dict): Trainable-specific configuration data.
|
||||
registry (obj): Object registry for user-defined envs, models, etc.
|
||||
If unspecified, the default registry will be used.
|
||||
logger_creator (func): Function that creates a ray.tune.Logger
|
||||
object. If unspecified, a default logger is created.
|
||||
"""
|
||||
|
||||
if registry is None:
|
||||
from ray.tune.registry import get_registry
|
||||
registry = get_registry()
|
||||
|
||||
self._initialize_ok = False
|
||||
self._experiment_id = uuid.uuid4().hex
|
||||
self.config = config or {}
|
||||
self.registry = registry
|
||||
|
||||
if logger_creator:
|
||||
self._result_logger = logger_creator(self.config)
|
||||
|
||||
@@ -57,7 +57,7 @@ class Resources(
|
||||
|
||||
|
||||
def has_trainable(trainable_name):
|
||||
return ray.tune.registry._default_registry.contains(
|
||||
return ray.tune.registry._global_registry.contains(
|
||||
ray.tune.registry.TRAINABLE_CLASS, trainable_name)
|
||||
|
||||
|
||||
@@ -377,12 +377,10 @@ class Trial(object):
|
||||
# Logging for trials is handled centrally by TrialRunner, so
|
||||
# configure the remote runner to use a noop-logger.
|
||||
self.runner = cls.remote(
|
||||
config=self.config,
|
||||
registry=ray.tune.registry.get_registry(),
|
||||
logger_creator=logger_creator)
|
||||
config=self.config, logger_creator=logger_creator)
|
||||
|
||||
def _get_trainable_cls(self):
|
||||
return ray.tune.registry.get_registry().get(
|
||||
return ray.tune.registry._global_registry.get(
|
||||
ray.tune.registry.TRAINABLE_CLASS, self.trainable_name)
|
||||
|
||||
def set_verbose(self, verbose):
|
||||
|
||||
+18
-2
@@ -2,12 +2,12 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import base64
|
||||
from six.moves import queue
|
||||
import base64
|
||||
import numpy as np
|
||||
import threading
|
||||
|
||||
import ray
|
||||
from ray.tune.registry import _to_pinnable, _from_pinnable
|
||||
|
||||
_pinned_objects = []
|
||||
_fetch_requests = queue.Queue()
|
||||
@@ -63,6 +63,22 @@ def _serve_get_pin_requests():
|
||||
pass
|
||||
|
||||
|
||||
def _to_pinnable(obj):
|
||||
"""Converts obj to a form that can be pinned in object store memory.
|
||||
|
||||
Currently only numpy arrays are pinned in memory, if you have a strong
|
||||
reference to the array value.
|
||||
"""
|
||||
|
||||
return (obj, np.zeros(1))
|
||||
|
||||
|
||||
def _from_pinnable(obj):
|
||||
"""Retrieve from _to_pinnable format."""
|
||||
|
||||
return obj[0]
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
ray.init()
|
||||
X = pin_in_object_store("hello")
|
||||
|
||||
@@ -1741,7 +1741,7 @@ def init(redis_address=None,
|
||||
redis_address = services.address_to_ip(redis_address)
|
||||
|
||||
info = {"node_ip_address": node_ip_address, "redis_address": redis_address}
|
||||
return _init(
|
||||
ret = _init(
|
||||
address_info=info,
|
||||
start_ray_local=(redis_address is None),
|
||||
num_workers=num_workers,
|
||||
@@ -1758,6 +1758,13 @@ def init(redis_address=None,
|
||||
include_webui=include_webui,
|
||||
object_store_memory=object_store_memory,
|
||||
use_raylet=use_raylet)
|
||||
for hook in _post_init_hooks:
|
||||
hook()
|
||||
return ret
|
||||
|
||||
|
||||
# Functions to run as callback after a successful ray init
|
||||
_post_init_hooks = []
|
||||
|
||||
|
||||
def cleanup(worker=global_worker):
|
||||
|
||||
Reference in New Issue
Block a user