mirror of
https://github.com/wassname/ray.git
synced 2026-06-28 02:46:49 +08:00
[rllib] Provide internal access to episode state in compute_actions() and allow returning extra batches (#2559)
The goal of this PR is to allow custom policies to perform model-based rollouts. In the multi-agent setting, this requires access to not only policies of other agents, but also their current observations. Also, you might want to return the model-based trajectories as part of the rollout for efficiency. compute_actions() now takes a new keyword arg episodes pull out internal episode class into a top-level file add function to return extra trajectories from an episode that will be appended to the sample batch documentation
This commit is contained in:
@@ -1,3 +1,4 @@
|
||||
from ray.rllib.evaluation.episode import MultiAgentEpisode
|
||||
from ray.rllib.evaluation.policy_evaluator import PolicyEvaluator
|
||||
from ray.rllib.evaluation.interface import EvaluatorInterface
|
||||
from ray.rllib.evaluation.policy_graph import PolicyGraph
|
||||
@@ -15,5 +16,6 @@ __all__ = [
|
||||
"EvaluatorInterface", "PolicyEvaluator", "PolicyGraph", "TFPolicyGraph",
|
||||
"TorchPolicyGraph", "SampleBatch", "MultiAgentBatch", "SampleBatchBuilder",
|
||||
"MultiAgentSampleBatchBuilder", "SyncSampler", "AsyncSampler",
|
||||
"compute_advantages", "compute_targets", "collect_metrics"
|
||||
"compute_advantages", "compute_targets", "collect_metrics",
|
||||
"MultiAgentEpisode"
|
||||
]
|
||||
|
||||
@@ -0,0 +1,119 @@
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from collections import defaultdict
|
||||
import random
|
||||
|
||||
import numpy as np
|
||||
|
||||
|
||||
class MultiAgentEpisode(object):
|
||||
"""Tracks the current state of a (possibly multi-agent) episode.
|
||||
|
||||
The APIs in this class should be considered experimental, but we should
|
||||
avoid changing things for the sake of changing them since users may
|
||||
depend on them for advanced algorithms.
|
||||
|
||||
Attributes:
|
||||
new_batch_builder (func): Create a new MultiAgentSampleBatchBuilder.
|
||||
add_extra_batch (func): Return a built MultiAgentBatch to the sampler.
|
||||
batch_builder (obj): Batch builder for the current episode.
|
||||
total_reward (float): Summed reward across all agents in this episode.
|
||||
length (int): Length of this episode.
|
||||
episode_id (int): Unique id identifying this trajectory.
|
||||
agent_rewards (dict): Summed rewards broken down by agent.
|
||||
|
||||
Use case 1: Model-based rollouts in multi-agent:
|
||||
A custom compute_actions() function in a policy graph can inspect the
|
||||
current episode state and perform a number of rollouts based on the
|
||||
policies and state of other agents in the environment.
|
||||
|
||||
Use case 2: Returning extra rollouts data.
|
||||
The model rollouts can be returned back to the sampler by calling:
|
||||
|
||||
>>> batch = episode.new_batch_builder()
|
||||
>>> for each transition:
|
||||
batch.add_values(...) # see sampler for usage
|
||||
>>> episode.extra_batches.add(batch.build_and_reset())
|
||||
"""
|
||||
|
||||
def __init__(self, policies, policy_mapping_fn, batch_builder_factory,
|
||||
extra_batch_callback):
|
||||
self.new_batch_builder = batch_builder_factory
|
||||
self.add_extra_batch = extra_batch_callback
|
||||
self.batch_builder = batch_builder_factory()
|
||||
self.total_reward = 0.0
|
||||
self.length = 0
|
||||
self.episode_id = random.randrange(2e9)
|
||||
self.agent_rewards = defaultdict(float)
|
||||
self._policies = policies
|
||||
self._policy_mapping_fn = policy_mapping_fn
|
||||
self._agent_to_policy = {}
|
||||
self._agent_to_rnn_state = {}
|
||||
self._agent_to_last_obs = {}
|
||||
self._agent_to_last_action = {}
|
||||
self._agent_to_last_pi_info = {}
|
||||
|
||||
def policy_for(self, agent_id):
|
||||
"""Returns the policy graph for the specified agent.
|
||||
|
||||
If the agent is new, the policy mapping fn will be called to bind the
|
||||
agent to a policy for the duration of the episode.
|
||||
"""
|
||||
|
||||
if agent_id not in self._agent_to_policy:
|
||||
self._agent_to_policy[agent_id] = self._policy_mapping_fn(agent_id)
|
||||
return self._agent_to_policy[agent_id]
|
||||
|
||||
def last_observation_for(self, agent_id):
|
||||
"""Returns the last observation for the specified agent."""
|
||||
|
||||
return self._agent_to_last_obs.get(agent_id)
|
||||
|
||||
def last_action_for(self, agent_id):
|
||||
"""Returns the last action for the specified agent."""
|
||||
|
||||
action = self._agent_to_last_action[agent_id]
|
||||
# Concatenate tuple actions
|
||||
if isinstance(action, list):
|
||||
expanded = []
|
||||
for a in action:
|
||||
if len(a.shape) == 1:
|
||||
expanded.append(np.expand_dims(a, 1))
|
||||
else:
|
||||
expanded.append(a)
|
||||
action = np.concatenate(expanded, axis=1).flatten()
|
||||
return action
|
||||
|
||||
def rnn_state_for(self, agent_id):
|
||||
"""Returns the last RNN state for the specified agent."""
|
||||
|
||||
if agent_id not in self._agent_to_rnn_state:
|
||||
policy = self._policies[self.policy_for(agent_id)]
|
||||
self._agent_to_rnn_state[agent_id] = policy.get_initial_state()
|
||||
return self._agent_to_rnn_state[agent_id]
|
||||
|
||||
def last_pi_info_for(self, agent_id):
|
||||
"""Returns the last info object for the specified agent."""
|
||||
|
||||
return self._agent_to_last_pi_info[agent_id]
|
||||
|
||||
def _add_agent_rewards(self, reward_dict):
|
||||
for agent_id, reward in reward_dict.items():
|
||||
if reward is not None:
|
||||
self.agent_rewards[agent_id,
|
||||
self.policy_for(agent_id)] += reward
|
||||
self.total_reward += reward
|
||||
|
||||
def _set_rnn_state(self, agent_id, rnn_state):
|
||||
self._agent_to_rnn_state[agent_id] = rnn_state
|
||||
|
||||
def _set_last_observation(self, agent_id, obs):
|
||||
self._agent_to_last_obs[agent_id] = obs
|
||||
|
||||
def _set_last_action(self, agent_id, action):
|
||||
self._agent_to_last_action[agent_id] = action
|
||||
|
||||
def _set_last_pi_info(self, agent_id, pi_info):
|
||||
self._agent_to_last_pi_info[agent_id] = pi_info
|
||||
@@ -292,6 +292,7 @@ class PolicyEvaluator(EvaluatorInterface):
|
||||
batch = self.sampler.get_data()
|
||||
steps_so_far += batch.count
|
||||
batches.append(batch)
|
||||
batches.extend(self.sampler.get_extra_batches())
|
||||
batch = batches[0].concat_samples(batches)
|
||||
|
||||
if self.compress_observations:
|
||||
|
||||
@@ -37,13 +37,20 @@ class PolicyGraph(object):
|
||||
self.observation_space = observation_space
|
||||
self.action_space = action_space
|
||||
|
||||
def compute_actions(self, obs_batch, state_batches, is_training=False):
|
||||
def compute_actions(self,
|
||||
obs_batch,
|
||||
state_batches,
|
||||
is_training=False,
|
||||
episodes=None):
|
||||
"""Compute actions for the current policy.
|
||||
|
||||
Arguments:
|
||||
obs_batch (np.ndarray): batch of observations
|
||||
state_batches (list): list of RNN state input batches, if any
|
||||
is_training (bool): whether we are training the policy
|
||||
episodes (list): MultiAgentEpisode for each obs in obs_batch.
|
||||
This provides access to all of the internal episode state,
|
||||
which may be useful for model-based or multiagent algorithms.
|
||||
|
||||
Returns:
|
||||
actions (np.ndarray): batch of output actions, with shape like
|
||||
@@ -55,13 +62,20 @@ class PolicyGraph(object):
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def compute_single_action(self, obs, state, is_training=False):
|
||||
def compute_single_action(self,
|
||||
obs,
|
||||
state,
|
||||
is_training=False,
|
||||
episode=None):
|
||||
"""Unbatched version of compute_actions.
|
||||
|
||||
Arguments:
|
||||
obs (obj): single observation
|
||||
state_batches (list): list of RNN state inputs, if any
|
||||
is_training (bool): whether we are training the policy
|
||||
episode (MultiAgentEpisode): this provides access to all of the
|
||||
internal episode state, which may be useful for model-based or
|
||||
multiagent algorithms.
|
||||
|
||||
Returns:
|
||||
actions (obj): single action
|
||||
@@ -70,13 +84,16 @@ class PolicyGraph(object):
|
||||
"""
|
||||
|
||||
[action], state_out, info = self.compute_actions(
|
||||
[obs], [[s] for s in state], is_training)
|
||||
[obs], [[s] for s in state], is_training, episodes=[episode])
|
||||
return action, [s[0] for s in state_out], \
|
||||
{k: v[0] for k, v in info.items()}
|
||||
|
||||
def postprocess_trajectory(self, sample_batch, other_agent_batches=None):
|
||||
"""Implements algorithm-specific trajectory postprocessing.
|
||||
|
||||
This will be called on each trajectory fragment computed during policy
|
||||
evaluation. Each fragment is guaranteed to be only from one episode.
|
||||
|
||||
Arguments:
|
||||
sample_batch (SampleBatch): batch of experiences for the policy,
|
||||
which will contain at most one episode trajectory.
|
||||
|
||||
@@ -117,6 +117,11 @@ class MultiAgentSampleBatchBuilder(object):
|
||||
other_batches = pre_batches.copy()
|
||||
del other_batches[agent_id]
|
||||
policy = self.policy_map[self.agent_to_policy[agent_id]]
|
||||
if any(pre_batch["dones"][:-1]) or len(set(
|
||||
pre_batch["eps_id"])) > 1:
|
||||
raise ValueError(
|
||||
"Batches sent to postprocessing must only contain steps "
|
||||
"from a single trajectory.", pre_batch)
|
||||
post_batches[agent_id] = policy.postprocess_trajectory(
|
||||
pre_batch, other_batches)
|
||||
|
||||
|
||||
@@ -3,13 +3,13 @@ from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from collections import defaultdict, namedtuple
|
||||
import numpy as np
|
||||
import random
|
||||
import six.moves.queue as queue
|
||||
import threading
|
||||
|
||||
from ray.rllib.evaluation.episode import MultiAgentEpisode
|
||||
from ray.rllib.evaluation.sample_batch import MultiAgentSampleBatchBuilder, \
|
||||
MultiAgentBatch
|
||||
from ray.rllib.evaluation.tf_policy_graph import TFPolicyGraph
|
||||
from ray.rllib.env.async_vector_env import AsyncVectorEnv
|
||||
from ray.rllib.utils.tf_run_builder import TFRunBuilder
|
||||
|
||||
@@ -44,10 +44,11 @@ class SyncSampler(object):
|
||||
self.policies = policies
|
||||
self.policy_mapping_fn = policy_mapping_fn
|
||||
self._obs_filters = obs_filters
|
||||
self.extra_batches = queue.Queue()
|
||||
self.rollout_provider = _env_runner(
|
||||
self.async_vector_env, self.policies, self.policy_mapping_fn,
|
||||
self.num_local_steps, self.horizon, self._obs_filters, pack,
|
||||
tf_sess)
|
||||
self.async_vector_env, self.extra_batches.put, self.policies,
|
||||
self.policy_mapping_fn, self.num_local_steps, self.horizon,
|
||||
self._obs_filters, pack, tf_sess)
|
||||
self.metrics_queue = queue.Queue()
|
||||
|
||||
def get_data(self):
|
||||
@@ -67,6 +68,15 @@ class SyncSampler(object):
|
||||
break
|
||||
return completed
|
||||
|
||||
def get_extra_batches(self):
|
||||
extra = []
|
||||
while True:
|
||||
try:
|
||||
extra.append(self.extra_batches.get_nowait())
|
||||
except queue.Empty:
|
||||
break
|
||||
return extra
|
||||
|
||||
|
||||
class AsyncSampler(threading.Thread):
|
||||
"""This class interacts with the environment and tells it what to do.
|
||||
@@ -89,6 +99,7 @@ class AsyncSampler(threading.Thread):
|
||||
self.async_vector_env = AsyncVectorEnv.wrap_async(env)
|
||||
threading.Thread.__init__(self)
|
||||
self.queue = queue.Queue(5)
|
||||
self.extra_batches = queue.Queue()
|
||||
self.metrics_queue = queue.Queue()
|
||||
self.num_local_steps = num_local_steps
|
||||
self.horizon = horizon
|
||||
@@ -108,9 +119,9 @@ class AsyncSampler(threading.Thread):
|
||||
|
||||
def _run(self):
|
||||
rollout_provider = _env_runner(
|
||||
self.async_vector_env, self.policies, self.policy_mapping_fn,
|
||||
self.num_local_steps, self.horizon, self._obs_filters, self.pack,
|
||||
self.tf_sess)
|
||||
self.async_vector_env, self.extra_batches.put, self.policies,
|
||||
self.policy_mapping_fn, self.num_local_steps, self.horizon,
|
||||
self._obs_filters, self.pack, self.tf_sess)
|
||||
while True:
|
||||
# The timeout variable exists because apparently, if one worker
|
||||
# dies, the other workers won't die with it, unless the timeout is
|
||||
@@ -153,8 +164,18 @@ class AsyncSampler(threading.Thread):
|
||||
break
|
||||
return completed
|
||||
|
||||
def get_extra_batches(self):
|
||||
extra = []
|
||||
while True:
|
||||
try:
|
||||
extra.append(self.extra_batches.get_nowait())
|
||||
except queue.Empty:
|
||||
break
|
||||
return extra
|
||||
|
||||
|
||||
def _env_runner(async_vector_env,
|
||||
extra_batch_callback,
|
||||
policies,
|
||||
policy_mapping_fn,
|
||||
num_local_steps,
|
||||
@@ -166,6 +187,7 @@ def _env_runner(async_vector_env,
|
||||
|
||||
Args:
|
||||
async_vector_env (AsyncVectorEnv): env implementing AsyncVectorEnv.
|
||||
extra_batch_callback (fn): function to send extra batch data to.
|
||||
policies (dict): Map of policy ids to PolicyGraph instances.
|
||||
policy_mapping_fn (func): Function that maps agent ids to policy ids.
|
||||
This is called when an agent first enters the environment. The
|
||||
@@ -204,8 +226,8 @@ def _env_runner(async_vector_env,
|
||||
return MultiAgentSampleBatchBuilder(policies)
|
||||
|
||||
def new_episode():
|
||||
return _MultiAgentEpisode(policies, policy_mapping_fn,
|
||||
get_batch_builder)
|
||||
return MultiAgentEpisode(policies, policy_mapping_fn,
|
||||
get_batch_builder, extra_batch_callback)
|
||||
|
||||
active_episodes = defaultdict(new_episode)
|
||||
|
||||
@@ -227,7 +249,7 @@ def _env_runner(async_vector_env,
|
||||
if not new_episode:
|
||||
episode.length += 1
|
||||
episode.batch_builder.count += 1
|
||||
episode.add_agent_rewards(rewards[env_id])
|
||||
episode._add_agent_rewards(rewards[env_id])
|
||||
|
||||
# Check episode termination conditions
|
||||
if dones[env_id]["__all__"] or episode.length >= horizon:
|
||||
@@ -250,7 +272,7 @@ def _env_runner(async_vector_env,
|
||||
episode.rnn_state_for(agent_id)))
|
||||
|
||||
last_observation = episode.last_observation_for(agent_id)
|
||||
episode.set_last_observation(agent_id, filtered_obs)
|
||||
episode._set_last_observation(agent_id, filtered_obs)
|
||||
|
||||
# Record transition info if applicable
|
||||
if last_observation is not None and \
|
||||
@@ -294,7 +316,7 @@ def _env_runner(async_vector_env,
|
||||
policy_id = episode.policy_for(agent_id)
|
||||
filtered_obs = _get_or_raise(obs_filters,
|
||||
policy_id)(raw_obs)
|
||||
episode.set_last_observation(agent_id, filtered_obs)
|
||||
episode._set_last_observation(agent_id, filtered_obs)
|
||||
to_eval[policy_id].append(
|
||||
PolicyEvalData(env_id, agent_id, filtered_obs,
|
||||
episode.rnn_state_for(agent_id)))
|
||||
@@ -302,6 +324,7 @@ def _env_runner(async_vector_env,
|
||||
# Batch eval policy actions if possible
|
||||
if tf_sess:
|
||||
builder = TFRunBuilder(tf_sess, "policy_eval")
|
||||
pending_fetches = {}
|
||||
else:
|
||||
builder = None
|
||||
eval_results = {}
|
||||
@@ -310,16 +333,21 @@ def _env_runner(async_vector_env,
|
||||
rnn_in = _to_column_format([t.rnn_state for t in eval_data])
|
||||
rnn_in_cols[policy_id] = rnn_in
|
||||
policy = _get_or_raise(policies, policy_id)
|
||||
if builder:
|
||||
eval_results[policy_id] = policy.build_compute_actions(
|
||||
if builder and (policy.compute_actions.__code__ is
|
||||
TFPolicyGraph.compute_actions.__code__):
|
||||
pending_fetches[policy_id] = policy.build_compute_actions(
|
||||
builder, [t.obs for t in eval_data],
|
||||
rnn_in,
|
||||
is_training=True)
|
||||
else:
|
||||
eval_results[policy_id] = policy.compute_actions(
|
||||
[t.obs for t in eval_data], rnn_in, is_training=True)
|
||||
[t.obs for t in eval_data],
|
||||
rnn_in,
|
||||
is_training=True,
|
||||
episodes=[active_episodes[t.env_id] for t in eval_data])
|
||||
if builder:
|
||||
eval_results = {k: builder.get(v) for k, v in eval_results.items()}
|
||||
for k, v in pending_fetches.items():
|
||||
eval_results[k] = builder.get(v)
|
||||
|
||||
# Record the policy eval results
|
||||
for policy_id, eval_data in to_eval.items():
|
||||
@@ -335,16 +363,16 @@ def _env_runner(async_vector_env,
|
||||
agent_id = eval_data[i].agent_id
|
||||
actions_to_send[env_id][agent_id] = action
|
||||
episode = active_episodes[env_id]
|
||||
episode.set_rnn_state(agent_id, [c[i] for c in rnn_out_cols])
|
||||
episode.set_last_pi_info(
|
||||
episode._set_rnn_state(agent_id, [c[i] for c in rnn_out_cols])
|
||||
episode._set_last_pi_info(
|
||||
agent_id, {k: v[i]
|
||||
for k, v in pi_info_cols.items()})
|
||||
if env_id in off_policy_actions and \
|
||||
agent_id in off_policy_actions[env_id]:
|
||||
episode.set_last_action(
|
||||
episode._set_last_action(
|
||||
agent_id, off_policy_actions[env_id][agent_id])
|
||||
else:
|
||||
episode.set_last_action(agent_id, action)
|
||||
episode._set_last_action(agent_id, action)
|
||||
|
||||
# Return computed actions to ready envs. We also send to envs that have
|
||||
# taken off-policy actions; those envs are free to ignore the action.
|
||||
@@ -362,68 +390,3 @@ def _get_or_raise(mapping, policy_id):
|
||||
"Could not find policy for agent: agent policy id `{}` not "
|
||||
"in policy map keys {}.".format(policy_id, mapping.keys()))
|
||||
return mapping[policy_id]
|
||||
|
||||
|
||||
class _MultiAgentEpisode(object):
|
||||
def __init__(self, policies, policy_mapping_fn, batch_builder_factory):
|
||||
self.batch_builder = batch_builder_factory()
|
||||
self.total_reward = 0.0
|
||||
self.length = 0
|
||||
self.episode_id = random.randrange(2e9)
|
||||
self.agent_rewards = defaultdict(float)
|
||||
self._policies = policies
|
||||
self._policy_mapping_fn = policy_mapping_fn
|
||||
self._agent_to_policy = {}
|
||||
self._agent_to_rnn_state = {}
|
||||
self._agent_to_last_obs = {}
|
||||
self._agent_to_last_action = {}
|
||||
self._agent_to_last_pi_info = {}
|
||||
|
||||
def add_agent_rewards(self, reward_dict):
|
||||
for agent_id, reward in reward_dict.items():
|
||||
if reward is not None:
|
||||
self.agent_rewards[agent_id,
|
||||
self.policy_for(agent_id)] += reward
|
||||
self.total_reward += reward
|
||||
|
||||
def policy_for(self, agent_id):
|
||||
if agent_id not in self._agent_to_policy:
|
||||
self._agent_to_policy[agent_id] = self._policy_mapping_fn(agent_id)
|
||||
return self._agent_to_policy[agent_id]
|
||||
|
||||
def rnn_state_for(self, agent_id):
|
||||
if agent_id not in self._agent_to_rnn_state:
|
||||
policy = self._policies[self.policy_for(agent_id)]
|
||||
self._agent_to_rnn_state[agent_id] = policy.get_initial_state()
|
||||
return self._agent_to_rnn_state[agent_id]
|
||||
|
||||
def last_observation_for(self, agent_id):
|
||||
return self._agent_to_last_obs.get(agent_id)
|
||||
|
||||
def last_action_for(self, agent_id):
|
||||
action = self._agent_to_last_action[agent_id]
|
||||
# Concatenate tuple actions
|
||||
if isinstance(action, list):
|
||||
expanded = []
|
||||
for a in action:
|
||||
if len(a.shape) == 1:
|
||||
expanded.append(np.expand_dims(a, 1))
|
||||
else:
|
||||
expanded.append(a)
|
||||
action = np.concatenate(expanded, axis=1).flatten()
|
||||
return action
|
||||
|
||||
def last_pi_info_for(self, agent_id):
|
||||
return self._agent_to_last_pi_info[agent_id]
|
||||
|
||||
def set_rnn_state(self, agent_id, rnn_state):
|
||||
self._agent_to_rnn_state[agent_id] = rnn_state
|
||||
|
||||
def set_last_observation(self, agent_id, obs):
|
||||
self._agent_to_last_obs[agent_id] = obs
|
||||
|
||||
def set_last_action(self, agent_id, action):
|
||||
self._agent_to_last_action[agent_id] = action
|
||||
|
||||
def set_last_pi_info(self, agent_id, pi_info):
|
||||
self._agent_to_last_pi_info[agent_id] = pi_info
|
||||
|
||||
@@ -104,7 +104,8 @@ class TFPolicyGraph(PolicyGraph):
|
||||
builder,
|
||||
obs_batch,
|
||||
state_batches=None,
|
||||
is_training=False):
|
||||
is_training=False,
|
||||
episodes=None):
|
||||
state_batches = state_batches or []
|
||||
assert len(self._state_inputs) == len(state_batches), \
|
||||
(self._state_inputs, state_batches)
|
||||
@@ -118,8 +119,11 @@ class TFPolicyGraph(PolicyGraph):
|
||||
[self.extra_compute_action_fetches()])
|
||||
return fetches[0], fetches[1:-1], fetches[-1]
|
||||
|
||||
def compute_actions(self, obs_batch, state_batches=None,
|
||||
is_training=False):
|
||||
def compute_actions(self,
|
||||
obs_batch,
|
||||
state_batches=None,
|
||||
is_training=False,
|
||||
episodes=None):
|
||||
builder = TFRunBuilder(self._sess, "compute_actions")
|
||||
fetches = self.build_compute_actions(builder, obs_batch, state_batches,
|
||||
is_training)
|
||||
|
||||
@@ -67,8 +67,11 @@ class TorchPolicyGraph(PolicyGraph):
|
||||
"""Custom PyTorch optimizer to use."""
|
||||
return torch.optim.Adam(self._model.parameters())
|
||||
|
||||
def compute_actions(self, obs_batch, state_batches=None,
|
||||
is_training=False):
|
||||
def compute_actions(self,
|
||||
obs_batch,
|
||||
state_batches=None,
|
||||
is_training=False,
|
||||
episodes=None):
|
||||
if state_batches:
|
||||
raise NotImplementedError("Torch RNN support")
|
||||
with self.lock:
|
||||
|
||||
@@ -306,6 +306,51 @@ class TestMultiAgentEnv(unittest.TestCase):
|
||||
self.assertEqual(batch.policy_batches["p0"]["t"].tolist()[:10],
|
||||
[4, 9, 14, 19, 24, 5, 10, 15, 20, 25])
|
||||
|
||||
def testReturningModelBasedRolloutsData(self):
|
||||
class ModelBasedPolicyGraph(PGPolicyGraph):
|
||||
def compute_actions(self,
|
||||
obs_batch,
|
||||
state_batches,
|
||||
is_training=False,
|
||||
episodes=None):
|
||||
# Pretend we did a model-based rollout and want to return
|
||||
# the extra trajectory.
|
||||
builder = episodes[0].new_batch_builder()
|
||||
rollout_id = random.randint(0, 10000)
|
||||
for t in range(5):
|
||||
builder.add_values(
|
||||
agent_id="extra_0",
|
||||
policy_id="p1", # use p1 so we can easily check it
|
||||
t=t,
|
||||
eps_id=rollout_id, # new id for each rollout
|
||||
obs=obs_batch[0],
|
||||
actions=0,
|
||||
rewards=0,
|
||||
dones=t == 4,
|
||||
infos={},
|
||||
new_obs=obs_batch[0])
|
||||
batch = builder.build_and_reset()
|
||||
episodes[0].add_extra_batch(batch)
|
||||
|
||||
# Just return zeros for actions
|
||||
return [0] * len(obs_batch), [], {}
|
||||
|
||||
single_env = gym.make("CartPole-v0")
|
||||
obs_space = single_env.observation_space
|
||||
act_space = single_env.action_space
|
||||
ev = PolicyEvaluator(
|
||||
env_creator=lambda _: MultiCartpole(2),
|
||||
policy_graph={
|
||||
"p0": (ModelBasedPolicyGraph, obs_space, act_space, {}),
|
||||
"p1": (ModelBasedPolicyGraph, obs_space, act_space, {}),
|
||||
},
|
||||
policy_mapping_fn=lambda agent_id: "p0",
|
||||
batch_steps=5)
|
||||
batch = ev.sample()
|
||||
self.assertEqual(batch.count, 5)
|
||||
self.assertEqual(batch.policy_batches["p0"].count, 10)
|
||||
self.assertEqual(batch.policy_batches["p1"].count, 25)
|
||||
|
||||
def testTrainMultiCartpoleSinglePolicy(self):
|
||||
n = 10
|
||||
register_env("multi_cartpole", lambda _: MultiCartpole(n))
|
||||
|
||||
@@ -17,7 +17,11 @@ from ray.tune.registry import register_env
|
||||
|
||||
|
||||
class MockPolicyGraph(PolicyGraph):
|
||||
def compute_actions(self, obs_batch, state_batches, is_training=False):
|
||||
def compute_actions(self,
|
||||
obs_batch,
|
||||
state_batches,
|
||||
is_training=False,
|
||||
episodes=None):
|
||||
return [0] * len(obs_batch), [], {}
|
||||
|
||||
def postprocess_trajectory(self, batch, other_agent_batches=None):
|
||||
@@ -25,7 +29,11 @@ class MockPolicyGraph(PolicyGraph):
|
||||
|
||||
|
||||
class BadPolicyGraph(PolicyGraph):
|
||||
def compute_actions(self, obs_batch, state_batches, is_training=False):
|
||||
def compute_actions(self,
|
||||
obs_batch,
|
||||
state_batches,
|
||||
is_training=False,
|
||||
episodes=None):
|
||||
raise Exception("intentional error")
|
||||
|
||||
def postprocess_trajectory(self, batch, other_agent_batches=None):
|
||||
|
||||
Reference in New Issue
Block a user