[rllib] Add unit test and some better error messages for custom policy states (#3032)

This commit is contained in:
Eric Liang
2018-10-13 00:03:52 -07:00
committed by GitHub
parent 87639b9e26
commit 473ee4eb3f
3 changed files with 42 additions and 5 deletions
+4
View File
@@ -364,6 +364,10 @@ def _env_runner(async_vector_env,
# Record the policy eval results
for policy_id, eval_data in to_eval.items():
actions, rnn_out_cols, pi_info_cols = eval_results[policy_id]
if len(rnn_in_cols[policy_id]) != len(rnn_out_cols):
raise ValueError(
"Length of RNN in did not match RNN out, got: "
"{} vs {}".format(rnn_in_cols[policy_id], rnn_out_cols))
# Add RNN state info
for f_i, column in enumerate(rnn_in_cols[policy_id]):
pi_info_cols["state_in_{}".format(f_i)] = column
+12 -5
View File
@@ -95,11 +95,18 @@ class TFPolicyGraph(PolicyGraph):
self._variables = ray.experimental.TensorFlowVariables(
self._loss, self._sess)
assert len(self._state_inputs) == len(self._state_outputs) == \
len(self.get_initial_state()), \
(self._state_inputs, self._state_outputs, self.get_initial_state())
if self._state_inputs:
assert self._seq_lens is not None
if len(self._state_inputs) != len(self._state_outputs):
raise ValueError(
"Number of state input and output tensors must match, got: "
"{} vs {}".format(self._state_inputs, self._state_outputs))
if len(self.get_initial_state()) != len(self._state_inputs):
raise ValueError(
"Length of initial state must match number of state inputs, "
"got: {} vs {}".format(self.get_initial_state(),
self._state_inputs))
if self._state_inputs and self._seq_lens is None:
raise ValueError(
"seq_lens tensor must be given if state inputs are defined")
def build_compute_actions(self,
builder,
@@ -15,6 +15,7 @@ from ray.rllib.optimizers import SyncSamplesOptimizer, \
from ray.rllib.test.test_policy_evaluator import MockEnv, MockEnv2, \
MockPolicyGraph
from ray.rllib.evaluation.policy_evaluator import PolicyEvaluator
from ray.rllib.evaluation.policy_graph import PolicyGraph
from ray.rllib.evaluation.metrics import collect_metrics
from ray.rllib.env.async_vector_env import _MultiAgentEnvToAsync
from ray.rllib.env.multi_agent_env import MultiAgentEnv
@@ -306,6 +307,31 @@ class TestMultiAgentEnv(unittest.TestCase):
self.assertEqual(batch.policy_batches["p0"]["t"].tolist()[:10],
[4, 9, 14, 19, 24, 5, 10, 15, 20, 25])
def testCustomRNNStateValues(self):
h = {"some": {"arbitrary": "structure", "here": [1, 2, 3]}}
class StatefulPolicyGraph(PolicyGraph):
def compute_actions(self,
obs_batch,
state_batches,
is_training=False,
episodes=None):
return [0] * len(obs_batch), [[h] * len(obs_batch)], {}
def get_initial_state(self):
return [{}] # empty dict
ev = PolicyEvaluator(
env_creator=lambda _: gym.make("CartPole-v0"),
policy_graph=StatefulPolicyGraph,
batch_steps=5)
batch = ev.sample()
self.assertEqual(batch.count, 5)
self.assertEqual(batch["state_in_0"][0], {})
self.assertEqual(batch["state_out_0"][0], h)
self.assertEqual(batch["state_in_0"][1], h)
self.assertEqual(batch["state_out_0"][1], h)
def testReturningModelBasedRolloutsData(self):
class ModelBasedPolicyGraph(PGPolicyGraph):
def compute_actions(self,