mirror of
https://github.com/wassname/ray.git
synced 2026-07-02 09:21:06 +08:00
[rllib] Model self loss isn't included in all algorithms (#3679)
This commit is contained in:
@@ -142,7 +142,7 @@ class A3CPolicyGraph(LearningRateSchedule, TFPolicyGraph):
|
||||
|
||||
@override(TFPolicyGraph)
|
||||
def gradients(self, optimizer):
|
||||
grads = tf.gradients(self.loss.total_loss, self.var_list)
|
||||
grads = tf.gradients(self._loss, self.var_list)
|
||||
self.grads, _ = tf.clip_by_global_norm(grads, self.config["grad_clip"])
|
||||
clipped_grads = list(zip(self.grads, self.var_list))
|
||||
return clipped_grads
|
||||
|
||||
@@ -40,6 +40,7 @@ class PNetwork(object):
|
||||
# shape of action_scores is [batch_size, dim_actions]
|
||||
self.action_scores = layers.fully_connected(
|
||||
action_out, num_outputs=dim_actions, activation_fn=tf.nn.sigmoid)
|
||||
self.model = model
|
||||
|
||||
|
||||
class ActionNetwork(object):
|
||||
@@ -177,8 +178,6 @@ class ActorCriticLoss(object):
|
||||
self.actor_loss = (-1.0 * actor_loss_coeff * policy_delay_mask *
|
||||
tf.reduce_mean(q_tp0))
|
||||
|
||||
self.total_loss = self.actor_loss + self.critic_loss
|
||||
|
||||
|
||||
class DDPGPolicyGraph(TFPolicyGraph):
|
||||
def __init__(self, observation_space, action_space, config):
|
||||
@@ -207,8 +206,8 @@ class DDPGPolicyGraph(TFPolicyGraph):
|
||||
|
||||
# Actor: P (policy) network
|
||||
with tf.variable_scope(P_SCOPE) as scope:
|
||||
p_values = self._build_p_network(self.cur_observations,
|
||||
observation_space)
|
||||
p_values, self.p_model = self._build_p_network(
|
||||
self.cur_observations, observation_space)
|
||||
self.p_func_vars = _scope_vars(scope.name)
|
||||
|
||||
# Action outputs
|
||||
@@ -241,14 +240,14 @@ class DDPGPolicyGraph(TFPolicyGraph):
|
||||
# p network evaluation
|
||||
with tf.variable_scope(P_SCOPE, reuse=True) as scope:
|
||||
prev_update_ops = set(tf.get_collection(tf.GraphKeys.UPDATE_OPS))
|
||||
self.p_t = self._build_p_network(self.obs_t, observation_space)
|
||||
self.p_t, _ = self._build_p_network(self.obs_t, observation_space)
|
||||
p_batchnorm_update_ops = list(
|
||||
set(tf.get_collection(tf.GraphKeys.UPDATE_OPS)) -
|
||||
prev_update_ops)
|
||||
|
||||
# target p network evaluation
|
||||
with tf.variable_scope(P_TARGET_SCOPE) as scope:
|
||||
p_tp1 = self._build_p_network(self.obs_tp1, observation_space)
|
||||
p_tp1, _ = self._build_p_network(self.obs_tp1, observation_space)
|
||||
target_p_func_vars = _scope_vars(scope.name)
|
||||
|
||||
# Action outputs
|
||||
@@ -267,15 +266,15 @@ class DDPGPolicyGraph(TFPolicyGraph):
|
||||
# q network evaluation
|
||||
prev_update_ops = set(tf.get_collection(tf.GraphKeys.UPDATE_OPS))
|
||||
with tf.variable_scope(Q_SCOPE) as scope:
|
||||
q_t, model = self._build_q_network(self.obs_t, observation_space,
|
||||
self.act_t)
|
||||
q_t, self.q_model = self._build_q_network(
|
||||
self.obs_t, observation_space, self.act_t)
|
||||
self.q_func_vars = _scope_vars(scope.name)
|
||||
with tf.variable_scope(Q_SCOPE, reuse=True):
|
||||
q_tp0, _ = self._build_q_network(self.obs_t, observation_space,
|
||||
output_actions)
|
||||
if self.config["twin_q"]:
|
||||
with tf.variable_scope(TWIN_Q_SCOPE) as scope:
|
||||
twin_q_t, twin_model = self._build_q_network(
|
||||
twin_q_t, self.twin_q_model = self._build_q_network(
|
||||
self.obs_t, observation_space, self.act_t)
|
||||
self.twin_q_func_vars = _scope_vars(scope.name)
|
||||
q_batchnorm_update_ops = list(
|
||||
@@ -313,6 +312,12 @@ class DDPGPolicyGraph(TFPolicyGraph):
|
||||
self.loss.critic_loss += (
|
||||
config["l2_reg"] * 0.5 * tf.nn.l2_loss(var))
|
||||
|
||||
# Model self-supervised losses
|
||||
self.loss.actor_loss += self.p_model.loss()
|
||||
self.loss.critic_loss += self.q_model.loss()
|
||||
if self.config["twin_q"]:
|
||||
self.loss.critic_loss += self.twin_q_model.loss()
|
||||
|
||||
# update_target_fn will be called periodically to copy Q network to
|
||||
# target Q network
|
||||
self.tau_value = config.get("tau")
|
||||
@@ -355,7 +360,7 @@ class DDPGPolicyGraph(TFPolicyGraph):
|
||||
self.sess,
|
||||
obs_input=self.cur_observations,
|
||||
action_sampler=self.output_actions,
|
||||
loss=model.loss() + self.loss.total_loss,
|
||||
loss=self.loss.actor_loss + self.loss.critic_loss,
|
||||
loss_inputs=self.loss_inputs,
|
||||
update_ops=q_batchnorm_update_ops + p_batchnorm_update_ops)
|
||||
self.sess.run(tf.global_variables_initializer())
|
||||
@@ -448,13 +453,14 @@ class DDPGPolicyGraph(TFPolicyGraph):
|
||||
return q_net.value, q_net.model
|
||||
|
||||
def _build_p_network(self, obs, obs_space):
|
||||
return PNetwork(
|
||||
policy_net = PNetwork(
|
||||
ModelCatalog.get_model({
|
||||
"obs": obs,
|
||||
"is_training": self._get_is_training_placeholder(),
|
||||
}, obs_space, 1, self.config["model"]), self.dim_actions,
|
||||
self.config["actor_hiddens"],
|
||||
self.config["actor_hidden_activation"]).action_scores
|
||||
self.config["actor_hidden_activation"])
|
||||
return policy_net.action_scores, policy_net.model
|
||||
|
||||
def _build_action_network(self, p_values, stochastic, eps,
|
||||
is_target=False):
|
||||
|
||||
@@ -403,7 +403,7 @@ class DQNPolicyGraph(TFPolicyGraph):
|
||||
if self.config["grad_norm_clipping"] is not None:
|
||||
grads_and_vars = _minimize_and_clip(
|
||||
optimizer,
|
||||
self.loss.loss,
|
||||
self._loss,
|
||||
var_list=self.q_func_vars,
|
||||
clip_val=self.config["grad_norm_clipping"])
|
||||
else:
|
||||
|
||||
@@ -263,7 +263,7 @@ class VTracePolicyGraph(LearningRateSchedule, TFPolicyGraph):
|
||||
|
||||
@override(TFPolicyGraph)
|
||||
def gradients(self, optimizer):
|
||||
grads = tf.gradients(self.loss.total_loss, self.var_list)
|
||||
grads = tf.gradients(self._loss, self.var_list)
|
||||
self.grads, _ = tf.clip_by_global_norm(grads, self.config["grad_clip"])
|
||||
clipped_grads = list(zip(self.grads, self.var_list))
|
||||
return clipped_grads
|
||||
|
||||
@@ -4,7 +4,7 @@ pendulum-ddpg:
|
||||
run: DDPG
|
||||
stop:
|
||||
episode_reward_mean: -160
|
||||
time_total_s: 600 # 10 minutes
|
||||
timesteps_total: 100000
|
||||
config:
|
||||
# === Model ===
|
||||
actor_hiddens: [64, 64]
|
||||
|
||||
Reference in New Issue
Block a user