mirror of
https://github.com/wassname/ray.git
synced 2026-07-02 18:48:35 +08:00
bc082e9a9e
* Code for Supporting Shared Models Running (with vnet modification) - needs to be tested for performance Summaries Small refactoring + generalized to more domains Small fix for jenkins Linting linting Addressing changes Addressing changes Update envs.py Addressing changes convnet Merge - new model final touches final linting Changing iterations back removed extra change changes for fast experimentation changes to enable a2c TEMP FOR DEBUGGING ContinuousActions - Still doesn't work InvertedPendulum trains with 8 workers - k=200 huber loss Maxes for InvertedPendulum-v1 - 16w,200steps temp: working with a2c Back to shared model more fixes small nit LSTM to shared models need to fix last_features tuning pong Best record for hitting 0 - with k=16,n=20 nit a2cremoval remove A2c reference and nits nit removed a2c vestiges removing a2c removing example.py Linting nit * Linting + Removing vestigal code * Final Touches * nits * rerun travis
55 lines
2.1 KiB
Python
55 lines
2.1 KiB
Python
from __future__ import absolute_import
|
|
from __future__ import division
|
|
from __future__ import print_function
|
|
|
|
import numpy as np
|
|
import tensorflow as tf
|
|
import tensorflow.contrib.rnn as rnn
|
|
import distutils.version
|
|
|
|
from ray.rllib.models.misc import (conv2d, linear, flatten,
|
|
normc_initializer)
|
|
from ray.rllib.models.model import Model
|
|
|
|
use_tf100_api = (distutils.version.LooseVersion(tf.VERSION) >=
|
|
distutils.version.LooseVersion("1.0.0"))
|
|
|
|
|
|
class LSTM(Model):
|
|
# TODO(rliaw): Add LSTM code for other algorithms
|
|
def _init(self, inputs, num_outputs, options):
|
|
self.x = x = inputs
|
|
for i in range(4):
|
|
x = tf.nn.elu(conv2d(x, 32, "l{}".format(i + 1), [3, 3], [2, 2]))
|
|
# Introduce a "fake" batch dimension of 1 after flatten so that we can
|
|
# do LSTM over the time dim.
|
|
x = tf.expand_dims(flatten(x), [0])
|
|
|
|
size = 256
|
|
if use_tf100_api:
|
|
lstm = rnn.BasicLSTMCell(size, state_is_tuple=True)
|
|
else:
|
|
lstm = rnn.rnn_cell.BasicLSTMCell(size, state_is_tuple=True)
|
|
step_size = tf.shape(self.x)[:1]
|
|
|
|
c_init = np.zeros((1, lstm.state_size.c), np.float32)
|
|
h_init = np.zeros((1, lstm.state_size.h), np.float32)
|
|
self.state_init = [c_init, h_init]
|
|
c_in = tf.placeholder(tf.float32, [1, lstm.state_size.c])
|
|
h_in = tf.placeholder(tf.float32, [1, lstm.state_size.h])
|
|
self.state_in = [c_in, h_in]
|
|
|
|
if use_tf100_api:
|
|
state_in = rnn.LSTMStateTuple(c_in, h_in)
|
|
else:
|
|
state_in = rnn.rnn_cell.LSTMStateTuple(c_in, h_in)
|
|
lstm_out, lstm_state = tf.nn.dynamic_rnn(lstm, x,
|
|
initial_state=state_in,
|
|
sequence_length=step_size,
|
|
time_major=False)
|
|
lstm_c, lstm_h = lstm_state
|
|
x = tf.reshape(lstm_out, [-1, size])
|
|
logits = linear(x, num_outputs, "action", normc_initializer(0.01))
|
|
self.state_out = [lstm_c[:1, :], lstm_h[:1, :]]
|
|
return logits, x
|