[rllib] TF model custom_loss() should actually allow access to full rollout data (#4220)

This commit is contained in:
Eric Liang
2019-03-02 22:57:51 -08:00
committed by GitHub
parent ff6dd8459a
commit ba03048254
6 changed files with 27 additions and 18 deletions
@@ -333,13 +333,6 @@ class DDPGPolicyGraph(TFPolicyGraph):
self.loss.critic_loss += (
config["l2_reg"] * 0.5 * tf.nn.l2_loss(var))
# Model self-supervised losses
self.loss.actor_loss = self.p_model.custom_loss(self.loss.actor_loss)
self.loss.critic_loss = self.q_model.custom_loss(self.loss.critic_loss)
if self.config["twin_q"]:
self.loss.critic_loss = self.twin_q_model.custom_loss(
self.loss.critic_loss)
# update_target_fn will be called periodically to copy Q network to
# target Q network
self.tau_value = config.get("tau")
@@ -375,6 +368,17 @@ class DDPGPolicyGraph(TFPolicyGraph):
("dones", self.done_mask),
("weights", self.importance_weights),
]
input_dict = dict(self.loss_inputs)
# Model self-supervised losses
self.loss.actor_loss = self.p_model.custom_loss(
self.loss.actor_loss, input_dict)
self.loss.critic_loss = self.q_model.custom_loss(
self.loss.critic_loss, input_dict)
if self.config["twin_q"]:
self.loss.critic_loss = self.twin_q_model.custom_loss(
self.loss.critic_loss, input_dict)
TFPolicyGraph.__init__(
self,
observation_space,
@@ -108,12 +108,6 @@ class TFPolicyGraph(PolicyGraph):
self._prev_action_input = prev_action_input
self._prev_reward_input = prev_reward_input
self._sampler = action_sampler
if self.model:
self._loss = self.model.custom_loss(loss)
self._stats_fetches = {"model": self.model.custom_stats()}
else:
self._loss = loss
self._stats_fetches = {}
self._loss_inputs = loss_inputs
self._loss_input_dict = dict(self._loss_inputs)
self._is_training = self._get_is_training_placeholder()
@@ -126,6 +120,13 @@ class TFPolicyGraph(PolicyGraph):
self._max_seq_len = max_seq_len
self._batch_divisibility_req = batch_divisibility_req
if self.model:
self._loss = self.model.custom_loss(loss, self._loss_input_dict)
self._stats_fetches = {"model": self.model.custom_stats()}
else:
self._loss = loss
self._stats_fetches = {}
self._optimizer = self.optimizer()
self._grads_and_vars = [(g, v)
for (g, v) in self.gradients(self._optimizer)
+4 -2
View File
@@ -43,7 +43,7 @@ class CustomLossModel(Model):
num_outputs, options)
return self.fcnet.outputs, self.fcnet.last_layer
def custom_loss(self, policy_loss):
def custom_loss(self, policy_loss, loss_inputs):
# create a new input reader per worker
reader = JsonReader(self.options["custom_options"]["input_files"])
input_ops = reader.tf_input_ops()
@@ -59,7 +59,9 @@ class CustomLossModel(Model):
# You can also add self-supervised losses easily by referencing tensors
# created during _build_layers_v2(). For example, an autoencoder-style
# loss can be added as follows:
# ae_loss = squared_diff(self.obs_in, Decoder(self.fcnet.last_layer))
# ae_loss = squared_diff(
# loss_inputs["obs"], Decoder(self.fcnet.last_layer))
print("FYI: You can also use these tensors: {}, ".format(loss_inputs))
# compute the IL loss
action_dist = Categorical(logits)
+2 -1
View File
@@ -146,7 +146,7 @@ class Model(object):
linear(self.last_layer, 1, "value", normc_initializer(1.0)), [-1])
@PublicAPI
def custom_loss(self, policy_loss):
def custom_loss(self, policy_loss, loss_inputs):
"""Override to customize the loss function used to optimize this model.
This can be used to incorporate self-supervised losses (by defining
@@ -158,6 +158,7 @@ class Model(object):
Arguments:
policy_loss (Tensor): scalar policy loss from the policy graph.
loss_inputs (dict): map of input placeholders for rollout data.
Returns:
Scalar tensor for the customized loss for this model.
+1 -1
View File
@@ -42,7 +42,7 @@ class InputReader(object):
Example:
>>> class MyModel(rllib.model.Model):
... def custom_loss(self, policy_loss):
... def custom_loss(self, policy_loss, loss_inputs):
... reader = JsonReader(...)
... input_ops = reader.tf_input_ops()
... with tf.variable_scope(