[rllib] Custom supervised loss API (#4083)

This commit is contained in:
Eric Liang
2019-02-24 15:36:13 -08:00
committed by GitHub
parent 7b04ed059e
commit d9da183c7d
24 changed files with 551 additions and 181 deletions
@@ -98,7 +98,8 @@ class A3CPolicyGraph(LearningRateSchedule, TFPolicyGraph):
obs_input=self.observations,
action_sampler=action_dist.sample(),
action_prob=action_dist.sampled_action_prob(),
loss=self.model.loss() + self.loss.total_loss,
loss=self.loss.total_loss,
model=self.model,
loss_inputs=loss_in,
state_inputs=self.model.state_in,
state_outputs=self.model.state_out,
+12 -9
View File
@@ -243,7 +243,7 @@ class Agent(Trainable):
self.global_vars = {"timestep": 0}
# Agents allow env ids to be passed directly to the constructor.
self._env_id = _register_if_needed(env or config.get("env"))
self._env_id = self._register_if_needed(env or config.get("env"))
# Create a default logger creator if no logger_creator is specified
if logger_creator is None:
@@ -671,11 +671,14 @@ class Agent(Trainable):
if "optimizer" in state:
self.optimizer.restore(state["optimizer"])
def _register_if_needed(env_object):
if isinstance(env_object, six.string_types):
return env_object
elif isinstance(env_object, type):
name = env_object.__name__
register_env(name, lambda config: env_object(config))
return name
def _register_if_needed(self, env_object):
if isinstance(env_object, six.string_types):
return env_object
elif isinstance(env_object, type):
name = env_object.__name__
register_env(name, lambda config: env_object(config))
return name
raise ValueError(
"{} is an invalid env specification. ".format(env_object) +
"You can specify a custom env as either a class "
"(e.g., YourEnvCls) or a registered env id (e.g., \"your_env\").")
@@ -334,10 +334,11 @@ class DDPGPolicyGraph(TFPolicyGraph):
config["l2_reg"] * 0.5 * tf.nn.l2_loss(var))
# Model self-supervised losses
self.loss.actor_loss += self.p_model.loss()
self.loss.critic_loss += self.q_model.loss()
self.loss.actor_loss = self.p_model.custom_loss(self.loss.actor_loss)
self.loss.critic_loss = self.q_model.custom_loss(self.loss.critic_loss)
if self.config["twin_q"]:
self.loss.critic_loss += self.twin_q_model.loss()
self.loss.critic_loss = self.twin_q_model.custom_loss(
self.loss.critic_loss)
# update_target_fn will be called periodically to copy Q network to
# target Q network
@@ -410,7 +410,8 @@ class DQNPolicyGraph(TFPolicyGraph):
obs_input=self.cur_observations,
action_sampler=self.output_actions,
action_prob=self.action_prob,
loss=model.loss() + self.loss.loss,
loss=self.loss.loss,
model=model,
loss_inputs=self.loss_inputs,
update_ops=q_batchnorm_update_ops)
self.sess.run(tf.global_variables_initializer())
@@ -216,7 +216,8 @@ class VTracePolicyGraph(LearningRateSchedule, TFPolicyGraph):
obs_input=observations,
action_sampler=action_dist.sample(),
action_prob=action_dist.sampled_action_prob(),
loss=self.model.loss() + self.loss.total_loss,
loss=self.loss.total_loss,
model=self.model,
loss_inputs=loss_in,
state_inputs=self.model.state_in,
state_outputs=self.model.state_out,
@@ -114,7 +114,8 @@ class MARWILPolicyGraph(TFPolicyGraph):
obs_input=self.obs_t,
action_sampler=self.output_actions,
action_prob=action_dist.sampled_action_prob(),
loss=self.model.loss() + objective,
loss=objective,
model=self.model,
loss_inputs=self.loss_inputs,
state_inputs=self.model.state_in,
state_outputs=self.model.state_out,
+3
View File
@@ -46,6 +46,9 @@ class _MockAgent(Agent):
self.info = info
self.restored = True
def _register_if_needed(self, env_object):
pass
def set_info(self, info):
self.info = info
return info
@@ -68,8 +68,9 @@ class PGPolicyGraph(TFPolicyGraph):
obs_input=obs,
action_sampler=action_dist.sample(),
action_prob=action_dist.sampled_action_prob(),
loss=self.model.loss() + loss,
loss=loss,
loss_inputs=loss_in,
model=self.model,
state_inputs=self.model.state_in,
state_outputs=self.model.state_out,
prev_action_input=prev_actions,
@@ -321,7 +321,8 @@ class AsyncPPOPolicyGraph(LearningRateSchedule, TFPolicyGraph):
obs_input=observations,
action_sampler=action_dist.sample(),
action_prob=action_dist.sampled_action_prob(),
loss=self.model.loss() + self.loss.total_loss,
loss=self.loss.total_loss,
model=self.model,
loss_inputs=loss_in,
state_inputs=self.model.state_in,
state_outputs=self.model.state_out,
@@ -339,7 +340,6 @@ class AsyncPPOPolicyGraph(LearningRateSchedule, TFPolicyGraph):
values_batched = to_batches(values)
self.stats_fetches = {
"stats": {
"model_loss": self.model.loss(),
"cur_lr": tf.cast(self.cur_lr, tf.float64),
"policy_loss": self.loss.pi_loss,
"entropy": self.loss.entropy,
@@ -148,6 +148,8 @@ class PPOPolicyGraph(LearningRateSchedule, TFPolicyGraph):
existing_state_in = None
existing_seq_lens = None
self.observations = obs_ph
self.prev_actions = prev_actions_ph
self.prev_rewards = prev_rewards_ph
self.loss_in = [
("obs", obs_ph),
@@ -245,7 +247,8 @@ class PPOPolicyGraph(LearningRateSchedule, TFPolicyGraph):
obs_input=obs_ph,
action_sampler=self.sampler,
action_prob=curr_action_dist.sampled_action_prob(),
loss=self.model.loss() + self.loss_obj.loss,
loss=self.loss_obj.loss,
model=self.model,
loss_inputs=self.loss_in,
state_inputs=self.model.state_in,
state_outputs=self.model.state_out,
@@ -289,7 +292,9 @@ class PPOPolicyGraph(LearningRateSchedule, TFPolicyGraph):
next_state = []
for i in range(len(self.model.state_in)):
next_state.append([sample_batch["state_out_{}".format(i)][-1]])
last_r = self._value(sample_batch["new_obs"][-1], *next_state)
last_r = self._value(sample_batch["new_obs"][-1],
sample_batch["actions"][-1],
sample_batch["rewards"][-1], *next_state)
batch = compute_advantages(
sample_batch,
last_r,
@@ -336,8 +341,13 @@ class PPOPolicyGraph(LearningRateSchedule, TFPolicyGraph):
self.kl_coeff.load(self.kl_coeff_val, session=self.sess)
return self.kl_coeff_val
def _value(self, ob, *args):
feed_dict = {self.observations: [ob], self.model.seq_lens: [1]}
def _value(self, ob, prev_action, prev_reward, *args):
feed_dict = {
self.observations: [ob],
self.prev_actions: [prev_action],
self.prev_rewards: [prev_reward],
self.model.seq_lens: [1]
}
assert len(args) == len(self.model.state_in), \
(args, self.model.state_in)
for k, v in zip(self.model.state_in, args):
@@ -8,10 +8,12 @@ import pickle
import tensorflow as tf
import ray
from ray.rllib.env.base_env import BaseEnv
from ray.rllib.env.atari_wrappers import wrap_deepmind, is_atari
from ray.rllib.env.base_env import BaseEnv
from ray.rllib.env.env_context import EnvContext
from ray.rllib.env.external_env import ExternalEnv
from ray.rllib.env.multi_agent_env import MultiAgentEnv
from ray.rllib.env.vector_env import VectorEnv
from ray.rllib.evaluation.interface import EvaluatorInterface
from ray.rllib.evaluation.sample_batch import MultiAgentBatch, \
DEFAULT_POLICY_ID
@@ -222,7 +224,7 @@ class PolicyEvaluator(EvaluatorInterface):
self.compress_observations = compress_observations
self.preprocessing_enabled = True
self.env = env_creator(env_context)
self.env = _validate_env(env_creator(env_context))
if isinstance(self.env, MultiAgentEnv) or \
isinstance(self.env, BaseEnv):
@@ -701,6 +703,20 @@ def _validate_and_canonicalize(policy_graph, env):
}
def _validate_env(env):
# allow this as a special case (assumed gym.Env)
if hasattr(env, "observation_space") and hasattr(env, "action_space"):
return env
allowed_types = [gym.Env, MultiAgentEnv, ExternalEnv, VectorEnv, BaseEnv]
if not any(isinstance(env, tpe) for tpe in allowed_types):
raise ValueError(
"Returned env should be an instance of gym.Env, MultiAgentEnv, "
"ExternalEnv, VectorEnv, or BaseEnv. The provided env creator "
"function returned {} ({}).".format(env, type(env)))
return env
def _monitor(env, path):
return gym.wrappers.Monitor(env, path, resume=True)
+20 -3
View File
@@ -32,6 +32,7 @@ class TFPolicyGraph(PolicyGraph):
Attributes:
observation_space (gym.Space): observation space of the policy.
action_space (gym.Space): action space of the policy.
model (rllib.models.Model): RLlib model used for the policy.
Examples:
>>> policy = TFPolicyGraphSubclass(
@@ -53,6 +54,7 @@ class TFPolicyGraph(PolicyGraph):
action_sampler,
loss,
loss_inputs,
model=None,
action_prob=None,
state_inputs=None,
state_outputs=None,
@@ -79,6 +81,8 @@ class TFPolicyGraph(PolicyGraph):
and has shape [BATCH_SIZE, data...]. These keys will be read
from postprocessed sample batches and fed into the specified
placeholders during loss computation.
model (rllib.models.Model): used to integrate custom losses and
stats from user-defined RLlib models.
action_prob (Tensor): probability of the sampled action.
state_inputs (list): list of RNN state input Tensors.
state_outputs (list): list of RNN state output Tensors.
@@ -98,12 +102,18 @@ class TFPolicyGraph(PolicyGraph):
self.observation_space = observation_space
self.action_space = action_space
self.model = model
self._sess = sess
self._obs_input = obs_input
self._prev_action_input = prev_action_input
self._prev_reward_input = prev_reward_input
self._sampler = action_sampler
self._loss = loss
if self.model:
self._loss = self.model.custom_loss(loss)
self._stats_fetches = {"model": self.model.custom_stats()}
else:
self._loss = loss
self._stats_fetches = {}
self._loss_inputs = loss_inputs
self._loss_input_dict = dict(self._loss_inputs)
self._is_training = self._get_is_training_placeholder()
@@ -375,7 +385,7 @@ class TFPolicyGraph(PolicyGraph):
builder.add_feed_dict({self._is_training: True})
builder.add_feed_dict(self._get_loss_inputs_dict(postprocessed_batch))
fetches = builder.add_fetches(
[self._grads, self.extra_compute_grad_fetches()])
[self._grads, self._get_grad_and_stats_fetches()])
return fetches[0], fetches[1]
def _build_apply_gradients(self, builder, gradients):
@@ -397,11 +407,18 @@ class TFPolicyGraph(PolicyGraph):
builder.add_feed_dict({self._is_training: True})
fetches = builder.add_fetches([
self._apply_op,
self.extra_compute_grad_fetches(),
self._get_grad_and_stats_fetches(),
self.extra_apply_grad_fetches()
])
return fetches[1], fetches[2]
def _get_grad_and_stats_fetches(self):
fetches = self.extra_compute_grad_fetches()
if self._stats_fetches:
fetches["stats"] = dict(self._stats_fetches,
**fetches.get("stats", {}))
return fetches
def _get_loss_inputs_dict(self, batch):
feed_dict = {}
if self._batch_divisibility_req > 1:
+100
View File
@@ -0,0 +1,100 @@
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
"""Example of using custom_loss() with an imitation learning loss.
The default input file is too small to learn a good policy, but you can
generate new experiences for IL training as follows:
To generate experiences:
$ ./train.py --run=PG --config='{"output": "/tmp/cartpole"}' --env=CartPole-v0
To train on experiences with joint PG + IL loss:
$ python custom_loss.py --input-files=/tmp/cartpole
"""
import argparse
import os
import tensorflow as tf
import ray
from ray.tune import run_experiments
from ray.rllib.models import (Categorical, FullyConnectedNetwork, Model,
ModelCatalog)
from ray.rllib.models.model import restore_original_dimensions
from ray.rllib.offline import JsonReader
parser = argparse.ArgumentParser()
parser.add_argument("--iters", type=int, default=200)
parser.add_argument(
"--input-files",
type=str,
default=os.path.join(
os.path.dirname(os.path.abspath(__file__)),
"../test/data/cartpole_small"))
class CustomLossModel(Model):
"""Custom model that adds an imitation loss on top of the policy loss."""
def _build_layers_v2(self, input_dict, num_outputs, options):
self.obs_in = input_dict["obs"]
self.fcnet = FullyConnectedNetwork(input_dict, self.obs_space,
num_outputs, options)
return self.fcnet.outputs, self.fcnet.last_layer
def custom_loss(self, policy_loss):
# create a new input reader per worker
reader = JsonReader(self.options["custom_options"]["input_files"])
input_ops = reader.tf_input_ops()
# define a secondary loss by building a graph copy with weight sharing
with tf.variable_scope(
self.scope, reuse=tf.AUTO_REUSE, auxiliary_name_scope=False):
logits, _ = self._build_layers_v2({
"obs": restore_original_dimensions(input_ops["obs"],
self.obs_space)
}, self.num_outputs, self.options)
# You can also add self-supervised losses easily by referencing tensors
# created during _build_layers_v2(). For example, an autoencoder-style
# loss can be added as follows:
# ae_loss = squared_diff(self.obs_in, Decoder(self.fcnet.last_layer))
# compute the IL loss
action_dist = Categorical(logits)
self.policy_loss = policy_loss
self.imitation_loss = tf.reduce_mean(
-action_dist.logp(input_ops["actions"]))
return policy_loss + 10 * self.imitation_loss
def custom_stats(self):
return {
"policy_loss": self.policy_loss,
"imitation_loss": self.imitation_loss,
}
if __name__ == "__main__":
ray.init()
args = parser.parse_args()
ModelCatalog.register_custom_model("custom_loss", CustomLossModel)
run_experiments({
"custom_loss": {
"run": "PG",
"env": "CartPole-v0",
"stop": {
"training_iteration": args.iters,
},
"config": {
"num_workers": 0,
"model": {
"custom_model": "custom_loss",
"custom_options": {
"input_files": args.input_files,
},
},
},
},
})
@@ -1,8 +1,7 @@
"""Example of handling variable length and/or parametric action spaces.
This is a toy example of the action-embedding based approach for handling large
discrete action spaces (potentially infinite in size), similar to how
OpenAI Five works:
discrete action spaces (potentially infinite in size), similar to this:
https://neuro.cs.ut.ee/the-use-of-embeddings-in-openai-five/
+68 -18
View File
@@ -9,7 +9,7 @@ import tensorflow as tf
from ray.rllib.models.misc import linear, normc_initializer
from ray.rllib.models.preprocessors import get_preprocessor
from ray.rllib.utils.annotations import PublicAPI
from ray.rllib.utils.annotations import PublicAPI, DeveloperAPI
@PublicAPI
@@ -58,6 +58,11 @@ class Model(object):
self.state_init = []
self.state_in = state_in or []
self.state_out = []
self.obs_space = obs_space
self.num_outputs = num_outputs
self.options = options
self.scope = tf.get_variable_scope()
self.session = tf.get_default_session()
if seq_lens is not None:
self.seq_lens = seq_lens
else:
@@ -69,9 +74,11 @@ class Model(object):
assert num_outputs % 2 == 0
num_outputs = num_outputs // 2
try:
restored = input_dict.copy()
restored["obs"] = restore_original_dimensions(
input_dict["obs"], obs_space)
self.outputs, self.last_layer = self._build_layers_v2(
_restore_original_dimensions(input_dict, obs_space),
num_outputs, options)
restored, num_outputs, options)
except NotImplementedError:
self.outputs, self.last_layer = self._build_layers(
input_dict["obs"], num_outputs, options)
@@ -139,17 +146,46 @@ class Model(object):
linear(self.last_layer, 1, "value", normc_initializer(1.0)), [-1])
@PublicAPI
def loss(self):
"""Builds any built-in (self-supervised) loss for the model.
def custom_loss(self, policy_loss):
"""Override to customize the loss function used to optimize this model.
For example, this can be used to incorporate auto-encoder style losses.
Note that this loss has to be included in the policy graph loss to have
an effect (done for built-in algorithms).
This can be used to incorporate self-supervised losses (by defining
a loss over existing input and output tensors of this model), and
supervised losses (by defining losses over a variable-sharing copy of
this model's layers).
You can find an runnable example in examples/custom_loss.py.
Arguments:
policy_loss (Tensor): scalar policy loss from the policy graph.
Returns:
Scalar tensor for the self-supervised loss.
Scalar tensor for the customized loss for this model.
"""
return tf.constant(0.0)
if self.loss() is not None:
raise DeprecationWarning(
"self.loss() is deprecated, use self.custom_loss() instead.")
return policy_loss
@PublicAPI
def custom_stats(self):
"""Override to return custom metrics from your model.
The stats will be reported as part of the learner stats, i.e.,
info:
learner:
model:
key1: metric1
key2: metric2
Returns:
Dict of string keys to scalar tensors.
"""
return {}
def loss(self):
"""Deprecated: use self.custom_loss()."""
return None
def _validate_output_shape(self):
"""Checks that the model has the correct number of outputs."""
@@ -165,15 +201,29 @@ class Model(object):
self._num_outputs, shape))
def _restore_original_dimensions(input_dict, obs_space, tensorlib=tf):
@DeveloperAPI
def restore_original_dimensions(obs, obs_space, tensorlib=tf):
"""Unpacks Dict and Tuple space observations into their original form.
This is needed since we flatten Dict and Tuple observations in transit.
Before sending them to the model though, we should unflatten them into
Dicts or Tuples of tensors.
Arguments:
obs: The flattened observation tensor.
obs_space: The flattened obs space. If this has the `original_space`
attribute, we will unflatten the tensor to that shape.
tensorlib: The library used to unflatten (reshape) the array/tensor.
Returns:
single tensor or dict / tuple of tensors matching the original
observation space.
"""
if hasattr(obs_space, "original_space"):
return dict(
input_dict,
obs=_unpack_obs(
input_dict["obs"],
obs_space.original_space,
tensorlib=tensorlib))
return input_dict
return _unpack_obs(obs, obs_space.original_space, tensorlib=tensorlib)
else:
return obs
def _unpack_obs(obs, space, tensorlib=tf):
+3 -3
View File
@@ -5,7 +5,7 @@ from __future__ import print_function
import torch
import torch.nn as nn
from ray.rllib.models.model import _restore_original_dimensions
from ray.rllib.models.model import restore_original_dimensions
from ray.rllib.utils.annotations import PublicAPI
@@ -31,8 +31,8 @@ class TorchModel(nn.Module):
def forward(self, input_dict, hidden_state):
"""Wraps _forward() to unpack flattened Dict and Tuple observations."""
input_dict["obs"] = input_dict["obs"].float() # TODO(ekl): avoid cast
input_dict = _restore_original_dimensions(
input_dict, self.obs_space, tensorlib=torch)
input_dict["obs"] = restore_original_dimensions(
input_dict["obs"], self.obs_space, tensorlib=torch)
outputs, features, vf, h = self._forward(input_dict, hidden_state)
return outputs, features, vf, h
+102
View File
@@ -2,8 +2,16 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import logging
import numpy as np
import tensorflow as tf
import threading
from ray.rllib.evaluation.sample_batch import MultiAgentBatch
from ray.rllib.utils.annotations import PublicAPI
logger = logging.getLogger(__name__)
@PublicAPI
class InputReader(object):
@@ -17,3 +25,97 @@ class InputReader(object):
SampleBatch or MultiAgentBatch read.
"""
raise NotImplementedError
@PublicAPI
def tf_input_ops(self, queue_size=1):
"""Returns TensorFlow queue ops for reading inputs from this reader.
The main use of these ops is for integration into custom model losses.
For example, you can use tf_input_ops() to read from files of external
experiences to add an imitation learning loss to your model.
This method creates a queue runner thread that will call next() on this
reader repeatedly to feed the TensorFlow queue.
Arguments:
queue_size (int): Max elements to allow in the TF queue.
Example:
>>> class MyModel(rllib.model.Model):
... def custom_loss(self, policy_loss):
... reader = JsonReader(...)
... input_ops = reader.tf_input_ops()
... with tf.variable_scope(
... self.scope, reuse=tf.AUTO_REUSE,
... auxiliary_name_scope=False):
... logits, _ = self._build_layers_v2(
... {"obs": input_ops["obs"]},
... self.num_outputs, self.options)
... il_loss = imitation_loss(logits, input_ops["action"])
... return policy_loss + il_loss
You can find a runnable version of this in examples/custom_loss.py.
Returns:
dict of Tensors, one for each column of the read SampleBatch.
"""
if hasattr(self, "_queue_runner"):
raise ValueError(
"A queue runner already exists for this input reader. "
"You can only call tf_input_ops() once per reader.")
logger.info("Reading initial batch of data from input reader.")
batch = self.next()
if isinstance(batch, MultiAgentBatch):
raise NotImplementedError(
"tf_input_ops() is not implemented for multi agent batches")
keys = [
k for k in sorted(list(batch.keys()))
if np.issubdtype(batch[k].dtype, np.number)
]
dtypes = [batch[k].dtype for k in keys]
shapes = {
k: (-1, ) + s[1:]
for (k, s) in [(k, batch[k].shape) for k in keys]
}
queue = tf.FIFOQueue(capacity=queue_size, dtypes=dtypes, names=keys)
tensors = queue.dequeue()
logger.info("Creating TF queue runner for {}".format(self))
self._queue_runner = _QueueRunner(self, queue, keys, dtypes)
self._queue_runner.enqueue(batch)
self._queue_runner.start()
out = {k: tf.reshape(t, shapes[k]) for k, t in tensors.items()}
return out
class _QueueRunner(threading.Thread):
"""Thread that feeds a TF queue from a InputReader."""
def __init__(self, input_reader, queue, keys, dtypes):
threading.Thread.__init__(self)
self.sess = tf.get_default_session()
self.daemon = True
self.input_reader = input_reader
self.keys = keys
self.queue = queue
self.placeholders = [tf.placeholder(dtype) for dtype in dtypes]
self.enqueue_op = queue.enqueue(dict(zip(keys, self.placeholders)))
def enqueue(self, batch):
data = {
self.placeholders[i]: batch[key]
for i, key in enumerate(self.keys)
}
self.sess.run(self.enqueue_op, feed_dict=data)
def run(self):
while True:
try:
batch = self.input_reader.next()
self.enqueue(batch)
except Exception:
logger.exception("Error reading from input")
+4 -3
View File
@@ -311,9 +311,10 @@ class TrialRunner(object):
for state, trials in states.items()
}
total_number_of_trials = sum(num_trials_per_state.values())
messages.append("Number of trials: {} ({})"
"".format(total_number_of_trials,
num_trials_per_state))
if total_number_of_trials > 0:
messages.append("Number of trials: {} ({})"
"".format(total_number_of_trials,
num_trials_per_state))
for state, trials in sorted(states.items()):
limit = limit_per_state[state]