From a54386e499f24909f85e3338c4703dafda7b9ff4 Mon Sep 17 00:00:00 2001 From: Stefan Pantic Date: Fri, 22 Feb 2019 06:07:48 +0100 Subject: [PATCH] Added custom LSTM detection (#4087) * Added autodetection of custom LSTM usage * Reverted line separators * Added check for LSTM * Update vtrace_policy_graph.py * Update appo_policy_graph.py --- python/ray/rllib/agents/impala/vtrace_policy_graph.py | 2 +- python/ray/rllib/agents/ppo/appo_policy_graph.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/python/ray/rllib/agents/impala/vtrace_policy_graph.py b/python/ray/rllib/agents/impala/vtrace_policy_graph.py index af9f0397f..d8ce1be6e 100644 --- a/python/ray/rllib/agents/impala/vtrace_policy_graph.py +++ b/python/ray/rllib/agents/impala/vtrace_policy_graph.py @@ -148,7 +148,7 @@ class VTracePolicyGraph(LearningRateSchedule, TFPolicyGraph): tf.get_variable_scope().name) def to_batches(tensor): - if self.config["model"]["use_lstm"]: + if self.model.state_init: B = tf.shape(self.model.seq_lens)[0] T = tf.shape(tensor)[0] // B else: diff --git a/python/ray/rllib/agents/ppo/appo_policy_graph.py b/python/ray/rllib/agents/ppo/appo_policy_graph.py index ace8f39be..e0716f274 100644 --- a/python/ray/rllib/agents/ppo/appo_policy_graph.py +++ b/python/ray/rllib/agents/ppo/appo_policy_graph.py @@ -230,7 +230,7 @@ class AsyncPPOPolicyGraph(LearningRateSchedule, TFPolicyGraph): tf.get_variable_scope().name) def to_batches(tensor): - if self.config["model"]["use_lstm"]: + if self.model.state_init: B = tf.shape(self.model.seq_lens)[0] T = tf.shape(tensor)[0] // B else: