[rllib] Provide internal access to episode state in compute_actions() and allow returning extra batches (#2559)

The goal of this PR is to allow custom policies to perform model-based rollouts. In the multi-agent setting, this requires access to not only policies of other agents, but also their current observations.
Also, you might want to return the model-based trajectories as part of the rollout for efficiency.

  compute_actions() now takes a new keyword arg episodes
  pull out internal episode class into a top-level file
  add function to return extra trajectories from an episode that will be appended to the sample batch
  documentation
This commit is contained in:
Eric Liang
2018-08-16 14:37:21 -07:00
committed by GitHub
parent 127cf291a3
commit 5f430da180
12 changed files with 286 additions and 97 deletions
+3 -1
View File
@@ -1,3 +1,4 @@
from ray.rllib.evaluation.episode import MultiAgentEpisode
from ray.rllib.evaluation.policy_evaluator import PolicyEvaluator
from ray.rllib.evaluation.interface import EvaluatorInterface
from ray.rllib.evaluation.policy_graph import PolicyGraph
@@ -15,5 +16,6 @@ __all__ = [
"EvaluatorInterface", "PolicyEvaluator", "PolicyGraph", "TFPolicyGraph",
"TorchPolicyGraph", "SampleBatch", "MultiAgentBatch", "SampleBatchBuilder",
"MultiAgentSampleBatchBuilder", "SyncSampler", "AsyncSampler",
"compute_advantages", "compute_targets", "collect_metrics"
"compute_advantages", "compute_targets", "collect_metrics",
"MultiAgentEpisode"
]
+119
View File
@@ -0,0 +1,119 @@
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from collections import defaultdict
import random
import numpy as np
class MultiAgentEpisode(object):
"""Tracks the current state of a (possibly multi-agent) episode.
The APIs in this class should be considered experimental, but we should
avoid changing things for the sake of changing them since users may
depend on them for advanced algorithms.
Attributes:
new_batch_builder (func): Create a new MultiAgentSampleBatchBuilder.
add_extra_batch (func): Return a built MultiAgentBatch to the sampler.
batch_builder (obj): Batch builder for the current episode.
total_reward (float): Summed reward across all agents in this episode.
length (int): Length of this episode.
episode_id (int): Unique id identifying this trajectory.
agent_rewards (dict): Summed rewards broken down by agent.
Use case 1: Model-based rollouts in multi-agent:
A custom compute_actions() function in a policy graph can inspect the
current episode state and perform a number of rollouts based on the
policies and state of other agents in the environment.
Use case 2: Returning extra rollouts data.
The model rollouts can be returned back to the sampler by calling:
>>> batch = episode.new_batch_builder()
>>> for each transition:
batch.add_values(...) # see sampler for usage
>>> episode.extra_batches.add(batch.build_and_reset())
"""
def __init__(self, policies, policy_mapping_fn, batch_builder_factory,
extra_batch_callback):
self.new_batch_builder = batch_builder_factory
self.add_extra_batch = extra_batch_callback
self.batch_builder = batch_builder_factory()
self.total_reward = 0.0
self.length = 0
self.episode_id = random.randrange(2e9)
self.agent_rewards = defaultdict(float)
self._policies = policies
self._policy_mapping_fn = policy_mapping_fn
self._agent_to_policy = {}
self._agent_to_rnn_state = {}
self._agent_to_last_obs = {}
self._agent_to_last_action = {}
self._agent_to_last_pi_info = {}
def policy_for(self, agent_id):
"""Returns the policy graph for the specified agent.
If the agent is new, the policy mapping fn will be called to bind the
agent to a policy for the duration of the episode.
"""
if agent_id not in self._agent_to_policy:
self._agent_to_policy[agent_id] = self._policy_mapping_fn(agent_id)
return self._agent_to_policy[agent_id]
def last_observation_for(self, agent_id):
"""Returns the last observation for the specified agent."""
return self._agent_to_last_obs.get(agent_id)
def last_action_for(self, agent_id):
"""Returns the last action for the specified agent."""
action = self._agent_to_last_action[agent_id]
# Concatenate tuple actions
if isinstance(action, list):
expanded = []
for a in action:
if len(a.shape) == 1:
expanded.append(np.expand_dims(a, 1))
else:
expanded.append(a)
action = np.concatenate(expanded, axis=1).flatten()
return action
def rnn_state_for(self, agent_id):
"""Returns the last RNN state for the specified agent."""
if agent_id not in self._agent_to_rnn_state:
policy = self._policies[self.policy_for(agent_id)]
self._agent_to_rnn_state[agent_id] = policy.get_initial_state()
return self._agent_to_rnn_state[agent_id]
def last_pi_info_for(self, agent_id):
"""Returns the last info object for the specified agent."""
return self._agent_to_last_pi_info[agent_id]
def _add_agent_rewards(self, reward_dict):
for agent_id, reward in reward_dict.items():
if reward is not None:
self.agent_rewards[agent_id,
self.policy_for(agent_id)] += reward
self.total_reward += reward
def _set_rnn_state(self, agent_id, rnn_state):
self._agent_to_rnn_state[agent_id] = rnn_state
def _set_last_observation(self, agent_id, obs):
self._agent_to_last_obs[agent_id] = obs
def _set_last_action(self, agent_id, action):
self._agent_to_last_action[agent_id] = action
def _set_last_pi_info(self, agent_id, pi_info):
self._agent_to_last_pi_info[agent_id] = pi_info
@@ -292,6 +292,7 @@ class PolicyEvaluator(EvaluatorInterface):
batch = self.sampler.get_data()
steps_so_far += batch.count
batches.append(batch)
batches.extend(self.sampler.get_extra_batches())
batch = batches[0].concat_samples(batches)
if self.compress_observations:
+20 -3
View File
@@ -37,13 +37,20 @@ class PolicyGraph(object):
self.observation_space = observation_space
self.action_space = action_space
def compute_actions(self, obs_batch, state_batches, is_training=False):
def compute_actions(self,
obs_batch,
state_batches,
is_training=False,
episodes=None):
"""Compute actions for the current policy.
Arguments:
obs_batch (np.ndarray): batch of observations
state_batches (list): list of RNN state input batches, if any
is_training (bool): whether we are training the policy
episodes (list): MultiAgentEpisode for each obs in obs_batch.
This provides access to all of the internal episode state,
which may be useful for model-based or multiagent algorithms.
Returns:
actions (np.ndarray): batch of output actions, with shape like
@@ -55,13 +62,20 @@ class PolicyGraph(object):
"""
raise NotImplementedError
def compute_single_action(self, obs, state, is_training=False):
def compute_single_action(self,
obs,
state,
is_training=False,
episode=None):
"""Unbatched version of compute_actions.
Arguments:
obs (obj): single observation
state_batches (list): list of RNN state inputs, if any
is_training (bool): whether we are training the policy
episode (MultiAgentEpisode): this provides access to all of the
internal episode state, which may be useful for model-based or
multiagent algorithms.
Returns:
actions (obj): single action
@@ -70,13 +84,16 @@ class PolicyGraph(object):
"""
[action], state_out, info = self.compute_actions(
[obs], [[s] for s in state], is_training)
[obs], [[s] for s in state], is_training, episodes=[episode])
return action, [s[0] for s in state_out], \
{k: v[0] for k, v in info.items()}
def postprocess_trajectory(self, sample_batch, other_agent_batches=None):
"""Implements algorithm-specific trajectory postprocessing.
This will be called on each trajectory fragment computed during policy
evaluation. Each fragment is guaranteed to be only from one episode.
Arguments:
sample_batch (SampleBatch): batch of experiences for the policy,
which will contain at most one episode trajectory.
@@ -117,6 +117,11 @@ class MultiAgentSampleBatchBuilder(object):
other_batches = pre_batches.copy()
del other_batches[agent_id]
policy = self.policy_map[self.agent_to_policy[agent_id]]
if any(pre_batch["dones"][:-1]) or len(set(
pre_batch["eps_id"])) > 1:
raise ValueError(
"Batches sent to postprocessing must only contain steps "
"from a single trajectory.", pre_batch)
post_batches[agent_id] = policy.postprocess_trajectory(
pre_batch, other_batches)
+49 -86
View File
@@ -3,13 +3,13 @@ from __future__ import division
from __future__ import print_function
from collections import defaultdict, namedtuple
import numpy as np
import random
import six.moves.queue as queue
import threading
from ray.rllib.evaluation.episode import MultiAgentEpisode
from ray.rllib.evaluation.sample_batch import MultiAgentSampleBatchBuilder, \
MultiAgentBatch
from ray.rllib.evaluation.tf_policy_graph import TFPolicyGraph
from ray.rllib.env.async_vector_env import AsyncVectorEnv
from ray.rllib.utils.tf_run_builder import TFRunBuilder
@@ -44,10 +44,11 @@ class SyncSampler(object):
self.policies = policies
self.policy_mapping_fn = policy_mapping_fn
self._obs_filters = obs_filters
self.extra_batches = queue.Queue()
self.rollout_provider = _env_runner(
self.async_vector_env, self.policies, self.policy_mapping_fn,
self.num_local_steps, self.horizon, self._obs_filters, pack,
tf_sess)
self.async_vector_env, self.extra_batches.put, self.policies,
self.policy_mapping_fn, self.num_local_steps, self.horizon,
self._obs_filters, pack, tf_sess)
self.metrics_queue = queue.Queue()
def get_data(self):
@@ -67,6 +68,15 @@ class SyncSampler(object):
break
return completed
def get_extra_batches(self):
extra = []
while True:
try:
extra.append(self.extra_batches.get_nowait())
except queue.Empty:
break
return extra
class AsyncSampler(threading.Thread):
"""This class interacts with the environment and tells it what to do.
@@ -89,6 +99,7 @@ class AsyncSampler(threading.Thread):
self.async_vector_env = AsyncVectorEnv.wrap_async(env)
threading.Thread.__init__(self)
self.queue = queue.Queue(5)
self.extra_batches = queue.Queue()
self.metrics_queue = queue.Queue()
self.num_local_steps = num_local_steps
self.horizon = horizon
@@ -108,9 +119,9 @@ class AsyncSampler(threading.Thread):
def _run(self):
rollout_provider = _env_runner(
self.async_vector_env, self.policies, self.policy_mapping_fn,
self.num_local_steps, self.horizon, self._obs_filters, self.pack,
self.tf_sess)
self.async_vector_env, self.extra_batches.put, self.policies,
self.policy_mapping_fn, self.num_local_steps, self.horizon,
self._obs_filters, self.pack, self.tf_sess)
while True:
# The timeout variable exists because apparently, if one worker
# dies, the other workers won't die with it, unless the timeout is
@@ -153,8 +164,18 @@ class AsyncSampler(threading.Thread):
break
return completed
def get_extra_batches(self):
extra = []
while True:
try:
extra.append(self.extra_batches.get_nowait())
except queue.Empty:
break
return extra
def _env_runner(async_vector_env,
extra_batch_callback,
policies,
policy_mapping_fn,
num_local_steps,
@@ -166,6 +187,7 @@ def _env_runner(async_vector_env,
Args:
async_vector_env (AsyncVectorEnv): env implementing AsyncVectorEnv.
extra_batch_callback (fn): function to send extra batch data to.
policies (dict): Map of policy ids to PolicyGraph instances.
policy_mapping_fn (func): Function that maps agent ids to policy ids.
This is called when an agent first enters the environment. The
@@ -204,8 +226,8 @@ def _env_runner(async_vector_env,
return MultiAgentSampleBatchBuilder(policies)
def new_episode():
return _MultiAgentEpisode(policies, policy_mapping_fn,
get_batch_builder)
return MultiAgentEpisode(policies, policy_mapping_fn,
get_batch_builder, extra_batch_callback)
active_episodes = defaultdict(new_episode)
@@ -227,7 +249,7 @@ def _env_runner(async_vector_env,
if not new_episode:
episode.length += 1
episode.batch_builder.count += 1
episode.add_agent_rewards(rewards[env_id])
episode._add_agent_rewards(rewards[env_id])
# Check episode termination conditions
if dones[env_id]["__all__"] or episode.length >= horizon:
@@ -250,7 +272,7 @@ def _env_runner(async_vector_env,
episode.rnn_state_for(agent_id)))
last_observation = episode.last_observation_for(agent_id)
episode.set_last_observation(agent_id, filtered_obs)
episode._set_last_observation(agent_id, filtered_obs)
# Record transition info if applicable
if last_observation is not None and \
@@ -294,7 +316,7 @@ def _env_runner(async_vector_env,
policy_id = episode.policy_for(agent_id)
filtered_obs = _get_or_raise(obs_filters,
policy_id)(raw_obs)
episode.set_last_observation(agent_id, filtered_obs)
episode._set_last_observation(agent_id, filtered_obs)
to_eval[policy_id].append(
PolicyEvalData(env_id, agent_id, filtered_obs,
episode.rnn_state_for(agent_id)))
@@ -302,6 +324,7 @@ def _env_runner(async_vector_env,
# Batch eval policy actions if possible
if tf_sess:
builder = TFRunBuilder(tf_sess, "policy_eval")
pending_fetches = {}
else:
builder = None
eval_results = {}
@@ -310,16 +333,21 @@ def _env_runner(async_vector_env,
rnn_in = _to_column_format([t.rnn_state for t in eval_data])
rnn_in_cols[policy_id] = rnn_in
policy = _get_or_raise(policies, policy_id)
if builder:
eval_results[policy_id] = policy.build_compute_actions(
if builder and (policy.compute_actions.__code__ is
TFPolicyGraph.compute_actions.__code__):
pending_fetches[policy_id] = policy.build_compute_actions(
builder, [t.obs for t in eval_data],
rnn_in,
is_training=True)
else:
eval_results[policy_id] = policy.compute_actions(
[t.obs for t in eval_data], rnn_in, is_training=True)
[t.obs for t in eval_data],
rnn_in,
is_training=True,
episodes=[active_episodes[t.env_id] for t in eval_data])
if builder:
eval_results = {k: builder.get(v) for k, v in eval_results.items()}
for k, v in pending_fetches.items():
eval_results[k] = builder.get(v)
# Record the policy eval results
for policy_id, eval_data in to_eval.items():
@@ -335,16 +363,16 @@ def _env_runner(async_vector_env,
agent_id = eval_data[i].agent_id
actions_to_send[env_id][agent_id] = action
episode = active_episodes[env_id]
episode.set_rnn_state(agent_id, [c[i] for c in rnn_out_cols])
episode.set_last_pi_info(
episode._set_rnn_state(agent_id, [c[i] for c in rnn_out_cols])
episode._set_last_pi_info(
agent_id, {k: v[i]
for k, v in pi_info_cols.items()})
if env_id in off_policy_actions and \
agent_id in off_policy_actions[env_id]:
episode.set_last_action(
episode._set_last_action(
agent_id, off_policy_actions[env_id][agent_id])
else:
episode.set_last_action(agent_id, action)
episode._set_last_action(agent_id, action)
# Return computed actions to ready envs. We also send to envs that have
# taken off-policy actions; those envs are free to ignore the action.
@@ -362,68 +390,3 @@ def _get_or_raise(mapping, policy_id):
"Could not find policy for agent: agent policy id `{}` not "
"in policy map keys {}.".format(policy_id, mapping.keys()))
return mapping[policy_id]
class _MultiAgentEpisode(object):
def __init__(self, policies, policy_mapping_fn, batch_builder_factory):
self.batch_builder = batch_builder_factory()
self.total_reward = 0.0
self.length = 0
self.episode_id = random.randrange(2e9)
self.agent_rewards = defaultdict(float)
self._policies = policies
self._policy_mapping_fn = policy_mapping_fn
self._agent_to_policy = {}
self._agent_to_rnn_state = {}
self._agent_to_last_obs = {}
self._agent_to_last_action = {}
self._agent_to_last_pi_info = {}
def add_agent_rewards(self, reward_dict):
for agent_id, reward in reward_dict.items():
if reward is not None:
self.agent_rewards[agent_id,
self.policy_for(agent_id)] += reward
self.total_reward += reward
def policy_for(self, agent_id):
if agent_id not in self._agent_to_policy:
self._agent_to_policy[agent_id] = self._policy_mapping_fn(agent_id)
return self._agent_to_policy[agent_id]
def rnn_state_for(self, agent_id):
if agent_id not in self._agent_to_rnn_state:
policy = self._policies[self.policy_for(agent_id)]
self._agent_to_rnn_state[agent_id] = policy.get_initial_state()
return self._agent_to_rnn_state[agent_id]
def last_observation_for(self, agent_id):
return self._agent_to_last_obs.get(agent_id)
def last_action_for(self, agent_id):
action = self._agent_to_last_action[agent_id]
# Concatenate tuple actions
if isinstance(action, list):
expanded = []
for a in action:
if len(a.shape) == 1:
expanded.append(np.expand_dims(a, 1))
else:
expanded.append(a)
action = np.concatenate(expanded, axis=1).flatten()
return action
def last_pi_info_for(self, agent_id):
return self._agent_to_last_pi_info[agent_id]
def set_rnn_state(self, agent_id, rnn_state):
self._agent_to_rnn_state[agent_id] = rnn_state
def set_last_observation(self, agent_id, obs):
self._agent_to_last_obs[agent_id] = obs
def set_last_action(self, agent_id, action):
self._agent_to_last_action[agent_id] = action
def set_last_pi_info(self, agent_id, pi_info):
self._agent_to_last_pi_info[agent_id] = pi_info
@@ -104,7 +104,8 @@ class TFPolicyGraph(PolicyGraph):
builder,
obs_batch,
state_batches=None,
is_training=False):
is_training=False,
episodes=None):
state_batches = state_batches or []
assert len(self._state_inputs) == len(state_batches), \
(self._state_inputs, state_batches)
@@ -118,8 +119,11 @@ class TFPolicyGraph(PolicyGraph):
[self.extra_compute_action_fetches()])
return fetches[0], fetches[1:-1], fetches[-1]
def compute_actions(self, obs_batch, state_batches=None,
is_training=False):
def compute_actions(self,
obs_batch,
state_batches=None,
is_training=False,
episodes=None):
builder = TFRunBuilder(self._sess, "compute_actions")
fetches = self.build_compute_actions(builder, obs_batch, state_batches,
is_training)
@@ -67,8 +67,11 @@ class TorchPolicyGraph(PolicyGraph):
"""Custom PyTorch optimizer to use."""
return torch.optim.Adam(self._model.parameters())
def compute_actions(self, obs_batch, state_batches=None,
is_training=False):
def compute_actions(self,
obs_batch,
state_batches=None,
is_training=False,
episodes=None):
if state_batches:
raise NotImplementedError("Torch RNN support")
with self.lock:
@@ -306,6 +306,51 @@ class TestMultiAgentEnv(unittest.TestCase):
self.assertEqual(batch.policy_batches["p0"]["t"].tolist()[:10],
[4, 9, 14, 19, 24, 5, 10, 15, 20, 25])
def testReturningModelBasedRolloutsData(self):
class ModelBasedPolicyGraph(PGPolicyGraph):
def compute_actions(self,
obs_batch,
state_batches,
is_training=False,
episodes=None):
# Pretend we did a model-based rollout and want to return
# the extra trajectory.
builder = episodes[0].new_batch_builder()
rollout_id = random.randint(0, 10000)
for t in range(5):
builder.add_values(
agent_id="extra_0",
policy_id="p1", # use p1 so we can easily check it
t=t,
eps_id=rollout_id, # new id for each rollout
obs=obs_batch[0],
actions=0,
rewards=0,
dones=t == 4,
infos={},
new_obs=obs_batch[0])
batch = builder.build_and_reset()
episodes[0].add_extra_batch(batch)
# Just return zeros for actions
return [0] * len(obs_batch), [], {}
single_env = gym.make("CartPole-v0")
obs_space = single_env.observation_space
act_space = single_env.action_space
ev = PolicyEvaluator(
env_creator=lambda _: MultiCartpole(2),
policy_graph={
"p0": (ModelBasedPolicyGraph, obs_space, act_space, {}),
"p1": (ModelBasedPolicyGraph, obs_space, act_space, {}),
},
policy_mapping_fn=lambda agent_id: "p0",
batch_steps=5)
batch = ev.sample()
self.assertEqual(batch.count, 5)
self.assertEqual(batch.policy_batches["p0"].count, 10)
self.assertEqual(batch.policy_batches["p1"].count, 25)
def testTrainMultiCartpoleSinglePolicy(self):
n = 10
register_env("multi_cartpole", lambda _: MultiCartpole(n))
+10 -2
View File
@@ -17,7 +17,11 @@ from ray.tune.registry import register_env
class MockPolicyGraph(PolicyGraph):
def compute_actions(self, obs_batch, state_batches, is_training=False):
def compute_actions(self,
obs_batch,
state_batches,
is_training=False,
episodes=None):
return [0] * len(obs_batch), [], {}
def postprocess_trajectory(self, batch, other_agent_batches=None):
@@ -25,7 +29,11 @@ class MockPolicyGraph(PolicyGraph):
class BadPolicyGraph(PolicyGraph):
def compute_actions(self, obs_batch, state_batches, is_training=False):
def compute_actions(self,
obs_batch,
state_batches,
is_training=False,
episodes=None):
raise Exception("intentional error")
def postprocess_trajectory(self, batch, other_agent_batches=None):