mirror of
https://github.com/wassname/ray.git
synced 2026-06-30 20:18:33 +08:00
[rllib] Refactor the sampler (#3387)
* refactor * fix test * add perf test * Update sampler.py
This commit is contained in:
@@ -262,184 +262,242 @@ def _env_runner(async_vector_env,
|
||||
unfiltered_obs, rewards, dones, infos, off_policy_actions = \
|
||||
async_vector_env.poll()
|
||||
|
||||
# Map of policy_id to list of PolicyEvalData
|
||||
to_eval = defaultdict(list)
|
||||
# Process observations and prepare for policy evaluation
|
||||
active_envs, to_eval, outputs = _process_observations(
|
||||
async_vector_env, policies, batch_builder_pool, active_episodes,
|
||||
unfiltered_obs, rewards, dones, infos, off_policy_actions, horizon,
|
||||
obs_filters, unroll_length, pack, callbacks)
|
||||
for o in outputs:
|
||||
yield o
|
||||
|
||||
# Map of env_id -> agent_id -> action replies
|
||||
actions_to_send = defaultdict(dict)
|
||||
# Do batched policy eval
|
||||
eval_results = _do_policy_eval(tf_sess, to_eval, policies,
|
||||
active_episodes)
|
||||
|
||||
# For each environment
|
||||
for env_id, agent_obs in unfiltered_obs.items():
|
||||
new_episode = env_id not in active_episodes
|
||||
episode = active_episodes[env_id]
|
||||
if not new_episode:
|
||||
episode.length += 1
|
||||
episode.batch_builder.count += 1
|
||||
episode._add_agent_rewards(rewards[env_id])
|
||||
|
||||
# Check episode termination conditions
|
||||
if dones[env_id]["__all__"] or episode.length >= horizon:
|
||||
all_done = True
|
||||
atari_metrics = _fetch_atari_metrics(async_vector_env)
|
||||
if atari_metrics is not None:
|
||||
for m in atari_metrics:
|
||||
yield m._replace(custom_metrics=episode.custom_metrics)
|
||||
else:
|
||||
yield RolloutMetrics(episode.length, episode.total_reward,
|
||||
dict(episode.agent_rewards),
|
||||
episode.custom_metrics)
|
||||
else:
|
||||
all_done = False
|
||||
# At least send an empty dict if not done
|
||||
actions_to_send[env_id] = {}
|
||||
|
||||
# For each agent in the environment
|
||||
for agent_id, raw_obs in agent_obs.items():
|
||||
policy_id = episode.policy_for(agent_id)
|
||||
filtered_obs = _get_or_raise(obs_filters, policy_id)(raw_obs)
|
||||
agent_done = bool(all_done or dones[env_id].get(agent_id))
|
||||
if not agent_done:
|
||||
to_eval[policy_id].append(
|
||||
PolicyEvalData(env_id, agent_id, filtered_obs,
|
||||
episode.rnn_state_for(agent_id),
|
||||
episode.last_action_for(agent_id),
|
||||
rewards[env_id][agent_id] or 0.0))
|
||||
|
||||
last_observation = episode.last_observation_for(agent_id)
|
||||
episode._set_last_observation(agent_id, filtered_obs)
|
||||
|
||||
# Record transition info if applicable
|
||||
if last_observation is not None and \
|
||||
infos[env_id][agent_id].get("training_enabled", True):
|
||||
episode.batch_builder.add_values(
|
||||
agent_id,
|
||||
policy_id,
|
||||
t=episode.length - 1,
|
||||
eps_id=episode.episode_id,
|
||||
agent_index=episode._agent_index(agent_id),
|
||||
obs=last_observation,
|
||||
actions=episode.last_action_for(agent_id),
|
||||
rewards=rewards[env_id][agent_id],
|
||||
prev_actions=episode.prev_action_for(agent_id),
|
||||
prev_rewards=episode.prev_reward_for(agent_id),
|
||||
dones=agent_done,
|
||||
infos=infos[env_id][agent_id],
|
||||
new_obs=filtered_obs,
|
||||
**episode.last_pi_info_for(agent_id))
|
||||
|
||||
# Invoke the step callback after the step is logged to the episode
|
||||
if callbacks.get("on_episode_step"):
|
||||
callbacks["on_episode_step"]({
|
||||
"env": async_vector_env,
|
||||
"episode": episode
|
||||
})
|
||||
|
||||
# Cut the batch if we're not packing multiple episodes into one,
|
||||
# or if we've exceeded the requested batch size.
|
||||
if episode.batch_builder.has_pending_data():
|
||||
if (all_done and not pack) or \
|
||||
episode.batch_builder.count >= unroll_length:
|
||||
yield episode.batch_builder.build_and_reset(episode)
|
||||
elif all_done:
|
||||
# Make sure postprocessor stays within one episode
|
||||
episode.batch_builder.postprocess_batch_so_far(episode)
|
||||
|
||||
if all_done:
|
||||
# Handle episode termination
|
||||
batch_builder_pool.append(episode.batch_builder)
|
||||
if callbacks.get("on_episode_end"):
|
||||
callbacks["on_episode_end"]({
|
||||
"env": async_vector_env,
|
||||
"episode": episode
|
||||
})
|
||||
del active_episodes[env_id]
|
||||
resetted_obs = async_vector_env.try_reset(env_id)
|
||||
if resetted_obs is None:
|
||||
# Reset not supported, drop this env from the ready list
|
||||
assert horizon == float("inf"), \
|
||||
"Setting episode horizon requires reset() support."
|
||||
else:
|
||||
# Creates a new episode
|
||||
episode = active_episodes[env_id]
|
||||
for agent_id, raw_obs in resetted_obs.items():
|
||||
policy_id = episode.policy_for(agent_id)
|
||||
policy = _get_or_raise(policies, policy_id)
|
||||
filtered_obs = _get_or_raise(obs_filters,
|
||||
policy_id)(raw_obs)
|
||||
episode._set_last_observation(agent_id, filtered_obs)
|
||||
to_eval[policy_id].append(
|
||||
PolicyEvalData(
|
||||
env_id, agent_id, filtered_obs,
|
||||
episode.rnn_state_for(agent_id),
|
||||
np.zeros_like(
|
||||
_flatten_action(
|
||||
policy.action_space.sample())), 0.0))
|
||||
|
||||
# Batch eval policy actions if possible
|
||||
if tf_sess:
|
||||
builder = TFRunBuilder(tf_sess, "policy_eval")
|
||||
pending_fetches = {}
|
||||
else:
|
||||
builder = None
|
||||
eval_results = {}
|
||||
rnn_in_cols = {}
|
||||
for policy_id, eval_data in to_eval.items():
|
||||
rnn_in = _to_column_format([t.rnn_state for t in eval_data])
|
||||
rnn_in_cols[policy_id] = rnn_in
|
||||
policy = _get_or_raise(policies, policy_id)
|
||||
if builder and (policy.compute_actions.__code__ is
|
||||
TFPolicyGraph.compute_actions.__code__):
|
||||
pending_fetches[policy_id] = policy.build_compute_actions(
|
||||
builder, [t.obs for t in eval_data],
|
||||
rnn_in,
|
||||
prev_action_batch=[t.prev_action for t in eval_data],
|
||||
prev_reward_batch=[t.prev_reward for t in eval_data],
|
||||
is_training=True)
|
||||
else:
|
||||
eval_results[policy_id] = policy.compute_actions(
|
||||
[t.obs for t in eval_data],
|
||||
rnn_in,
|
||||
prev_action_batch=[t.prev_action for t in eval_data],
|
||||
prev_reward_batch=[t.prev_reward for t in eval_data],
|
||||
is_training=True,
|
||||
episodes=[active_episodes[t.env_id] for t in eval_data])
|
||||
if builder:
|
||||
for k, v in pending_fetches.items():
|
||||
eval_results[k] = builder.get(v)
|
||||
|
||||
# Record the policy eval results
|
||||
for policy_id, eval_data in to_eval.items():
|
||||
actions, rnn_out_cols, pi_info_cols = eval_results[policy_id]
|
||||
if len(rnn_in_cols[policy_id]) != len(rnn_out_cols):
|
||||
raise ValueError(
|
||||
"Length of RNN in did not match RNN out, got: "
|
||||
"{} vs {}".format(rnn_in_cols[policy_id], rnn_out_cols))
|
||||
# Add RNN state info
|
||||
for f_i, column in enumerate(rnn_in_cols[policy_id]):
|
||||
pi_info_cols["state_in_{}".format(f_i)] = column
|
||||
for f_i, column in enumerate(rnn_out_cols):
|
||||
pi_info_cols["state_out_{}".format(f_i)] = column
|
||||
# Save output rows
|
||||
actions = _unbatch_tuple_actions(actions)
|
||||
for i, action in enumerate(actions):
|
||||
env_id = eval_data[i].env_id
|
||||
agent_id = eval_data[i].agent_id
|
||||
actions_to_send[env_id][agent_id] = action
|
||||
episode = active_episodes[env_id]
|
||||
episode._set_rnn_state(agent_id, [c[i] for c in rnn_out_cols])
|
||||
episode._set_last_pi_info(
|
||||
agent_id, {k: v[i]
|
||||
for k, v in pi_info_cols.items()})
|
||||
if env_id in off_policy_actions and \
|
||||
agent_id in off_policy_actions[env_id]:
|
||||
episode._set_last_action(
|
||||
agent_id, off_policy_actions[env_id][agent_id])
|
||||
else:
|
||||
episode._set_last_action(agent_id, action)
|
||||
# Process results and update episode state
|
||||
actions_to_send = _process_policy_eval_results(
|
||||
to_eval, eval_results, active_episodes, active_envs,
|
||||
off_policy_actions)
|
||||
|
||||
# Return computed actions to ready envs. We also send to envs that have
|
||||
# taken off-policy actions; those envs are free to ignore the action.
|
||||
async_vector_env.send_actions(dict(actions_to_send))
|
||||
async_vector_env.send_actions(actions_to_send)
|
||||
|
||||
|
||||
def _process_observations(async_vector_env, policies, batch_builder_pool,
|
||||
active_episodes, unfiltered_obs, rewards, dones,
|
||||
infos, off_policy_actions, horizon, obs_filters,
|
||||
unroll_length, pack, callbacks):
|
||||
"""Record new data from the environment and prepare for policy evaluation.
|
||||
|
||||
Returns:
|
||||
active_envs: set of non-terminated env ids
|
||||
to_eval: map of policy_id to list of agent PolicyEvalData
|
||||
outputs: list of metrics and samples to return from the sampler
|
||||
"""
|
||||
|
||||
active_envs = set()
|
||||
to_eval = defaultdict(list)
|
||||
outputs = []
|
||||
|
||||
# For each environment
|
||||
for env_id, agent_obs in unfiltered_obs.items():
|
||||
new_episode = env_id not in active_episodes
|
||||
episode = active_episodes[env_id]
|
||||
if not new_episode:
|
||||
episode.length += 1
|
||||
episode.batch_builder.count += 1
|
||||
episode._add_agent_rewards(rewards[env_id])
|
||||
|
||||
# Check episode termination conditions
|
||||
if dones[env_id]["__all__"] or episode.length >= horizon:
|
||||
all_done = True
|
||||
atari_metrics = _fetch_atari_metrics(async_vector_env)
|
||||
if atari_metrics is not None:
|
||||
for m in atari_metrics:
|
||||
outputs.append(
|
||||
m._replace(custom_metrics=episode.custom_metrics))
|
||||
else:
|
||||
outputs.append(
|
||||
RolloutMetrics(episode.length, episode.total_reward,
|
||||
dict(episode.agent_rewards),
|
||||
episode.custom_metrics))
|
||||
else:
|
||||
all_done = False
|
||||
active_envs.add(env_id)
|
||||
|
||||
# For each agent in the environment
|
||||
for agent_id, raw_obs in agent_obs.items():
|
||||
policy_id = episode.policy_for(agent_id)
|
||||
filtered_obs = _get_or_raise(obs_filters, policy_id)(raw_obs)
|
||||
agent_done = bool(all_done or dones[env_id].get(agent_id))
|
||||
if not agent_done:
|
||||
to_eval[policy_id].append(
|
||||
PolicyEvalData(env_id, agent_id, filtered_obs,
|
||||
episode.rnn_state_for(agent_id),
|
||||
episode.last_action_for(agent_id),
|
||||
rewards[env_id][agent_id] or 0.0))
|
||||
|
||||
last_observation = episode.last_observation_for(agent_id)
|
||||
episode._set_last_observation(agent_id, filtered_obs)
|
||||
|
||||
# Record transition info if applicable
|
||||
if last_observation is not None and \
|
||||
infos[env_id][agent_id].get("training_enabled", True):
|
||||
episode.batch_builder.add_values(
|
||||
agent_id,
|
||||
policy_id,
|
||||
t=episode.length - 1,
|
||||
eps_id=episode.episode_id,
|
||||
agent_index=episode._agent_index(agent_id),
|
||||
obs=last_observation,
|
||||
actions=episode.last_action_for(agent_id),
|
||||
rewards=rewards[env_id][agent_id],
|
||||
prev_actions=episode.prev_action_for(agent_id),
|
||||
prev_rewards=episode.prev_reward_for(agent_id),
|
||||
dones=agent_done,
|
||||
infos=infos[env_id][agent_id],
|
||||
new_obs=filtered_obs,
|
||||
**episode.last_pi_info_for(agent_id))
|
||||
|
||||
# Invoke the step callback after the step is logged to the episode
|
||||
if callbacks.get("on_episode_step"):
|
||||
callbacks["on_episode_step"]({
|
||||
"env": async_vector_env,
|
||||
"episode": episode
|
||||
})
|
||||
|
||||
# Cut the batch if we're not packing multiple episodes into one,
|
||||
# or if we've exceeded the requested batch size.
|
||||
if episode.batch_builder.has_pending_data():
|
||||
if (all_done and not pack) or \
|
||||
episode.batch_builder.count >= unroll_length:
|
||||
outputs.append(episode.batch_builder.build_and_reset(episode))
|
||||
elif all_done:
|
||||
# Make sure postprocessor stays within one episode
|
||||
episode.batch_builder.postprocess_batch_so_far(episode)
|
||||
|
||||
if all_done:
|
||||
# Handle episode termination
|
||||
batch_builder_pool.append(episode.batch_builder)
|
||||
if callbacks.get("on_episode_end"):
|
||||
callbacks["on_episode_end"]({
|
||||
"env": async_vector_env,
|
||||
"episode": episode
|
||||
})
|
||||
del active_episodes[env_id]
|
||||
resetted_obs = async_vector_env.try_reset(env_id)
|
||||
if resetted_obs is None:
|
||||
# Reset not supported, drop this env from the ready list
|
||||
if horizon != float("inf"):
|
||||
raise ValueError(
|
||||
"Setting episode horizon requires reset() support "
|
||||
"from the environment.")
|
||||
else:
|
||||
# Creates a new episode
|
||||
episode = active_episodes[env_id]
|
||||
for agent_id, raw_obs in resetted_obs.items():
|
||||
policy_id = episode.policy_for(agent_id)
|
||||
policy = _get_or_raise(policies, policy_id)
|
||||
filtered_obs = _get_or_raise(obs_filters,
|
||||
policy_id)(raw_obs)
|
||||
episode._set_last_observation(agent_id, filtered_obs)
|
||||
to_eval[policy_id].append(
|
||||
PolicyEvalData(
|
||||
env_id, agent_id, filtered_obs,
|
||||
episode.rnn_state_for(agent_id),
|
||||
np.zeros_like(
|
||||
_flatten_action(policy.action_space.sample())),
|
||||
0.0))
|
||||
|
||||
return active_envs, to_eval, outputs
|
||||
|
||||
|
||||
def _do_policy_eval(tf_sess, to_eval, policies, active_episodes):
|
||||
"""Call compute actions on observation batches to get next actions.
|
||||
|
||||
Returns:
|
||||
eval_results: dict of policy to compute_action() outputs.
|
||||
"""
|
||||
|
||||
eval_results = {}
|
||||
|
||||
if tf_sess:
|
||||
builder = TFRunBuilder(tf_sess, "policy_eval")
|
||||
pending_fetches = {}
|
||||
else:
|
||||
builder = None
|
||||
for policy_id, eval_data in to_eval.items():
|
||||
rnn_in_cols = _to_column_format([t.rnn_state for t in eval_data])
|
||||
policy = _get_or_raise(policies, policy_id)
|
||||
if builder and (policy.compute_actions.__code__ is
|
||||
TFPolicyGraph.compute_actions.__code__):
|
||||
pending_fetches[policy_id] = policy.build_compute_actions(
|
||||
builder, [t.obs for t in eval_data],
|
||||
rnn_in_cols,
|
||||
prev_action_batch=[t.prev_action for t in eval_data],
|
||||
prev_reward_batch=[t.prev_reward for t in eval_data],
|
||||
is_training=True)
|
||||
else:
|
||||
eval_results[policy_id] = policy.compute_actions(
|
||||
[t.obs for t in eval_data],
|
||||
rnn_in_cols,
|
||||
prev_action_batch=[t.prev_action for t in eval_data],
|
||||
prev_reward_batch=[t.prev_reward for t in eval_data],
|
||||
is_training=True,
|
||||
episodes=[active_episodes[t.env_id] for t in eval_data])
|
||||
if builder:
|
||||
for k, v in pending_fetches.items():
|
||||
eval_results[k] = builder.get(v)
|
||||
|
||||
return eval_results
|
||||
|
||||
|
||||
def _process_policy_eval_results(to_eval, eval_results, active_episodes,
|
||||
active_envs, off_policy_actions):
|
||||
"""Process the output of policy neural network evaluation.
|
||||
|
||||
Records policy evaluation results into the given episode objects and
|
||||
returns replies to send back to agents in the env.
|
||||
|
||||
Returns:
|
||||
actions_to_send: nested dict of env id -> agent id -> agent replies.
|
||||
"""
|
||||
|
||||
actions_to_send = defaultdict(dict)
|
||||
for env_id in active_envs:
|
||||
actions_to_send[env_id] = {} # at minimum send empty dict
|
||||
|
||||
for policy_id, eval_data in to_eval.items():
|
||||
rnn_in_cols = _to_column_format([t.rnn_state for t in eval_data])
|
||||
actions, rnn_out_cols, pi_info_cols = eval_results[policy_id]
|
||||
if len(rnn_in_cols) != len(rnn_out_cols):
|
||||
raise ValueError("Length of RNN in did not match RNN out, got: "
|
||||
"{} vs {}".format(rnn_in_cols, rnn_out_cols))
|
||||
# Add RNN state info
|
||||
for f_i, column in enumerate(rnn_in_cols):
|
||||
pi_info_cols["state_in_{}".format(f_i)] = column
|
||||
for f_i, column in enumerate(rnn_out_cols):
|
||||
pi_info_cols["state_out_{}".format(f_i)] = column
|
||||
# Save output rows
|
||||
actions = _unbatch_tuple_actions(actions)
|
||||
for i, action in enumerate(actions):
|
||||
env_id = eval_data[i].env_id
|
||||
agent_id = eval_data[i].agent_id
|
||||
actions_to_send[env_id][agent_id] = action
|
||||
episode = active_episodes[env_id]
|
||||
episode._set_rnn_state(agent_id, [c[i] for c in rnn_out_cols])
|
||||
episode._set_last_pi_info(
|
||||
agent_id, {k: v[i]
|
||||
for k, v in pi_info_cols.items()})
|
||||
if env_id in off_policy_actions and \
|
||||
agent_id in off_policy_actions[env_id]:
|
||||
episode._set_last_action(agent_id,
|
||||
off_policy_actions[env_id][agent_id])
|
||||
else:
|
||||
episode._set_last_action(agent_id, action)
|
||||
|
||||
return actions_to_send
|
||||
|
||||
|
||||
def _fetch_atari_metrics(async_vector_env):
|
||||
|
||||
@@ -192,8 +192,7 @@ class TestExternalEnv(unittest.TestCase):
|
||||
episode_horizon=20,
|
||||
batch_steps=10,
|
||||
batch_mode="complete_episodes")
|
||||
ev.sample()
|
||||
self.assertRaises(Exception, lambda: ev.sample())
|
||||
self.assertRaises(ValueError, lambda: ev.sample())
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
@@ -152,6 +152,20 @@ class TestPolicyEvaluator(unittest.TestCase):
|
||||
to_prev(batch["actions"]))
|
||||
self.assertGreater(batch["advantages"][0], 1)
|
||||
|
||||
# 11/23/18: Samples per second 8501.125113727468
|
||||
def testBaselinePerformance(self):
|
||||
ev = PolicyEvaluator(
|
||||
env_creator=lambda _: gym.make("CartPole-v0"),
|
||||
policy_graph=MockPolicyGraph,
|
||||
batch_steps=100)
|
||||
start = time.time()
|
||||
count = 0
|
||||
while time.time() - start < 1:
|
||||
count += ev.sample().count
|
||||
print()
|
||||
print("Samples per second {}".format(count / (time.time() - start)))
|
||||
print()
|
||||
|
||||
def testGlobalVarsUpdate(self):
|
||||
agent = A2CAgent(
|
||||
env="CartPole-v0",
|
||||
|
||||
Reference in New Issue
Block a user