From 798944fbfaf9c13c409765d81b3356f861cabf9b Mon Sep 17 00:00:00 2001 From: Risto Vuorio Date: Fri, 29 Mar 2019 16:31:59 -0400 Subject: [PATCH] =?UTF-8?q?Fixes=20Inconsistent=20weight=20assignment=20op?= =?UTF-8?q?erations=20in=20DQNPolicyGraph=20(#4=E2=80=A6=20(#4504)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Fixes Inconsistent weight assignment operations in DQNPolicyGraph (#4502) * Update dqn_policy_graph.py --- python/ray/rllib/agents/dqn/dqn_policy_graph.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/python/ray/rllib/agents/dqn/dqn_policy_graph.py b/python/ray/rllib/agents/dqn/dqn_policy_graph.py index 0ba87314d..56af1e04b 100644 --- a/python/ray/rllib/agents/dqn/dqn_policy_graph.py +++ b/python/ray/rllib/agents/dqn/dqn_policy_graph.py @@ -431,9 +431,9 @@ class DQNPolicyGraph(DQNPostprocessing, TFPolicyGraph): # update_target_fn will be called periodically to copy Q network to # target Q network update_target_expr = [] - for var, var_target in zip( - sorted(self.q_func_vars, key=lambda v: v.name), - sorted(self.target_q_func_vars, key=lambda v: v.name)): + assert len(self.q_func_vars) == len(self.target_q_func_vars), \ + (self.q_func_vars, self.target_q_func_vars) + for var, var_target in zip(self.q_func_vars, self.target_q_func_vars): update_target_expr.append(var_target.assign(var)) self.update_target_expr = tf.group(*update_target_expr)