[rllib] Add self-supervised loss to model (#3291)

# What do these changes do?

Allow self-supervised losses to be easily defined in custom models. Add this to the reference policy graphs.
This commit is contained in:
Eric Liang
2018-11-12 18:55:24 -08:00
committed by Richard Liaw
parent ce6e01b988
commit d90f365394
8 changed files with 47 additions and 20 deletions
@@ -94,7 +94,7 @@ class A3CPolicyGraph(LearningRateSchedule, TFPolicyGraph):
self.sess,
obs_input=self.observations,
action_sampler=action_dist.sample(),
loss=self.loss.total_loss,
loss=self.model.loss() + self.loss.total_loss,
loss_inputs=loss_in,
state_inputs=self.model.state_in,
state_outputs=self.model.state_out,
@@ -86,6 +86,7 @@ class QNetwork(object):
q_out, num_outputs=hidden, activation_fn=activation)
self.value = layers.fully_connected(
q_out, num_outputs=1, activation_fn=None)
self.model = model
class ActorCriticLoss(object):
@@ -198,17 +199,17 @@ class DDPGPolicyGraph(TFPolicyGraph):
# q network evaluation
with tf.variable_scope(Q_SCOPE) as scope:
q_t = self._build_q_network(self.obs_t, observation_space,
self.act_t)
q_t, model = self._build_q_network(self.obs_t, observation_space,
self.act_t)
self.q_func_vars = _scope_vars(scope.name)
with tf.variable_scope(Q_SCOPE, reuse=True):
q_tp0 = self._build_q_network(self.obs_t, observation_space,
output_actions)
q_tp0, _ = self._build_q_network(self.obs_t, observation_space,
output_actions)
# target q network evalution
with tf.variable_scope(Q_TARGET_SCOPE) as scope:
q_tp1 = self._build_q_network(self.obs_tp1, observation_space,
output_actions_estimated)
q_tp1, _ = self._build_q_network(self.obs_tp1, observation_space,
output_actions_estimated)
target_q_func_vars = _scope_vars(scope.name)
self.loss = self._build_actor_critic_loss(q_t, q_tp1, q_tp0)
@@ -258,7 +259,7 @@ class DDPGPolicyGraph(TFPolicyGraph):
self.sess,
obs_input=self.cur_observations,
action_sampler=self.output_actions,
loss=self.loss.total_loss,
loss=model.loss() + self.loss.total_loss,
loss_inputs=self.loss_inputs)
self.sess.run(tf.global_variables_initializer())
@@ -271,12 +272,13 @@ class DDPGPolicyGraph(TFPolicyGraph):
self.update_target(tau=1.0)
def _build_q_network(self, obs, obs_space, actions):
return QNetwork(
q_net = QNetwork(
ModelCatalog.get_model({
"obs": obs
}, obs_space, 1, self.config["model"]), actions,
self.config["critic_hiddens"],
self.config["critic_hidden_activation"]).value
self.config["critic_hidden_activation"])
return q_net.value, q_net.model
def _build_p_network(self, obs, obs_space):
return PNetwork(
@@ -28,6 +28,7 @@ class QNetwork(object):
v_min=-10.0,
v_max=10.0,
sigma0=0.5):
self.model = model
with tf.variable_scope("action_value"):
action_out = model.last_layer
for i in range(len(hiddens)):
@@ -274,7 +275,7 @@ class DQNPolicyGraph(TFPolicyGraph):
# Action Q network
with tf.variable_scope(Q_SCOPE) as scope:
q_values, q_logits, q_dist = self._build_q_network(
q_values, q_logits, q_dist, _ = self._build_q_network(
self.cur_observations, observation_space)
self.q_func_vars = _scope_vars(scope.name)
@@ -294,12 +295,12 @@ class DQNPolicyGraph(TFPolicyGraph):
# q network evaluation
with tf.variable_scope(Q_SCOPE, reuse=True):
q_t, q_logits_t, q_dist_t = self._build_q_network(
q_t, q_logits_t, q_dist_t, model = self._build_q_network(
self.obs_t, observation_space)
# target q network evalution
with tf.variable_scope(Q_TARGET_SCOPE) as scope:
q_tp1, q_logits_tp1, q_dist_tp1 = self._build_q_network(
q_tp1, q_logits_tp1, q_dist_tp1, _ = self._build_q_network(
self.obs_tp1, observation_space)
self.target_q_func_vars = _scope_vars(scope.name)
@@ -313,7 +314,7 @@ class DQNPolicyGraph(TFPolicyGraph):
if config["double_q"]:
with tf.variable_scope(Q_SCOPE, reuse=True):
q_tp1_using_online_net, q_logits_tp1_using_online_net, \
q_dist_tp1_using_online_net = self._build_q_network(
q_dist_tp1_using_online_net, _ = self._build_q_network(
self.obs_tp1, observation_space)
q_tp1_best_using_online_net = tf.argmax(q_tp1_using_online_net, 1)
q_tp1_best_one_hot_selection = tf.one_hot(
@@ -359,7 +360,7 @@ class DQNPolicyGraph(TFPolicyGraph):
self.sess,
obs_input=self.cur_observations,
action_sampler=self.output_actions,
loss=self.loss.loss,
loss=model.loss() + self.loss.loss,
loss_inputs=self.loss_inputs)
self.sess.run(tf.global_variables_initializer())
@@ -371,7 +372,7 @@ class DQNPolicyGraph(TFPolicyGraph):
self.config["dueling"], self.config["hiddens"],
self.config["noisy"], self.config["num_atoms"],
self.config["v_min"], self.config["v_max"], self.config["sigma0"])
return qnet.value, qnet.logits, qnet.dist
return qnet.value, qnet.logits, qnet.dist, qnet.model
def _build_q_value_policy(self, q_values):
return QValuePolicy(q_values, self.cur_observations, self.num_actions,
@@ -203,7 +203,7 @@ class VTracePolicyGraph(LearningRateSchedule, TFPolicyGraph):
self.sess,
obs_input=observations,
action_sampler=action_dist.sample(),
loss=self.loss.total_loss,
loss=self.model.loss() + self.loss.total_loss,
loss_inputs=loss_in,
state_inputs=self.model.state_in,
state_outputs=self.model.state_out,
@@ -64,7 +64,7 @@ class PGPolicyGraph(TFPolicyGraph):
sess,
obs_input=obs,
action_sampler=action_dist.sample(),
loss=loss,
loss=self.model.loss() + loss,
loss_inputs=loss_in,
state_inputs=self.model.state_in,
state_outputs=self.model.state_out,
@@ -230,7 +230,7 @@ class PPOPolicyGraph(LearningRateSchedule, TFPolicyGraph):
self.sess,
obs_input=obs_ph,
action_sampler=self.sampler,
loss=self.loss_obj.loss,
loss=self.model.loss() + self.loss_obj.loss,
loss_inputs=self.loss_in,
state_inputs=self.model.state_in,
state_outputs=self.model.state_out,
+12
View File
@@ -144,6 +144,18 @@ class Model(object):
return tf.reshape(
linear(self.last_layer, 1, "value", normc_initializer(1.0)), [-1])
def loss(self):
"""Builds any built-in (self-supervised) loss for the model.
For example, this can be used to incorporate auto-encoder style losses.
Note that this loss has to be included in the policy graph loss to have
an effect (done for built-in algorithms).
Returns:
Scalar tensor for the self-supervised loss.
"""
return tf.constant(0.0)
def _restore_original_dimensions(input_dict, obs_space):
if hasattr(obs_space, "original_space"):