Files
ray/python/ray/rllib/models/lstm.py
T
Richard Liaw bc082e9a9e [rllib] Additional support for Shared Models in A3C (#866)
* Code for Supporting Shared Models

Running (with vnet modification) - needs to be tested for performance

Summaries

Small refactoring + generalized to more domains

Small fix for jenkins

Linting

linting

Addressing changes

Addressing changes

Update envs.py

Addressing changes

convnet

Merge - new model

final touches

final linting

Changing iterations back

removed extra change

changes for fast experimentation

changes to enable a2c

TEMP FOR DEBUGGING

ContinuousActions - Still doesn't work

InvertedPendulum trains with 8 workers - k=200

huber loss

Maxes for InvertedPendulum-v1 - 16w,200steps

temp: working with a2c

Back to shared model

more fixes

small

nit

LSTM to shared models

need to fix last_features

tuning pong

Best record for hitting 0 - with k=16,n=20

nit

a2cremoval

remove A2c reference and nits

nit

removed a2c vestiges

removing a2c

removing example.py

Linting

nit

* Linting + Removing vestigal code

* Final Touches

* nits

* rerun travis
2017-08-28 12:23:14 -07:00

55 lines
2.1 KiB
Python

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
import tensorflow as tf
import tensorflow.contrib.rnn as rnn
import distutils.version
from ray.rllib.models.misc import (conv2d, linear, flatten,
normc_initializer)
from ray.rllib.models.model import Model
use_tf100_api = (distutils.version.LooseVersion(tf.VERSION) >=
distutils.version.LooseVersion("1.0.0"))
class LSTM(Model):
# TODO(rliaw): Add LSTM code for other algorithms
def _init(self, inputs, num_outputs, options):
self.x = x = inputs
for i in range(4):
x = tf.nn.elu(conv2d(x, 32, "l{}".format(i + 1), [3, 3], [2, 2]))
# Introduce a "fake" batch dimension of 1 after flatten so that we can
# do LSTM over the time dim.
x = tf.expand_dims(flatten(x), [0])
size = 256
if use_tf100_api:
lstm = rnn.BasicLSTMCell(size, state_is_tuple=True)
else:
lstm = rnn.rnn_cell.BasicLSTMCell(size, state_is_tuple=True)
step_size = tf.shape(self.x)[:1]
c_init = np.zeros((1, lstm.state_size.c), np.float32)
h_init = np.zeros((1, lstm.state_size.h), np.float32)
self.state_init = [c_init, h_init]
c_in = tf.placeholder(tf.float32, [1, lstm.state_size.c])
h_in = tf.placeholder(tf.float32, [1, lstm.state_size.h])
self.state_in = [c_in, h_in]
if use_tf100_api:
state_in = rnn.LSTMStateTuple(c_in, h_in)
else:
state_in = rnn.rnn_cell.LSTMStateTuple(c_in, h_in)
lstm_out, lstm_state = tf.nn.dynamic_rnn(lstm, x,
initial_state=state_in,
sequence_length=step_size,
time_major=False)
lstm_c, lstm_h = lstm_state
x = tf.reshape(lstm_out, [-1, size])
logits = linear(x, num_outputs, "action", normc_initializer(0.01))
self.state_out = [lstm_c[:1, :], lstm_h[:1, :]]
return logits, x