mirror of
https://github.com/wassname/ray.git
synced 2026-06-30 05:09:18 +08:00
[rllib] [tune] Custom preprocessors and models, various fixes (#1372)
This commit is contained in:
@@ -2,18 +2,17 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
# Note: do not introduce unnecessary library dependencies here, e.g. gym
|
||||
from ray.tune.registry import register_trainable
|
||||
from ray.rllib import ppo, es, dqn, a3c
|
||||
from ray.rllib.agent import _MockAgent, _SigmoidFakeData
|
||||
from ray.rllib.agent import get_agent_class
|
||||
|
||||
|
||||
def _register_all():
|
||||
register_trainable("PPO", ppo.PPOAgent)
|
||||
register_trainable("ES", es.ESAgent)
|
||||
register_trainable("DQN", dqn.DQNAgent)
|
||||
register_trainable("A3C", a3c.A3CAgent)
|
||||
register_trainable("__fake", _MockAgent)
|
||||
register_trainable("__sigmoid_fake_data", _SigmoidFakeData)
|
||||
for key in ["PPO", "ES", "DQN", "A3C", "__fake", "__sigmoid_fake_data"]:
|
||||
try:
|
||||
register_trainable(key, get_agent_class(key))
|
||||
except ImportError as e:
|
||||
print("Warning: could not import {}: {}".format(key, e))
|
||||
|
||||
|
||||
_register_all()
|
||||
|
||||
@@ -9,7 +9,7 @@ import os
|
||||
import ray
|
||||
from ray.rllib.agent import Agent
|
||||
from ray.rllib.optimizers import AsyncOptimizer
|
||||
from ray.rllib.a3c.base_evaluator import A3CEvaluator, RemoteA3CEvaluator
|
||||
from ray.rllib.a3c.a3c_evaluator import A3CEvaluator, RemoteA3CEvaluator
|
||||
from ray.tune.result import TrainingResult
|
||||
|
||||
|
||||
@@ -38,8 +38,8 @@ DEFAULT_CONFIG = {
|
||||
"vf_loss_coeff": 0.5,
|
||||
# Entropy coefficient
|
||||
"entropy_coeff": -0.01,
|
||||
# Preprocessing for environment
|
||||
"preprocessing": {
|
||||
# Model and preprocessor options
|
||||
"model": {
|
||||
# (Image statespace) - Converts image to Channels = 1
|
||||
"grayscale": True,
|
||||
# (Image statespace) - Each pixel
|
||||
@@ -49,8 +49,6 @@ DEFAULT_CONFIG = {
|
||||
# (Image statespace) - Converts image shape to (C, dim, dim)
|
||||
"channel_major": False
|
||||
},
|
||||
# Configuration for model specification
|
||||
"model": {},
|
||||
# Arguments to pass to the rllib optimizer
|
||||
"optimizer": {
|
||||
# Number of gradients applied for each `train` step
|
||||
@@ -66,10 +64,11 @@ class A3CAgent(Agent):
|
||||
|
||||
def _init(self):
|
||||
self.local_evaluator = A3CEvaluator(
|
||||
self.env_creator, self.config, self.logdir, start_sampler=False)
|
||||
self.registry, self.env_creator, self.config, self.logdir,
|
||||
start_sampler=False)
|
||||
self.remote_evaluators = [
|
||||
RemoteA3CEvaluator.remote(
|
||||
self.env_creator, self.config, self.logdir)
|
||||
self.registry, self.env_creator, self.config, self.logdir)
|
||||
for i in range(self.config["num_workers"])]
|
||||
self.optimizer = AsyncOptimizer(
|
||||
self.config["optimizer"], self.local_evaluator,
|
||||
|
||||
@@ -5,7 +5,7 @@ from __future__ import print_function
|
||||
import pickle
|
||||
|
||||
import ray
|
||||
from ray.rllib.envs import create_and_wrap
|
||||
from ray.rllib.models import ModelCatalog
|
||||
from ray.rllib.optimizers import Evaluator
|
||||
from ray.rllib.a3c.common import get_policy_cls
|
||||
from ray.rllib.utils.filter import get_filter
|
||||
@@ -25,12 +25,15 @@ class A3CEvaluator(Evaluator):
|
||||
rollouts.
|
||||
logdir: Directory for logging.
|
||||
"""
|
||||
def __init__(self, env_creator, config, logdir, start_sampler=True):
|
||||
self.env = env = create_and_wrap(env_creator, config["preprocessing"])
|
||||
def __init__(
|
||||
self, registry, env_creator, config, logdir, start_sampler=True):
|
||||
env = ModelCatalog.get_preprocessor_as_wrapper(
|
||||
registry, env_creator(), config["model"])
|
||||
self.env = env
|
||||
policy_cls = get_policy_cls(config)
|
||||
# TODO(rliaw): should change this to be just env.observation_space
|
||||
self.policy = policy_cls(
|
||||
env.observation_space.shape, env.action_space, config)
|
||||
registry, env.observation_space.shape, env.action_space, config)
|
||||
self.config = config
|
||||
|
||||
# Technically not needed when not remote
|
||||
@@ -13,14 +13,15 @@ class SharedModel(TFPolicy):
|
||||
other_output = ["vf_preds"]
|
||||
is_recurrent = False
|
||||
|
||||
def __init__(self, ob_space, ac_space, config, **kwargs):
|
||||
super(SharedModel, self).__init__(ob_space, ac_space, config, **kwargs)
|
||||
def __init__(self, registry, ob_space, ac_space, config, **kwargs):
|
||||
super(SharedModel, self).__init__(
|
||||
registry, ob_space, ac_space, config, **kwargs)
|
||||
|
||||
def _setup_graph(self, ob_space, ac_space):
|
||||
self.x = tf.placeholder(tf.float32, [None] + list(ob_space))
|
||||
dist_class, self.logit_dim = ModelCatalog.get_action_dist(ac_space)
|
||||
self._model = ModelCatalog.get_model(
|
||||
self.x, self.logit_dim, self.config["model"])
|
||||
self.registry, self.x, self.logit_dim, self.config["model"])
|
||||
self.logits = self._model.outputs
|
||||
self.curr_dist = dist_class(self.logits)
|
||||
# with tf.variable_scope("vf"):
|
||||
|
||||
@@ -21,9 +21,9 @@ class SharedModelLSTM(TFPolicy):
|
||||
other_output = ["vf_preds", "features"]
|
||||
is_recurrent = True
|
||||
|
||||
def __init__(self, ob_space, ac_space, config, **kwargs):
|
||||
def __init__(self, registry, ob_space, ac_space, config, **kwargs):
|
||||
super(SharedModelLSTM, self).__init__(
|
||||
ob_space, ac_space, config, **kwargs)
|
||||
registry, ob_space, ac_space, config, **kwargs)
|
||||
|
||||
def _setup_graph(self, ob_space, ac_space):
|
||||
self.x = tf.placeholder(tf.float32, [None] + list(ob_space))
|
||||
|
||||
@@ -24,7 +24,7 @@ class SharedTorchPolicy(TorchPolicy):
|
||||
def _setup_graph(self, ob_space, ac_space):
|
||||
_, self.logit_dim = ModelCatalog.get_action_dist(ac_space)
|
||||
self._model = ModelCatalog.get_torch_model(
|
||||
ob_space, self.logit_dim, self.config["model"])
|
||||
self.registry, ob_space, self.logit_dim, self.config["model"])
|
||||
self.optimizer = torch.optim.Adam(
|
||||
self._model.parameters(), lr=self.config["lr"])
|
||||
|
||||
|
||||
@@ -10,8 +10,9 @@ from ray.rllib.a3c.policy import Policy
|
||||
|
||||
class TFPolicy(Policy):
|
||||
"""The policy base class."""
|
||||
def __init__(self, ob_space, action_space, config,
|
||||
def __init__(self, registry, ob_space, action_space, config,
|
||||
name="local", summarize=True):
|
||||
self.registry = registry
|
||||
self.local_steps = 0
|
||||
self.config = config
|
||||
self.summarize = summarize
|
||||
|
||||
@@ -15,8 +15,9 @@ class TorchPolicy(Policy):
|
||||
The model is a separate object than the policy. This could be changed
|
||||
in the future."""
|
||||
|
||||
def __init__(self, ob_space, action_space, config,
|
||||
def __init__(self, registry, ob_space, action_space, config,
|
||||
name="local", summarize=True):
|
||||
self.registry = registry
|
||||
self.local_steps = 0
|
||||
self.config = config
|
||||
self.summarize = summarize
|
||||
|
||||
@@ -15,9 +15,11 @@ import tempfile
|
||||
import time
|
||||
import uuid
|
||||
|
||||
# Note: avoid introducing unnecessary library dependencies here, e.g. gym
|
||||
# until https://github.com/ray-project/ray/issues/1144 is resolved
|
||||
import tensorflow as tf
|
||||
from ray.tune.logger import UnifiedLogger
|
||||
from ray.tune.registry import ENV_CREATOR
|
||||
from ray.tune.registry import ENV_CREATOR, get_registry
|
||||
from ray.tune.result import DEFAULT_RESULTS_DIR, TrainingResult
|
||||
from ray.tune.trainable import Trainable
|
||||
|
||||
@@ -74,7 +76,8 @@ class Agent(Trainable):
|
||||
_allow_unknown_subkeys = []
|
||||
|
||||
def __init__(
|
||||
self, config={}, env=None, registry=None, logger_creator=None):
|
||||
self, config={}, env=None, registry=get_registry(),
|
||||
logger_creator=None):
|
||||
"""Initialize an RLLib agent.
|
||||
|
||||
Args:
|
||||
@@ -91,11 +94,13 @@ class Agent(Trainable):
|
||||
env = env or config.get("env")
|
||||
if env:
|
||||
config["env"] = env
|
||||
if registry and registry.contains(ENV_CREATOR, env):
|
||||
self.env_creator = registry.get(ENV_CREATOR, env)
|
||||
if registry and registry.contains(ENV_CREATOR, env):
|
||||
self.env_creator = registry.get(ENV_CREATOR, env)
|
||||
else:
|
||||
import gym # soft dependency
|
||||
self.env_creator = lambda: gym.make(env)
|
||||
else:
|
||||
import gym
|
||||
self.env_creator = lambda: gym.make(env)
|
||||
self.env_creator = lambda: None
|
||||
self.config = self._default_config.copy()
|
||||
self.registry = registry
|
||||
|
||||
|
||||
@@ -213,7 +213,7 @@ class FrameStack(gym.Wrapper):
|
||||
return LazyFrames(list(self.frames))
|
||||
|
||||
|
||||
def wrap_dqn(env, options):
|
||||
def wrap_dqn(registry, env, options):
|
||||
"""Apply a common set of wrappers for DQN."""
|
||||
|
||||
is_atari = hasattr(env.unwrapped, "ale")
|
||||
@@ -226,7 +226,7 @@ def wrap_dqn(env, options):
|
||||
if 'FIRE' in env.unwrapped.get_action_meanings():
|
||||
env = FireResetEnv(env)
|
||||
|
||||
env = ModelCatalog.get_preprocessor_as_wrapper(env, options)
|
||||
env = ModelCatalog.get_preprocessor_as_wrapper(registry, env, options)
|
||||
|
||||
if is_atari:
|
||||
env = FrameStack(env, 4)
|
||||
|
||||
@@ -8,8 +8,8 @@ import os
|
||||
import tensorflow as tf
|
||||
|
||||
import ray
|
||||
from ray.rllib.dqn.base_evaluator import DQNEvaluator
|
||||
from ray.rllib.dqn.replay_evaluator import DQNReplayEvaluator
|
||||
from ray.rllib.dqn.dqn_evaluator import DQNEvaluator
|
||||
from ray.rllib.dqn.dqn_replay_evaluator import DQNReplayEvaluator
|
||||
from ray.rllib.optimizers import AsyncOptimizer, LocalMultiGPUOptimizer, \
|
||||
LocalSyncOptimizer
|
||||
from ray.rllib.agent import Agent
|
||||
@@ -113,7 +113,7 @@ class DQNAgent(Agent):
|
||||
def _init(self):
|
||||
if self.config["async_updates"]:
|
||||
self.local_evaluator = DQNEvaluator(
|
||||
self.env_creator, self.config, self.logdir)
|
||||
self.registry, self.env_creator, self.config, self.logdir)
|
||||
remote_cls = ray.remote(
|
||||
num_cpus=1, num_gpus=self.config["num_gpus_per_worker"])(
|
||||
DQNReplayEvaluator)
|
||||
@@ -122,12 +122,13 @@ class DQNAgent(Agent):
|
||||
# own replay buffer (i.e. the replay buffer is sharded).
|
||||
self.remote_evaluators = [
|
||||
remote_cls.remote(
|
||||
self.env_creator, remote_config, self.logdir)
|
||||
self.registry, self.env_creator, remote_config,
|
||||
self.logdir)
|
||||
for _ in range(self.config["num_workers"])]
|
||||
optimizer_cls = AsyncOptimizer
|
||||
else:
|
||||
self.local_evaluator = DQNReplayEvaluator(
|
||||
self.env_creator, self.config, self.logdir)
|
||||
self.registry, self.env_creator, self.config, self.logdir)
|
||||
# No remote evaluators. If num_workers > 1, the DQNReplayEvaluator
|
||||
# will internally create more workers for parallelism. This means
|
||||
# there is only one replay buffer regardless of num_workers.
|
||||
|
||||
@@ -15,15 +15,15 @@ from ray.rllib.optimizers import SampleBatch, TFMultiGPUSupport
|
||||
class DQNEvaluator(TFMultiGPUSupport):
|
||||
"""The base DQN Evaluator that does not include the replay buffer."""
|
||||
|
||||
def __init__(self, env_creator, config, logdir):
|
||||
def __init__(self, registry, env_creator, config, logdir):
|
||||
env = env_creator()
|
||||
env = wrap_dqn(env, config["model"])
|
||||
env = wrap_dqn(registry, env, config["model"])
|
||||
self.env = env
|
||||
self.config = config
|
||||
|
||||
tf_config = tf.ConfigProto(**config["tf_session_args"])
|
||||
self.sess = tf.Session(config=tf_config)
|
||||
self.dqn_graph = models.DQNGraph(env, config, logdir)
|
||||
self.dqn_graph = models.DQNGraph(registry, env, config, logdir)
|
||||
|
||||
# Create the schedule for exploration starting from 1.
|
||||
self.exploration = LinearSchedule(
|
||||
+3
-3
@@ -5,7 +5,7 @@ from __future__ import print_function
|
||||
import numpy as np
|
||||
|
||||
import ray
|
||||
from ray.rllib.dqn.base_evaluator import DQNEvaluator
|
||||
from ray.rllib.dqn.dqn_evaluator import DQNEvaluator
|
||||
from ray.rllib.dqn.common.schedules import LinearSchedule
|
||||
from ray.rllib.dqn.replay_buffer import ReplayBuffer, PrioritizedReplayBuffer
|
||||
from ray.rllib.optimizers import SampleBatch
|
||||
@@ -21,8 +21,8 @@ class DQNReplayEvaluator(DQNEvaluator):
|
||||
Samples will be collected from a number of remote workers.
|
||||
"""
|
||||
|
||||
def __init__(self, env_creator, config, logdir):
|
||||
DQNEvaluator.__init__(self, env_creator, config, logdir)
|
||||
def __init__(self, registry, env_creator, config, logdir):
|
||||
DQNEvaluator.__init__(self, registry, env_creator, config, logdir)
|
||||
|
||||
# Create extra workers if needed
|
||||
if self.config["num_workers"] > 1:
|
||||
@@ -6,13 +6,13 @@ import tensorflow as tf
|
||||
import tensorflow.contrib.layers as layers
|
||||
|
||||
from ray.rllib.models import ModelCatalog
|
||||
from ray.rllib.parallel import TOWER_SCOPE_NAME
|
||||
from ray.rllib.optimizers.multi_gpu_impl import TOWER_SCOPE_NAME
|
||||
|
||||
|
||||
def _build_q_network(inputs, num_actions, config):
|
||||
def _build_q_network(registry, inputs, num_actions, config):
|
||||
dueling = config["dueling"]
|
||||
hiddens = config["hiddens"]
|
||||
frontend = ModelCatalog.get_model(inputs, 1, config["model"])
|
||||
frontend = ModelCatalog.get_model(registry, inputs, 1, config["model"])
|
||||
frontend_out = frontend.last_layer
|
||||
|
||||
with tf.variable_scope("action_value"):
|
||||
@@ -106,15 +106,16 @@ class ModelAndLoss(object):
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, num_actions, config,
|
||||
self, registry, num_actions, config,
|
||||
obs_t, act_t, rew_t, obs_tp1, done_mask, importance_weights):
|
||||
# q network evaluation
|
||||
with tf.variable_scope("q_func", reuse=True):
|
||||
self.q_t = _build_q_network(obs_t, num_actions, config)
|
||||
self.q_t = _build_q_network(registry, obs_t, num_actions, config)
|
||||
|
||||
# target q network evalution
|
||||
with tf.variable_scope("target_q_func") as scope:
|
||||
self.q_tp1 = _build_q_network(obs_tp1, num_actions, config)
|
||||
self.q_tp1 = _build_q_network(
|
||||
registry, obs_tp1, num_actions, config)
|
||||
self.target_q_func_vars = _scope_vars(scope.name)
|
||||
|
||||
# q scores for actions which we know were selected in the given state.
|
||||
@@ -125,7 +126,7 @@ class ModelAndLoss(object):
|
||||
if config["double_q"]:
|
||||
with tf.variable_scope("q_func", reuse=True):
|
||||
q_tp1_using_online_net = _build_q_network(
|
||||
obs_tp1, num_actions, config)
|
||||
registry, obs_tp1, num_actions, config)
|
||||
q_tp1_best_using_online_net = tf.argmax(q_tp1_using_online_net, 1)
|
||||
q_tp1_best = tf.reduce_sum(
|
||||
self.q_tp1 * tf.one_hot(
|
||||
@@ -147,7 +148,7 @@ class ModelAndLoss(object):
|
||||
|
||||
|
||||
class DQNGraph(object):
|
||||
def __init__(self, env, config, logdir):
|
||||
def __init__(self, registry, env, config, logdir):
|
||||
self.env = env
|
||||
num_actions = env.action_space.n
|
||||
optimizer = tf.train.AdamOptimizer(learning_rate=config["lr"])
|
||||
@@ -162,7 +163,7 @@ class DQNGraph(object):
|
||||
q_scope_name = TOWER_SCOPE_NAME + "/q_func"
|
||||
with tf.variable_scope(q_scope_name) as scope:
|
||||
q_values = _build_q_network(
|
||||
self.cur_observations, num_actions, config)
|
||||
registry, self.cur_observations, num_actions, config)
|
||||
q_func_vars = _scope_vars(scope.name)
|
||||
|
||||
# Action outputs
|
||||
@@ -187,6 +188,7 @@ class DQNGraph(object):
|
||||
def build_loss(
|
||||
obs_t, act_t, rew_t, obs_tp1, done_mask, importance_weights):
|
||||
return ModelAndLoss(
|
||||
registry,
|
||||
num_actions, config,
|
||||
obs_t, act_t, rew_t, obs_tp1, done_mask, importance_weights)
|
||||
|
||||
|
||||
@@ -1,85 +0,0 @@
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import gym
|
||||
import logging
|
||||
import time
|
||||
|
||||
from ray.rllib.models import ModelCatalog
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
logger.setLevel(logging.INFO)
|
||||
|
||||
|
||||
def create_and_wrap(env_creator, options):
|
||||
env = env_creator()
|
||||
env = ModelCatalog.get_preprocessor_as_wrapper(env, options)
|
||||
env = Diagnostic(env)
|
||||
return env
|
||||
|
||||
|
||||
class Diagnostic(gym.Wrapper):
|
||||
def __init__(self, env=None):
|
||||
super(Diagnostic, self).__init__(env)
|
||||
self.diagnostics = DiagnosticsLogger()
|
||||
|
||||
def _reset(self):
|
||||
observation = self.env.reset()
|
||||
return self.diagnostics._after_reset(observation)
|
||||
|
||||
def _step(self, action):
|
||||
results = self.env.step(action)
|
||||
return self.diagnostics._after_step(*results)
|
||||
|
||||
|
||||
class DiagnosticsLogger(object):
|
||||
def __init__(self, log_interval=503):
|
||||
self._episode_time = time.time()
|
||||
self._last_time = time.time()
|
||||
self._local_t = 0
|
||||
self._log_interval = log_interval
|
||||
self._episode_reward = 0
|
||||
self._episode_length = 0
|
||||
self._all_rewards = []
|
||||
self._last_episode_id = -1
|
||||
|
||||
def _after_reset(self, observation):
|
||||
logger.info("Resetting environment")
|
||||
self._episode_reward = 0
|
||||
self._episode_length = 0
|
||||
self._all_rewards = []
|
||||
return observation
|
||||
|
||||
def _after_step(self, observation, reward, done, info):
|
||||
to_log = {}
|
||||
if self._episode_length == 0:
|
||||
self._episode_time = time.time()
|
||||
|
||||
self._local_t += 1
|
||||
|
||||
if self._local_t % self._log_interval == 0:
|
||||
cur_time = time.time()
|
||||
self._last_time = cur_time
|
||||
|
||||
if reward is not None:
|
||||
self._episode_reward += reward
|
||||
if observation is not None:
|
||||
self._episode_length += 1
|
||||
self._all_rewards.append(reward)
|
||||
|
||||
if done:
|
||||
logger.info("Episode terminating: episode_reward=%s "
|
||||
"episode_length=%s",
|
||||
self._episode_reward, self._episode_length)
|
||||
total_time = time.time() - self._episode_time
|
||||
to_log["global/episode_reward"] = self._episode_reward
|
||||
to_log["global/episode_length"] = self._episode_length
|
||||
to_log["global/episode_time"] = total_time
|
||||
to_log["global/reward_per_time"] = (self._episode_reward /
|
||||
total_time)
|
||||
self._episode_reward = 0
|
||||
self._episode_length = 0
|
||||
self._all_rewards = []
|
||||
|
||||
return observation, reward, done, to_log
|
||||
@@ -63,7 +63,7 @@ class SharedNoiseTable(object):
|
||||
|
||||
@ray.remote
|
||||
class Worker(object):
|
||||
def __init__(self, config, policy_params, env_creator, noise,
|
||||
def __init__(self, registry, config, policy_params, env_creator, noise,
|
||||
min_task_runtime=0.2):
|
||||
self.min_task_runtime = min_task_runtime
|
||||
self.config = config
|
||||
@@ -71,13 +71,12 @@ class Worker(object):
|
||||
self.noise = SharedNoiseTable(noise)
|
||||
|
||||
self.env = env_creator()
|
||||
self.preprocessor = ModelCatalog.get_preprocessor(self.env)
|
||||
self.preprocessor = ModelCatalog.get_preprocessor(registry, self.env)
|
||||
|
||||
self.sess = utils.make_session(single_threaded=True)
|
||||
self.policy = policies.GenericPolicy(self.sess, self.env.action_space,
|
||||
self.preprocessor,
|
||||
config["observation_filter"],
|
||||
**policy_params)
|
||||
self.policy = policies.GenericPolicy(
|
||||
registry, self.sess, self.env.action_space, self.preprocessor,
|
||||
config["observation_filter"], **policy_params)
|
||||
|
||||
def rollout(self, timestep_limit, add_noise=True):
|
||||
rollout_rewards, rollout_length = policies.rollout(
|
||||
@@ -143,11 +142,11 @@ class ESAgent(Agent):
|
||||
}
|
||||
|
||||
env = self.env_creator()
|
||||
preprocessor = ModelCatalog.get_preprocessor(env)
|
||||
preprocessor = ModelCatalog.get_preprocessor(self.registry, env)
|
||||
|
||||
self.sess = utils.make_session(single_threaded=False)
|
||||
self.policy = policies.GenericPolicy(
|
||||
self.sess, env.action_space, preprocessor,
|
||||
self.registry, self.sess, env.action_space, preprocessor,
|
||||
self.config["observation_filter"], **policy_params)
|
||||
self.optimizer = optimizers.Adam(self.policy, self.config["stepsize"])
|
||||
|
||||
@@ -160,7 +159,8 @@ class ESAgent(Agent):
|
||||
print("Creating actors.")
|
||||
self.workers = [
|
||||
Worker.remote(
|
||||
self.config, policy_params, self.env_creator, noise_id)
|
||||
self.registry, self.config, policy_params, self.env_creator,
|
||||
noise_id)
|
||||
for _ in range(self.config["num_workers"])]
|
||||
|
||||
self.episodes_so_far = 0
|
||||
|
||||
@@ -39,7 +39,7 @@ def rollout(policy, env, timestep_limit=None, add_noise=False):
|
||||
|
||||
|
||||
class GenericPolicy(object):
|
||||
def __init__(self, sess, action_space, preprocessor,
|
||||
def __init__(self, registry, sess, action_space, preprocessor,
|
||||
observation_filter, action_noise_std):
|
||||
self.sess = sess
|
||||
self.action_space = action_space
|
||||
@@ -53,7 +53,7 @@ class GenericPolicy(object):
|
||||
# Policy network.
|
||||
dist_class, dist_dim = ModelCatalog.get_action_dist(
|
||||
self.action_space, dist_type="deterministic")
|
||||
model = ModelCatalog.get_model(self.inputs, dist_dim)
|
||||
model = ModelCatalog.get_model(registry, self.inputs, dist_dim)
|
||||
dist = dist_class(model.outputs)
|
||||
self.sampler = dist.sample()
|
||||
|
||||
|
||||
@@ -4,6 +4,9 @@ from __future__ import print_function
|
||||
|
||||
import gym
|
||||
|
||||
from ray.tune.registry import RLLIB_MODEL, RLLIB_PREPROCESSOR, \
|
||||
_default_registry
|
||||
|
||||
from ray.rllib.models.action_dist import (
|
||||
Categorical, Deterministic, DiagGaussian)
|
||||
from ray.rllib.models.preprocessors import (
|
||||
@@ -14,6 +17,7 @@ from ray.rllib.models.visionnet import VisionNetwork
|
||||
|
||||
|
||||
MODEL_CONFIGS = [
|
||||
# === Built-in options ===
|
||||
"conv_filters", # Number of filters
|
||||
"dim", # Dimension for ATARI
|
||||
"grayscale", # Converts ATARI frame to 1 Channel Grayscale image
|
||||
@@ -23,6 +27,11 @@ MODEL_CONFIGS = [
|
||||
"fcnet_hiddens", # Number of hidden layers for fully connected net
|
||||
"free_log_std", # Documented in ray.rllib.models.Model
|
||||
"channel_major", # Pytorch conv requires images to be channel-major
|
||||
|
||||
# === Options for custom models ===
|
||||
"custom_preprocessor", # Name of a custom preprocessor to use
|
||||
"custom_model", # Name of a custom model to use
|
||||
"custom_options", # Extra options to pass to the custom classes
|
||||
]
|
||||
|
||||
|
||||
@@ -32,8 +41,6 @@ class ModelCatalog(object):
|
||||
ATARI_OBS_SHAPE = (210, 160, 3)
|
||||
ATARI_RAM_OBS_SHAPE = (128,)
|
||||
|
||||
_registered_preprocessor = dict()
|
||||
|
||||
@staticmethod
|
||||
def get_action_dist(action_space, dist_type=None):
|
||||
"""Returns action distribution class and size for the given action space.
|
||||
@@ -59,10 +66,11 @@ class ModelCatalog(object):
|
||||
"Unsupported args: {} {}".format(action_space, dist_type))
|
||||
|
||||
@staticmethod
|
||||
def get_model(inputs, num_outputs, options=dict()):
|
||||
def get_model(registry, inputs, num_outputs, options=dict()):
|
||||
"""Returns a suitable model conforming to given input and output specs.
|
||||
|
||||
Args:
|
||||
registry (obj): Registry of named objects (ray.tune.registry).
|
||||
inputs (Tensor): The input tensor to the model.
|
||||
num_outputs (int): The size of the output vector of the model.
|
||||
options (dict): Optional args to pass to the model constructor.
|
||||
@@ -71,7 +79,13 @@ class ModelCatalog(object):
|
||||
model (Model): Neural network model.
|
||||
"""
|
||||
|
||||
obs_rank = len(inputs.get_shape()) - 1
|
||||
if "custom_model" in options:
|
||||
model = options["custom_model"]
|
||||
print("Using custom model {}".format(model))
|
||||
return registry.get(RLLIB_MODEL, model)(
|
||||
inputs, num_outputs, options)
|
||||
|
||||
obs_rank = len(inputs.shape) - 1
|
||||
|
||||
if obs_rank > 1:
|
||||
return VisionNetwork(inputs, num_outputs, options)
|
||||
@@ -79,11 +93,12 @@ class ModelCatalog(object):
|
||||
return FullyConnectedNetwork(inputs, num_outputs, options)
|
||||
|
||||
@staticmethod
|
||||
def get_torch_model(input_shape, num_outputs, options=dict()):
|
||||
def get_torch_model(registry, input_shape, num_outputs, options=dict()):
|
||||
"""Returns a PyTorch suitable model. This is currently only supported
|
||||
in A3C.
|
||||
|
||||
Args:
|
||||
registry (obj): Registry of named objects (ray.tune.registry).
|
||||
input_shape (tuple): The input shape to the model.
|
||||
num_outputs (int): The size of the output vector of the model.
|
||||
options (dict): Optional args to pass to the model constructor.
|
||||
@@ -96,6 +111,12 @@ class ModelCatalog(object):
|
||||
from ray.rllib.models.pytorch.visionnet import (
|
||||
VisionNetwork as PyTorchVisionNet)
|
||||
|
||||
if "custom_model" in options:
|
||||
model = options["custom_model"]
|
||||
print("Using custom torch model {}".format(model))
|
||||
return registry.get(RLLIB_MODEL, model)(
|
||||
input_shape, num_outputs, options)
|
||||
|
||||
obs_rank = len(input_shape) - 1
|
||||
|
||||
if obs_rank > 1:
|
||||
@@ -103,11 +124,12 @@ class ModelCatalog(object):
|
||||
|
||||
return PyTorchFCNet(input_shape[0], num_outputs, options)
|
||||
|
||||
@classmethod
|
||||
def get_preprocessor(cls, env, options=dict()):
|
||||
@staticmethod
|
||||
def get_preprocessor(registry, env, options=dict()):
|
||||
"""Returns a suitable processor for the given environment.
|
||||
|
||||
Args:
|
||||
registry (obj): Registry of named objects (ray.tune.registry).
|
||||
env (gym.Env): The gym environment to preprocess.
|
||||
options (dict): Options to pass to the preprocessor.
|
||||
|
||||
@@ -120,7 +142,6 @@ class ModelCatalog(object):
|
||||
isinstance(env.observation_space, gym.spaces.Discrete):
|
||||
env.observation_space.shape = ()
|
||||
|
||||
env_name = env.spec.id
|
||||
obs_shape = env.observation_space.shape
|
||||
|
||||
for k in options.keys():
|
||||
@@ -131,30 +152,33 @@ class ModelCatalog(object):
|
||||
|
||||
print("Observation shape is {}".format(obs_shape))
|
||||
|
||||
if env_name in cls._registered_preprocessor:
|
||||
return cls._registered_preprocessor[env_name](
|
||||
env.observation_space, options)
|
||||
if "custom_preprocessor" in options:
|
||||
preprocessor = options["custom_preprocessor"]
|
||||
print("Using custom preprocessor {}".format(preprocessor))
|
||||
return registry.get(RLLIB_PREPROCESSOR, preprocessor)(
|
||||
env.observation_space, options)
|
||||
|
||||
if obs_shape == ():
|
||||
print("Using one-hot preprocessor for discrete envs.")
|
||||
preprocessor = OneHotPreprocessor
|
||||
elif obs_shape == cls.ATARI_OBS_SHAPE:
|
||||
elif obs_shape == ModelCatalog.ATARI_OBS_SHAPE:
|
||||
print("Assuming Atari pixel env, using AtariPixelPreprocessor.")
|
||||
preprocessor = AtariPixelPreprocessor
|
||||
elif obs_shape == cls.ATARI_RAM_OBS_SHAPE:
|
||||
elif obs_shape == ModelCatalog.ATARI_RAM_OBS_SHAPE:
|
||||
print("Assuming Atari ram env, using AtariRamPreprocessor.")
|
||||
preprocessor = AtariRamPreprocessor
|
||||
else:
|
||||
print("Non-atari env, not using any observation preprocessor.")
|
||||
print("Not using any observation preprocessor.")
|
||||
preprocessor = NoPreprocessor
|
||||
|
||||
return preprocessor(env.observation_space, options)
|
||||
|
||||
@classmethod
|
||||
def get_preprocessor_as_wrapper(cls, env, options=dict()):
|
||||
@staticmethod
|
||||
def get_preprocessor_as_wrapper(registry, env, options=dict()):
|
||||
"""Returns a preprocessor as a gym observation wrapper.
|
||||
|
||||
Args:
|
||||
registry (obj): Registry of named objects (ray.tune.registry).
|
||||
env (gym.Env): The gym environment to wrap.
|
||||
options (dict): Options to pass to the preprocessor.
|
||||
|
||||
@@ -162,20 +186,35 @@ class ModelCatalog(object):
|
||||
wrapper (gym.ObservationWrapper): Preprocessor in wrapper form.
|
||||
"""
|
||||
|
||||
preprocessor = cls.get_preprocessor(env, options)
|
||||
preprocessor = ModelCatalog.get_preprocessor(registry, env, options)
|
||||
return _RLlibPreprocessorWrapper(env, preprocessor)
|
||||
|
||||
@classmethod
|
||||
def register_preprocessor(cls, env_name, preprocessor_class):
|
||||
"""Register a preprocessor class for a specific environment.
|
||||
@staticmethod
|
||||
def register_custom_preprocessor(preprocessor_name, preprocessor_class):
|
||||
"""Register a custom preprocessor class by name.
|
||||
|
||||
The preprocessor can be later used by specifying
|
||||
{"custom_preprocessor": preprocesor_name} in the model config.
|
||||
|
||||
Args:
|
||||
env_name (str): Name of the gym env we register the
|
||||
preprocessor for.
|
||||
preprocessor_class (type):
|
||||
Python class of the distribution.
|
||||
preprocessor_name (str): Name to register the preprocessor under.
|
||||
preprocessor_class (type): Python class of the preprocessor.
|
||||
"""
|
||||
cls._registered_preprocessor[env_name] = preprocessor_class
|
||||
_default_registry.register(
|
||||
RLLIB_PREPROCESSOR, preprocessor_name, preprocessor_class)
|
||||
|
||||
@staticmethod
|
||||
def register_custom_model(model_name, model_class):
|
||||
"""Register a custom model class by name.
|
||||
|
||||
The model can be later used by specifying {"custom_model": model_name}
|
||||
in the model config.
|
||||
|
||||
Args:
|
||||
model_name (str): Name to register the model under.
|
||||
model_class (type): Python class of the model.
|
||||
"""
|
||||
_default_registry.register(RLLIB_MODEL, model_name, model_class)
|
||||
|
||||
|
||||
class _RLlibPreprocessorWrapper(gym.ObservationWrapper):
|
||||
|
||||
@@ -10,7 +10,7 @@ import ray
|
||||
from ray.rllib.optimizers.evaluator import TFMultiGPUSupport
|
||||
from ray.rllib.optimizers.optimizer import Optimizer
|
||||
from ray.rllib.optimizers.sample_batch import SampleBatch
|
||||
from ray.rllib.parallel import LocalSyncParallelOptimizer
|
||||
from ray.rllib.optimizers.multi_gpu_impl import LocalSyncParallelOptimizer
|
||||
from ray.rllib.utils.timer import TimerStat
|
||||
|
||||
|
||||
@@ -20,7 +20,7 @@ class LocalMultiGPUOptimizer(Optimizer):
|
||||
Samples are pulled synchronously from multiple remote evaluators,
|
||||
concatenated, and then split across the memory of multiple local GPUs.
|
||||
A number of SGD passes are then taken over the in-memory data. For more
|
||||
details, see `ray.rllib.parallel.LocalSyncParallelOptimizer`.
|
||||
details, see `multi_gpu_impl.LocalSyncParallelOptimizer`.
|
||||
|
||||
This optimizer is Tensorflow-specific and require evaluators to implement
|
||||
the TFMultiGPUSupport API.
|
||||
|
||||
@@ -1,50 +0,0 @@
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import numpy as np
|
||||
|
||||
from ray.rllib.models import ModelCatalog
|
||||
|
||||
|
||||
class BatchedEnv(object):
|
||||
"""This holds multiple gym envs and performs steps on all of them."""
|
||||
def __init__(self, env_creator, batchsize, options):
|
||||
self.envs = [env_creator() for _ in range(batchsize)]
|
||||
self.observation_space = self.envs[0].observation_space
|
||||
self.action_space = self.envs[0].action_space
|
||||
self.batchsize = batchsize
|
||||
self.preprocessor = ModelCatalog.get_preprocessor(
|
||||
self.envs[0], options["model"])
|
||||
self.extra_frameskip = options.get("extra_frameskip", 1)
|
||||
assert self.extra_frameskip >= 1
|
||||
|
||||
def reset(self):
|
||||
observations = [
|
||||
self.preprocessor.transform(env.reset())[None]
|
||||
for env in self.envs]
|
||||
self.shape = observations[0].shape
|
||||
self.dones = [False for _ in range(self.batchsize)]
|
||||
return np.vstack(observations)
|
||||
|
||||
def step(self, actions, render=False):
|
||||
observations = []
|
||||
rewards = []
|
||||
for i, action in enumerate(actions):
|
||||
if self.dones[i]:
|
||||
observations.append(np.zeros(self.shape))
|
||||
rewards.append(0.0)
|
||||
continue
|
||||
reward = 0.0
|
||||
for j in range(self.extra_frameskip):
|
||||
observation, r, done, info = self.envs[i].step(action)
|
||||
reward += r
|
||||
if done:
|
||||
break
|
||||
if render:
|
||||
self.envs[0].render()
|
||||
observations.append(self.preprocessor.transform(observation)[None])
|
||||
rewards.append(reward)
|
||||
self.dones[i] = done
|
||||
return (np.vstack(observations), np.array(rewards, dtype="float32"),
|
||||
np.array(self.dones))
|
||||
@@ -17,7 +17,7 @@ class ProximalPolicyLoss(object):
|
||||
self, observation_space, action_space,
|
||||
observations, value_targets, advantages, actions,
|
||||
prev_logits, prev_vf_preds, logit_dim,
|
||||
kl_coeff, distribution_class, config, sess):
|
||||
kl_coeff, distribution_class, config, sess, registry):
|
||||
assert (isinstance(action_space, gym.spaces.Discrete) or
|
||||
isinstance(action_space, gym.spaces.Box))
|
||||
self.prev_dist = distribution_class(prev_logits)
|
||||
@@ -26,7 +26,7 @@ class ProximalPolicyLoss(object):
|
||||
self.observations = observations
|
||||
|
||||
self.curr_logits = ModelCatalog.get_model(
|
||||
observations, logit_dim, config["model"]).outputs
|
||||
registry, observations, logit_dim, config["model"]).outputs
|
||||
self.curr_dist = distribution_class(self.curr_logits)
|
||||
self.sampler = self.curr_dist.sample()
|
||||
|
||||
@@ -38,7 +38,7 @@ class ProximalPolicyLoss(object):
|
||||
vf_config["free_log_std"] = False
|
||||
with tf.variable_scope("value_function"):
|
||||
self.value_function = ModelCatalog.get_model(
|
||||
observations, 1, vf_config).outputs
|
||||
registry, observations, 1, vf_config).outputs
|
||||
self.value_function = tf.reshape(self.value_function, [-1])
|
||||
|
||||
# Make loss functions.
|
||||
|
||||
@@ -94,10 +94,11 @@ class PPOAgent(Agent):
|
||||
self.global_step = 0
|
||||
self.kl_coeff = self.config["kl_coeff"]
|
||||
self.model = Runner(
|
||||
self.env_creator, self.config, self.logdir, False)
|
||||
self.registry, self.env_creator, self.config, self.logdir, False)
|
||||
self.agents = [
|
||||
RemoteRunner.remote(
|
||||
self.env_creator, self.config, self.logdir, True)
|
||||
self.registry, self.env_creator, self.config, self.logdir,
|
||||
True)
|
||||
for _ in range(self.config["num_workers"])]
|
||||
self.start_time = time.time()
|
||||
if self.config["write_logs"]:
|
||||
|
||||
@@ -12,9 +12,8 @@ from tensorflow.python import debug as tf_debug
|
||||
import numpy as np
|
||||
import ray
|
||||
|
||||
from ray.rllib.parallel import LocalSyncParallelOptimizer
|
||||
from ray.rllib.optimizers.multi_gpu_impl import LocalSyncParallelOptimizer
|
||||
from ray.rllib.models import ModelCatalog
|
||||
from ray.rllib.envs import create_and_wrap
|
||||
from ray.rllib.utils.sampler import SyncSampler
|
||||
from ray.rllib.utils.filter import get_filter, MeanStdFilter
|
||||
from ray.rllib.utils.process_rollout import process_rollout
|
||||
@@ -38,7 +37,8 @@ class Runner(object):
|
||||
network weights. When run as a remote agent, only this graph is used.
|
||||
"""
|
||||
|
||||
def __init__(self, env_creator, config, logdir, is_remote):
|
||||
def __init__(self, registry, env_creator, config, logdir, is_remote):
|
||||
self.registry = registry
|
||||
self.is_remote = is_remote
|
||||
if is_remote:
|
||||
os.environ["CUDA_VISIBLE_DEVICES"] = ""
|
||||
@@ -48,7 +48,8 @@ class Runner(object):
|
||||
self.devices = devices
|
||||
self.config = config
|
||||
self.logdir = logdir
|
||||
self.env = create_and_wrap(env_creator, config["model"])
|
||||
self.env = ModelCatalog.get_preprocessor_as_wrapper(
|
||||
registry, env_creator(), config["model"])
|
||||
if is_remote:
|
||||
config_proto = tf.ConfigProto()
|
||||
else:
|
||||
@@ -105,7 +106,7 @@ class Runner(object):
|
||||
self.env.observation_space, self.env.action_space,
|
||||
obs, vtargets, advs, acts, plog, pvf_preds, self.logit_dim,
|
||||
self.kl_coeff, self.distribution_class, self.config,
|
||||
self.sess)
|
||||
self.sess, self.registry)
|
||||
|
||||
self.par_opt = LocalSyncParallelOptimizer(
|
||||
tf.train.AdamOptimizer(self.config["sgd_stepsize"]),
|
||||
|
||||
@@ -1,22 +1,79 @@
|
||||
import gym
|
||||
import numpy as np
|
||||
import tensorflow as tf
|
||||
import unittest
|
||||
|
||||
import ray
|
||||
from ray.tune.registry import get_registry
|
||||
|
||||
from ray.rllib.models import ModelCatalog
|
||||
from ray.rllib.models.preprocessors import Preprocessor
|
||||
from ray.rllib.models.model import Model
|
||||
from ray.rllib.models.preprocessors import (
|
||||
NoPreprocessor, OneHotPreprocessor, Preprocessor)
|
||||
from ray.rllib.models.fcnet import FullyConnectedNetwork
|
||||
from ray.rllib.models.visionnet import VisionNetwork
|
||||
|
||||
|
||||
class FakePreprocessor(Preprocessor):
|
||||
def _init(self):
|
||||
pass
|
||||
class CustomPreprocessor(Preprocessor):
|
||||
pass
|
||||
|
||||
|
||||
class FakeEnv(object):
|
||||
def __init__(self):
|
||||
self.observation_space = lambda: None
|
||||
self.observation_space.shape = ()
|
||||
self.spec = lambda: None
|
||||
self.spec.id = "FakeEnv-v0"
|
||||
class CustomPreprocessor2(Preprocessor):
|
||||
pass
|
||||
|
||||
|
||||
def test_preprocessor():
|
||||
ModelCatalog.register_preprocessor("FakeEnv-v0", FakePreprocessor)
|
||||
env = FakeEnv()
|
||||
preprocessor = ModelCatalog.get_preprocessor(env)
|
||||
assert type(preprocessor) == FakePreprocessor
|
||||
class CustomModel(Model):
|
||||
def _init(self, *args):
|
||||
return None, None
|
||||
|
||||
|
||||
class ModelCatalogTest(unittest.TestCase):
|
||||
def tearDown(self):
|
||||
ray.worker.cleanup()
|
||||
|
||||
def testGymPreprocessors(self):
|
||||
p1 = ModelCatalog.get_preprocessor(
|
||||
get_registry(), gym.make("CartPole-v0"))
|
||||
assert type(p1) == NoPreprocessor
|
||||
|
||||
p2 = ModelCatalog.get_preprocessor(
|
||||
get_registry(), gym.make("FrozenLake-v0"))
|
||||
assert type(p2) == OneHotPreprocessor
|
||||
|
||||
def testCustomPreprocessor(self):
|
||||
ray.init()
|
||||
ModelCatalog.register_custom_preprocessor("foo", CustomPreprocessor)
|
||||
ModelCatalog.register_custom_preprocessor("bar", CustomPreprocessor2)
|
||||
env = gym.make("CartPole-v0")
|
||||
p1 = ModelCatalog.get_preprocessor(
|
||||
get_registry(), env, {"custom_preprocessor": "foo"})
|
||||
assert type(p1) == CustomPreprocessor
|
||||
p2 = ModelCatalog.get_preprocessor(
|
||||
get_registry(), env, {"custom_preprocessor": "bar"})
|
||||
assert type(p2) == CustomPreprocessor2
|
||||
p3 = ModelCatalog.get_preprocessor(get_registry(), env)
|
||||
assert type(p3) == NoPreprocessor
|
||||
|
||||
def testDefaultModels(self):
|
||||
ray.init()
|
||||
|
||||
with tf.variable_scope("test1"):
|
||||
p1 = ModelCatalog.get_model(
|
||||
get_registry(), np.zeros((10, 3), dtype=np.float32), 5)
|
||||
assert type(p1) == FullyConnectedNetwork
|
||||
|
||||
with tf.variable_scope("test2"):
|
||||
p2 = ModelCatalog.get_model(
|
||||
get_registry(), np.zeros((10, 80, 80, 3), dtype=np.float32), 5)
|
||||
assert type(p2) == VisionNetwork
|
||||
|
||||
def testCustomModel(self):
|
||||
ray.init()
|
||||
ModelCatalog.register_custom_model("foo", CustomModel)
|
||||
p1 = ModelCatalog.get_model(
|
||||
get_registry(), 1, 5, {"custom_model": "foo"})
|
||||
assert type(p1) == CustomModel
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main(verbosity=2)
|
||||
|
||||
@@ -1,52 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import numpy as np
|
||||
import gym
|
||||
import ray
|
||||
from ray.rllib.ppo import PPOAgent, DEFAULT_CONFIG
|
||||
|
||||
|
||||
config = DEFAULT_CONFIG.copy()
|
||||
config["num_workers"] = 3
|
||||
config["num_sgd_iter"] = 6
|
||||
config["sgd_batchsize"] = 128
|
||||
config["timesteps_per_batch"] = 4000
|
||||
|
||||
ray.init()
|
||||
|
||||
# first train one agent
|
||||
agent = PPOAgent("CartPole-v0", config)
|
||||
|
||||
for i in range(10):
|
||||
result = agent.train()
|
||||
print(result)
|
||||
|
||||
# checkpoint and restore in a copied agent
|
||||
checkpoint_path = agent.save()
|
||||
trained_config = config.copy()
|
||||
test_agent = PPOAgent("CartPole-v0", trained_config)
|
||||
test_agent.restore(checkpoint_path)
|
||||
|
||||
# evaluate on copied agent
|
||||
results = []
|
||||
env = gym.make("CartPole-v0")
|
||||
for _ in range(20):
|
||||
state = env.reset()
|
||||
done = False
|
||||
cumulative_reward = 0
|
||||
|
||||
while not done:
|
||||
action = test_agent.compute_action(state)
|
||||
state, reward, done, _ = env.step(action)
|
||||
cumulative_reward += reward
|
||||
|
||||
results.append(cumulative_reward)
|
||||
|
||||
print("All results", results)
|
||||
print("Mean result", np.mean(results))
|
||||
|
||||
assert(np.mean(results)) > 0.9 * result.episode_reward_mean
|
||||
@@ -4,6 +4,8 @@ from __future__ import print_function
|
||||
|
||||
from types import FunctionType
|
||||
|
||||
import numpy as np
|
||||
|
||||
import ray
|
||||
from ray.tune import TuneError
|
||||
from ray.local_scheduler import ObjectID
|
||||
@@ -11,7 +13,10 @@ from ray.tune.trainable import Trainable, wrap_function
|
||||
|
||||
TRAINABLE_CLASS = "trainable_class"
|
||||
ENV_CREATOR = "env_creator"
|
||||
KNOWN_CATEGORIES = [TRAINABLE_CLASS, ENV_CREATOR]
|
||||
RLLIB_MODEL = "rllib_model"
|
||||
RLLIB_PREPROCESSOR = "rllib_preprocessor"
|
||||
KNOWN_CATEGORIES = [
|
||||
TRAINABLE_CLASS, ENV_CREATOR, RLLIB_MODEL, RLLIB_PREPROCESSOR]
|
||||
|
||||
|
||||
def register_trainable(name, trainable):
|
||||
@@ -55,6 +60,22 @@ def get_registry():
|
||||
return _Registry(_default_registry._all_objects)
|
||||
|
||||
|
||||
def _to_pinnable(obj):
|
||||
"""Converts obj to a form that can be pinned in object store memory.
|
||||
|
||||
Currently only numpy arrays are pinned in memory, if you have a strong
|
||||
reference to the array value.
|
||||
"""
|
||||
|
||||
return (obj, np.zeros(1))
|
||||
|
||||
|
||||
def _from_pinnable(obj):
|
||||
"""Retrieve from _to_pinnable format."""
|
||||
|
||||
return obj[0]
|
||||
|
||||
|
||||
class _Registry(object):
|
||||
def __init__(self, objs={}):
|
||||
self._all_objects = objs
|
||||
@@ -72,14 +93,14 @@ class _Registry(object):
|
||||
def get(self, category, key):
|
||||
value = self._all_objects[(category, key)]
|
||||
if type(value) == ObjectID:
|
||||
return ray.get(value)
|
||||
return _from_pinnable(ray.get(value))
|
||||
else:
|
||||
return value
|
||||
|
||||
def flush_values_to_object_store(self):
|
||||
for k, v in self._all_objects.items():
|
||||
if type(v) != ObjectID:
|
||||
obj = ray.put(v)
|
||||
obj = ray.put(_to_pinnable(v))
|
||||
self._all_objects[k] = obj
|
||||
self._refs.append(ray.get(obj))
|
||||
|
||||
|
||||
Reference in New Issue
Block a user