mirror of
https://github.com/wassname/ray.git
synced 2026-06-28 17:18:45 +08:00
Added custom LSTM detection (#4087)
* Added autodetection of custom LSTM usage * Reverted line separators * Added check for LSTM * Update vtrace_policy_graph.py * Update appo_policy_graph.py
This commit is contained in:
committed by
Eric Liang
parent
692bb336a1
commit
a54386e499
@@ -148,7 +148,7 @@ class VTracePolicyGraph(LearningRateSchedule, TFPolicyGraph):
|
||||
tf.get_variable_scope().name)
|
||||
|
||||
def to_batches(tensor):
|
||||
if self.config["model"]["use_lstm"]:
|
||||
if self.model.state_init:
|
||||
B = tf.shape(self.model.seq_lens)[0]
|
||||
T = tf.shape(tensor)[0] // B
|
||||
else:
|
||||
|
||||
@@ -230,7 +230,7 @@ class AsyncPPOPolicyGraph(LearningRateSchedule, TFPolicyGraph):
|
||||
tf.get_variable_scope().name)
|
||||
|
||||
def to_batches(tensor):
|
||||
if self.config["model"]["use_lstm"]:
|
||||
if self.model.state_init:
|
||||
B = tf.shape(self.model.seq_lens)[0]
|
||||
T = tf.shape(tensor)[0] // B
|
||||
else:
|
||||
|
||||
Reference in New Issue
Block a user