mirror of
https://github.com/wassname/ray.git
synced 2026-06-30 09:57:46 +08:00
[rllib] Add end-to-end tests for RNN sequencing (#4258)
This commit is contained in:
@@ -6,9 +6,9 @@ SCRIPT=$1
|
||||
shift
|
||||
|
||||
if [ -x $DIRECTORY/../$SCRIPT ]; then
|
||||
$DIRECTORY/../$SCRIPT "$@" >$TMPFILE 2>&1
|
||||
time $DIRECTORY/../$SCRIPT "$@" >$TMPFILE 2>&1
|
||||
else
|
||||
python $DIRECTORY/../$SCRIPT "$@" >$TMPFILE 2>&1
|
||||
time python $DIRECTORY/../$SCRIPT "$@" >$TMPFILE 2>&1
|
||||
fi
|
||||
|
||||
CODE=$?
|
||||
|
||||
@@ -2,9 +2,20 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import gym
|
||||
import numpy as np
|
||||
import pickle
|
||||
import unittest
|
||||
import tensorflow as tf
|
||||
import tensorflow.contrib.rnn as rnn
|
||||
|
||||
from ray.rllib.models.lstm import chop_into_sequences
|
||||
import ray
|
||||
from ray.rllib.agents.ppo import PPOAgent
|
||||
from ray.rllib.models import ModelCatalog
|
||||
from ray.rllib.models.lstm import add_time_dimension, chop_into_sequences
|
||||
from ray.rllib.models.misc import linear, normc_initializer
|
||||
from ray.rllib.models.model import Model
|
||||
from ray.tune.registry import register_env
|
||||
|
||||
|
||||
class LSTMUtilsTest(unittest.TestCase):
|
||||
@@ -48,5 +59,209 @@ class LSTMUtilsTest(unittest.TestCase):
|
||||
self.assertEqual(seq_lens.tolist(), [1, 2])
|
||||
|
||||
|
||||
class RNNSpyModel(Model):
|
||||
capture_index = 0
|
||||
|
||||
def _build_layers_v2(self, input_dict, num_outputs, options):
|
||||
def spy(sequences, state_in, state_out, seq_lens):
|
||||
if len(sequences) == 1:
|
||||
return 0 # don't capture inference inputs
|
||||
# TF runs this function in an isolated context, so we have to use
|
||||
# redis to communicate back to our suite
|
||||
ray.experimental.internal_kv._internal_kv_put(
|
||||
"rnn_spy_in_{}".format(RNNSpyModel.capture_index),
|
||||
pickle.dumps({
|
||||
"sequences": sequences,
|
||||
"state_in": state_in,
|
||||
"state_out": state_out,
|
||||
"seq_lens": seq_lens
|
||||
}),
|
||||
overwrite=True)
|
||||
RNNSpyModel.capture_index += 1
|
||||
return 0
|
||||
|
||||
features = input_dict["obs"]
|
||||
cell_size = 3
|
||||
last_layer = add_time_dimension(features, self.seq_lens)
|
||||
|
||||
# Setup the LSTM cell
|
||||
lstm = rnn.BasicLSTMCell(cell_size, state_is_tuple=True)
|
||||
self.state_init = [
|
||||
np.zeros(lstm.state_size.c, np.float32),
|
||||
np.zeros(lstm.state_size.h, np.float32)
|
||||
]
|
||||
|
||||
# Setup LSTM inputs
|
||||
if self.state_in:
|
||||
c_in, h_in = self.state_in
|
||||
else:
|
||||
c_in = tf.placeholder(
|
||||
tf.float32, [None, lstm.state_size.c], name="c")
|
||||
h_in = tf.placeholder(
|
||||
tf.float32, [None, lstm.state_size.h], name="h")
|
||||
self.state_in = [c_in, h_in]
|
||||
|
||||
# Setup LSTM outputs
|
||||
state_in = rnn.LSTMStateTuple(c_in, h_in)
|
||||
lstm_out, lstm_state = tf.nn.dynamic_rnn(
|
||||
lstm,
|
||||
last_layer,
|
||||
initial_state=state_in,
|
||||
sequence_length=self.seq_lens,
|
||||
time_major=False,
|
||||
dtype=tf.float32)
|
||||
|
||||
self.state_out = list(lstm_state)
|
||||
spy_fn = tf.py_func(
|
||||
spy, [
|
||||
last_layer,
|
||||
self.state_in,
|
||||
self.state_out,
|
||||
self.seq_lens,
|
||||
],
|
||||
tf.int64,
|
||||
stateful=True)
|
||||
|
||||
# Compute outputs
|
||||
with tf.control_dependencies([spy_fn]):
|
||||
last_layer = tf.reshape(lstm_out, [-1, cell_size])
|
||||
logits = linear(last_layer, num_outputs, "action",
|
||||
normc_initializer(0.01))
|
||||
return logits, last_layer
|
||||
|
||||
|
||||
class DebugCounterEnv(gym.Env):
|
||||
def __init__(self):
|
||||
self.action_space = gym.spaces.Discrete(2)
|
||||
self.observation_space = gym.spaces.Box(0, 100, (1, ))
|
||||
self.i = 0
|
||||
|
||||
def reset(self):
|
||||
self.i = 0
|
||||
return [self.i]
|
||||
|
||||
def step(self, action):
|
||||
self.i += 1
|
||||
return [self.i], self.i % 3, self.i >= 15, {}
|
||||
|
||||
|
||||
class RNNSequencing(unittest.TestCase):
|
||||
def testSimpleOptimizerSequencing(self):
|
||||
ModelCatalog.register_custom_model("rnn", RNNSpyModel)
|
||||
register_env("counter", lambda _: DebugCounterEnv())
|
||||
ppo = PPOAgent(
|
||||
env="counter",
|
||||
config={
|
||||
"num_workers": 0,
|
||||
"sample_batch_size": 10,
|
||||
"train_batch_size": 10,
|
||||
"sgd_minibatch_size": 10,
|
||||
"vf_share_layers": True,
|
||||
"simple_optimizer": True,
|
||||
"num_sgd_iter": 1,
|
||||
"model": {
|
||||
"custom_model": "rnn",
|
||||
"max_seq_len": 4,
|
||||
},
|
||||
})
|
||||
ppo.train()
|
||||
ppo.train()
|
||||
|
||||
batch0 = pickle.loads(
|
||||
ray.experimental.internal_kv._internal_kv_get("rnn_spy_in_0"))
|
||||
self.assertEqual(
|
||||
batch0["sequences"].tolist(),
|
||||
[[[0], [1], [2], [3]], [[4], [5], [6], [7]], [[8], [9], [0], [0]]])
|
||||
self.assertEqual(batch0["seq_lens"].tolist(), [4, 4, 2])
|
||||
self.assertEqual(batch0["state_in"][0][0].tolist(), [0, 0, 0])
|
||||
self.assertEqual(batch0["state_in"][1][0].tolist(), [0, 0, 0])
|
||||
self.assertGreater(abs(np.sum(batch0["state_in"][0][1])), 0)
|
||||
self.assertGreater(abs(np.sum(batch0["state_in"][1][1])), 0)
|
||||
self.assertTrue(
|
||||
np.allclose(batch0["state_in"][0].tolist()[1:],
|
||||
batch0["state_out"][0].tolist()[:-1]))
|
||||
self.assertTrue(
|
||||
np.allclose(batch0["state_in"][1].tolist()[1:],
|
||||
batch0["state_out"][1].tolist()[:-1]))
|
||||
|
||||
batch1 = pickle.loads(
|
||||
ray.experimental.internal_kv._internal_kv_get("rnn_spy_in_1"))
|
||||
self.assertEqual(batch1["sequences"].tolist(), [
|
||||
[[10], [11], [12], [13]],
|
||||
[[14], [0], [0], [0]],
|
||||
[[0], [1], [2], [3]],
|
||||
[[4], [0], [0], [0]],
|
||||
])
|
||||
self.assertEqual(batch1["seq_lens"].tolist(), [4, 1, 4, 1])
|
||||
self.assertEqual(batch1["state_in"][0][2].tolist(), [0, 0, 0])
|
||||
self.assertEqual(batch1["state_in"][1][2].tolist(), [0, 0, 0])
|
||||
self.assertGreater(abs(np.sum(batch1["state_in"][0][0])), 0)
|
||||
self.assertGreater(abs(np.sum(batch1["state_in"][1][0])), 0)
|
||||
self.assertGreater(abs(np.sum(batch1["state_in"][0][1])), 0)
|
||||
self.assertGreater(abs(np.sum(batch1["state_in"][1][1])), 0)
|
||||
self.assertGreater(abs(np.sum(batch1["state_in"][0][3])), 0)
|
||||
self.assertGreater(abs(np.sum(batch1["state_in"][1][3])), 0)
|
||||
|
||||
def testMinibatchSequencing(self):
|
||||
ModelCatalog.register_custom_model("rnn", RNNSpyModel)
|
||||
register_env("counter", lambda _: DebugCounterEnv())
|
||||
ppo = PPOAgent(
|
||||
env="counter",
|
||||
config={
|
||||
"num_workers": 0,
|
||||
"sample_batch_size": 20,
|
||||
"train_batch_size": 20,
|
||||
"sgd_minibatch_size": 10,
|
||||
"vf_share_layers": True,
|
||||
"simple_optimizer": False,
|
||||
"num_sgd_iter": 1,
|
||||
"model": {
|
||||
"custom_model": "rnn",
|
||||
"max_seq_len": 4,
|
||||
},
|
||||
})
|
||||
ppo.train()
|
||||
ppo.train()
|
||||
|
||||
# first epoch: 20 observations get split into 2 minibatches of 8
|
||||
# four observations are discarded
|
||||
batch0 = pickle.loads(
|
||||
ray.experimental.internal_kv._internal_kv_get("rnn_spy_in_0"))
|
||||
batch1 = pickle.loads(
|
||||
ray.experimental.internal_kv._internal_kv_get("rnn_spy_in_1"))
|
||||
if batch0["sequences"][0][0][0] > batch1["sequences"][0][0][0]:
|
||||
batch0, batch1 = batch1, batch0 # sort minibatches
|
||||
self.assertEqual(batch0["seq_lens"].tolist(), [4, 4])
|
||||
self.assertEqual(batch1["seq_lens"].tolist(), [4, 3])
|
||||
self.assertEqual(batch0["sequences"].tolist(), [
|
||||
[[0], [1], [2], [3]],
|
||||
[[4], [5], [6], [7]],
|
||||
])
|
||||
self.assertEqual(batch1["sequences"].tolist(), [
|
||||
[[8], [9], [10], [11]],
|
||||
[[12], [13], [14], [0]],
|
||||
])
|
||||
|
||||
# second epoch: 20 observations get split into 2 minibatches of 8
|
||||
# four observations are discarded
|
||||
batch2 = pickle.loads(
|
||||
ray.experimental.internal_kv._internal_kv_get("rnn_spy_in_2"))
|
||||
batch3 = pickle.loads(
|
||||
ray.experimental.internal_kv._internal_kv_get("rnn_spy_in_3"))
|
||||
if batch2["sequences"][0][0][0] > batch3["sequences"][0][0][0]:
|
||||
batch2, batch3 = batch3, batch2
|
||||
self.assertEqual(batch2["seq_lens"].tolist(), [4, 4])
|
||||
self.assertEqual(batch3["seq_lens"].tolist(), [2, 4])
|
||||
self.assertEqual(batch2["sequences"].tolist(), [
|
||||
[[5], [6], [7], [8]],
|
||||
[[9], [10], [11], [12]],
|
||||
])
|
||||
self.assertEqual(batch3["sequences"].tolist(), [
|
||||
[[13], [14], [0], [0]],
|
||||
[[0], [1], [2], [3]],
|
||||
])
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
ray.init(num_cpus=4)
|
||||
unittest.main(verbosity=2)
|
||||
|
||||
Reference in New Issue
Block a user