mirror of
https://github.com/wassname/ray.git
synced 2026-06-28 12:45:44 +08:00
[rllib] Custom supervised loss API (#4083)
This commit is contained in:
@@ -98,7 +98,8 @@ class A3CPolicyGraph(LearningRateSchedule, TFPolicyGraph):
|
||||
obs_input=self.observations,
|
||||
action_sampler=action_dist.sample(),
|
||||
action_prob=action_dist.sampled_action_prob(),
|
||||
loss=self.model.loss() + self.loss.total_loss,
|
||||
loss=self.loss.total_loss,
|
||||
model=self.model,
|
||||
loss_inputs=loss_in,
|
||||
state_inputs=self.model.state_in,
|
||||
state_outputs=self.model.state_out,
|
||||
|
||||
@@ -243,7 +243,7 @@ class Agent(Trainable):
|
||||
self.global_vars = {"timestep": 0}
|
||||
|
||||
# Agents allow env ids to be passed directly to the constructor.
|
||||
self._env_id = _register_if_needed(env or config.get("env"))
|
||||
self._env_id = self._register_if_needed(env or config.get("env"))
|
||||
|
||||
# Create a default logger creator if no logger_creator is specified
|
||||
if logger_creator is None:
|
||||
@@ -671,11 +671,14 @@ class Agent(Trainable):
|
||||
if "optimizer" in state:
|
||||
self.optimizer.restore(state["optimizer"])
|
||||
|
||||
|
||||
def _register_if_needed(env_object):
|
||||
if isinstance(env_object, six.string_types):
|
||||
return env_object
|
||||
elif isinstance(env_object, type):
|
||||
name = env_object.__name__
|
||||
register_env(name, lambda config: env_object(config))
|
||||
return name
|
||||
def _register_if_needed(self, env_object):
|
||||
if isinstance(env_object, six.string_types):
|
||||
return env_object
|
||||
elif isinstance(env_object, type):
|
||||
name = env_object.__name__
|
||||
register_env(name, lambda config: env_object(config))
|
||||
return name
|
||||
raise ValueError(
|
||||
"{} is an invalid env specification. ".format(env_object) +
|
||||
"You can specify a custom env as either a class "
|
||||
"(e.g., YourEnvCls) or a registered env id (e.g., \"your_env\").")
|
||||
|
||||
@@ -334,10 +334,11 @@ class DDPGPolicyGraph(TFPolicyGraph):
|
||||
config["l2_reg"] * 0.5 * tf.nn.l2_loss(var))
|
||||
|
||||
# Model self-supervised losses
|
||||
self.loss.actor_loss += self.p_model.loss()
|
||||
self.loss.critic_loss += self.q_model.loss()
|
||||
self.loss.actor_loss = self.p_model.custom_loss(self.loss.actor_loss)
|
||||
self.loss.critic_loss = self.q_model.custom_loss(self.loss.critic_loss)
|
||||
if self.config["twin_q"]:
|
||||
self.loss.critic_loss += self.twin_q_model.loss()
|
||||
self.loss.critic_loss = self.twin_q_model.custom_loss(
|
||||
self.loss.critic_loss)
|
||||
|
||||
# update_target_fn will be called periodically to copy Q network to
|
||||
# target Q network
|
||||
|
||||
@@ -410,7 +410,8 @@ class DQNPolicyGraph(TFPolicyGraph):
|
||||
obs_input=self.cur_observations,
|
||||
action_sampler=self.output_actions,
|
||||
action_prob=self.action_prob,
|
||||
loss=model.loss() + self.loss.loss,
|
||||
loss=self.loss.loss,
|
||||
model=model,
|
||||
loss_inputs=self.loss_inputs,
|
||||
update_ops=q_batchnorm_update_ops)
|
||||
self.sess.run(tf.global_variables_initializer())
|
||||
|
||||
@@ -216,7 +216,8 @@ class VTracePolicyGraph(LearningRateSchedule, TFPolicyGraph):
|
||||
obs_input=observations,
|
||||
action_sampler=action_dist.sample(),
|
||||
action_prob=action_dist.sampled_action_prob(),
|
||||
loss=self.model.loss() + self.loss.total_loss,
|
||||
loss=self.loss.total_loss,
|
||||
model=self.model,
|
||||
loss_inputs=loss_in,
|
||||
state_inputs=self.model.state_in,
|
||||
state_outputs=self.model.state_out,
|
||||
|
||||
@@ -114,7 +114,8 @@ class MARWILPolicyGraph(TFPolicyGraph):
|
||||
obs_input=self.obs_t,
|
||||
action_sampler=self.output_actions,
|
||||
action_prob=action_dist.sampled_action_prob(),
|
||||
loss=self.model.loss() + objective,
|
||||
loss=objective,
|
||||
model=self.model,
|
||||
loss_inputs=self.loss_inputs,
|
||||
state_inputs=self.model.state_in,
|
||||
state_outputs=self.model.state_out,
|
||||
|
||||
@@ -46,6 +46,9 @@ class _MockAgent(Agent):
|
||||
self.info = info
|
||||
self.restored = True
|
||||
|
||||
def _register_if_needed(self, env_object):
|
||||
pass
|
||||
|
||||
def set_info(self, info):
|
||||
self.info = info
|
||||
return info
|
||||
|
||||
@@ -68,8 +68,9 @@ class PGPolicyGraph(TFPolicyGraph):
|
||||
obs_input=obs,
|
||||
action_sampler=action_dist.sample(),
|
||||
action_prob=action_dist.sampled_action_prob(),
|
||||
loss=self.model.loss() + loss,
|
||||
loss=loss,
|
||||
loss_inputs=loss_in,
|
||||
model=self.model,
|
||||
state_inputs=self.model.state_in,
|
||||
state_outputs=self.model.state_out,
|
||||
prev_action_input=prev_actions,
|
||||
|
||||
@@ -321,7 +321,8 @@ class AsyncPPOPolicyGraph(LearningRateSchedule, TFPolicyGraph):
|
||||
obs_input=observations,
|
||||
action_sampler=action_dist.sample(),
|
||||
action_prob=action_dist.sampled_action_prob(),
|
||||
loss=self.model.loss() + self.loss.total_loss,
|
||||
loss=self.loss.total_loss,
|
||||
model=self.model,
|
||||
loss_inputs=loss_in,
|
||||
state_inputs=self.model.state_in,
|
||||
state_outputs=self.model.state_out,
|
||||
@@ -339,7 +340,6 @@ class AsyncPPOPolicyGraph(LearningRateSchedule, TFPolicyGraph):
|
||||
values_batched = to_batches(values)
|
||||
self.stats_fetches = {
|
||||
"stats": {
|
||||
"model_loss": self.model.loss(),
|
||||
"cur_lr": tf.cast(self.cur_lr, tf.float64),
|
||||
"policy_loss": self.loss.pi_loss,
|
||||
"entropy": self.loss.entropy,
|
||||
|
||||
@@ -148,6 +148,8 @@ class PPOPolicyGraph(LearningRateSchedule, TFPolicyGraph):
|
||||
existing_state_in = None
|
||||
existing_seq_lens = None
|
||||
self.observations = obs_ph
|
||||
self.prev_actions = prev_actions_ph
|
||||
self.prev_rewards = prev_rewards_ph
|
||||
|
||||
self.loss_in = [
|
||||
("obs", obs_ph),
|
||||
@@ -245,7 +247,8 @@ class PPOPolicyGraph(LearningRateSchedule, TFPolicyGraph):
|
||||
obs_input=obs_ph,
|
||||
action_sampler=self.sampler,
|
||||
action_prob=curr_action_dist.sampled_action_prob(),
|
||||
loss=self.model.loss() + self.loss_obj.loss,
|
||||
loss=self.loss_obj.loss,
|
||||
model=self.model,
|
||||
loss_inputs=self.loss_in,
|
||||
state_inputs=self.model.state_in,
|
||||
state_outputs=self.model.state_out,
|
||||
@@ -289,7 +292,9 @@ class PPOPolicyGraph(LearningRateSchedule, TFPolicyGraph):
|
||||
next_state = []
|
||||
for i in range(len(self.model.state_in)):
|
||||
next_state.append([sample_batch["state_out_{}".format(i)][-1]])
|
||||
last_r = self._value(sample_batch["new_obs"][-1], *next_state)
|
||||
last_r = self._value(sample_batch["new_obs"][-1],
|
||||
sample_batch["actions"][-1],
|
||||
sample_batch["rewards"][-1], *next_state)
|
||||
batch = compute_advantages(
|
||||
sample_batch,
|
||||
last_r,
|
||||
@@ -336,8 +341,13 @@ class PPOPolicyGraph(LearningRateSchedule, TFPolicyGraph):
|
||||
self.kl_coeff.load(self.kl_coeff_val, session=self.sess)
|
||||
return self.kl_coeff_val
|
||||
|
||||
def _value(self, ob, *args):
|
||||
feed_dict = {self.observations: [ob], self.model.seq_lens: [1]}
|
||||
def _value(self, ob, prev_action, prev_reward, *args):
|
||||
feed_dict = {
|
||||
self.observations: [ob],
|
||||
self.prev_actions: [prev_action],
|
||||
self.prev_rewards: [prev_reward],
|
||||
self.model.seq_lens: [1]
|
||||
}
|
||||
assert len(args) == len(self.model.state_in), \
|
||||
(args, self.model.state_in)
|
||||
for k, v in zip(self.model.state_in, args):
|
||||
|
||||
@@ -8,10 +8,12 @@ import pickle
|
||||
import tensorflow as tf
|
||||
|
||||
import ray
|
||||
from ray.rllib.env.base_env import BaseEnv
|
||||
from ray.rllib.env.atari_wrappers import wrap_deepmind, is_atari
|
||||
from ray.rllib.env.base_env import BaseEnv
|
||||
from ray.rllib.env.env_context import EnvContext
|
||||
from ray.rllib.env.external_env import ExternalEnv
|
||||
from ray.rllib.env.multi_agent_env import MultiAgentEnv
|
||||
from ray.rllib.env.vector_env import VectorEnv
|
||||
from ray.rllib.evaluation.interface import EvaluatorInterface
|
||||
from ray.rllib.evaluation.sample_batch import MultiAgentBatch, \
|
||||
DEFAULT_POLICY_ID
|
||||
@@ -222,7 +224,7 @@ class PolicyEvaluator(EvaluatorInterface):
|
||||
self.compress_observations = compress_observations
|
||||
self.preprocessing_enabled = True
|
||||
|
||||
self.env = env_creator(env_context)
|
||||
self.env = _validate_env(env_creator(env_context))
|
||||
if isinstance(self.env, MultiAgentEnv) or \
|
||||
isinstance(self.env, BaseEnv):
|
||||
|
||||
@@ -701,6 +703,20 @@ def _validate_and_canonicalize(policy_graph, env):
|
||||
}
|
||||
|
||||
|
||||
def _validate_env(env):
|
||||
# allow this as a special case (assumed gym.Env)
|
||||
if hasattr(env, "observation_space") and hasattr(env, "action_space"):
|
||||
return env
|
||||
|
||||
allowed_types = [gym.Env, MultiAgentEnv, ExternalEnv, VectorEnv, BaseEnv]
|
||||
if not any(isinstance(env, tpe) for tpe in allowed_types):
|
||||
raise ValueError(
|
||||
"Returned env should be an instance of gym.Env, MultiAgentEnv, "
|
||||
"ExternalEnv, VectorEnv, or BaseEnv. The provided env creator "
|
||||
"function returned {} ({}).".format(env, type(env)))
|
||||
return env
|
||||
|
||||
|
||||
def _monitor(env, path):
|
||||
return gym.wrappers.Monitor(env, path, resume=True)
|
||||
|
||||
|
||||
@@ -32,6 +32,7 @@ class TFPolicyGraph(PolicyGraph):
|
||||
Attributes:
|
||||
observation_space (gym.Space): observation space of the policy.
|
||||
action_space (gym.Space): action space of the policy.
|
||||
model (rllib.models.Model): RLlib model used for the policy.
|
||||
|
||||
Examples:
|
||||
>>> policy = TFPolicyGraphSubclass(
|
||||
@@ -53,6 +54,7 @@ class TFPolicyGraph(PolicyGraph):
|
||||
action_sampler,
|
||||
loss,
|
||||
loss_inputs,
|
||||
model=None,
|
||||
action_prob=None,
|
||||
state_inputs=None,
|
||||
state_outputs=None,
|
||||
@@ -79,6 +81,8 @@ class TFPolicyGraph(PolicyGraph):
|
||||
and has shape [BATCH_SIZE, data...]. These keys will be read
|
||||
from postprocessed sample batches and fed into the specified
|
||||
placeholders during loss computation.
|
||||
model (rllib.models.Model): used to integrate custom losses and
|
||||
stats from user-defined RLlib models.
|
||||
action_prob (Tensor): probability of the sampled action.
|
||||
state_inputs (list): list of RNN state input Tensors.
|
||||
state_outputs (list): list of RNN state output Tensors.
|
||||
@@ -98,12 +102,18 @@ class TFPolicyGraph(PolicyGraph):
|
||||
|
||||
self.observation_space = observation_space
|
||||
self.action_space = action_space
|
||||
self.model = model
|
||||
self._sess = sess
|
||||
self._obs_input = obs_input
|
||||
self._prev_action_input = prev_action_input
|
||||
self._prev_reward_input = prev_reward_input
|
||||
self._sampler = action_sampler
|
||||
self._loss = loss
|
||||
if self.model:
|
||||
self._loss = self.model.custom_loss(loss)
|
||||
self._stats_fetches = {"model": self.model.custom_stats()}
|
||||
else:
|
||||
self._loss = loss
|
||||
self._stats_fetches = {}
|
||||
self._loss_inputs = loss_inputs
|
||||
self._loss_input_dict = dict(self._loss_inputs)
|
||||
self._is_training = self._get_is_training_placeholder()
|
||||
@@ -375,7 +385,7 @@ class TFPolicyGraph(PolicyGraph):
|
||||
builder.add_feed_dict({self._is_training: True})
|
||||
builder.add_feed_dict(self._get_loss_inputs_dict(postprocessed_batch))
|
||||
fetches = builder.add_fetches(
|
||||
[self._grads, self.extra_compute_grad_fetches()])
|
||||
[self._grads, self._get_grad_and_stats_fetches()])
|
||||
return fetches[0], fetches[1]
|
||||
|
||||
def _build_apply_gradients(self, builder, gradients):
|
||||
@@ -397,11 +407,18 @@ class TFPolicyGraph(PolicyGraph):
|
||||
builder.add_feed_dict({self._is_training: True})
|
||||
fetches = builder.add_fetches([
|
||||
self._apply_op,
|
||||
self.extra_compute_grad_fetches(),
|
||||
self._get_grad_and_stats_fetches(),
|
||||
self.extra_apply_grad_fetches()
|
||||
])
|
||||
return fetches[1], fetches[2]
|
||||
|
||||
def _get_grad_and_stats_fetches(self):
|
||||
fetches = self.extra_compute_grad_fetches()
|
||||
if self._stats_fetches:
|
||||
fetches["stats"] = dict(self._stats_fetches,
|
||||
**fetches.get("stats", {}))
|
||||
return fetches
|
||||
|
||||
def _get_loss_inputs_dict(self, batch):
|
||||
feed_dict = {}
|
||||
if self._batch_divisibility_req > 1:
|
||||
|
||||
@@ -0,0 +1,100 @@
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
"""Example of using custom_loss() with an imitation learning loss.
|
||||
|
||||
The default input file is too small to learn a good policy, but you can
|
||||
generate new experiences for IL training as follows:
|
||||
|
||||
To generate experiences:
|
||||
$ ./train.py --run=PG --config='{"output": "/tmp/cartpole"}' --env=CartPole-v0
|
||||
|
||||
To train on experiences with joint PG + IL loss:
|
||||
$ python custom_loss.py --input-files=/tmp/cartpole
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import os
|
||||
import tensorflow as tf
|
||||
|
||||
import ray
|
||||
from ray.tune import run_experiments
|
||||
from ray.rllib.models import (Categorical, FullyConnectedNetwork, Model,
|
||||
ModelCatalog)
|
||||
from ray.rllib.models.model import restore_original_dimensions
|
||||
from ray.rllib.offline import JsonReader
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--iters", type=int, default=200)
|
||||
parser.add_argument(
|
||||
"--input-files",
|
||||
type=str,
|
||||
default=os.path.join(
|
||||
os.path.dirname(os.path.abspath(__file__)),
|
||||
"../test/data/cartpole_small"))
|
||||
|
||||
|
||||
class CustomLossModel(Model):
|
||||
"""Custom model that adds an imitation loss on top of the policy loss."""
|
||||
|
||||
def _build_layers_v2(self, input_dict, num_outputs, options):
|
||||
self.obs_in = input_dict["obs"]
|
||||
self.fcnet = FullyConnectedNetwork(input_dict, self.obs_space,
|
||||
num_outputs, options)
|
||||
return self.fcnet.outputs, self.fcnet.last_layer
|
||||
|
||||
def custom_loss(self, policy_loss):
|
||||
# create a new input reader per worker
|
||||
reader = JsonReader(self.options["custom_options"]["input_files"])
|
||||
input_ops = reader.tf_input_ops()
|
||||
|
||||
# define a secondary loss by building a graph copy with weight sharing
|
||||
with tf.variable_scope(
|
||||
self.scope, reuse=tf.AUTO_REUSE, auxiliary_name_scope=False):
|
||||
logits, _ = self._build_layers_v2({
|
||||
"obs": restore_original_dimensions(input_ops["obs"],
|
||||
self.obs_space)
|
||||
}, self.num_outputs, self.options)
|
||||
|
||||
# You can also add self-supervised losses easily by referencing tensors
|
||||
# created during _build_layers_v2(). For example, an autoencoder-style
|
||||
# loss can be added as follows:
|
||||
# ae_loss = squared_diff(self.obs_in, Decoder(self.fcnet.last_layer))
|
||||
|
||||
# compute the IL loss
|
||||
action_dist = Categorical(logits)
|
||||
self.policy_loss = policy_loss
|
||||
self.imitation_loss = tf.reduce_mean(
|
||||
-action_dist.logp(input_ops["actions"]))
|
||||
return policy_loss + 10 * self.imitation_loss
|
||||
|
||||
def custom_stats(self):
|
||||
return {
|
||||
"policy_loss": self.policy_loss,
|
||||
"imitation_loss": self.imitation_loss,
|
||||
}
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
ray.init()
|
||||
args = parser.parse_args()
|
||||
|
||||
ModelCatalog.register_custom_model("custom_loss", CustomLossModel)
|
||||
run_experiments({
|
||||
"custom_loss": {
|
||||
"run": "PG",
|
||||
"env": "CartPole-v0",
|
||||
"stop": {
|
||||
"training_iteration": args.iters,
|
||||
},
|
||||
"config": {
|
||||
"num_workers": 0,
|
||||
"model": {
|
||||
"custom_model": "custom_loss",
|
||||
"custom_options": {
|
||||
"input_files": args.input_files,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
})
|
||||
@@ -1,8 +1,7 @@
|
||||
"""Example of handling variable length and/or parametric action spaces.
|
||||
|
||||
This is a toy example of the action-embedding based approach for handling large
|
||||
discrete action spaces (potentially infinite in size), similar to how
|
||||
OpenAI Five works:
|
||||
discrete action spaces (potentially infinite in size), similar to this:
|
||||
|
||||
https://neuro.cs.ut.ee/the-use-of-embeddings-in-openai-five/
|
||||
|
||||
|
||||
@@ -9,7 +9,7 @@ import tensorflow as tf
|
||||
|
||||
from ray.rllib.models.misc import linear, normc_initializer
|
||||
from ray.rllib.models.preprocessors import get_preprocessor
|
||||
from ray.rllib.utils.annotations import PublicAPI
|
||||
from ray.rllib.utils.annotations import PublicAPI, DeveloperAPI
|
||||
|
||||
|
||||
@PublicAPI
|
||||
@@ -58,6 +58,11 @@ class Model(object):
|
||||
self.state_init = []
|
||||
self.state_in = state_in or []
|
||||
self.state_out = []
|
||||
self.obs_space = obs_space
|
||||
self.num_outputs = num_outputs
|
||||
self.options = options
|
||||
self.scope = tf.get_variable_scope()
|
||||
self.session = tf.get_default_session()
|
||||
if seq_lens is not None:
|
||||
self.seq_lens = seq_lens
|
||||
else:
|
||||
@@ -69,9 +74,11 @@ class Model(object):
|
||||
assert num_outputs % 2 == 0
|
||||
num_outputs = num_outputs // 2
|
||||
try:
|
||||
restored = input_dict.copy()
|
||||
restored["obs"] = restore_original_dimensions(
|
||||
input_dict["obs"], obs_space)
|
||||
self.outputs, self.last_layer = self._build_layers_v2(
|
||||
_restore_original_dimensions(input_dict, obs_space),
|
||||
num_outputs, options)
|
||||
restored, num_outputs, options)
|
||||
except NotImplementedError:
|
||||
self.outputs, self.last_layer = self._build_layers(
|
||||
input_dict["obs"], num_outputs, options)
|
||||
@@ -139,17 +146,46 @@ class Model(object):
|
||||
linear(self.last_layer, 1, "value", normc_initializer(1.0)), [-1])
|
||||
|
||||
@PublicAPI
|
||||
def loss(self):
|
||||
"""Builds any built-in (self-supervised) loss for the model.
|
||||
def custom_loss(self, policy_loss):
|
||||
"""Override to customize the loss function used to optimize this model.
|
||||
|
||||
For example, this can be used to incorporate auto-encoder style losses.
|
||||
Note that this loss has to be included in the policy graph loss to have
|
||||
an effect (done for built-in algorithms).
|
||||
This can be used to incorporate self-supervised losses (by defining
|
||||
a loss over existing input and output tensors of this model), and
|
||||
supervised losses (by defining losses over a variable-sharing copy of
|
||||
this model's layers).
|
||||
|
||||
You can find an runnable example in examples/custom_loss.py.
|
||||
|
||||
Arguments:
|
||||
policy_loss (Tensor): scalar policy loss from the policy graph.
|
||||
|
||||
Returns:
|
||||
Scalar tensor for the self-supervised loss.
|
||||
Scalar tensor for the customized loss for this model.
|
||||
"""
|
||||
return tf.constant(0.0)
|
||||
if self.loss() is not None:
|
||||
raise DeprecationWarning(
|
||||
"self.loss() is deprecated, use self.custom_loss() instead.")
|
||||
return policy_loss
|
||||
|
||||
@PublicAPI
|
||||
def custom_stats(self):
|
||||
"""Override to return custom metrics from your model.
|
||||
|
||||
The stats will be reported as part of the learner stats, i.e.,
|
||||
info:
|
||||
learner:
|
||||
model:
|
||||
key1: metric1
|
||||
key2: metric2
|
||||
|
||||
Returns:
|
||||
Dict of string keys to scalar tensors.
|
||||
"""
|
||||
return {}
|
||||
|
||||
def loss(self):
|
||||
"""Deprecated: use self.custom_loss()."""
|
||||
return None
|
||||
|
||||
def _validate_output_shape(self):
|
||||
"""Checks that the model has the correct number of outputs."""
|
||||
@@ -165,15 +201,29 @@ class Model(object):
|
||||
self._num_outputs, shape))
|
||||
|
||||
|
||||
def _restore_original_dimensions(input_dict, obs_space, tensorlib=tf):
|
||||
@DeveloperAPI
|
||||
def restore_original_dimensions(obs, obs_space, tensorlib=tf):
|
||||
"""Unpacks Dict and Tuple space observations into their original form.
|
||||
|
||||
This is needed since we flatten Dict and Tuple observations in transit.
|
||||
Before sending them to the model though, we should unflatten them into
|
||||
Dicts or Tuples of tensors.
|
||||
|
||||
Arguments:
|
||||
obs: The flattened observation tensor.
|
||||
obs_space: The flattened obs space. If this has the `original_space`
|
||||
attribute, we will unflatten the tensor to that shape.
|
||||
tensorlib: The library used to unflatten (reshape) the array/tensor.
|
||||
|
||||
Returns:
|
||||
single tensor or dict / tuple of tensors matching the original
|
||||
observation space.
|
||||
"""
|
||||
|
||||
if hasattr(obs_space, "original_space"):
|
||||
return dict(
|
||||
input_dict,
|
||||
obs=_unpack_obs(
|
||||
input_dict["obs"],
|
||||
obs_space.original_space,
|
||||
tensorlib=tensorlib))
|
||||
return input_dict
|
||||
return _unpack_obs(obs, obs_space.original_space, tensorlib=tensorlib)
|
||||
else:
|
||||
return obs
|
||||
|
||||
|
||||
def _unpack_obs(obs, space, tensorlib=tf):
|
||||
|
||||
@@ -5,7 +5,7 @@ from __future__ import print_function
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from ray.rllib.models.model import _restore_original_dimensions
|
||||
from ray.rllib.models.model import restore_original_dimensions
|
||||
from ray.rllib.utils.annotations import PublicAPI
|
||||
|
||||
|
||||
@@ -31,8 +31,8 @@ class TorchModel(nn.Module):
|
||||
def forward(self, input_dict, hidden_state):
|
||||
"""Wraps _forward() to unpack flattened Dict and Tuple observations."""
|
||||
input_dict["obs"] = input_dict["obs"].float() # TODO(ekl): avoid cast
|
||||
input_dict = _restore_original_dimensions(
|
||||
input_dict, self.obs_space, tensorlib=torch)
|
||||
input_dict["obs"] = restore_original_dimensions(
|
||||
input_dict["obs"], self.obs_space, tensorlib=torch)
|
||||
outputs, features, vf, h = self._forward(input_dict, hidden_state)
|
||||
return outputs, features, vf, h
|
||||
|
||||
|
||||
@@ -2,8 +2,16 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import logging
|
||||
import numpy as np
|
||||
import tensorflow as tf
|
||||
import threading
|
||||
|
||||
from ray.rllib.evaluation.sample_batch import MultiAgentBatch
|
||||
from ray.rllib.utils.annotations import PublicAPI
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@PublicAPI
|
||||
class InputReader(object):
|
||||
@@ -17,3 +25,97 @@ class InputReader(object):
|
||||
SampleBatch or MultiAgentBatch read.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@PublicAPI
|
||||
def tf_input_ops(self, queue_size=1):
|
||||
"""Returns TensorFlow queue ops for reading inputs from this reader.
|
||||
|
||||
The main use of these ops is for integration into custom model losses.
|
||||
For example, you can use tf_input_ops() to read from files of external
|
||||
experiences to add an imitation learning loss to your model.
|
||||
|
||||
This method creates a queue runner thread that will call next() on this
|
||||
reader repeatedly to feed the TensorFlow queue.
|
||||
|
||||
Arguments:
|
||||
queue_size (int): Max elements to allow in the TF queue.
|
||||
|
||||
Example:
|
||||
>>> class MyModel(rllib.model.Model):
|
||||
... def custom_loss(self, policy_loss):
|
||||
... reader = JsonReader(...)
|
||||
... input_ops = reader.tf_input_ops()
|
||||
... with tf.variable_scope(
|
||||
... self.scope, reuse=tf.AUTO_REUSE,
|
||||
... auxiliary_name_scope=False):
|
||||
... logits, _ = self._build_layers_v2(
|
||||
... {"obs": input_ops["obs"]},
|
||||
... self.num_outputs, self.options)
|
||||
... il_loss = imitation_loss(logits, input_ops["action"])
|
||||
... return policy_loss + il_loss
|
||||
|
||||
You can find a runnable version of this in examples/custom_loss.py.
|
||||
|
||||
Returns:
|
||||
dict of Tensors, one for each column of the read SampleBatch.
|
||||
"""
|
||||
|
||||
if hasattr(self, "_queue_runner"):
|
||||
raise ValueError(
|
||||
"A queue runner already exists for this input reader. "
|
||||
"You can only call tf_input_ops() once per reader.")
|
||||
|
||||
logger.info("Reading initial batch of data from input reader.")
|
||||
batch = self.next()
|
||||
if isinstance(batch, MultiAgentBatch):
|
||||
raise NotImplementedError(
|
||||
"tf_input_ops() is not implemented for multi agent batches")
|
||||
|
||||
keys = [
|
||||
k for k in sorted(list(batch.keys()))
|
||||
if np.issubdtype(batch[k].dtype, np.number)
|
||||
]
|
||||
dtypes = [batch[k].dtype for k in keys]
|
||||
shapes = {
|
||||
k: (-1, ) + s[1:]
|
||||
for (k, s) in [(k, batch[k].shape) for k in keys]
|
||||
}
|
||||
queue = tf.FIFOQueue(capacity=queue_size, dtypes=dtypes, names=keys)
|
||||
tensors = queue.dequeue()
|
||||
|
||||
logger.info("Creating TF queue runner for {}".format(self))
|
||||
self._queue_runner = _QueueRunner(self, queue, keys, dtypes)
|
||||
self._queue_runner.enqueue(batch)
|
||||
self._queue_runner.start()
|
||||
|
||||
out = {k: tf.reshape(t, shapes[k]) for k, t in tensors.items()}
|
||||
return out
|
||||
|
||||
|
||||
class _QueueRunner(threading.Thread):
|
||||
"""Thread that feeds a TF queue from a InputReader."""
|
||||
|
||||
def __init__(self, input_reader, queue, keys, dtypes):
|
||||
threading.Thread.__init__(self)
|
||||
self.sess = tf.get_default_session()
|
||||
self.daemon = True
|
||||
self.input_reader = input_reader
|
||||
self.keys = keys
|
||||
self.queue = queue
|
||||
self.placeholders = [tf.placeholder(dtype) for dtype in dtypes]
|
||||
self.enqueue_op = queue.enqueue(dict(zip(keys, self.placeholders)))
|
||||
|
||||
def enqueue(self, batch):
|
||||
data = {
|
||||
self.placeholders[i]: batch[key]
|
||||
for i, key in enumerate(self.keys)
|
||||
}
|
||||
self.sess.run(self.enqueue_op, feed_dict=data)
|
||||
|
||||
def run(self):
|
||||
while True:
|
||||
try:
|
||||
batch = self.input_reader.next()
|
||||
self.enqueue(batch)
|
||||
except Exception:
|
||||
logger.exception("Error reading from input")
|
||||
|
||||
@@ -311,9 +311,10 @@ class TrialRunner(object):
|
||||
for state, trials in states.items()
|
||||
}
|
||||
total_number_of_trials = sum(num_trials_per_state.values())
|
||||
messages.append("Number of trials: {} ({})"
|
||||
"".format(total_number_of_trials,
|
||||
num_trials_per_state))
|
||||
if total_number_of_trials > 0:
|
||||
messages.append("Number of trials: {} ({})"
|
||||
"".format(total_number_of_trials,
|
||||
num_trials_per_state))
|
||||
|
||||
for state, trials in sorted(states.items()):
|
||||
limit = limit_per_state[state]
|
||||
|
||||
Reference in New Issue
Block a user