mirror of
https://github.com/wassname/ray.git
synced 2026-07-02 03:50:57 +08:00
[rllib] Add self-supervised loss to model (#3291)
# What do these changes do? Allow self-supervised losses to be easily defined in custom models. Add this to the reference policy graphs.
This commit is contained in:
@@ -94,7 +94,7 @@ class A3CPolicyGraph(LearningRateSchedule, TFPolicyGraph):
|
||||
self.sess,
|
||||
obs_input=self.observations,
|
||||
action_sampler=action_dist.sample(),
|
||||
loss=self.loss.total_loss,
|
||||
loss=self.model.loss() + self.loss.total_loss,
|
||||
loss_inputs=loss_in,
|
||||
state_inputs=self.model.state_in,
|
||||
state_outputs=self.model.state_out,
|
||||
|
||||
@@ -86,6 +86,7 @@ class QNetwork(object):
|
||||
q_out, num_outputs=hidden, activation_fn=activation)
|
||||
self.value = layers.fully_connected(
|
||||
q_out, num_outputs=1, activation_fn=None)
|
||||
self.model = model
|
||||
|
||||
|
||||
class ActorCriticLoss(object):
|
||||
@@ -198,17 +199,17 @@ class DDPGPolicyGraph(TFPolicyGraph):
|
||||
|
||||
# q network evaluation
|
||||
with tf.variable_scope(Q_SCOPE) as scope:
|
||||
q_t = self._build_q_network(self.obs_t, observation_space,
|
||||
self.act_t)
|
||||
q_t, model = self._build_q_network(self.obs_t, observation_space,
|
||||
self.act_t)
|
||||
self.q_func_vars = _scope_vars(scope.name)
|
||||
with tf.variable_scope(Q_SCOPE, reuse=True):
|
||||
q_tp0 = self._build_q_network(self.obs_t, observation_space,
|
||||
output_actions)
|
||||
q_tp0, _ = self._build_q_network(self.obs_t, observation_space,
|
||||
output_actions)
|
||||
|
||||
# target q network evalution
|
||||
with tf.variable_scope(Q_TARGET_SCOPE) as scope:
|
||||
q_tp1 = self._build_q_network(self.obs_tp1, observation_space,
|
||||
output_actions_estimated)
|
||||
q_tp1, _ = self._build_q_network(self.obs_tp1, observation_space,
|
||||
output_actions_estimated)
|
||||
target_q_func_vars = _scope_vars(scope.name)
|
||||
|
||||
self.loss = self._build_actor_critic_loss(q_t, q_tp1, q_tp0)
|
||||
@@ -258,7 +259,7 @@ class DDPGPolicyGraph(TFPolicyGraph):
|
||||
self.sess,
|
||||
obs_input=self.cur_observations,
|
||||
action_sampler=self.output_actions,
|
||||
loss=self.loss.total_loss,
|
||||
loss=model.loss() + self.loss.total_loss,
|
||||
loss_inputs=self.loss_inputs)
|
||||
self.sess.run(tf.global_variables_initializer())
|
||||
|
||||
@@ -271,12 +272,13 @@ class DDPGPolicyGraph(TFPolicyGraph):
|
||||
self.update_target(tau=1.0)
|
||||
|
||||
def _build_q_network(self, obs, obs_space, actions):
|
||||
return QNetwork(
|
||||
q_net = QNetwork(
|
||||
ModelCatalog.get_model({
|
||||
"obs": obs
|
||||
}, obs_space, 1, self.config["model"]), actions,
|
||||
self.config["critic_hiddens"],
|
||||
self.config["critic_hidden_activation"]).value
|
||||
self.config["critic_hidden_activation"])
|
||||
return q_net.value, q_net.model
|
||||
|
||||
def _build_p_network(self, obs, obs_space):
|
||||
return PNetwork(
|
||||
|
||||
@@ -28,6 +28,7 @@ class QNetwork(object):
|
||||
v_min=-10.0,
|
||||
v_max=10.0,
|
||||
sigma0=0.5):
|
||||
self.model = model
|
||||
with tf.variable_scope("action_value"):
|
||||
action_out = model.last_layer
|
||||
for i in range(len(hiddens)):
|
||||
@@ -274,7 +275,7 @@ class DQNPolicyGraph(TFPolicyGraph):
|
||||
|
||||
# Action Q network
|
||||
with tf.variable_scope(Q_SCOPE) as scope:
|
||||
q_values, q_logits, q_dist = self._build_q_network(
|
||||
q_values, q_logits, q_dist, _ = self._build_q_network(
|
||||
self.cur_observations, observation_space)
|
||||
self.q_func_vars = _scope_vars(scope.name)
|
||||
|
||||
@@ -294,12 +295,12 @@ class DQNPolicyGraph(TFPolicyGraph):
|
||||
|
||||
# q network evaluation
|
||||
with tf.variable_scope(Q_SCOPE, reuse=True):
|
||||
q_t, q_logits_t, q_dist_t = self._build_q_network(
|
||||
q_t, q_logits_t, q_dist_t, model = self._build_q_network(
|
||||
self.obs_t, observation_space)
|
||||
|
||||
# target q network evalution
|
||||
with tf.variable_scope(Q_TARGET_SCOPE) as scope:
|
||||
q_tp1, q_logits_tp1, q_dist_tp1 = self._build_q_network(
|
||||
q_tp1, q_logits_tp1, q_dist_tp1, _ = self._build_q_network(
|
||||
self.obs_tp1, observation_space)
|
||||
self.target_q_func_vars = _scope_vars(scope.name)
|
||||
|
||||
@@ -313,7 +314,7 @@ class DQNPolicyGraph(TFPolicyGraph):
|
||||
if config["double_q"]:
|
||||
with tf.variable_scope(Q_SCOPE, reuse=True):
|
||||
q_tp1_using_online_net, q_logits_tp1_using_online_net, \
|
||||
q_dist_tp1_using_online_net = self._build_q_network(
|
||||
q_dist_tp1_using_online_net, _ = self._build_q_network(
|
||||
self.obs_tp1, observation_space)
|
||||
q_tp1_best_using_online_net = tf.argmax(q_tp1_using_online_net, 1)
|
||||
q_tp1_best_one_hot_selection = tf.one_hot(
|
||||
@@ -359,7 +360,7 @@ class DQNPolicyGraph(TFPolicyGraph):
|
||||
self.sess,
|
||||
obs_input=self.cur_observations,
|
||||
action_sampler=self.output_actions,
|
||||
loss=self.loss.loss,
|
||||
loss=model.loss() + self.loss.loss,
|
||||
loss_inputs=self.loss_inputs)
|
||||
self.sess.run(tf.global_variables_initializer())
|
||||
|
||||
@@ -371,7 +372,7 @@ class DQNPolicyGraph(TFPolicyGraph):
|
||||
self.config["dueling"], self.config["hiddens"],
|
||||
self.config["noisy"], self.config["num_atoms"],
|
||||
self.config["v_min"], self.config["v_max"], self.config["sigma0"])
|
||||
return qnet.value, qnet.logits, qnet.dist
|
||||
return qnet.value, qnet.logits, qnet.dist, qnet.model
|
||||
|
||||
def _build_q_value_policy(self, q_values):
|
||||
return QValuePolicy(q_values, self.cur_observations, self.num_actions,
|
||||
|
||||
@@ -203,7 +203,7 @@ class VTracePolicyGraph(LearningRateSchedule, TFPolicyGraph):
|
||||
self.sess,
|
||||
obs_input=observations,
|
||||
action_sampler=action_dist.sample(),
|
||||
loss=self.loss.total_loss,
|
||||
loss=self.model.loss() + self.loss.total_loss,
|
||||
loss_inputs=loss_in,
|
||||
state_inputs=self.model.state_in,
|
||||
state_outputs=self.model.state_out,
|
||||
|
||||
@@ -64,7 +64,7 @@ class PGPolicyGraph(TFPolicyGraph):
|
||||
sess,
|
||||
obs_input=obs,
|
||||
action_sampler=action_dist.sample(),
|
||||
loss=loss,
|
||||
loss=self.model.loss() + loss,
|
||||
loss_inputs=loss_in,
|
||||
state_inputs=self.model.state_in,
|
||||
state_outputs=self.model.state_out,
|
||||
|
||||
@@ -230,7 +230,7 @@ class PPOPolicyGraph(LearningRateSchedule, TFPolicyGraph):
|
||||
self.sess,
|
||||
obs_input=obs_ph,
|
||||
action_sampler=self.sampler,
|
||||
loss=self.loss_obj.loss,
|
||||
loss=self.model.loss() + self.loss_obj.loss,
|
||||
loss_inputs=self.loss_in,
|
||||
state_inputs=self.model.state_in,
|
||||
state_outputs=self.model.state_out,
|
||||
|
||||
@@ -144,6 +144,18 @@ class Model(object):
|
||||
return tf.reshape(
|
||||
linear(self.last_layer, 1, "value", normc_initializer(1.0)), [-1])
|
||||
|
||||
def loss(self):
|
||||
"""Builds any built-in (self-supervised) loss for the model.
|
||||
|
||||
For example, this can be used to incorporate auto-encoder style losses.
|
||||
Note that this loss has to be included in the policy graph loss to have
|
||||
an effect (done for built-in algorithms).
|
||||
|
||||
Returns:
|
||||
Scalar tensor for the self-supervised loss.
|
||||
"""
|
||||
return tf.constant(0.0)
|
||||
|
||||
|
||||
def _restore_original_dimensions(input_dict, obs_space):
|
||||
if hasattr(obs_space, "original_space"):
|
||||
|
||||
Reference in New Issue
Block a user