diff --git a/python/ray/rllib/rollout.py b/python/ray/rllib/rollout.py index e1a231093..f7ea31e96 100755 --- a/python/ray/rllib/rollout.py +++ b/python/ray/rllib/rollout.py @@ -132,7 +132,7 @@ def rollout(agent, env_name, num_steps, out=None, no_render=True): state_init = {p: m.get_initial_state() for p, m in policy_map.items()} use_lstm = {p: len(s) > 0 for p, s in state_init.items()} action_init = { - p: m.action_space.sample() + p: _flatten_action(m.action_space.sample()) for p, m in policy_map.items() } else: