[rllib] Refactor the sampler (#3387)

* refactor

* fix test

* add perf test

* Update sampler.py
This commit is contained in:
Eric Liang
2018-11-24 18:16:54 -08:00
committed by GitHub
parent 3856533065
commit b85e7b43f3
3 changed files with 246 additions and 175 deletions
+231 -173
View File
@@ -262,184 +262,242 @@ def _env_runner(async_vector_env,
unfiltered_obs, rewards, dones, infos, off_policy_actions = \
async_vector_env.poll()
# Map of policy_id to list of PolicyEvalData
to_eval = defaultdict(list)
# Process observations and prepare for policy evaluation
active_envs, to_eval, outputs = _process_observations(
async_vector_env, policies, batch_builder_pool, active_episodes,
unfiltered_obs, rewards, dones, infos, off_policy_actions, horizon,
obs_filters, unroll_length, pack, callbacks)
for o in outputs:
yield o
# Map of env_id -> agent_id -> action replies
actions_to_send = defaultdict(dict)
# Do batched policy eval
eval_results = _do_policy_eval(tf_sess, to_eval, policies,
active_episodes)
# For each environment
for env_id, agent_obs in unfiltered_obs.items():
new_episode = env_id not in active_episodes
episode = active_episodes[env_id]
if not new_episode:
episode.length += 1
episode.batch_builder.count += 1
episode._add_agent_rewards(rewards[env_id])
# Check episode termination conditions
if dones[env_id]["__all__"] or episode.length >= horizon:
all_done = True
atari_metrics = _fetch_atari_metrics(async_vector_env)
if atari_metrics is not None:
for m in atari_metrics:
yield m._replace(custom_metrics=episode.custom_metrics)
else:
yield RolloutMetrics(episode.length, episode.total_reward,
dict(episode.agent_rewards),
episode.custom_metrics)
else:
all_done = False
# At least send an empty dict if not done
actions_to_send[env_id] = {}
# For each agent in the environment
for agent_id, raw_obs in agent_obs.items():
policy_id = episode.policy_for(agent_id)
filtered_obs = _get_or_raise(obs_filters, policy_id)(raw_obs)
agent_done = bool(all_done or dones[env_id].get(agent_id))
if not agent_done:
to_eval[policy_id].append(
PolicyEvalData(env_id, agent_id, filtered_obs,
episode.rnn_state_for(agent_id),
episode.last_action_for(agent_id),
rewards[env_id][agent_id] or 0.0))
last_observation = episode.last_observation_for(agent_id)
episode._set_last_observation(agent_id, filtered_obs)
# Record transition info if applicable
if last_observation is not None and \
infos[env_id][agent_id].get("training_enabled", True):
episode.batch_builder.add_values(
agent_id,
policy_id,
t=episode.length - 1,
eps_id=episode.episode_id,
agent_index=episode._agent_index(agent_id),
obs=last_observation,
actions=episode.last_action_for(agent_id),
rewards=rewards[env_id][agent_id],
prev_actions=episode.prev_action_for(agent_id),
prev_rewards=episode.prev_reward_for(agent_id),
dones=agent_done,
infos=infos[env_id][agent_id],
new_obs=filtered_obs,
**episode.last_pi_info_for(agent_id))
# Invoke the step callback after the step is logged to the episode
if callbacks.get("on_episode_step"):
callbacks["on_episode_step"]({
"env": async_vector_env,
"episode": episode
})
# Cut the batch if we're not packing multiple episodes into one,
# or if we've exceeded the requested batch size.
if episode.batch_builder.has_pending_data():
if (all_done and not pack) or \
episode.batch_builder.count >= unroll_length:
yield episode.batch_builder.build_and_reset(episode)
elif all_done:
# Make sure postprocessor stays within one episode
episode.batch_builder.postprocess_batch_so_far(episode)
if all_done:
# Handle episode termination
batch_builder_pool.append(episode.batch_builder)
if callbacks.get("on_episode_end"):
callbacks["on_episode_end"]({
"env": async_vector_env,
"episode": episode
})
del active_episodes[env_id]
resetted_obs = async_vector_env.try_reset(env_id)
if resetted_obs is None:
# Reset not supported, drop this env from the ready list
assert horizon == float("inf"), \
"Setting episode horizon requires reset() support."
else:
# Creates a new episode
episode = active_episodes[env_id]
for agent_id, raw_obs in resetted_obs.items():
policy_id = episode.policy_for(agent_id)
policy = _get_or_raise(policies, policy_id)
filtered_obs = _get_or_raise(obs_filters,
policy_id)(raw_obs)
episode._set_last_observation(agent_id, filtered_obs)
to_eval[policy_id].append(
PolicyEvalData(
env_id, agent_id, filtered_obs,
episode.rnn_state_for(agent_id),
np.zeros_like(
_flatten_action(
policy.action_space.sample())), 0.0))
# Batch eval policy actions if possible
if tf_sess:
builder = TFRunBuilder(tf_sess, "policy_eval")
pending_fetches = {}
else:
builder = None
eval_results = {}
rnn_in_cols = {}
for policy_id, eval_data in to_eval.items():
rnn_in = _to_column_format([t.rnn_state for t in eval_data])
rnn_in_cols[policy_id] = rnn_in
policy = _get_or_raise(policies, policy_id)
if builder and (policy.compute_actions.__code__ is
TFPolicyGraph.compute_actions.__code__):
pending_fetches[policy_id] = policy.build_compute_actions(
builder, [t.obs for t in eval_data],
rnn_in,
prev_action_batch=[t.prev_action for t in eval_data],
prev_reward_batch=[t.prev_reward for t in eval_data],
is_training=True)
else:
eval_results[policy_id] = policy.compute_actions(
[t.obs for t in eval_data],
rnn_in,
prev_action_batch=[t.prev_action for t in eval_data],
prev_reward_batch=[t.prev_reward for t in eval_data],
is_training=True,
episodes=[active_episodes[t.env_id] for t in eval_data])
if builder:
for k, v in pending_fetches.items():
eval_results[k] = builder.get(v)
# Record the policy eval results
for policy_id, eval_data in to_eval.items():
actions, rnn_out_cols, pi_info_cols = eval_results[policy_id]
if len(rnn_in_cols[policy_id]) != len(rnn_out_cols):
raise ValueError(
"Length of RNN in did not match RNN out, got: "
"{} vs {}".format(rnn_in_cols[policy_id], rnn_out_cols))
# Add RNN state info
for f_i, column in enumerate(rnn_in_cols[policy_id]):
pi_info_cols["state_in_{}".format(f_i)] = column
for f_i, column in enumerate(rnn_out_cols):
pi_info_cols["state_out_{}".format(f_i)] = column
# Save output rows
actions = _unbatch_tuple_actions(actions)
for i, action in enumerate(actions):
env_id = eval_data[i].env_id
agent_id = eval_data[i].agent_id
actions_to_send[env_id][agent_id] = action
episode = active_episodes[env_id]
episode._set_rnn_state(agent_id, [c[i] for c in rnn_out_cols])
episode._set_last_pi_info(
agent_id, {k: v[i]
for k, v in pi_info_cols.items()})
if env_id in off_policy_actions and \
agent_id in off_policy_actions[env_id]:
episode._set_last_action(
agent_id, off_policy_actions[env_id][agent_id])
else:
episode._set_last_action(agent_id, action)
# Process results and update episode state
actions_to_send = _process_policy_eval_results(
to_eval, eval_results, active_episodes, active_envs,
off_policy_actions)
# Return computed actions to ready envs. We also send to envs that have
# taken off-policy actions; those envs are free to ignore the action.
async_vector_env.send_actions(dict(actions_to_send))
async_vector_env.send_actions(actions_to_send)
def _process_observations(async_vector_env, policies, batch_builder_pool,
active_episodes, unfiltered_obs, rewards, dones,
infos, off_policy_actions, horizon, obs_filters,
unroll_length, pack, callbacks):
"""Record new data from the environment and prepare for policy evaluation.
Returns:
active_envs: set of non-terminated env ids
to_eval: map of policy_id to list of agent PolicyEvalData
outputs: list of metrics and samples to return from the sampler
"""
active_envs = set()
to_eval = defaultdict(list)
outputs = []
# For each environment
for env_id, agent_obs in unfiltered_obs.items():
new_episode = env_id not in active_episodes
episode = active_episodes[env_id]
if not new_episode:
episode.length += 1
episode.batch_builder.count += 1
episode._add_agent_rewards(rewards[env_id])
# Check episode termination conditions
if dones[env_id]["__all__"] or episode.length >= horizon:
all_done = True
atari_metrics = _fetch_atari_metrics(async_vector_env)
if atari_metrics is not None:
for m in atari_metrics:
outputs.append(
m._replace(custom_metrics=episode.custom_metrics))
else:
outputs.append(
RolloutMetrics(episode.length, episode.total_reward,
dict(episode.agent_rewards),
episode.custom_metrics))
else:
all_done = False
active_envs.add(env_id)
# For each agent in the environment
for agent_id, raw_obs in agent_obs.items():
policy_id = episode.policy_for(agent_id)
filtered_obs = _get_or_raise(obs_filters, policy_id)(raw_obs)
agent_done = bool(all_done or dones[env_id].get(agent_id))
if not agent_done:
to_eval[policy_id].append(
PolicyEvalData(env_id, agent_id, filtered_obs,
episode.rnn_state_for(agent_id),
episode.last_action_for(agent_id),
rewards[env_id][agent_id] or 0.0))
last_observation = episode.last_observation_for(agent_id)
episode._set_last_observation(agent_id, filtered_obs)
# Record transition info if applicable
if last_observation is not None and \
infos[env_id][agent_id].get("training_enabled", True):
episode.batch_builder.add_values(
agent_id,
policy_id,
t=episode.length - 1,
eps_id=episode.episode_id,
agent_index=episode._agent_index(agent_id),
obs=last_observation,
actions=episode.last_action_for(agent_id),
rewards=rewards[env_id][agent_id],
prev_actions=episode.prev_action_for(agent_id),
prev_rewards=episode.prev_reward_for(agent_id),
dones=agent_done,
infos=infos[env_id][agent_id],
new_obs=filtered_obs,
**episode.last_pi_info_for(agent_id))
# Invoke the step callback after the step is logged to the episode
if callbacks.get("on_episode_step"):
callbacks["on_episode_step"]({
"env": async_vector_env,
"episode": episode
})
# Cut the batch if we're not packing multiple episodes into one,
# or if we've exceeded the requested batch size.
if episode.batch_builder.has_pending_data():
if (all_done and not pack) or \
episode.batch_builder.count >= unroll_length:
outputs.append(episode.batch_builder.build_and_reset(episode))
elif all_done:
# Make sure postprocessor stays within one episode
episode.batch_builder.postprocess_batch_so_far(episode)
if all_done:
# Handle episode termination
batch_builder_pool.append(episode.batch_builder)
if callbacks.get("on_episode_end"):
callbacks["on_episode_end"]({
"env": async_vector_env,
"episode": episode
})
del active_episodes[env_id]
resetted_obs = async_vector_env.try_reset(env_id)
if resetted_obs is None:
# Reset not supported, drop this env from the ready list
if horizon != float("inf"):
raise ValueError(
"Setting episode horizon requires reset() support "
"from the environment.")
else:
# Creates a new episode
episode = active_episodes[env_id]
for agent_id, raw_obs in resetted_obs.items():
policy_id = episode.policy_for(agent_id)
policy = _get_or_raise(policies, policy_id)
filtered_obs = _get_or_raise(obs_filters,
policy_id)(raw_obs)
episode._set_last_observation(agent_id, filtered_obs)
to_eval[policy_id].append(
PolicyEvalData(
env_id, agent_id, filtered_obs,
episode.rnn_state_for(agent_id),
np.zeros_like(
_flatten_action(policy.action_space.sample())),
0.0))
return active_envs, to_eval, outputs
def _do_policy_eval(tf_sess, to_eval, policies, active_episodes):
"""Call compute actions on observation batches to get next actions.
Returns:
eval_results: dict of policy to compute_action() outputs.
"""
eval_results = {}
if tf_sess:
builder = TFRunBuilder(tf_sess, "policy_eval")
pending_fetches = {}
else:
builder = None
for policy_id, eval_data in to_eval.items():
rnn_in_cols = _to_column_format([t.rnn_state for t in eval_data])
policy = _get_or_raise(policies, policy_id)
if builder and (policy.compute_actions.__code__ is
TFPolicyGraph.compute_actions.__code__):
pending_fetches[policy_id] = policy.build_compute_actions(
builder, [t.obs for t in eval_data],
rnn_in_cols,
prev_action_batch=[t.prev_action for t in eval_data],
prev_reward_batch=[t.prev_reward for t in eval_data],
is_training=True)
else:
eval_results[policy_id] = policy.compute_actions(
[t.obs for t in eval_data],
rnn_in_cols,
prev_action_batch=[t.prev_action for t in eval_data],
prev_reward_batch=[t.prev_reward for t in eval_data],
is_training=True,
episodes=[active_episodes[t.env_id] for t in eval_data])
if builder:
for k, v in pending_fetches.items():
eval_results[k] = builder.get(v)
return eval_results
def _process_policy_eval_results(to_eval, eval_results, active_episodes,
active_envs, off_policy_actions):
"""Process the output of policy neural network evaluation.
Records policy evaluation results into the given episode objects and
returns replies to send back to agents in the env.
Returns:
actions_to_send: nested dict of env id -> agent id -> agent replies.
"""
actions_to_send = defaultdict(dict)
for env_id in active_envs:
actions_to_send[env_id] = {} # at minimum send empty dict
for policy_id, eval_data in to_eval.items():
rnn_in_cols = _to_column_format([t.rnn_state for t in eval_data])
actions, rnn_out_cols, pi_info_cols = eval_results[policy_id]
if len(rnn_in_cols) != len(rnn_out_cols):
raise ValueError("Length of RNN in did not match RNN out, got: "
"{} vs {}".format(rnn_in_cols, rnn_out_cols))
# Add RNN state info
for f_i, column in enumerate(rnn_in_cols):
pi_info_cols["state_in_{}".format(f_i)] = column
for f_i, column in enumerate(rnn_out_cols):
pi_info_cols["state_out_{}".format(f_i)] = column
# Save output rows
actions = _unbatch_tuple_actions(actions)
for i, action in enumerate(actions):
env_id = eval_data[i].env_id
agent_id = eval_data[i].agent_id
actions_to_send[env_id][agent_id] = action
episode = active_episodes[env_id]
episode._set_rnn_state(agent_id, [c[i] for c in rnn_out_cols])
episode._set_last_pi_info(
agent_id, {k: v[i]
for k, v in pi_info_cols.items()})
if env_id in off_policy_actions and \
agent_id in off_policy_actions[env_id]:
episode._set_last_action(agent_id,
off_policy_actions[env_id][agent_id])
else:
episode._set_last_action(agent_id, action)
return actions_to_send
def _fetch_atari_metrics(async_vector_env):
+1 -2
View File
@@ -192,8 +192,7 @@ class TestExternalEnv(unittest.TestCase):
episode_horizon=20,
batch_steps=10,
batch_mode="complete_episodes")
ev.sample()
self.assertRaises(Exception, lambda: ev.sample())
self.assertRaises(ValueError, lambda: ev.sample())
if __name__ == '__main__':
@@ -152,6 +152,20 @@ class TestPolicyEvaluator(unittest.TestCase):
to_prev(batch["actions"]))
self.assertGreater(batch["advantages"][0], 1)
# 11/23/18: Samples per second 8501.125113727468
def testBaselinePerformance(self):
ev = PolicyEvaluator(
env_creator=lambda _: gym.make("CartPole-v0"),
policy_graph=MockPolicyGraph,
batch_steps=100)
start = time.time()
count = 0
while time.time() - start < 1:
count += ev.sample().count
print()
print("Samples per second {}".format(count / (time.time() - start)))
print()
def testGlobalVarsUpdate(self):
agent = A2CAgent(
env="CartPole-v0",