mirror of
https://github.com/wassname/ray.git
synced 2026-07-01 18:04:09 +08:00
[rllib] TF model custom_loss() should actually allow access to full rollout data (#4220)
This commit is contained in:
@@ -333,13 +333,6 @@ class DDPGPolicyGraph(TFPolicyGraph):
|
||||
self.loss.critic_loss += (
|
||||
config["l2_reg"] * 0.5 * tf.nn.l2_loss(var))
|
||||
|
||||
# Model self-supervised losses
|
||||
self.loss.actor_loss = self.p_model.custom_loss(self.loss.actor_loss)
|
||||
self.loss.critic_loss = self.q_model.custom_loss(self.loss.critic_loss)
|
||||
if self.config["twin_q"]:
|
||||
self.loss.critic_loss = self.twin_q_model.custom_loss(
|
||||
self.loss.critic_loss)
|
||||
|
||||
# update_target_fn will be called periodically to copy Q network to
|
||||
# target Q network
|
||||
self.tau_value = config.get("tau")
|
||||
@@ -375,6 +368,17 @@ class DDPGPolicyGraph(TFPolicyGraph):
|
||||
("dones", self.done_mask),
|
||||
("weights", self.importance_weights),
|
||||
]
|
||||
input_dict = dict(self.loss_inputs)
|
||||
|
||||
# Model self-supervised losses
|
||||
self.loss.actor_loss = self.p_model.custom_loss(
|
||||
self.loss.actor_loss, input_dict)
|
||||
self.loss.critic_loss = self.q_model.custom_loss(
|
||||
self.loss.critic_loss, input_dict)
|
||||
if self.config["twin_q"]:
|
||||
self.loss.critic_loss = self.twin_q_model.custom_loss(
|
||||
self.loss.critic_loss, input_dict)
|
||||
|
||||
TFPolicyGraph.__init__(
|
||||
self,
|
||||
observation_space,
|
||||
|
||||
@@ -108,12 +108,6 @@ class TFPolicyGraph(PolicyGraph):
|
||||
self._prev_action_input = prev_action_input
|
||||
self._prev_reward_input = prev_reward_input
|
||||
self._sampler = action_sampler
|
||||
if self.model:
|
||||
self._loss = self.model.custom_loss(loss)
|
||||
self._stats_fetches = {"model": self.model.custom_stats()}
|
||||
else:
|
||||
self._loss = loss
|
||||
self._stats_fetches = {}
|
||||
self._loss_inputs = loss_inputs
|
||||
self._loss_input_dict = dict(self._loss_inputs)
|
||||
self._is_training = self._get_is_training_placeholder()
|
||||
@@ -126,6 +120,13 @@ class TFPolicyGraph(PolicyGraph):
|
||||
self._max_seq_len = max_seq_len
|
||||
self._batch_divisibility_req = batch_divisibility_req
|
||||
|
||||
if self.model:
|
||||
self._loss = self.model.custom_loss(loss, self._loss_input_dict)
|
||||
self._stats_fetches = {"model": self.model.custom_stats()}
|
||||
else:
|
||||
self._loss = loss
|
||||
self._stats_fetches = {}
|
||||
|
||||
self._optimizer = self.optimizer()
|
||||
self._grads_and_vars = [(g, v)
|
||||
for (g, v) in self.gradients(self._optimizer)
|
||||
|
||||
@@ -43,7 +43,7 @@ class CustomLossModel(Model):
|
||||
num_outputs, options)
|
||||
return self.fcnet.outputs, self.fcnet.last_layer
|
||||
|
||||
def custom_loss(self, policy_loss):
|
||||
def custom_loss(self, policy_loss, loss_inputs):
|
||||
# create a new input reader per worker
|
||||
reader = JsonReader(self.options["custom_options"]["input_files"])
|
||||
input_ops = reader.tf_input_ops()
|
||||
@@ -59,7 +59,9 @@ class CustomLossModel(Model):
|
||||
# You can also add self-supervised losses easily by referencing tensors
|
||||
# created during _build_layers_v2(). For example, an autoencoder-style
|
||||
# loss can be added as follows:
|
||||
# ae_loss = squared_diff(self.obs_in, Decoder(self.fcnet.last_layer))
|
||||
# ae_loss = squared_diff(
|
||||
# loss_inputs["obs"], Decoder(self.fcnet.last_layer))
|
||||
print("FYI: You can also use these tensors: {}, ".format(loss_inputs))
|
||||
|
||||
# compute the IL loss
|
||||
action_dist = Categorical(logits)
|
||||
|
||||
@@ -146,7 +146,7 @@ class Model(object):
|
||||
linear(self.last_layer, 1, "value", normc_initializer(1.0)), [-1])
|
||||
|
||||
@PublicAPI
|
||||
def custom_loss(self, policy_loss):
|
||||
def custom_loss(self, policy_loss, loss_inputs):
|
||||
"""Override to customize the loss function used to optimize this model.
|
||||
|
||||
This can be used to incorporate self-supervised losses (by defining
|
||||
@@ -158,6 +158,7 @@ class Model(object):
|
||||
|
||||
Arguments:
|
||||
policy_loss (Tensor): scalar policy loss from the policy graph.
|
||||
loss_inputs (dict): map of input placeholders for rollout data.
|
||||
|
||||
Returns:
|
||||
Scalar tensor for the customized loss for this model.
|
||||
|
||||
@@ -42,7 +42,7 @@ class InputReader(object):
|
||||
|
||||
Example:
|
||||
>>> class MyModel(rllib.model.Model):
|
||||
... def custom_loss(self, policy_loss):
|
||||
... def custom_loss(self, policy_loss, loss_inputs):
|
||||
... reader = JsonReader(...)
|
||||
... input_ops = reader.tf_input_ops()
|
||||
... with tf.variable_scope(
|
||||
|
||||
Reference in New Issue
Block a user