Files
ray/python/ray/rllib/utils/sampler.py
T
Eric Liang 1251abf0d1 [rllib] Modularize Torch and TF policy graphs (#2294)
* wip

* cls

* re

* wip

* wip

* a3c working

* torch support

* pg works

* lint

* rm v2

* consumer id

* clean up pg

* clean up more

* fix python 2.7

* tf session management

* docs

* dqn wip

* fix compile

* dqn

* apex runs

* up

* impotrs

* ddpg

* quotes

* fix tests

* fix last r

* fix tests

* lint

* pass checkpoint restore

* kwar

* nits

* policy graph

* fix yapf

* com

* class

* pyt

* vectorization

* update

* test cpe

* unit test

* fix ddpg2

* changes

* wip

* args

* faster test

* common

* fix

* add alg option

* batch mode and policy serving

* multi serving test

* todo

* wip

* serving test

* doc async env

* num envs

* comments

* thread

* remove init hook

* update

* fix ppo

* comments1

* fix

* updates

* add jenkins tests

* fix

* fix pytorch

* fix

* fixes

* fix a3c policy

* fix squeeze

* fix trunc on apex

* fix squeezing for real

* update

* remove horizon test for now

* multiagent wip

* update

* fix race condition

* fix ma

* t

* doc

* st

* wip

* example

* wip

* working

* cartpole

* wip

* batch wip

* fix bug

* make other_batches None default

* working

* debug

* nit

* warn

* comments

* fix ppo

* fix obs filter

* update

* wip

* tf

* update

* fix

* cleanup

* cleanup

* spacing

* model

* fix

* dqn

* fix ddpg

* doc

* keep names

* update

* fix

* com

* docs

* clarify model outputs

* Update torch_policy_graph.py

* fix obs filter

* pass thru worker index

* fix

* rename

* vlad torch comments

* fix log action

* debug name

* fix lstm

* remove unused ddpg net

* remove conv net

* revert lstm

* cast

* clean up

* fix lstm check

* move to end

* fix sphinx

* fix cmd

* remove bad doc

* clarify

* copy

* async sa

* fix
2018-06-26 13:17:15 -07:00

405 lines
16 KiB
Python

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from collections import defaultdict, namedtuple
import numpy as np
import six.moves.queue as queue
import threading
from ray.rllib.optimizers.sample_batch import MultiAgentSampleBatchBuilder, \
MultiAgentBatch
from ray.rllib.utils.async_vector_env import AsyncVectorEnv
from ray.rllib.utils.tf_run_builder import TFRunBuilder
RolloutMetrics = namedtuple(
"RolloutMetrics", ["episode_length", "episode_reward", "agent_rewards"])
PolicyEvalData = namedtuple(
"PolicyEvalData", ["env_id", "agent_id", "obs", "rnn_state"])
class SyncSampler(object):
"""This class interacts with the environment and tells it what to do.
Note that batch_size is only a unit of measure here. Batches can
accumulate and the gradient can be calculated on up to 5 batches.
This class provides data on invocation, rather than on a separate
thread."""
def __init__(
self, env, policies, policy_mapping_fn, obs_filters,
num_local_steps, horizon=None, pack=False, tf_sess=None):
self.async_vector_env = AsyncVectorEnv.wrap_async(env)
self.num_local_steps = num_local_steps
self.horizon = horizon
self.policies = policies
self.policy_mapping_fn = policy_mapping_fn
self._obs_filters = obs_filters
self.rollout_provider = _env_runner(
self.async_vector_env, self.policies, self.policy_mapping_fn,
self.num_local_steps, self.horizon, self._obs_filters, pack,
tf_sess)
self.metrics_queue = queue.Queue()
def get_data(self):
while True:
item = next(self.rollout_provider)
if isinstance(item, RolloutMetrics):
self.metrics_queue.put(item)
else:
return item
def get_metrics(self):
completed = []
while True:
try:
completed.append(self.metrics_queue.get_nowait())
except queue.Empty:
break
return completed
class AsyncSampler(threading.Thread):
"""This class interacts with the environment and tells it what to do.
Note that batch_size is only a unit of measure here. Batches can
accumulate and the gradient can be calculated on up to 5 batches."""
def __init__(
self, env, policies, policy_mapping_fn, obs_filters,
num_local_steps, horizon=None, pack=False, tf_sess=None):
for _, f in obs_filters.items():
assert getattr(f, "is_concurrent", False), \
"Observation Filter must support concurrent updates."
self.async_vector_env = AsyncVectorEnv.wrap_async(env)
threading.Thread.__init__(self)
self.queue = queue.Queue(5)
self.metrics_queue = queue.Queue()
self.num_local_steps = num_local_steps
self.horizon = horizon
self.policies = policies
self.policy_mapping_fn = policy_mapping_fn
self._obs_filters = obs_filters
self.daemon = True
self.pack = pack
self.tf_sess = tf_sess
def run(self):
try:
self._run()
except BaseException as e:
self.queue.put(e)
raise e
def _run(self):
rollout_provider = _env_runner(
self.async_vector_env, self.policies, self.policy_mapping_fn,
self.num_local_steps, self.horizon, self._obs_filters, self.pack,
self.tf_sess)
while True:
# The timeout variable exists because apparently, if one worker
# dies, the other workers won't die with it, unless the timeout is
# set to some large number. This is an empirical observation.
item = next(rollout_provider)
if isinstance(item, RolloutMetrics):
self.metrics_queue.put(item)
else:
self.queue.put(item, timeout=600.0)
def get_data(self):
rollout = self.queue.get(timeout=600.0)
# Propagate errors
if isinstance(rollout, BaseException):
raise rollout
# We can't auto-concat rollouts in these modes
if self.async_vector_env.num_envs > 1 or \
isinstance(rollout, MultiAgentBatch):
return rollout
# Auto-concat rollouts; TODO(ekl) is this important for A3C perf?
while not rollout["dones"][-1]:
try:
part = self.queue.get_nowait()
if isinstance(part, BaseException):
raise rollout
rollout = rollout.concat(part)
except queue.Empty:
break
return rollout
def get_metrics(self):
completed = []
while True:
try:
completed.append(self.metrics_queue.get_nowait())
except queue.Empty:
break
return completed
def _env_runner(
async_vector_env, policies, policy_mapping_fn, num_local_steps,
horizon, obs_filters, pack, tf_sess=None):
"""This implements the common experience collection logic.
Args:
async_vector_env (AsyncVectorEnv): env implementing AsyncVectorEnv.
policies (dict): Map of policy ids to PolicyGraph instances.
policy_mapping_fn (func): Function that maps agent ids to policy ids.
This is called when an agent first enters the environment. The
agent is then "bound" to the returned policy for the episode.
num_local_steps (int): Number of episode steps before `SampleBatch` is
yielded. Set to infinity to yield complete episodes.
horizon (int): Horizon of the episode.
obs_filters (dict): Map of policy id to filter used to process
observations for the policy.
pack (bool): Whether to pack multiple episodes into each batch. This
guarantees batches will be exactly `num_local_steps` in size.
tf_sess (Session|None): Optional tensorflow session to use for batching
TF policy evaluations.
Yields:
rollout (SampleBatch): Object containing state, action, reward,
terminal condition, and other fields as dictated by `policy`.
"""
try:
if not horizon:
horizon = async_vector_env.get_unwrapped().spec.max_episode_steps
except Exception:
print("Warning, no horizon specified, assuming infinite")
if not horizon:
horizon = float("inf")
# Pool of batch builders, which can be shared across episodes to pack
# trajectory data.
batch_builder_pool = []
def get_batch_builder():
if batch_builder_pool:
return batch_builder_pool.pop()
else:
return MultiAgentSampleBatchBuilder(policies)
active_episodes = defaultdict(
lambda: _MultiAgentEpisode(
policies, policy_mapping_fn, get_batch_builder))
while True:
# Get observations from all ready agents
unfiltered_obs, rewards, dones, infos, off_policy_actions = \
async_vector_env.poll()
# Map of policy_id to list of PolicyEvalData
to_eval = defaultdict(list)
# Map of env_id -> agent_id -> action replies
actions_to_send = defaultdict(dict)
# For each environment
for env_id, agent_obs in unfiltered_obs.items():
new_episode = env_id not in active_episodes
episode = active_episodes[env_id]
if not new_episode:
episode.length += 1
episode.batch_builder.count += 1
episode.add_agent_rewards(rewards[env_id])
# Check episode termination conditions
if dones[env_id]["__all__"] or episode.length >= horizon:
all_done = True
yield RolloutMetrics(
episode.length, episode.total_reward,
dict(episode.agent_rewards))
else:
all_done = False
# At least send an empty dict if not done
actions_to_send[env_id] = {}
# For each agent in the environment
for agent_id, raw_obs in agent_obs.items():
policy_id = episode.policy_for(agent_id)
filtered_obs = _get_or_raise(obs_filters, policy_id)(raw_obs)
agent_done = bool(all_done or dones[env_id].get(agent_id))
if not agent_done:
to_eval[policy_id].append(
PolicyEvalData(
env_id, agent_id, filtered_obs,
episode.rnn_state_for(agent_id)))
last_observation = episode.last_observation_for(agent_id)
episode.set_last_observation(agent_id, filtered_obs)
# Record transition info if applicable
if last_observation is not None and \
infos[env_id][agent_id].get("training_enabled", True):
episode.batch_builder.add_values(
agent_id,
policy_id,
t=episode.length - 1,
obs=last_observation,
actions=episode.last_action_for(agent_id),
rewards=rewards[env_id][agent_id],
dones=agent_done,
infos=infos[env_id][agent_id],
new_obs=filtered_obs,
**episode.last_pi_info_for(agent_id))
# Cut the batch if we're not packing multiple episodes into one,
# or if we've exceeded the requested batch size.
if episode.batch_builder.has_pending_data():
if (all_done and not pack) or \
episode.batch_builder.count >= num_local_steps:
yield episode.batch_builder.build_and_reset()
elif all_done:
# Make sure postprocessor stays within one episode
episode.batch_builder.postprocess_batch_so_far()
if all_done:
# Handle episode termination
batch_builder_pool.append(episode.batch_builder)
del active_episodes[env_id]
resetted_obs = async_vector_env.try_reset(env_id)
if resetted_obs is None:
# Reset not supported, drop this env from the ready list
assert horizon == float("inf"), \
"Setting episode horizon requires reset() support."
else:
# Creates a new episode
episode = active_episodes[env_id]
for agent_id, raw_obs in resetted_obs.items():
policy_id = episode.policy_for(agent_id)
filtered_obs = _get_or_raise(
obs_filters, policy_id)(raw_obs)
episode.set_last_observation(agent_id, filtered_obs)
to_eval[policy_id].append(
PolicyEvalData(
env_id, agent_id, filtered_obs,
episode.rnn_state_for(agent_id)))
# Batch eval policy actions if possible
if tf_sess:
builder = TFRunBuilder(tf_sess, "policy_eval")
else:
builder = None
eval_results = {}
rnn_in_cols = {}
for policy_id, eval_data in to_eval.items():
rnn_in = _to_column_format([t.rnn_state for t in eval_data])
rnn_in_cols[policy_id] = rnn_in
policy = _get_or_raise(policies, policy_id)
if builder:
eval_results[policy_id] = policy.build_compute_actions(
builder, [t.obs for t in eval_data], rnn_in,
is_training=True)
else:
eval_results[policy_id] = policy.compute_actions(
[t.obs for t in eval_data], rnn_in, is_training=True)
if builder:
eval_results = {k: builder.get(v) for k, v in eval_results.items()}
# Record the policy eval results
for policy_id, eval_data in to_eval.items():
actions, rnn_out_cols, pi_info_cols = eval_results[policy_id]
# Add RNN state info
for f_i, column in enumerate(rnn_in_cols[policy_id]):
pi_info_cols["state_in_{}".format(f_i)] = column
for f_i, column in enumerate(rnn_out_cols):
pi_info_cols["state_out_{}".format(f_i)] = column
# Save output rows
for i, action in enumerate(actions):
env_id = eval_data[i].env_id
agent_id = eval_data[i].agent_id
actions_to_send[env_id][agent_id] = action
episode = active_episodes[env_id]
episode.set_rnn_state(agent_id, [c[i] for c in rnn_out_cols])
episode.set_last_pi_info(
agent_id, {k: v[i] for k, v in pi_info_cols.items()})
if env_id in off_policy_actions and \
agent_id in off_policy_actions[env_id]:
episode.set_last_action(
agent_id, off_policy_actions[env_id][agent_id])
else:
episode.set_last_action(agent_id, action)
# Return computed actions to ready envs. We also send to envs that have
# taken off-policy actions; those envs are free to ignore the action.
async_vector_env.send_actions(dict(actions_to_send))
def _to_column_format(rnn_state_rows):
num_cols = len(rnn_state_rows[0])
return [
[row[i] for row in rnn_state_rows] for i in range(num_cols)]
def _get_or_raise(mapping, policy_id):
if policy_id not in mapping:
raise ValueError(
"Could not find policy for agent: agent policy id `{}` not "
"in policy map keys {}.".format(policy_id, mapping.keys()))
return mapping[policy_id]
class _MultiAgentEpisode(object):
def __init__(self, policies, policy_mapping_fn, batch_builder_factory):
self.batch_builder = batch_builder_factory()
self.total_reward = 0.0
self.length = 0
self.agent_rewards = defaultdict(float)
self._policies = policies
self._policy_mapping_fn = policy_mapping_fn
self._agent_to_policy = {}
self._agent_to_rnn_state = {}
self._agent_to_last_obs = {}
self._agent_to_last_action = {}
self._agent_to_last_pi_info = {}
def add_agent_rewards(self, reward_dict):
for agent_id, reward in reward_dict.items():
if reward is not None:
self.agent_rewards[
agent_id, self.policy_for(agent_id)] += reward
self.total_reward += reward
def policy_for(self, agent_id):
if agent_id not in self._agent_to_policy:
self._agent_to_policy[agent_id] = self._policy_mapping_fn(agent_id)
return self._agent_to_policy[agent_id]
def rnn_state_for(self, agent_id):
if agent_id not in self._agent_to_rnn_state:
policy = self._policies[self.policy_for(agent_id)]
self._agent_to_rnn_state[agent_id] = policy.get_initial_state()
return self._agent_to_rnn_state[agent_id]
def last_observation_for(self, agent_id):
return self._agent_to_last_obs.get(agent_id)
def last_action_for(self, agent_id):
action = self._agent_to_last_action[agent_id]
# Concatenate tuple actions
if isinstance(action, list):
action = np.concatenate(action, axis=0).flatten()
return action
def last_pi_info_for(self, agent_id):
return self._agent_to_last_pi_info[agent_id]
def set_rnn_state(self, agent_id, rnn_state):
self._agent_to_rnn_state[agent_id] = rnn_state
def set_last_observation(self, agent_id, obs):
self._agent_to_last_obs[agent_id] = obs
def set_last_action(self, agent_id, action):
self._agent_to_last_action[agent_id] = action
def set_last_pi_info(self, agent_id, pi_info):
self._agent_to_last_pi_info[agent_id] = pi_info