mirror of
https://github.com/wassname/ray.git
synced 2026-06-27 20:06:31 +08:00
[rllib] Add unit test and some better error messages for custom policy states (#3032)
This commit is contained in:
@@ -364,6 +364,10 @@ def _env_runner(async_vector_env,
|
||||
# Record the policy eval results
|
||||
for policy_id, eval_data in to_eval.items():
|
||||
actions, rnn_out_cols, pi_info_cols = eval_results[policy_id]
|
||||
if len(rnn_in_cols[policy_id]) != len(rnn_out_cols):
|
||||
raise ValueError(
|
||||
"Length of RNN in did not match RNN out, got: "
|
||||
"{} vs {}".format(rnn_in_cols[policy_id], rnn_out_cols))
|
||||
# Add RNN state info
|
||||
for f_i, column in enumerate(rnn_in_cols[policy_id]):
|
||||
pi_info_cols["state_in_{}".format(f_i)] = column
|
||||
|
||||
@@ -95,11 +95,18 @@ class TFPolicyGraph(PolicyGraph):
|
||||
self._variables = ray.experimental.TensorFlowVariables(
|
||||
self._loss, self._sess)
|
||||
|
||||
assert len(self._state_inputs) == len(self._state_outputs) == \
|
||||
len(self.get_initial_state()), \
|
||||
(self._state_inputs, self._state_outputs, self.get_initial_state())
|
||||
if self._state_inputs:
|
||||
assert self._seq_lens is not None
|
||||
if len(self._state_inputs) != len(self._state_outputs):
|
||||
raise ValueError(
|
||||
"Number of state input and output tensors must match, got: "
|
||||
"{} vs {}".format(self._state_inputs, self._state_outputs))
|
||||
if len(self.get_initial_state()) != len(self._state_inputs):
|
||||
raise ValueError(
|
||||
"Length of initial state must match number of state inputs, "
|
||||
"got: {} vs {}".format(self.get_initial_state(),
|
||||
self._state_inputs))
|
||||
if self._state_inputs and self._seq_lens is None:
|
||||
raise ValueError(
|
||||
"seq_lens tensor must be given if state inputs are defined")
|
||||
|
||||
def build_compute_actions(self,
|
||||
builder,
|
||||
|
||||
@@ -15,6 +15,7 @@ from ray.rllib.optimizers import SyncSamplesOptimizer, \
|
||||
from ray.rllib.test.test_policy_evaluator import MockEnv, MockEnv2, \
|
||||
MockPolicyGraph
|
||||
from ray.rllib.evaluation.policy_evaluator import PolicyEvaluator
|
||||
from ray.rllib.evaluation.policy_graph import PolicyGraph
|
||||
from ray.rllib.evaluation.metrics import collect_metrics
|
||||
from ray.rllib.env.async_vector_env import _MultiAgentEnvToAsync
|
||||
from ray.rllib.env.multi_agent_env import MultiAgentEnv
|
||||
@@ -306,6 +307,31 @@ class TestMultiAgentEnv(unittest.TestCase):
|
||||
self.assertEqual(batch.policy_batches["p0"]["t"].tolist()[:10],
|
||||
[4, 9, 14, 19, 24, 5, 10, 15, 20, 25])
|
||||
|
||||
def testCustomRNNStateValues(self):
|
||||
h = {"some": {"arbitrary": "structure", "here": [1, 2, 3]}}
|
||||
|
||||
class StatefulPolicyGraph(PolicyGraph):
|
||||
def compute_actions(self,
|
||||
obs_batch,
|
||||
state_batches,
|
||||
is_training=False,
|
||||
episodes=None):
|
||||
return [0] * len(obs_batch), [[h] * len(obs_batch)], {}
|
||||
|
||||
def get_initial_state(self):
|
||||
return [{}] # empty dict
|
||||
|
||||
ev = PolicyEvaluator(
|
||||
env_creator=lambda _: gym.make("CartPole-v0"),
|
||||
policy_graph=StatefulPolicyGraph,
|
||||
batch_steps=5)
|
||||
batch = ev.sample()
|
||||
self.assertEqual(batch.count, 5)
|
||||
self.assertEqual(batch["state_in_0"][0], {})
|
||||
self.assertEqual(batch["state_out_0"][0], h)
|
||||
self.assertEqual(batch["state_in_0"][1], h)
|
||||
self.assertEqual(batch["state_out_0"][1], h)
|
||||
|
||||
def testReturningModelBasedRolloutsData(self):
|
||||
class ModelBasedPolicyGraph(PGPolicyGraph):
|
||||
def compute_actions(self,
|
||||
|
||||
Reference in New Issue
Block a user