diff --git a/python/ray/rllib/models/lstm.py b/python/ray/rllib/models/lstm.py index b3fb557b8..18f141d09 100644 --- a/python/ray/rllib/models/lstm.py +++ b/python/ray/rllib/models/lstm.py @@ -56,7 +56,7 @@ class LSTM(Model): last_layer = add_time_dimension(features, self.seq_lens) # Setup the LSTM cell - lstm = rnn.BasicLSTMCell(cell_size, state_is_tuple=True) + lstm = tf.nn.rnn_cell.LSTMCell(cell_size, state_is_tuple=True) self.state_init = [ np.zeros(lstm.state_size.c, np.float32), np.zeros(lstm.state_size.h, np.float32)