diff --git a/python/ray/rllib/evaluation/sampler.py b/python/ray/rllib/evaluation/sampler.py index c578ea0a6..2fd2fc4e2 100644 --- a/python/ray/rllib/evaluation/sampler.py +++ b/python/ray/rllib/evaluation/sampler.py @@ -262,184 +262,242 @@ def _env_runner(async_vector_env, unfiltered_obs, rewards, dones, infos, off_policy_actions = \ async_vector_env.poll() - # Map of policy_id to list of PolicyEvalData - to_eval = defaultdict(list) + # Process observations and prepare for policy evaluation + active_envs, to_eval, outputs = _process_observations( + async_vector_env, policies, batch_builder_pool, active_episodes, + unfiltered_obs, rewards, dones, infos, off_policy_actions, horizon, + obs_filters, unroll_length, pack, callbacks) + for o in outputs: + yield o - # Map of env_id -> agent_id -> action replies - actions_to_send = defaultdict(dict) + # Do batched policy eval + eval_results = _do_policy_eval(tf_sess, to_eval, policies, + active_episodes) - # For each environment - for env_id, agent_obs in unfiltered_obs.items(): - new_episode = env_id not in active_episodes - episode = active_episodes[env_id] - if not new_episode: - episode.length += 1 - episode.batch_builder.count += 1 - episode._add_agent_rewards(rewards[env_id]) - - # Check episode termination conditions - if dones[env_id]["__all__"] or episode.length >= horizon: - all_done = True - atari_metrics = _fetch_atari_metrics(async_vector_env) - if atari_metrics is not None: - for m in atari_metrics: - yield m._replace(custom_metrics=episode.custom_metrics) - else: - yield RolloutMetrics(episode.length, episode.total_reward, - dict(episode.agent_rewards), - episode.custom_metrics) - else: - all_done = False - # At least send an empty dict if not done - actions_to_send[env_id] = {} - - # For each agent in the environment - for agent_id, raw_obs in agent_obs.items(): - policy_id = episode.policy_for(agent_id) - filtered_obs = _get_or_raise(obs_filters, policy_id)(raw_obs) - agent_done = bool(all_done or dones[env_id].get(agent_id)) - if not agent_done: - to_eval[policy_id].append( - PolicyEvalData(env_id, agent_id, filtered_obs, - episode.rnn_state_for(agent_id), - episode.last_action_for(agent_id), - rewards[env_id][agent_id] or 0.0)) - - last_observation = episode.last_observation_for(agent_id) - episode._set_last_observation(agent_id, filtered_obs) - - # Record transition info if applicable - if last_observation is not None and \ - infos[env_id][agent_id].get("training_enabled", True): - episode.batch_builder.add_values( - agent_id, - policy_id, - t=episode.length - 1, - eps_id=episode.episode_id, - agent_index=episode._agent_index(agent_id), - obs=last_observation, - actions=episode.last_action_for(agent_id), - rewards=rewards[env_id][agent_id], - prev_actions=episode.prev_action_for(agent_id), - prev_rewards=episode.prev_reward_for(agent_id), - dones=agent_done, - infos=infos[env_id][agent_id], - new_obs=filtered_obs, - **episode.last_pi_info_for(agent_id)) - - # Invoke the step callback after the step is logged to the episode - if callbacks.get("on_episode_step"): - callbacks["on_episode_step"]({ - "env": async_vector_env, - "episode": episode - }) - - # Cut the batch if we're not packing multiple episodes into one, - # or if we've exceeded the requested batch size. - if episode.batch_builder.has_pending_data(): - if (all_done and not pack) or \ - episode.batch_builder.count >= unroll_length: - yield episode.batch_builder.build_and_reset(episode) - elif all_done: - # Make sure postprocessor stays within one episode - episode.batch_builder.postprocess_batch_so_far(episode) - - if all_done: - # Handle episode termination - batch_builder_pool.append(episode.batch_builder) - if callbacks.get("on_episode_end"): - callbacks["on_episode_end"]({ - "env": async_vector_env, - "episode": episode - }) - del active_episodes[env_id] - resetted_obs = async_vector_env.try_reset(env_id) - if resetted_obs is None: - # Reset not supported, drop this env from the ready list - assert horizon == float("inf"), \ - "Setting episode horizon requires reset() support." - else: - # Creates a new episode - episode = active_episodes[env_id] - for agent_id, raw_obs in resetted_obs.items(): - policy_id = episode.policy_for(agent_id) - policy = _get_or_raise(policies, policy_id) - filtered_obs = _get_or_raise(obs_filters, - policy_id)(raw_obs) - episode._set_last_observation(agent_id, filtered_obs) - to_eval[policy_id].append( - PolicyEvalData( - env_id, agent_id, filtered_obs, - episode.rnn_state_for(agent_id), - np.zeros_like( - _flatten_action( - policy.action_space.sample())), 0.0)) - - # Batch eval policy actions if possible - if tf_sess: - builder = TFRunBuilder(tf_sess, "policy_eval") - pending_fetches = {} - else: - builder = None - eval_results = {} - rnn_in_cols = {} - for policy_id, eval_data in to_eval.items(): - rnn_in = _to_column_format([t.rnn_state for t in eval_data]) - rnn_in_cols[policy_id] = rnn_in - policy = _get_or_raise(policies, policy_id) - if builder and (policy.compute_actions.__code__ is - TFPolicyGraph.compute_actions.__code__): - pending_fetches[policy_id] = policy.build_compute_actions( - builder, [t.obs for t in eval_data], - rnn_in, - prev_action_batch=[t.prev_action for t in eval_data], - prev_reward_batch=[t.prev_reward for t in eval_data], - is_training=True) - else: - eval_results[policy_id] = policy.compute_actions( - [t.obs for t in eval_data], - rnn_in, - prev_action_batch=[t.prev_action for t in eval_data], - prev_reward_batch=[t.prev_reward for t in eval_data], - is_training=True, - episodes=[active_episodes[t.env_id] for t in eval_data]) - if builder: - for k, v in pending_fetches.items(): - eval_results[k] = builder.get(v) - - # Record the policy eval results - for policy_id, eval_data in to_eval.items(): - actions, rnn_out_cols, pi_info_cols = eval_results[policy_id] - if len(rnn_in_cols[policy_id]) != len(rnn_out_cols): - raise ValueError( - "Length of RNN in did not match RNN out, got: " - "{} vs {}".format(rnn_in_cols[policy_id], rnn_out_cols)) - # Add RNN state info - for f_i, column in enumerate(rnn_in_cols[policy_id]): - pi_info_cols["state_in_{}".format(f_i)] = column - for f_i, column in enumerate(rnn_out_cols): - pi_info_cols["state_out_{}".format(f_i)] = column - # Save output rows - actions = _unbatch_tuple_actions(actions) - for i, action in enumerate(actions): - env_id = eval_data[i].env_id - agent_id = eval_data[i].agent_id - actions_to_send[env_id][agent_id] = action - episode = active_episodes[env_id] - episode._set_rnn_state(agent_id, [c[i] for c in rnn_out_cols]) - episode._set_last_pi_info( - agent_id, {k: v[i] - for k, v in pi_info_cols.items()}) - if env_id in off_policy_actions and \ - agent_id in off_policy_actions[env_id]: - episode._set_last_action( - agent_id, off_policy_actions[env_id][agent_id]) - else: - episode._set_last_action(agent_id, action) + # Process results and update episode state + actions_to_send = _process_policy_eval_results( + to_eval, eval_results, active_episodes, active_envs, + off_policy_actions) # Return computed actions to ready envs. We also send to envs that have # taken off-policy actions; those envs are free to ignore the action. - async_vector_env.send_actions(dict(actions_to_send)) + async_vector_env.send_actions(actions_to_send) + + +def _process_observations(async_vector_env, policies, batch_builder_pool, + active_episodes, unfiltered_obs, rewards, dones, + infos, off_policy_actions, horizon, obs_filters, + unroll_length, pack, callbacks): + """Record new data from the environment and prepare for policy evaluation. + + Returns: + active_envs: set of non-terminated env ids + to_eval: map of policy_id to list of agent PolicyEvalData + outputs: list of metrics and samples to return from the sampler + """ + + active_envs = set() + to_eval = defaultdict(list) + outputs = [] + + # For each environment + for env_id, agent_obs in unfiltered_obs.items(): + new_episode = env_id not in active_episodes + episode = active_episodes[env_id] + if not new_episode: + episode.length += 1 + episode.batch_builder.count += 1 + episode._add_agent_rewards(rewards[env_id]) + + # Check episode termination conditions + if dones[env_id]["__all__"] or episode.length >= horizon: + all_done = True + atari_metrics = _fetch_atari_metrics(async_vector_env) + if atari_metrics is not None: + for m in atari_metrics: + outputs.append( + m._replace(custom_metrics=episode.custom_metrics)) + else: + outputs.append( + RolloutMetrics(episode.length, episode.total_reward, + dict(episode.agent_rewards), + episode.custom_metrics)) + else: + all_done = False + active_envs.add(env_id) + + # For each agent in the environment + for agent_id, raw_obs in agent_obs.items(): + policy_id = episode.policy_for(agent_id) + filtered_obs = _get_or_raise(obs_filters, policy_id)(raw_obs) + agent_done = bool(all_done or dones[env_id].get(agent_id)) + if not agent_done: + to_eval[policy_id].append( + PolicyEvalData(env_id, agent_id, filtered_obs, + episode.rnn_state_for(agent_id), + episode.last_action_for(agent_id), + rewards[env_id][agent_id] or 0.0)) + + last_observation = episode.last_observation_for(agent_id) + episode._set_last_observation(agent_id, filtered_obs) + + # Record transition info if applicable + if last_observation is not None and \ + infos[env_id][agent_id].get("training_enabled", True): + episode.batch_builder.add_values( + agent_id, + policy_id, + t=episode.length - 1, + eps_id=episode.episode_id, + agent_index=episode._agent_index(agent_id), + obs=last_observation, + actions=episode.last_action_for(agent_id), + rewards=rewards[env_id][agent_id], + prev_actions=episode.prev_action_for(agent_id), + prev_rewards=episode.prev_reward_for(agent_id), + dones=agent_done, + infos=infos[env_id][agent_id], + new_obs=filtered_obs, + **episode.last_pi_info_for(agent_id)) + + # Invoke the step callback after the step is logged to the episode + if callbacks.get("on_episode_step"): + callbacks["on_episode_step"]({ + "env": async_vector_env, + "episode": episode + }) + + # Cut the batch if we're not packing multiple episodes into one, + # or if we've exceeded the requested batch size. + if episode.batch_builder.has_pending_data(): + if (all_done and not pack) or \ + episode.batch_builder.count >= unroll_length: + outputs.append(episode.batch_builder.build_and_reset(episode)) + elif all_done: + # Make sure postprocessor stays within one episode + episode.batch_builder.postprocess_batch_so_far(episode) + + if all_done: + # Handle episode termination + batch_builder_pool.append(episode.batch_builder) + if callbacks.get("on_episode_end"): + callbacks["on_episode_end"]({ + "env": async_vector_env, + "episode": episode + }) + del active_episodes[env_id] + resetted_obs = async_vector_env.try_reset(env_id) + if resetted_obs is None: + # Reset not supported, drop this env from the ready list + if horizon != float("inf"): + raise ValueError( + "Setting episode horizon requires reset() support " + "from the environment.") + else: + # Creates a new episode + episode = active_episodes[env_id] + for agent_id, raw_obs in resetted_obs.items(): + policy_id = episode.policy_for(agent_id) + policy = _get_or_raise(policies, policy_id) + filtered_obs = _get_or_raise(obs_filters, + policy_id)(raw_obs) + episode._set_last_observation(agent_id, filtered_obs) + to_eval[policy_id].append( + PolicyEvalData( + env_id, agent_id, filtered_obs, + episode.rnn_state_for(agent_id), + np.zeros_like( + _flatten_action(policy.action_space.sample())), + 0.0)) + + return active_envs, to_eval, outputs + + +def _do_policy_eval(tf_sess, to_eval, policies, active_episodes): + """Call compute actions on observation batches to get next actions. + + Returns: + eval_results: dict of policy to compute_action() outputs. + """ + + eval_results = {} + + if tf_sess: + builder = TFRunBuilder(tf_sess, "policy_eval") + pending_fetches = {} + else: + builder = None + for policy_id, eval_data in to_eval.items(): + rnn_in_cols = _to_column_format([t.rnn_state for t in eval_data]) + policy = _get_or_raise(policies, policy_id) + if builder and (policy.compute_actions.__code__ is + TFPolicyGraph.compute_actions.__code__): + pending_fetches[policy_id] = policy.build_compute_actions( + builder, [t.obs for t in eval_data], + rnn_in_cols, + prev_action_batch=[t.prev_action for t in eval_data], + prev_reward_batch=[t.prev_reward for t in eval_data], + is_training=True) + else: + eval_results[policy_id] = policy.compute_actions( + [t.obs for t in eval_data], + rnn_in_cols, + prev_action_batch=[t.prev_action for t in eval_data], + prev_reward_batch=[t.prev_reward for t in eval_data], + is_training=True, + episodes=[active_episodes[t.env_id] for t in eval_data]) + if builder: + for k, v in pending_fetches.items(): + eval_results[k] = builder.get(v) + + return eval_results + + +def _process_policy_eval_results(to_eval, eval_results, active_episodes, + active_envs, off_policy_actions): + """Process the output of policy neural network evaluation. + + Records policy evaluation results into the given episode objects and + returns replies to send back to agents in the env. + + Returns: + actions_to_send: nested dict of env id -> agent id -> agent replies. + """ + + actions_to_send = defaultdict(dict) + for env_id in active_envs: + actions_to_send[env_id] = {} # at minimum send empty dict + + for policy_id, eval_data in to_eval.items(): + rnn_in_cols = _to_column_format([t.rnn_state for t in eval_data]) + actions, rnn_out_cols, pi_info_cols = eval_results[policy_id] + if len(rnn_in_cols) != len(rnn_out_cols): + raise ValueError("Length of RNN in did not match RNN out, got: " + "{} vs {}".format(rnn_in_cols, rnn_out_cols)) + # Add RNN state info + for f_i, column in enumerate(rnn_in_cols): + pi_info_cols["state_in_{}".format(f_i)] = column + for f_i, column in enumerate(rnn_out_cols): + pi_info_cols["state_out_{}".format(f_i)] = column + # Save output rows + actions = _unbatch_tuple_actions(actions) + for i, action in enumerate(actions): + env_id = eval_data[i].env_id + agent_id = eval_data[i].agent_id + actions_to_send[env_id][agent_id] = action + episode = active_episodes[env_id] + episode._set_rnn_state(agent_id, [c[i] for c in rnn_out_cols]) + episode._set_last_pi_info( + agent_id, {k: v[i] + for k, v in pi_info_cols.items()}) + if env_id in off_policy_actions and \ + agent_id in off_policy_actions[env_id]: + episode._set_last_action(agent_id, + off_policy_actions[env_id][agent_id]) + else: + episode._set_last_action(agent_id, action) + + return actions_to_send def _fetch_atari_metrics(async_vector_env): diff --git a/python/ray/rllib/test/test_external_env.py b/python/ray/rllib/test/test_external_env.py index c574ba633..f7e8308a5 100644 --- a/python/ray/rllib/test/test_external_env.py +++ b/python/ray/rllib/test/test_external_env.py @@ -192,8 +192,7 @@ class TestExternalEnv(unittest.TestCase): episode_horizon=20, batch_steps=10, batch_mode="complete_episodes") - ev.sample() - self.assertRaises(Exception, lambda: ev.sample()) + self.assertRaises(ValueError, lambda: ev.sample()) if __name__ == '__main__': diff --git a/python/ray/rllib/test/test_policy_evaluator.py b/python/ray/rllib/test/test_policy_evaluator.py index ab51771a6..7b4d6c8b5 100644 --- a/python/ray/rllib/test/test_policy_evaluator.py +++ b/python/ray/rllib/test/test_policy_evaluator.py @@ -152,6 +152,20 @@ class TestPolicyEvaluator(unittest.TestCase): to_prev(batch["actions"])) self.assertGreater(batch["advantages"][0], 1) + # 11/23/18: Samples per second 8501.125113727468 + def testBaselinePerformance(self): + ev = PolicyEvaluator( + env_creator=lambda _: gym.make("CartPole-v0"), + policy_graph=MockPolicyGraph, + batch_steps=100) + start = time.time() + count = 0 + while time.time() - start < 1: + count += ev.sample().count + print() + print("Samples per second {}".format(count / (time.time() - start))) + print() + def testGlobalVarsUpdate(self): agent = A2CAgent( env="CartPole-v0",