diff --git a/python/ray/rllib/evaluation/sampler.py b/python/ray/rllib/evaluation/sampler.py index ec0f13c4e..78043b3d4 100644 --- a/python/ray/rllib/evaluation/sampler.py +++ b/python/ray/rllib/evaluation/sampler.py @@ -364,6 +364,10 @@ def _env_runner(async_vector_env, # Record the policy eval results for policy_id, eval_data in to_eval.items(): actions, rnn_out_cols, pi_info_cols = eval_results[policy_id] + if len(rnn_in_cols[policy_id]) != len(rnn_out_cols): + raise ValueError( + "Length of RNN in did not match RNN out, got: " + "{} vs {}".format(rnn_in_cols[policy_id], rnn_out_cols)) # Add RNN state info for f_i, column in enumerate(rnn_in_cols[policy_id]): pi_info_cols["state_in_{}".format(f_i)] = column diff --git a/python/ray/rllib/evaluation/tf_policy_graph.py b/python/ray/rllib/evaluation/tf_policy_graph.py index e9119c875..09a84981e 100644 --- a/python/ray/rllib/evaluation/tf_policy_graph.py +++ b/python/ray/rllib/evaluation/tf_policy_graph.py @@ -95,11 +95,18 @@ class TFPolicyGraph(PolicyGraph): self._variables = ray.experimental.TensorFlowVariables( self._loss, self._sess) - assert len(self._state_inputs) == len(self._state_outputs) == \ - len(self.get_initial_state()), \ - (self._state_inputs, self._state_outputs, self.get_initial_state()) - if self._state_inputs: - assert self._seq_lens is not None + if len(self._state_inputs) != len(self._state_outputs): + raise ValueError( + "Number of state input and output tensors must match, got: " + "{} vs {}".format(self._state_inputs, self._state_outputs)) + if len(self.get_initial_state()) != len(self._state_inputs): + raise ValueError( + "Length of initial state must match number of state inputs, " + "got: {} vs {}".format(self.get_initial_state(), + self._state_inputs)) + if self._state_inputs and self._seq_lens is None: + raise ValueError( + "seq_lens tensor must be given if state inputs are defined") def build_compute_actions(self, builder, diff --git a/python/ray/rllib/test/test_multi_agent_env.py b/python/ray/rllib/test/test_multi_agent_env.py index 96eaabaf1..493b338cf 100644 --- a/python/ray/rllib/test/test_multi_agent_env.py +++ b/python/ray/rllib/test/test_multi_agent_env.py @@ -15,6 +15,7 @@ from ray.rllib.optimizers import SyncSamplesOptimizer, \ from ray.rllib.test.test_policy_evaluator import MockEnv, MockEnv2, \ MockPolicyGraph from ray.rllib.evaluation.policy_evaluator import PolicyEvaluator +from ray.rllib.evaluation.policy_graph import PolicyGraph from ray.rllib.evaluation.metrics import collect_metrics from ray.rllib.env.async_vector_env import _MultiAgentEnvToAsync from ray.rllib.env.multi_agent_env import MultiAgentEnv @@ -306,6 +307,31 @@ class TestMultiAgentEnv(unittest.TestCase): self.assertEqual(batch.policy_batches["p0"]["t"].tolist()[:10], [4, 9, 14, 19, 24, 5, 10, 15, 20, 25]) + def testCustomRNNStateValues(self): + h = {"some": {"arbitrary": "structure", "here": [1, 2, 3]}} + + class StatefulPolicyGraph(PolicyGraph): + def compute_actions(self, + obs_batch, + state_batches, + is_training=False, + episodes=None): + return [0] * len(obs_batch), [[h] * len(obs_batch)], {} + + def get_initial_state(self): + return [{}] # empty dict + + ev = PolicyEvaluator( + env_creator=lambda _: gym.make("CartPole-v0"), + policy_graph=StatefulPolicyGraph, + batch_steps=5) + batch = ev.sample() + self.assertEqual(batch.count, 5) + self.assertEqual(batch["state_in_0"][0], {}) + self.assertEqual(batch["state_out_0"][0], h) + self.assertEqual(batch["state_in_0"][1], h) + self.assertEqual(batch["state_out_0"][1], h) + def testReturningModelBasedRolloutsData(self): class ModelBasedPolicyGraph(PGPolicyGraph): def compute_actions(self,