[rllib] Part 2 of multiagent support (#2286)

* wip

* cls

* re

* wip

* wip

* a3c working

* torch support

* pg works

* lint

* rm v2

* consumer id

* clean up pg

* clean up more

* fix python 2.7

* tf session management

* docs

* dqn wip

* fix compile

* dqn

* apex runs

* up

* impotrs

* ddpg

* quotes

* fix tests

* fix last r

* fix tests

* lint

* pass checkpoint restore

* kwar

* nits

* policy graph

* fix yapf

* com

* class

* pyt

* vectorization

* update

* test cpe

* unit test

* fix ddpg2

* changes

* wip

* args

* faster test

* common

* fix

* add alg option

* batch mode and policy serving

* multi serving test

* todo

* wip

* serving test

* doc async env

* num envs

* comments

* thread

* remove init hook

* update

* fix ppo

* comments1

* fix

* updates

* add jenkins tests

* fix

* fix pytorch

* fix

* fixes

* fix a3c policy

* fix squeeze

* fix trunc on apex

* fix squeezing for real

* update

* remove horizon test for now

* multiagent wip

* update

* fix race condition

* fix ma

* t

* doc

* st

* wip

* example

* wip

* working

* cartpole

* wip

* batch wip

* fix bug

* make other_batches None default

* working

* debug

* nit

* warn

* comments

* fix ppo

* fix obs filter

* update

* fix obs filter

* pass thru worker index

* fix

* fix log action

* debug name

* fix sphinx
This commit is contained in:
Eric Liang
2018-06-25 22:33:57 -07:00
committed by GitHub
parent 739ddfa229
commit a9a26b7560
32 changed files with 939 additions and 202 deletions
+15 -4
View File
@@ -63,13 +63,18 @@ DEFAULT_CONFIG = {
},
# Arguments to pass to the env creator
"env_config": {},
# === Multiagent ===
"multiagent": {
"policy_graphs": {},
"policy_mapping_fn": None,
},
}
class A3CAgent(Agent):
_agent_name = "A3C"
_default_config = DEFAULT_CONFIG
_allow_unknown_subkeys = ["model", "optimizer", "env_config"]
@classmethod
def default_resource_request(cls, config):
@@ -98,7 +103,9 @@ class A3CAgent(Agent):
remote_cls = CommonPolicyEvaluator.as_remote(
num_gpus=1 if self.config["use_gpu_for_workers"] else 0)
self.local_evaluator = CommonPolicyEvaluator(
self.env_creator, self.policy_cls,
self.env_creator,
self.config["multiagent"]["policy_graphs"] or self.policy_cls,
policy_mapping_fn=self.config["multiagent"]["policy_mapping_fn"],
batch_steps=self.config["batch_size"],
batch_mode="truncate_episodes",
tf_session_creator=session_creator,
@@ -107,13 +114,17 @@ class A3CAgent(Agent):
num_envs=self.config["num_envs"])
self.remote_evaluators = [
remote_cls.remote(
self.env_creator, self.policy_cls,
self.env_creator,
self.config["multiagent"]["policy_graphs"] or self.policy_cls,
policy_mapping_fn=(
self.config["multiagent"]["policy_mapping_fn"]),
batch_steps=self.config["batch_size"],
batch_mode="truncate_episodes", sample_async=True,
tf_session_creator=session_creator,
env_config=self.config["env_config"],
model_config=self.config["model"], policy_config=self.config,
num_envs=self.config["num_envs"])
num_envs=self.config["num_envs"],
worker_index=i+1)
for i in range(self.config["num_workers"])]
self.optimizer = AsyncOptimizer(
+3 -1
View File
@@ -5,6 +5,7 @@ from __future__ import print_function
import tensorflow as tf
import gym
import ray
from ray.rllib.utils.error import UnsupportedSpaceException
from ray.rllib.utils.process_rollout import compute_advantages
from ray.rllib.utils.tf_policy_graph import TFPolicyGraph
@@ -14,6 +15,7 @@ class A3CTFPolicyGraph(TFPolicyGraph):
"""The TF policy base class."""
def __init__(self, ob_space, action_space, config):
config = dict(ray.rllib.a3c.a3c.DEFAULT_CONFIG, **config)
self.local_steps = 0
self.config = config
self.summarize = config.get("summarize")
@@ -27,7 +29,7 @@ class A3CTFPolicyGraph(TFPolicyGraph):
self.sess = tf.get_default_session()
TFPolicyGraph.__init__(
self, self.sess, obs_input=self.x,
self, ob_space, action_space, self.sess, obs_input=self.x,
action_sampler=self.action_dist.sample(), loss=self.loss,
loss_inputs=self.loss_in, is_training=self.is_training,
state_inputs=self.state_in, state_outputs=self.state_out)
+2
View File
@@ -8,6 +8,7 @@ from threading import Lock
import torch
import torch.nn.functional as F
import ray
from ray.rllib.models.pytorch.misc import var_to_np, convert_batch
from ray.rllib.models.catalog import ModelCatalog
from ray.rllib.utils.process_rollout import compute_advantages
@@ -18,6 +19,7 @@ class SharedTorchPolicy(PolicyGraph):
"""A simple, non-recurrent PyTorch policy example."""
def __init__(self, obs_space, action_space, config):
config = dict(ray.rllib.a3c.a3c.DEFAULT_CONFIG, **config)
PolicyGraph.__init__(self, obs_space, action_space, config)
self.local_steps = 0
self.config = config
+2 -1
View File
@@ -59,7 +59,8 @@ class Agent(Trainable):
"""
_allow_unknown_configs = False
_allow_unknown_subkeys = ["env_config", "model", "optimizer"]
_allow_unknown_subkeys = [
"tf_session_args", "env_config", "model", "optimizer", "multiagent"]
@classmethod
def resource_help(cls, config):
+7 -3
View File
@@ -108,14 +108,18 @@ DEFAULT_CONFIG = {
# Whether to use a distribution of epsilons across workers for exploration.
"per_worker_exploration": False,
# Whether to compute priorities on workers.
"worker_side_prioritization": False
"worker_side_prioritization": False,
# === Multiagent ===
"multiagent": {
"policy_graphs": {},
"policy_mapping_fn": None,
},
}
class DDPGAgent(DQNAgent):
_agent_name = "DDPG"
_allow_unknown_subkeys = [
"model", "optimizer", "tf_session_args", "env_config"]
_default_config = DEFAULT_CONFIG
_policy_graph = DDPGPolicyGraph
+3 -1
View File
@@ -82,6 +82,7 @@ def _build_q_network(inputs, action_inputs, config):
class DDPGPolicyGraph(TFPolicyGraph):
def __init__(self, observation_space, action_space, config):
config = dict(ray.rllib.ddpg.ddpg.DEFAULT_CONFIG, **config)
if not isinstance(action_space, Box):
raise UnsupportedSpaceException(
"Action space {} is not supported for DDPG.".format(
@@ -232,7 +233,8 @@ class DDPGPolicyGraph(TFPolicyGraph):
]
self.is_training = tf.placeholder_with_default(True, ())
TFPolicyGraph.__init__(
self, self.sess, obs_input=self.cur_observations,
self, observation_space, action_space, self.sess,
obs_input=self.cur_observations,
action_sampler=self.output_actions, loss=self.loss,
loss_inputs=self.loss_inputs, is_training=self.is_training)
self.sess.run(tf.global_variables_initializer())
+17 -10
View File
@@ -102,14 +102,18 @@ DEFAULT_CONFIG = {
# Whether to use a distribution of epsilons across workers for exploration.
"per_worker_exploration": False,
# Whether to compute priorities on workers.
"worker_side_prioritization": False
"worker_side_prioritization": False,
# === Multiagent ===
"multiagent": {
"policy_graphs": {},
"policy_mapping_fn": None,
},
}
class DQNAgent(Agent):
_agent_name = "DQN"
_allow_unknown_subkeys = [
"model", "optimizer", "tf_session_args", "env_config"]
_default_config = DEFAULT_CONFIG
_policy_graph = DQNPolicyGraph
@@ -125,7 +129,9 @@ class DQNAgent(Agent):
adjusted_batch_size = (
self.config["sample_batch_size"] + self.config["n_step"] - 1)
self.local_evaluator = CommonPolicyEvaluator(
self.env_creator, self._policy_graph,
self.env_creator,
self.config["multiagent"]["policy_graphs"] or self._policy_graph,
policy_mapping_fn=self.config["multiagent"]["policy_mapping_fn"],
batch_steps=adjusted_batch_size,
batch_mode="truncate_episodes", preprocessor_pref="deepmind",
compress_observations=True,
@@ -143,8 +149,9 @@ class DQNAgent(Agent):
compress_observations=True,
env_config=self.config["env_config"],
model_config=self.config["model"], policy_config=self.config,
num_envs=self.config["num_envs"])
for _ in range(self.config["num_workers"])]
num_envs=self.config["num_envs"],
worker_index=i+1)
for i in range(self.config["num_workers"])]
self.exploration0 = self._make_exploration_schedule(0)
self.explorations = [
@@ -185,7 +192,7 @@ class DQNAgent(Agent):
def update_target_if_needed(self):
if self.global_timestep - self.last_target_update_ts > \
self.config["target_network_update_freq"]:
self.local_evaluator.for_policy(lambda p: p.update_target())
self.local_evaluator.foreach_policy(lambda p, _: p.update_target())
self.last_target_update_ts = self.global_timestep
self.num_target_updates += 1
@@ -198,11 +205,11 @@ class DQNAgent(Agent):
self.update_target_if_needed()
exp_vals = [self.exploration0.value(self.global_timestep)]
self.local_evaluator.for_policy(
lambda p: p.set_epsilon(exp_vals[0]))
self.local_evaluator.foreach_policy(
lambda p, _: p.set_epsilon(exp_vals[0]))
for i, e in enumerate(self.remote_evaluators):
exp_val = self.explorations[i].value(self.global_timestep)
e.for_policy.remote(lambda p: p.set_epsilon(exp_val))
e.foreach_policy.remote(lambda p, _: p.set_epsilon(exp_val))
exp_vals.append(exp_val)
result = collect_metrics(
+4 -1
View File
@@ -7,6 +7,7 @@ import numpy as np
import tensorflow as tf
import tensorflow.contrib.layers as layers
import ray
from ray.rllib.models import ModelCatalog
from ray.rllib.optimizers.sample_batch import SampleBatch
from ray.rllib.utils.error import UnsupportedSpaceException
@@ -47,6 +48,7 @@ def adjust_nstep(n_step, gamma, obs, actions, rewards, new_obs, dones):
class DQNPolicyGraph(TFPolicyGraph):
def __init__(self, observation_space, action_space, config):
config = dict(ray.rllib.dqn.dqn.DEFAULT_CONFIG, **config)
if not isinstance(action_space, Discrete):
raise UnsupportedSpaceException(
"Action space {} is not supported for DQN.".format(
@@ -144,7 +146,8 @@ class DQNPolicyGraph(TFPolicyGraph):
]
self.is_training = tf.placeholder_with_default(True, ())
TFPolicyGraph.__init__(
self, self.sess, obs_input=self.cur_observations,
self, observation_space, action_space, self.sess,
obs_input=self.cur_observations,
action_sampler=self.output_actions, loss=self.loss,
loss_inputs=self.loss_inputs, is_training=self.is_training)
self.sess.run(tf.global_variables_initializer())
-1
View File
@@ -137,7 +137,6 @@ class Worker(object):
class ESAgent(agent.Agent):
_agent_name = "ES"
_default_config = DEFAULT_CONFIG
_allow_unknown_subkeys = ["env_config"]
@classmethod
def default_resource_request(cls, config):
@@ -0,0 +1,70 @@
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
"""Simple example of setting up a multi-agent policy mapping.
Control the number of agents and policies via --num-agents and --num-policies.
This works with hundreds of agents and policies, but note that initializing
many TF policy graphs will take some time.
Also, TF evals might slow down with large numbers of policies. To debug TF
execution, set the TF_TIMELINE_DIR environment variable.
"""
import argparse
import gym
import random
import ray
from ray.rllib.pg.pg import PGAgent
from ray.rllib.pg.pg_policy_graph import PGPolicyGraph
from ray.rllib.test.test_multi_agent_env import MultiCartpole
from ray.tune.logger import pretty_print
from ray.tune.registry import register_env
parser = argparse.ArgumentParser()
parser.add_argument("--num-agents", type=int, default=4)
parser.add_argument("--num-policies", type=int, default=2)
parser.add_argument("--num-iters", type=int, default=20)
if __name__ == "__main__":
args = parser.parse_args()
ray.init()
# Simple environment with `num_agents` independent cartpole entities
register_env("multi_cartpole", lambda _: MultiCartpole(args.num_agents))
single_env = gym.make("CartPole-v0")
obs_space = single_env.observation_space
act_space = single_env.action_space
def gen_policy():
config = {
"gamma": random.choice([0.5, 0.8, 0.9, 0.95, 0.99]),
"n_step": random.choice([1, 2, 3, 4, 5]),
}
return (PGPolicyGraph, obs_space, act_space, config)
# Setup PG with an ensemble of `num_policies` different policy graphs
policy_graphs = {
"policy_{}".format(i): gen_policy() for i in range(args.num_policies)
}
policy_ids = list(policy_graphs.keys())
agent = PGAgent(
env="multi_cartpole",
config={
"multiagent": {
"policy_graphs": policy_graphs,
"policy_mapping_fn": (
lambda agent_id: random.choice(policy_ids)),
},
})
for i in range(args.num_iters):
print("== Iteration", i, "==")
print(pretty_print(agent.train()))
@@ -217,6 +217,7 @@ class ApexOptimizer(PolicyOptimizer):
with self.timers["sample_processing"]:
for ev, sample_batch in self.sample_tasks.completed():
self._check_not_multiagent(sample_batch)
sample_timesteps += self.sample_batch_size
# Send the data to the replay buffer
@@ -20,6 +20,9 @@ class AsyncOptimizer(PolicyOptimizer):
self.dispatch_timer = TimerStat()
self.grads_per_step = grads_per_step
self.batch_size = batch_size
if not self.remote_evaluators:
raise ValueError(
"Async optimizer requires at least 1 remote evaluator")
def step(self):
weights = ray.put(self.local_evaluator.get_weights())
@@ -2,13 +2,15 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import collections
import numpy as np
import ray
from ray.rllib.optimizers.replay_buffer import ReplayBuffer, \
PrioritizedReplayBuffer
from ray.rllib.optimizers.policy_optimizer import PolicyOptimizer
from ray.rllib.optimizers.sample_batch import SampleBatch
from ray.rllib.optimizers.sample_batch import SampleBatch, DEFAULT_POLICY_ID, \
MultiAgentBatch
from ray.rllib.utils.compression import pack_if_needed
from ray.rllib.utils.filter import RunningStat
from ray.rllib.utils.timer import TimerStat
@@ -41,11 +43,15 @@ class LocalSyncReplayOptimizer(PolicyOptimizer):
# Set up replay buffer
if prioritized_replay:
self.replay_buffer = PrioritizedReplayBuffer(
buffer_size, alpha=prioritized_replay_alpha,
clip_rewards=clip_rewards)
def new_buffer():
return PrioritizedReplayBuffer(
buffer_size, alpha=prioritized_replay_alpha,
clip_rewards=clip_rewards)
else:
self.replay_buffer = ReplayBuffer(buffer_size, clip_rewards)
def new_buffer():
return ReplayBuffer(buffer_size, clip_rewards)
self.replay_buffers = collections.defaultdict(new_buffer)
assert buffer_size >= self.replay_starts
@@ -63,47 +69,64 @@ class LocalSyncReplayOptimizer(PolicyOptimizer):
[e.sample.remote() for e in self.remote_evaluators]))
else:
batch = self.local_evaluator.sample()
for row in batch.rows():
self.replay_buffer.add(
pack_if_needed(row["obs"]), row["actions"], row["rewards"],
pack_if_needed(row["new_obs"]),
row["dones"], row["weights"])
if len(self.replay_buffer) >= self.replay_starts:
# Handle everything as if multiagent
if isinstance(batch, SampleBatch):
batch = MultiAgentBatch(
{DEFAULT_POLICY_ID: batch}, batch.count)
for policy_id, s in batch.policy_batches.items():
for row in s.rows():
if "weights" not in row:
row["weights"] = np.ones_like(row["rewards"])
self.replay_buffers[policy_id].add(
pack_if_needed(row["obs"]), row["actions"],
row["rewards"], pack_if_needed(row["new_obs"]),
row["dones"], row["weights"])
if self.num_steps_sampled >= self.replay_starts:
self._optimize()
self.num_steps_sampled += batch.count
def _optimize(self):
with self.replay_timer:
if isinstance(self.replay_buffer, PrioritizedReplayBuffer):
(obses_t, actions, rewards, obses_tp1,
dones, weights, batch_indexes) = self.replay_buffer.sample(
self.train_batch_size,
beta=self.prioritized_replay_beta)
else:
(obses_t, actions, rewards, obses_tp1,
dones) = self.replay_buffer.sample(
self.train_batch_size)
weights = np.ones_like(rewards)
batch_indexes = - np.ones_like(rewards)
samples = SampleBatch({
"obs": obses_t, "actions": actions, "rewards": rewards,
"new_obs": obses_tp1, "dones": dones, "weights": weights,
"batch_indexes": batch_indexes})
samples = self._replay()
with self.grad_timer:
info = self.local_evaluator.compute_apply(samples)
if isinstance(self.replay_buffer, PrioritizedReplayBuffer):
td_error = info["td_error"]
new_priorities = (
np.abs(td_error) + self.prioritized_replay_eps)
self.replay_buffer.update_priorities(
samples["batch_indexes"], new_priorities)
info_dict = self.local_evaluator.compute_apply(samples)
for policy_id, info in info_dict.items():
replay_buffer = self.replay_buffers[policy_id]
if isinstance(replay_buffer, PrioritizedReplayBuffer):
td_error = info["td_error"]
new_priorities = (
np.abs(td_error) + self.prioritized_replay_eps)
replay_buffer.update_priorities(
samples.policy_batches[policy_id]["batch_indexes"],
new_priorities)
self.grad_timer.push_units_processed(samples.count)
self.num_steps_trained += samples.count
def _replay(self):
samples = {}
with self.replay_timer:
for policy_id, replay_buffer in self.replay_buffers.items():
if isinstance(replay_buffer, PrioritizedReplayBuffer):
(obses_t, actions, rewards, obses_tp1,
dones, weights, batch_indexes) = replay_buffer.sample(
self.train_batch_size,
beta=self.prioritized_replay_beta)
else:
(obses_t, actions, rewards, obses_tp1,
dones) = replay_buffer.sample(self.train_batch_size)
weights = np.ones_like(rewards)
batch_indexes = - np.ones_like(rewards)
samples[policy_id] = SampleBatch({
"obs": obses_t, "actions": actions, "rewards": rewards,
"new_obs": obses_tp1, "dones": dones, "weights": weights,
"batch_indexes": batch_indexes})
return MultiAgentBatch(samples, self.train_batch_size)
def stats(self):
return dict(PolicyOptimizer.stats(self), **{
"sample_time_ms": round(1000 * self.sample_timer.mean, 3),
+1 -2
View File
@@ -9,7 +9,6 @@ import tensorflow as tf
import ray
from ray.rllib.optimizers.policy_evaluator import TFMultiGPUSupport
from ray.rllib.optimizers.policy_optimizer import PolicyOptimizer
from ray.rllib.optimizers.sample_batch import SampleBatch
from ray.rllib.optimizers.multi_gpu_impl import LocalSyncParallelOptimizer
from ray.rllib.utils.timer import TimerStat
@@ -90,7 +89,7 @@ class LocalMultiGPUOptimizer(PolicyOptimizer):
self.timesteps_per_batch)
else:
samples = self.local_evaluator.sample()
assert isinstance(samples, SampleBatch)
self._check_not_multiagent(samples)
if postprocess_fn:
postprocess_fn(samples)
@@ -20,7 +20,8 @@ class PolicyEvaluator(object):
This method must be implemented by subclasses.
Returns:
SampleBatch: A columnar batch of experiences (e.g., tensors).
SampleBatch|MultiAgentBatch: A columnar batch of experiences
(e.g., tensors), or a multi-agent batch.
Examples:
>>> print(ev.sample())
@@ -35,8 +36,10 @@ class PolicyEvaluator(object):
This method must be implemented by subclasses.
Returns:
object: A gradient that can be applied on a compatible evaluator.
info: dictionary of extra metadata.
(grads, info): A list of gradients that can be applied on a
compatible evaluator. In the multi-agent case, returns a dict
of gradients keyed by policy graph ids. An info dictionary of
extra metadata is also returned.
Examples:
>>> batch = ev.sample()
@@ -3,6 +3,7 @@ from __future__ import division
from __future__ import print_function
import ray
from ray.rllib.optimizers.sample_batch import MultiAgentBatch
class PolicyOptimizer(object):
@@ -54,8 +55,8 @@ class PolicyOptimizer(object):
else:
local_evaluator = evaluator_cls(**evaluator_args)
remote_evaluators = [
remote_cls.remote(**evaluator_args)
for _ in range(num_workers)]
remote_cls.remote(worker_index=i+1, **evaluator_args)
for i in range(num_workers)]
return cls(optimizer_config, local_evaluator, remote_evaluators)
def __init__(self, config, local_evaluator, remote_evaluators):
@@ -130,3 +131,8 @@ class PolicyOptimizer(object):
[ev.apply.remote(func, i + 1)
for i, ev in enumerate(self.remote_evaluators)])
return local_result + remote_results
def _check_not_multiagent(self, sample_batch):
if isinstance(sample_batch, MultiAgentBatch):
raise NotImplementedError(
"This optimizer does not support multi-agent yet.")
+47 -10
View File
@@ -6,6 +6,10 @@ import collections
import numpy as np
# Defaults policy id for single agent environments
DEFAULT_POLICY_ID = "default"
class SampleBatchBuilder(object):
"""Util to build a SampleBatch incrementally.
@@ -107,7 +111,7 @@ class MultiAgentSampleBatchBuilder(object):
pre_batch, other_batches)
# Append into policy batches and reset
for agent_id, post_batch in post_batches.items():
for agent_id, post_batch in sorted(post_batches.items()):
self.policy_builders[self.agent_to_policy[agent_id]].add_batch(
post_batch)
self.agent_builders.clear()
@@ -122,33 +126,62 @@ class MultiAgentSampleBatchBuilder(object):
self.postprocess_batch_so_far()
policy_batches = {}
for policy_id, policy_batch_builder in self.policy_builders.items():
policy_batches[policy_id] = policy_batch_builder.build_and_reset()
for policy_id, builder in self.policy_builders.items():
if builder.count > 0:
policy_batches[policy_id] = builder.build_and_reset()
old_count = self.count
self.count = 0
return MultiAgentBatch.wrap_as_needed(policy_batches)
return MultiAgentBatch.wrap_as_needed(policy_batches, old_count)
class MultiAgentBatch(object):
def __init__(self, policy_batches):
"""A batch of experiences from multiple policies in the environment.
Attributes:
policy_batches (dict): Mapping from policy id to a normal SampleBatch
of experiences. Note that these batches may be of different length.
count (int): The number of timesteps in the environment this batch
contains. This will be less than the number of transitions this
batch contains across all policies in total.
"""
def __init__(self, policy_batches, count):
self.policy_batches = policy_batches
self.count = count
@staticmethod
def wrap_as_needed(batches):
if len(batches) == 1 and "default" in batches:
return batches["default"]
return MultiAgentBatch(batches)
def wrap_as_needed(batches, count):
if len(batches) == 1 and DEFAULT_POLICY_ID in batches:
return batches[DEFAULT_POLICY_ID]
return MultiAgentBatch(batches, count)
@staticmethod
def concat_samples(samples):
policy_batches = collections.defaultdict(list)
total_count = 0
for s in samples:
assert isinstance(s, MultiAgentBatch)
for policy_id, batch in s.policy_batches.items():
policy_batches[policy_id].append(batch)
total_count += s.count
out = {}
for policy_id, batches in policy_batches.items():
out[policy_id] = SampleBatch.concat_samples(batches)
return MultiAgentBatch(out)
return MultiAgentBatch(out, total_count)
def total(self):
ct = 0
for batch in self.policy_batches.values():
ct += batch.count
return ct
def __str__(self):
return "MultiAgentBatch({}, count={})".format(
str(self.policy_batches), self.count)
def __repr__(self):
return "MultiAgentBatch({}, count={})".format(
str(self.policy_batches), self.count)
class SampleBatch(object):
@@ -166,11 +199,15 @@ class SampleBatch(object):
for k, v in self.data.copy().items():
assert type(k) == str, self
lengths.append(len(v))
if not lengths:
raise ValueError("Empty sample batch")
assert len(set(lengths)) == 1, "data columns must be same length"
self.count = lengths[0]
@staticmethod
def concat_samples(samples):
if isinstance(samples[0], MultiAgentBatch):
return MultiAgentBatch.concat_samples(samples)
out = {}
samples = [s for s in samples if s.count > 0]
for k in samples[0].keys():
+11 -1
View File
@@ -29,6 +29,12 @@ DEFAULT_CONFIG = {
"model": {"fcnet_hiddens": [128, 128]},
# Arguments to pass to the env creator
"env_config": {},
# === Multiagent ===
"multiagent": {
"policy_graphs": {},
"policy_mapping_fn": None,
},
}
@@ -52,7 +58,11 @@ class PGAgent(Agent):
evaluator_cls=CommonPolicyEvaluator,
evaluator_args={
"env_creator": self.env_creator,
"policy_graph": PGPolicyGraph,
"policy_graph": (
self.config["multiagent"]["policy_graphs"] or
PGPolicyGraph),
"policy_mapping_fn":
self.config["multiagent"]["policy_mapping_fn"],
"batch_steps": self.config["batch_size"],
"batch_mode": "truncate_episodes",
"model_config": self.config["model"],
+3 -1
View File
@@ -4,6 +4,7 @@ from __future__ import print_function
import tensorflow as tf
import ray
from ray.rllib.models.catalog import ModelCatalog
from ray.rllib.utils.process_rollout import compute_advantages
from ray.rllib.utils.tf_policy_graph import TFPolicyGraph
@@ -12,6 +13,7 @@ from ray.rllib.utils.tf_policy_graph import TFPolicyGraph
class PGPolicyGraph(TFPolicyGraph):
def __init__(self, obs_space, action_space, config):
config = dict(ray.rllib.pg.pg.DEFAULT_CONFIG, **config)
self.config = config
# setup policy
@@ -36,7 +38,7 @@ class PGPolicyGraph(TFPolicyGraph):
]
self.is_training = tf.placeholder_with_default(True, ())
TFPolicyGraph.__init__(
self, self.sess, obs_input=self.x,
self, obs_space, action_space, self.sess, obs_input=self.x,
action_sampler=self.dist.sample(), loss=self.loss,
loss_inputs=self.loss_in, is_training=self.is_training)
self.sess.run(tf.global_variables_initializer())
-1
View File
@@ -88,7 +88,6 @@ DEFAULT_CONFIG = {
class PPOAgent(Agent):
_agent_name = "PPO"
_allow_unknown_subkeys = ["model", "tf_session_args", "env_config"]
_default_config = DEFAULT_CONFIG
@classmethod
@@ -48,6 +48,22 @@ class MockEnv(gym.Env):
return 0, 1, self.i >= self.episode_length, {}
class MockEnv2(gym.Env):
def __init__(self, episode_length):
self.episode_length = episode_length
self.i = 0
self.observation_space = gym.spaces.Discrete(100)
self.action_space = gym.spaces.Discrete(2)
def reset(self):
self.i = 0
return self.i
def step(self, action):
self.i += 1
return self.i, 100, self.i >= self.episode_length, {}
class MockVectorEnv(VectorEnv):
def __init__(self, episode_length, num_envs):
self.envs = [
+217 -18
View File
@@ -2,12 +2,23 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import gym
import random
import unittest
import ray
from ray.rllib.test.test_common_policy_evaluator import MockEnv
from ray.rllib.pg import PGAgent
from ray.rllib.pg.pg_policy_graph import PGPolicyGraph
from ray.rllib.dqn.dqn_policy_graph import DQNPolicyGraph
from ray.rllib.optimizers import LocalSyncOptimizer, \
LocalSyncReplayOptimizer, AsyncOptimizer
from ray.rllib.test.test_common_policy_evaluator import MockEnv, MockEnv2, \
MockPolicyGraph
from ray.rllib.utils.common_policy_evaluator import CommonPolicyEvaluator, \
collect_metrics
from ray.rllib.utils.async_vector_env import _MultiAgentEnvToAsync
from ray.rllib.utils.multi_agent_env import MultiAgentEnv
from ray.tune.registry import register_env
class BasicMultiAgent(MultiAgentEnv):
@@ -16,6 +27,8 @@ class BasicMultiAgent(MultiAgentEnv):
def __init__(self, num):
self.agents = [MockEnv(25) for _ in range(num)]
self.dones = set()
self.observation_space = gym.spaces.Discrete(2)
self.action_space = gym.spaces.Discrete(2)
def reset(self):
self.dones = set()
@@ -36,8 +49,13 @@ class RoundRobinMultiAgent(MultiAgentEnv):
On each step() of the env, only one agent takes an action."""
def __init__(self, num):
self.agents = [MockEnv(5) for _ in range(num)]
def __init__(self, num, increment_obs=False):
if increment_obs:
# Observations are 0, 1, 2, 3... etc. as time advances
self.agents = [MockEnv2(5) for _ in range(num)]
else:
# Observations are all zeros
self.agents = [MockEnv(5) for _ in range(num)]
self.dones = set()
self.last_obs = {}
self.last_rew = {}
@@ -45,24 +63,59 @@ class RoundRobinMultiAgent(MultiAgentEnv):
self.last_info = {}
self.i = 0
self.num = num
self.observation_space = gym.spaces.Discrete(2)
self.action_space = gym.spaces.Discrete(2)
def reset(self):
self.dones = set()
return {i: a.reset() for i, a in enumerate(self.agents)}
self.last_obs = {}
self.last_rew = {}
self.last_done = {}
self.last_info = {}
self.i = 0
for i, a in enumerate(self.agents):
self.last_obs[i] = a.reset()
self.last_rew[i] = None
self.last_done[i] = False
self.last_info[i] = {}
obs_dict = {self.i: self.last_obs[self.i]}
self.i = (self.i + 1) % self.num
return obs_dict
def step(self, action_dict):
assert len(self.dones) != len(self.agents)
for i, action in action_dict.items():
(self.last_obs[i], self.last_rew[i], self.last_done[i],
self.last_info[i]) = self.agents[i].step(action)
if self.last_done[i]:
obs = {self.i: self.last_obs[self.i]}
rew = {self.i: self.last_rew[self.i]}
done = {self.i: self.last_done[self.i]}
info = {self.i: self.last_info[self.i]}
if done[self.i]:
rew[self.i] = 0
self.dones.add(self.i)
self.i = (self.i + 1) % self.num
done["__all__"] = len(self.dones) == len(self.agents)
return obs, rew, done, info
class MultiCartpole(MultiAgentEnv):
def __init__(self, num):
self.agents = [gym.make("CartPole-v0") for _ in range(num)]
self.dones = set()
self.observation_space = self.agents[0].observation_space
self.action_space = self.agents[0].action_space
def reset(self):
self.dones = set()
return {i: a.reset() for i, a in enumerate(self.agents)}
def step(self, action_dict):
obs, rew, done, info = {}, {}, {}, {}
for i, action in action_dict.items():
obs[i], rew[i], done[i], info[i] = self.agents[i].step(action)
if done[i]:
self.dones.add(i)
obs = {self.i: self.last_obs[i]}
rew = {self.i: self.last_rew[i]}
done = {self.i: self.last_done[i]}
info = {self.i: self.last_info[i]}
self.i += 1
self.i %= self.num
done["__all__"] = len(self.dones) == len(self.agents)
return obs, rew, done, info
@@ -86,15 +139,15 @@ class TestMultiAgentEnv(unittest.TestCase):
def testRoundRobinMock(self):
env = RoundRobinMultiAgent(2)
obs = env.reset()
self.assertEqual(obs, {0: 0, 1: 0})
obs, rew, done, info = env.step({0: 0, 1: 0})
self.assertEqual(obs, {0: 0})
for _ in range(4):
for _ in range(5):
obs, rew, done, info = env.step({0: 0})
self.assertEqual(obs, {1: 0})
self.assertEqual(done["__all__"], False)
obs, rew, done, info = env.step({1: 0})
self.assertEqual(obs, {0: 0})
self.assertEqual(done["__all__"], False)
obs, rew, done, info = env.step({0: 0})
self.assertEqual(done["__all__"], True)
def testVectorizeBasic(self):
@@ -140,14 +193,160 @@ class TestMultiAgentEnv(unittest.TestCase):
def testVectorizeRoundRobin(self):
env = _MultiAgentEnvToAsync(lambda: RoundRobinMultiAgent(2), [], 2)
obs, rew, dones, _, _ = env.poll()
self.assertEqual(obs, {0: {0: 0, 1: 0}, 1: {0: 0, 1: 0}})
self.assertEqual(rew, {0: {0: None, 1: None}, 1: {0: None, 1: None}})
env.send_actions({0: {0: 0}, 1: {0: 0}})
obs, rew, dones, _, _ = env.poll()
self.assertEqual(obs, {0: {0: 0}, 1: {0: 0}})
self.assertEqual(rew, {0: {0: None}, 1: {0: None}})
env.send_actions({0: {0: 0}, 1: {0: 0}})
obs, rew, dones, _, _ = env.poll()
self.assertEqual(obs, {0: {1: 0}, 1: {1: 0}})
env.send_actions({0: {1: 0}, 1: {1: 0}})
obs, rew, dones, _, _ = env.poll()
self.assertEqual(obs, {0: {0: 0}, 1: {0: 0}})
def testMultiAgentSample(self):
act_space = gym.spaces.Discrete(2)
obs_space = gym.spaces.Discrete(2)
ev = CommonPolicyEvaluator(
env_creator=lambda _: BasicMultiAgent(5),
policy_graph={
"p0": (MockPolicyGraph, obs_space, act_space, {}),
"p1": (MockPolicyGraph, obs_space, act_space, {}),
},
policy_mapping_fn=lambda agent_id: "p{}".format(agent_id % 2),
batch_steps=50)
batch = ev.sample()
self.assertEqual(batch.count, 50)
self.assertEqual(batch.policy_batches["p0"].count, 150)
self.assertEqual(batch.policy_batches["p1"].count, 100)
self.assertEqual(
batch.policy_batches["p0"]["t"].tolist(),
list(range(25)) * 6)
def testMultiAgentSampleRoundRobin(self):
act_space = gym.spaces.Discrete(2)
obs_space = gym.spaces.Discrete(2)
ev = CommonPolicyEvaluator(
env_creator=lambda _: RoundRobinMultiAgent(5, increment_obs=True),
policy_graph={
"p0": (MockPolicyGraph, obs_space, act_space, {}),
},
policy_mapping_fn=lambda agent_id: "p0",
batch_steps=50)
batch = ev.sample()
self.assertEqual(batch.count, 50)
# since we round robin introduce agents into the env, some of the env
# steps don't count as proper transitions
self.assertEqual(batch.policy_batches["p0"].count, 42)
self.assertEqual(
batch.policy_batches["p0"]["obs"].tolist()[:10],
[0, 1, 2, 3, 4] * 2)
self.assertEqual(
batch.policy_batches["p0"]["new_obs"].tolist()[:10],
[1, 2, 3, 4, 5] * 2)
self.assertEqual(
batch.policy_batches["p0"]["rewards"].tolist()[:10],
[100, 100, 100, 100, 0] * 2)
self.assertEqual(
batch.policy_batches["p0"]["dones"].tolist()[:10],
[False, False, False, False, True] * 2)
self.assertEqual(
batch.policy_batches["p0"]["t"].tolist()[:10],
[4, 9, 14, 19, 24, 5, 10, 15, 20, 25])
def testTrainMultiCartpoleSinglePolicy(self):
n = 10
register_env("multi_cartpole", lambda _: MultiCartpole(n))
pg = PGAgent(env="multi_cartpole", config={"num_workers": 0})
for i in range(100):
result = pg.train()
print("Iteration {}, reward {}, timesteps {}".format(
i, result.episode_reward_mean, result.timesteps_total))
if result.episode_reward_mean >= 50 * n:
return
raise Exception("failed to improve reward")
def _testWithOptimizer(self, optimizer_cls):
n = 3
env = gym.make("CartPole-v0")
act_space = env.action_space
obs_space = env.observation_space
dqn_config = {"gamma": 0.95, "n_step": 3}
if optimizer_cls == LocalSyncReplayOptimizer:
# TODO: support replay with non-DQN graphs. Currently this can't
# happen since the replay buffer doesn't encode extra fields like
# "advantages" that PG uses.
policies = {
"p1": (DQNPolicyGraph, obs_space, act_space, {}),
"p2": (DQNPolicyGraph, obs_space, act_space, dqn_config),
}
else:
policies = {
"p1": (PGPolicyGraph, obs_space, act_space, dqn_config),
"p2": (DQNPolicyGraph, obs_space, act_space, dqn_config),
}
ev = CommonPolicyEvaluator(
env_creator=lambda _: MultiCartpole(n),
policy_graph=policies,
policy_mapping_fn=lambda agent_id: ["p1", "p2"][agent_id % 2],
batch_steps=50)
if optimizer_cls == AsyncOptimizer:
remote_evs = [CommonPolicyEvaluator.as_remote().remote(
env_creator=lambda _: MultiCartpole(n),
policy_graph=policies,
policy_mapping_fn=lambda agent_id: ["p1", "p2"][agent_id % 2],
batch_steps=50)]
else:
remote_evs = []
optimizer = optimizer_cls({}, ev, remote_evs)
ev.foreach_policy(
lambda p, _: p.set_epsilon(0.02)
if isinstance(p, DQNPolicyGraph) else None)
for i in range(200):
optimizer.step()
result = collect_metrics(ev, remote_evs)
if i % 20 == 0:
ev.foreach_policy(
lambda p, _: p.update_target()
if isinstance(p, DQNPolicyGraph) else None)
print("Iter {}, rew {}".format(i, result.policy_reward_mean))
print("Total reward", result.episode_reward_mean)
if result.episode_reward_mean >= 25 * n:
return
print(result)
raise Exception("failed to improve reward")
def testMultiAgentSyncOptimizer(self):
self._testWithOptimizer(LocalSyncOptimizer)
def testMultiAgentAsyncOptimizer(self):
self._testWithOptimizer(AsyncOptimizer)
def testMultiAgentReplayOptimizer(self):
self._testWithOptimizer(LocalSyncReplayOptimizer)
def testTrainMultiCartpoleManyPolicies(self):
n = 20
env = gym.make("CartPole-v0")
act_space = env.action_space
obs_space = env.observation_space
policies = {}
for i in range(20):
policies["pg_{}".format(i)] = (
PGPolicyGraph, obs_space, act_space, {})
policy_ids = list(policies.keys())
ev = CommonPolicyEvaluator(
env_creator=lambda _: MultiCartpole(n),
policy_graph=policies,
policy_mapping_fn=lambda agent_id: random.choice(policy_ids),
batch_steps=100)
optimizer = LocalSyncOptimizer({}, ev, [])
for i in range(100):
optimizer.step()
result = collect_metrics(ev)
print("Iteration {}, rew {}".format(i, result.policy_reward_mean))
print("Total reward", result.episode_reward_mean)
if result.episode_reward_mean >= 25 * n:
return
raise Exception("failed to improve reward")
if __name__ == '__main__':
+4 -6
View File
@@ -284,14 +284,12 @@ class _MultiAgentEnvState(object):
self.reset()
def poll(self):
if self.last_obs is None:
raise ValueError("Need to send action after polling")
obs, rew, dones, info = (
self.last_obs, self.last_rewards, self.last_dones, self.last_infos)
self.last_obs = None
self.last_rewards = None
self.last_dones = None
self.last_infos = None
self.last_obs = {}
self.last_rewards = {}
self.last_dones = {"__all__": False}
self.last_infos = {}
return obs, rew, dones, info
def observe(self, obs, rewards, dones, infos):
+208 -47
View File
@@ -2,31 +2,38 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import pickle
import collections
import gym
import numpy as np
import pickle
import tensorflow as tf
import ray
from ray.rllib.models import ModelCatalog
from ray.rllib.optimizers import MultiAgentBatch
from ray.rllib.optimizers.policy_evaluator import PolicyEvaluator
from ray.rllib.optimizers.sample_batch import MultiAgentBatch, \
DEFAULT_POLICY_ID
from ray.rllib.utils.async_vector_env import AsyncVectorEnv
from ray.rllib.utils.atari_wrappers import wrap_deepmind, is_atari
from ray.rllib.utils.compression import pack
from ray.rllib.utils.env_context import EnvContext
from ray.rllib.utils.filter import get_filter
from ray.rllib.utils.multi_agent_env import MultiAgentEnv
from ray.rllib.utils.policy_graph import PolicyGraph
from ray.rllib.utils.sampler import AsyncSampler, SyncSampler
from ray.rllib.utils.serving_env import ServingEnv
from ray.rllib.utils.tf_policy_graph import TFPolicyGraph
from ray.rllib.utils.tf_run_builder import TFRunBuilder
from ray.rllib.utils.vector_env import VectorEnv
from ray.tune.result import TrainingResult
def collect_metrics(local_evaluator, remote_evaluators):
def collect_metrics(local_evaluator, remote_evaluators=[]):
"""Gathers episode metrics from CommonPolicyEvaluator instances."""
episode_rewards = []
episode_lengths = []
policy_rewards = collections.defaultdict(list)
metric_lists = ray.get(
[a.apply.remote(lambda ev: ev.sampler.get_metrics())
for a in remote_evaluators])
@@ -35,6 +42,8 @@ def collect_metrics(local_evaluator, remote_evaluators):
for episode in metrics:
episode_lengths.append(episode.episode_length)
episode_rewards.append(episode.episode_reward)
for (_, policy_id), reward in episode.agent_rewards.items():
policy_rewards[policy_id].append(reward)
if episode_rewards:
min_reward = min(episode_rewards)
max_reward = max(episode_rewards)
@@ -45,19 +54,22 @@ def collect_metrics(local_evaluator, remote_evaluators):
avg_length = np.mean(episode_lengths)
timesteps = np.sum(episode_lengths)
for policy_id, rewards in policy_rewards.copy().items():
policy_rewards[policy_id] = np.mean(rewards)
return TrainingResult(
episode_reward_max=max_reward,
episode_reward_min=min_reward,
episode_reward_mean=avg_reward,
episode_len_mean=avg_length,
episodes_total=len(episode_lengths),
timesteps_this_iter=timesteps)
timesteps_this_iter=timesteps,
policy_reward_mean=dict(policy_rewards))
class CommonPolicyEvaluator(PolicyEvaluator):
"""Policy evaluator implementation that operates on a rllib.PolicyGraph.
TODO: multi-agent
TODO: multi-gpu
Examples:
@@ -65,9 +77,10 @@ class CommonPolicyEvaluator(PolicyEvaluator):
>>> evaluator = CommonPolicyEvaluator(
env_creator=lambda _: gym.make("CartPole-v0"),
policy_graph=PGPolicyGraph)
>>> print(evaluator.sample().keys())
{"obs": [[...]], "actions": [[...]], "rewards": [[...]],
"dones": [[...]], "new_obs": [[...]]}
>>> print(evaluator.sample())
SampleBatch({
"obs": [[...]], "actions": [[...]], "rewards": [[...]],
"dones": [[...]], "new_obs": [[...]]})
# Creating policy evaluators using optimizer_cls.make().
>>> optimizer = LocalSyncOptimizer.make(
@@ -78,6 +91,28 @@ class CommonPolicyEvaluator(PolicyEvaluator):
},
num_workers=10)
>>> for _ in range(10): optimizer.step()
# Creating a multi-agent policy evaluator
>>> evaluator = CommonPolicyEvaluator(
env_creator=lambda _: MultiAgentTrafficGrid(num_cars=25),
policy_graph={
# Use an ensemble of two policies for car agents
"car_policy1":
(PGPolicyGraph, Box(...), Discrete(...), {"gamma": 0.99}),
"car_policy2":
(PGPolicyGraph, Box(...), Discrete(...), {"gamma": 0.95}),
# Use a single shared policy for all traffic lights
"traffic_light_policy":
(PGPolicyGraph, Box(...), Discrete(...), {}),
},
policy_mapping_fn=lambda agent_id:
random.choice(["car_policy1", "car_policy2"])
if agent_id.startswith("car_") else "traffic_light_policy")
>>> print(evaluator.sample().keys())
MultiAgentBatch({
"car_policy1": SampleBatch(...),
"car_policy2": SampleBatch(...),
"traffic_light_policy": SampleBatch(...)})
"""
@classmethod
@@ -88,6 +123,7 @@ class CommonPolicyEvaluator(PolicyEvaluator):
self,
env_creator,
policy_graph,
policy_mapping_fn=None,
tf_session_creator=None,
batch_steps=100,
batch_mode="truncate_episodes",
@@ -99,14 +135,22 @@ class CommonPolicyEvaluator(PolicyEvaluator):
observation_filter="NoFilter",
env_config=None,
model_config=None,
policy_config=None):
policy_config=None,
worker_index=0):
"""Initialize a policy evaluator.
Arguments:
env_creator (func): Function that returns a gym.Env given an
env config dict.
policy_graph (class): A class implementing rllib.PolicyGraph or
rllib.TFPolicyGraph.
EnvContext wrapped configuration.
policy_graph (class|dict): Either a class implementing
PolicyGraph, or a dictionary of policy id strings to
(PolicyGraph, obs_space, action_space, config) tuples. If a
dict is specified, then we are in multi-agent mode and a
policy_mapping_fn should also be set.
policy_mapping_fn (func): A function that maps agent ids to
policy ids in multi-agent mode. This function will be called
each time a new agent appears in an episode, to bind that agent
to a policy for the duration of the episode.
tf_session_creator (func): A function that returns a TF session.
This is optional and only useful with TFPolicyGraph.
batch_steps (int): The target number of env transitions to include
@@ -138,19 +182,26 @@ class CommonPolicyEvaluator(PolicyEvaluator):
observation_filter (str): Name of observation filter to use.
env_config (dict): Config to pass to the env creator.
model_config (dict): Config to use when creating the policy model.
policy_config (dict): Config to pass to the policy.
policy_config (dict): Config to pass to the policy. In the
multi-agent case, this config will be merged with the
per-policy configs specified by `policy_graph`.
worker_index (int): For remote evaluators, this should be set to a
non-zero and unique value. This index is passed to created envs
through EnvContext so that envs can be configured per worker.
"""
env_config = env_config or {}
env_context = EnvContext(env_config or {}, worker_index)
policy_config = policy_config or {}
model_config = model_config or {}
policy_mapping_fn = (
policy_mapping_fn or (lambda agent_id: DEFAULT_POLICY_ID))
self.env_creator = env_creator
self.policy_graph = policy_graph
self.batch_steps = batch_steps
self.batch_mode = batch_mode
self.compress_observations = compress_observations
self.env = env_creator(env_config)
self.env = env_creator(env_context)
if isinstance(self.env, VectorEnv) or \
isinstance(self.env, ServingEnv) or \
isinstance(self.env, MultiAgentEnv) or \
@@ -169,32 +220,29 @@ class CommonPolicyEvaluator(PolicyEvaluator):
self.env = wrap(self.env)
def make_env():
return wrap(env_creator(env_config))
return wrap(env_creator(env_context))
if issubclass(policy_graph, TFPolicyGraph):
self.tf_sess = None
policy_dict = _validate_and_canonicalize(policy_graph, self.env)
if _has_tensorflow_graph(policy_dict):
with tf.Graph().as_default():
if tf_session_creator:
self.sess = tf_session_creator()
self.tf_sess = tf_session_creator()
else:
self.sess = tf.Session(config=tf.ConfigProto(
self.tf_sess = tf.Session(config=tf.ConfigProto(
gpu_options=tf.GPUOptions(allow_growth=True)))
with self.sess.as_default():
policy = policy_graph(
self.env.observation_space, self.env.action_space,
policy_config)
with self.tf_sess.as_default():
self.policy_map = self._build_policy_map(
policy_dict, policy_config)
else:
policy = policy_graph(
self.env.observation_space, self.env.action_space,
policy_config)
self.policy_map = self._build_policy_map(
policy_dict, policy_config)
self.policy_map = {
"default": policy
}
self.multiagent = self.policy_map.keys() != set(DEFAULT_POLICY_ID)
self.filters = {
# TODO(ekl) make the obs space dependent on policy
policy_id: get_filter(
observation_filter, self.env.observation_space.shape)
observation_filter, policy.observation_space.shape)
for (policy_id, policy) in self.policy_map.items()
}
@@ -218,15 +266,25 @@ class CommonPolicyEvaluator(PolicyEvaluator):
"Unsupported batch mode: {}".format(self.batch_mode))
if sample_async:
self.sampler = AsyncSampler(
self.async_env, self.policy_map, lambda agent_id: "default",
self.async_env, self.policy_map, policy_mapping_fn,
self.filters, batch_steps, horizon=episode_horizon,
pack=pack_episodes)
pack=pack_episodes, tf_sess=self.tf_sess)
self.sampler.start()
else:
self.sampler = SyncSampler(
self.async_env, self.policy_map, lambda agent_id: "default",
self.async_env, self.policy_map, policy_mapping_fn,
self.filters, batch_steps, horizon=episode_horizon,
pack=pack_episodes)
pack=pack_episodes, tf_sess=self.tf_sess)
def _build_policy_map(self, policy_dict, policy_config):
policy_map = {}
for name, (cls, obs_space, act_space, conf) in sorted(
policy_dict.items()):
merged_conf = policy_config.copy()
merged_conf.update(conf)
with tf.variable_scope(name):
policy_map[name] = cls(obs_space, act_space, merged_conf)
return policy_map
def sample(self):
"""Evaluate the current policies and return a batch of experiences.
@@ -254,10 +312,15 @@ class CommonPolicyEvaluator(PolicyEvaluator):
return batch
def for_policy(self, func):
"""Apply the given function to this evaluator's default policy."""
def for_policy(self, func, policy_id=DEFAULT_POLICY_ID):
"""Apply the given function to the specified policy graph."""
return func(self.policy_map["default"])
return func(self.policy_map[policy_id])
def foreach_policy(self, func):
"""Apply the given function to each (policy, policy_id) tuple."""
return [func(policy, pid) for pid, policy in self.policy_map.items()]
def sync_filters(self, new_filters):
"""Changes self's filter to given and rebases any accumulated delta.
@@ -286,28 +349,126 @@ class CommonPolicyEvaluator(PolicyEvaluator):
return return_filters
def get_weights(self):
return self.policy_map["default"].get_weights()
return {
pid: policy.get_weights()
for pid, policy in self.policy_map.items()}
def set_weights(self, weights):
return self.policy_map["default"].set_weights(weights)
for pid, w in weights.items():
self.policy_map[pid].set_weights(w)
def compute_gradients(self, samples):
return self.policy_map["default"].compute_gradients(samples)
if isinstance(samples, MultiAgentBatch):
grad_out, info_out = {}, {}
if self.tf_sess is not None:
builder = TFRunBuilder(self.tf_sess, "compute_gradients")
for pid, batch in samples.policy_batches.items():
grad_out[pid], info_out[pid] = (
self.policy_map[pid].build_compute_gradients(
builder, batch))
grad_out = {k: builder.get(v) for k, v in grad_out.items()}
info_out = {k: builder.get(v) for k, v in info_out.items()}
else:
for pid, batch in samples.policy_batches.items():
grad_out[pid], info_out[pid] = (
self.policy_map[pid].compute_gradients(batch))
return grad_out, info_out
else:
return self.policy_map[DEFAULT_POLICY_ID].compute_gradients(
samples)
def apply_gradients(self, grads):
return self.policy_map["default"].apply_gradients(grads)
if isinstance(grads, dict):
if self.tf_sess is not None:
builder = TFRunBuilder(self.tf_sess, "apply_gradients")
outputs = {
pid: self.policy_map[pid].build_apply_gradients(
builder, grad)
for pid, grad in grads.items()
}
return {
k: builder.get(v) for k, v in outputs.items()
}
else:
return {
pid: self.policy_map[pid].apply_gradients(g)
for pid, g in grads.items()
}
else:
return self.policy_map[DEFAULT_POLICY_ID].apply_gradients(grads)
def compute_apply(self, samples):
grad_fetch, apply_fetch = self.policy_map["default"].compute_apply(
samples)
return grad_fetch
if isinstance(samples, MultiAgentBatch):
info_out = {}
if self.tf_sess is not None:
builder = TFRunBuilder(self.tf_sess, "compute_apply")
for pid, batch in samples.policy_batches.items():
info_out[pid], _ = (
self.policy_map[pid].build_compute_apply(
builder, batch))
info_out = {k: builder.get(v) for k, v in info_out.items()}
else:
for pid, batch in samples.policy_batches.items():
info_out[pid], _ = (
self.policy_map[pid].compute_apply(batch))
return info_out
else:
grad_fetch, apply_fetch = (
self.policy_map[DEFAULT_POLICY_ID].compute_apply(samples))
return grad_fetch
def save(self):
filters = self.get_filters(flush_after=True)
state = self.policy_map["default"].get_state()
state = {
pid: self.policy_map[pid].get_state()
for pid in self.policy_map
}
return pickle.dumps({"filters": filters, "state": state})
def restore(self, objs):
objs = pickle.loads(objs)
self.sync_filters(objs["filters"])
self.policy_map["default"].set_state(objs["state"])
for pid, state in objs["state"].items():
self.policy_map[pid].set_state(state)
def _validate_and_canonicalize(policy_graph, env):
if isinstance(policy_graph, dict):
for k, v in policy_graph.items():
if not isinstance(k, str):
raise ValueError(
"policy_graph keys must be strs, got {}".format(type(k)))
if not isinstance(v, tuple) or len(v) != 4:
raise ValueError(
"policy_graph values must be tuples of "
"(cls, obs_space, action_space, config), got {}".format(v))
if not issubclass(v[0], PolicyGraph):
raise ValueError(
"policy_graph tuple value 0 must be a rllib.PolicyGraph "
"class, got {}".format(v[0]))
if not isinstance(v[1], gym.Space):
raise ValueError(
"policy_graph tuple value 1 (observation_space) must be a "
"gym.Space, got {}".format(type(v[1])))
if not isinstance(v[2], gym.Space):
raise ValueError(
"policy_graph tuple value 2 (action_space) must be a "
"gym.Space, got {}".format(type(v[2])))
if not isinstance(v[3], dict):
raise ValueError(
"policy_graph tuple value 3 (config) must be a dict, "
"got {}".format(type(v[3])))
return policy_graph
elif not issubclass(policy_graph, PolicyGraph):
raise ValueError("policy_graph must be a rllib.PolicyGraph class")
else:
return {
DEFAULT_POLICY_ID: (
policy_graph, env.observation_space, env.action_space, {})}
def _has_tensorflow_graph(policy_dict):
for policy, _, _, _ in policy_dict.values():
if issubclass(policy, TFPolicyGraph):
return True
return False
+22
View File
@@ -0,0 +1,22 @@
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
class EnvContext(dict):
"""Wraps env configurations to include extra rllib metadata.
These attributes can be used to parameterize environments per process.
For example, one might use `worker_index` to control which data file an
environment reads in on initialization.
RLlib auto-sets these attributes when constructing registered envs.
Attributes:
worker_index (int): When there are multiple workers created, this
uniquely identifies the worker the env is created in.
"""
def __init__(self, env_config, worker_index):
dict.__init__(self, env_config)
self.worker_index = worker_index
+4
View File
@@ -15,6 +15,10 @@ class PolicyGraph(object):
find TFPolicyGraph simpler to implement. TFPolicyGraph also enables RLlib
to apply TensorFlow-specific optimizations such as fusing multiple policy
graphs and multi-GPU support.
Attributes:
observation_space (gym.Space): Observation space of the policy.
action_space (gym.Space): Action space of the policy.
"""
def __init__(self, observation_space, action_space, config):
+57 -20
View File
@@ -10,6 +10,7 @@ import threading
from ray.rllib.optimizers.sample_batch import MultiAgentSampleBatchBuilder, \
MultiAgentBatch
from ray.rllib.utils.async_vector_env import AsyncVectorEnv
from ray.rllib.utils.tf_run_builder import TFRunBuilder
RolloutMetrics = namedtuple(
@@ -30,7 +31,7 @@ class SyncSampler(object):
def __init__(
self, env, policies, policy_mapping_fn, obs_filters,
num_local_steps, horizon=None, pack=False):
num_local_steps, horizon=None, pack=False, tf_sess=None):
self.async_vector_env = AsyncVectorEnv.wrap_async(env)
self.num_local_steps = num_local_steps
self.horizon = horizon
@@ -39,7 +40,8 @@ class SyncSampler(object):
self._obs_filters = obs_filters
self.rollout_provider = _env_runner(
self.async_vector_env, self.policies, self.policy_mapping_fn,
self.num_local_steps, self.horizon, self._obs_filters, pack)
self.num_local_steps, self.horizon, self._obs_filters, pack,
tf_sess)
self.metrics_queue = queue.Queue()
def get_data(self):
@@ -68,7 +70,7 @@ class AsyncSampler(threading.Thread):
def __init__(
self, env, policies, policy_mapping_fn, obs_filters,
num_local_steps, horizon=None, pack=False):
num_local_steps, horizon=None, pack=False, tf_sess=None):
for _, f in obs_filters.items():
assert getattr(f, "is_concurrent", False), \
"Observation Filter must support concurrent updates."
@@ -83,6 +85,7 @@ class AsyncSampler(threading.Thread):
self._obs_filters = obs_filters
self.daemon = True
self.pack = pack
self.tf_sess = tf_sess
def run(self):
try:
@@ -94,7 +97,8 @@ class AsyncSampler(threading.Thread):
def _run(self):
rollout_provider = _env_runner(
self.async_vector_env, self.policies, self.policy_mapping_fn,
self.num_local_steps, self.horizon, self._obs_filters, self.pack)
self.num_local_steps, self.horizon, self._obs_filters, self.pack,
self.tf_sess)
while True:
# The timeout variable exists because apparently, if one worker
# dies, the other workers won't die with it, unless the timeout is
@@ -140,7 +144,7 @@ class AsyncSampler(threading.Thread):
def _env_runner(
async_vector_env, policies, policy_mapping_fn, num_local_steps,
horizon, obs_filters, pack):
horizon, obs_filters, pack, tf_sess=None):
"""This implements the common experience collection logic.
Args:
@@ -156,6 +160,8 @@ def _env_runner(
observations for the policy.
pack (bool): Whether to pack multiple episodes into each batch. This
guarantees batches will be exactly `num_local_steps` in size.
tf_sess (Session|None): Optional tensorflow session to use for batching
TF policy evaluations.
Yields:
rollout (SampleBatch): Object containing state, action, reward,
@@ -192,6 +198,9 @@ def _env_runner(
# Map of policy_id to list of PolicyEvalData
to_eval = defaultdict(list)
# Map of env_id -> agent_id -> action replies
actions_to_send = defaultdict(dict)
# For each environment
for env_id, agent_obs in unfiltered_obs.items():
new_episode = env_id not in active_episodes
@@ -209,11 +218,13 @@ def _env_runner(
dict(episode.agent_rewards))
else:
all_done = False
# At least send an empty dict if not done
actions_to_send[env_id]
# For each agent in the environment
for agent_id, raw_obs in agent_obs.items():
policy_id = episode.policy_for(agent_id)
filtered_obs = obs_filters[policy_id](raw_obs)
filtered_obs = _get_or_raise(obs_filters, policy_id)(raw_obs)
agent_done = bool(all_done or dones[env_id].get(agent_id))
if not agent_done:
to_eval[policy_id].append(
@@ -263,24 +274,40 @@ def _env_runner(
episode = active_episodes[env_id]
for agent_id, raw_obs in resetted_obs.items():
policy_id = episode.policy_for(agent_id)
filtered_obs = obs_filters[policy_id](raw_obs)
filtered_obs = _get_or_raise(
obs_filters, policy_id)(raw_obs)
episode.set_last_observation(agent_id, filtered_obs)
to_eval[policy_id].append(
PolicyEvalData(
env_id, agent_id, filtered_obs,
episode.rnn_state_for(agent_id)))
# Map of env_id -> agent_id -> action
action_dict = defaultdict(dict)
# TODO(ekl) fuse all policy evaluation into one TF run
# Batch eval policy actions if possible
if tf_sess:
builder = TFRunBuilder(tf_sess, "policy_eval")
else:
builder = None
eval_results = {}
rnn_in_cols = {}
for policy_id, eval_data in to_eval.items():
rnn_in_cols = _to_column_format([t.rnn_state for t in eval_data])
actions, rnn_out_cols, pi_info_cols = \
policies[policy_id].compute_actions(
[t.obs for t in eval_data], rnn_in_cols, is_training=True)
rnn_in = _to_column_format([t.rnn_state for t in eval_data])
rnn_in_cols[policy_id] = rnn_in
policy = _get_or_raise(policies, policy_id)
if builder:
eval_results[policy_id] = policy.build_compute_actions(
builder, [t.obs for t in eval_data], rnn_in,
is_training=True)
else:
eval_results[policy_id] = policy.compute_actions(
[t.obs for t in eval_data], rnn_in, is_training=True)
if builder:
eval_results = {k: builder.get(v) for k, v in eval_results.items()}
# Record the policy eval results
for policy_id, eval_data in to_eval.items():
actions, rnn_out_cols, pi_info_cols = eval_results[policy_id]
# Add RNN state info
for f_i, column in enumerate(rnn_in_cols):
for f_i, column in enumerate(rnn_in_cols[policy_id]):
pi_info_cols["state_in_{}".format(f_i)] = column
for f_i, column in enumerate(rnn_out_cols):
pi_info_cols["state_out_{}".format(f_i)] = column
@@ -288,7 +315,7 @@ def _env_runner(
for i, action in enumerate(actions):
env_id = eval_data[i].env_id
agent_id = eval_data[i].agent_id
action_dict[env_id][agent_id] = action
actions_to_send[env_id][agent_id] = action
episode = active_episodes[env_id]
episode.set_rnn_state(agent_id, [c[i] for c in rnn_out_cols])
episode.set_last_pi_info(
@@ -302,7 +329,7 @@ def _env_runner(
# Return computed actions to ready envs. We also send to envs that have
# taken off-policy actions; those envs are free to ignore the action.
async_vector_env.send_actions(dict(action_dict))
async_vector_env.send_actions(dict(actions_to_send))
def _to_column_format(rnn_state_rows):
@@ -311,6 +338,14 @@ def _to_column_format(rnn_state_rows):
[row[i] for row in rnn_state_rows] for i in range(num_cols)]
def _get_or_raise(mapping, policy_id):
if policy_id not in mapping:
raise ValueError(
"Could not find policy for agent: agent policy id `{}` not "
"in policy map keys {}.".format(policy_id, mapping.keys()))
return mapping[policy_id]
class _MultiAgentEpisode(object):
def __init__(self, policies, policy_mapping_fn, batch_builder_factory):
self.batch_builder = batch_builder_factory()
@@ -327,8 +362,10 @@ class _MultiAgentEpisode(object):
def add_agent_rewards(self, reward_dict):
for agent_id, reward in reward_dict.items():
self.agent_rewards[agent_id] += reward
self.total_reward += reward
if reward is not None:
self.agent_rewards[
agent_id, self.policy_for(agent_id)] += reward
self.total_reward += reward
def policy_for(self, agent_id):
if agent_id not in self._agent_to_policy:
+2
View File
@@ -35,6 +35,8 @@ class ServingEnv(threading.Thread):
def __init__(self, action_space, observation_space, max_concurrent=100):
"""Initialize a serving env.
ServingEnv subclasses must call this during their __init__.
Arguments:
action_space (gym.Space): Action space of the env.
observation_space (gym.Space): Observation space of the env.
+61 -35
View File
@@ -6,6 +6,7 @@ import tensorflow as tf
import ray
from ray.rllib.utils.policy_graph import PolicyGraph
from ray.rllib.utils.tf_run_builder import TFRunBuilder
class TFPolicyGraph(PolicyGraph):
@@ -29,11 +30,15 @@ class TFPolicyGraph(PolicyGraph):
"""
def __init__(
self, sess, obs_input, action_sampler, loss, loss_inputs,
self, observation_space, action_space, sess, obs_input,
action_sampler, loss, loss_inputs,
is_training, state_inputs=None, state_outputs=None):
"""Initialize the policy.
Arguments:
observation_space (gym.Space): Observation space of the env.
action_space (gym.Space): Action space of the env.
sess (Session): TensorFlow session to use.
obs_input (Tensor): input placeholder for observations.
action_sampler (Tensor): Tensor for sampling an action.
loss (Tensor): scalar policy loss output tensor.
@@ -46,6 +51,8 @@ class TFPolicyGraph(PolicyGraph):
state_outputs (list): list of initial state values.
"""
self.observation_space = observation_space
self.action_space = action_space
self._sess = sess
self._obs_input = obs_input
self._sampler = action_sampler
@@ -55,7 +62,9 @@ class TFPolicyGraph(PolicyGraph):
self._state_inputs = state_inputs or []
self._state_outputs = state_outputs or []
self._optimizer = self.optimizer()
self._grads_and_vars = self.gradients(self._optimizer)
self._grads_and_vars = [
(g, v) for (g, v) in self.gradients(self._optimizer)
if g is not None]
self._grads = [g for (g, v) in self._grads_and_vars]
self._apply_op = self._optimizer.apply_gradients(self._grads_and_vars)
self._variables = ray.experimental.TensorFlowVariables(
@@ -64,21 +73,27 @@ class TFPolicyGraph(PolicyGraph):
assert len(self._state_inputs) == len(self._state_outputs) == \
len(self.get_initial_state())
def compute_actions(
self, obs_batch, state_batches=None, is_training=False):
def build_compute_actions(
self, builder, obs_batch, state_batches=None, is_training=False):
state_batches = state_batches or []
assert len(self._state_inputs) == len(state_batches), \
(self._state_inputs, state_batches)
feed_dict = self.extra_compute_action_feed_dict()
feed_dict[self._obs_input] = obs_batch
feed_dict[self._is_training] = is_training
for ph, value in zip(self._state_inputs, state_batches):
feed_dict[ph] = value
fetches = self._sess.run(
([self._sampler] + self._state_outputs +
[self.extra_compute_action_fetches()]), feed_dict=feed_dict)
builder.add_feed_dict(self.extra_compute_action_feed_dict())
builder.add_feed_dict({self._obs_input: obs_batch})
builder.add_feed_dict({self._is_training: is_training})
builder.add_feed_dict(dict(zip(self._state_inputs, state_batches)))
fetches = builder.add_fetches(
[self._sampler] + self._state_outputs +
[self.extra_compute_action_fetches()])
return fetches[0], fetches[1:-1], fetches[-1]
def compute_actions(
self, obs_batch, state_batches=None, is_training=False):
builder = TFRunBuilder(self._sess, "compute_actions")
fetches = self.build_compute_actions(
builder, obs_batch, state_batches, is_training)
return builder.get(fetches)
def _get_loss_inputs_dict(self, postprocessed_batch):
feed_dict = {}
for key, ph in self._loss_inputs:
@@ -90,37 +105,48 @@ class TFPolicyGraph(PolicyGraph):
feed_dict[ph] = postprocessed_batch[key]
return feed_dict
def compute_gradients(self, postprocessed_batch):
feed_dict = self.extra_compute_grad_feed_dict()
feed_dict[self._is_training] = True
feed_dict.update(self._get_loss_inputs_dict(postprocessed_batch))
fetches = self._sess.run(
[self._grads, self.extra_compute_grad_fetches()],
feed_dict=feed_dict)
def build_compute_gradients(self, builder, postprocessed_batch):
builder.add_feed_dict(self.extra_compute_grad_feed_dict())
builder.add_feed_dict({self._is_training: True})
builder.add_feed_dict(self._get_loss_inputs_dict(postprocessed_batch))
fetches = builder.add_fetches(
[self._grads, self.extra_compute_grad_fetches()])
return fetches[0], fetches[1]
def apply_gradients(self, gradients):
def compute_gradients(self, postprocessed_batch):
builder = TFRunBuilder(self._sess, "compute_gradients")
fetches = self.build_compute_gradients(builder, postprocessed_batch)
return builder.get(fetches)
def build_apply_gradients(self, builder, gradients):
assert len(gradients) == len(self._grads), (gradients, self._grads)
feed_dict = self.extra_apply_grad_feed_dict()
feed_dict[self._is_training] = True
for ph, value in zip(self._grads, gradients):
feed_dict[ph] = value
fetches = self._sess.run(
[self._apply_op, self.extra_apply_grad_fetches()],
feed_dict=feed_dict)
builder.add_feed_dict(self.extra_apply_grad_feed_dict())
builder.add_feed_dict({self._is_training: True})
builder.add_feed_dict(dict(zip(self._grads, gradients)))
fetches = builder.add_fetches(
[self._apply_op, self.extra_apply_grad_fetches()])
return fetches[1]
def compute_apply(self, postprocessed_batch):
feed_dict = self.extra_compute_grad_feed_dict()
feed_dict.update(self.extra_apply_grad_feed_dict())
feed_dict.update(self._get_loss_inputs_dict(postprocessed_batch))
feed_dict[self._is_training] = True
fetches = self._sess.run(
def apply_gradients(self, gradients):
builder = TFRunBuilder(self._sess, "apply_gradients")
fetches = self.build_apply_gradients(builder, gradients)
return builder.get(fetches)
def build_compute_apply(self, builder, postprocessed_batch):
builder.add_feed_dict(self.extra_compute_grad_feed_dict())
builder.add_feed_dict(self.extra_apply_grad_feed_dict())
builder.add_feed_dict(self._get_loss_inputs_dict(postprocessed_batch))
builder.add_feed_dict({self._is_training: True})
fetches = builder.add_fetches(
[self._apply_op, self.extra_compute_grad_fetches(),
self.extra_apply_grad_fetches()],
feed_dict=feed_dict)
self.extra_apply_grad_fetches()])
return fetches[1], fetches[2]
def compute_apply(self, postprocessed_batch):
builder = TFRunBuilder(self._sess, "compute_apply")
fetches = self.build_compute_apply(builder, postprocessed_batch)
return builder.get(fetches)
def get_weights(self):
return self._variables.get_flat()
+82
View File
@@ -0,0 +1,82 @@
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import time
import tensorflow as tf
from tensorflow.python.client import timeline
class TFRunBuilder(object):
"""Used to incrementally build up a TensorFlow run.
This is particularly useful for batching ops from multiple different
policies in the multi-agent setting.
"""
def __init__(self, session, debug_name):
self.session = session
self.debug_name = debug_name
self.feed_dict = {}
self.fetches = []
self._executed = None
def add_feed_dict(self, feed_dict):
assert not self._executed
for k in feed_dict:
assert k not in self.feed_dict
self.feed_dict.update(feed_dict)
def add_fetches(self, fetches):
assert not self._executed
base_index = len(self.fetches)
self.fetches.extend(fetches)
return list(range(base_index, len(self.fetches)))
def get(self, to_fetch):
if self._executed is None:
try:
self._executed = run_timeline(
self.session, self.fetches, self.debug_name,
self.feed_dict, os.environ.get("TF_TIMELINE_DIR"))
except Exception as e:
print("Error fetching: {}, feed_dict={}".format(
self.fetches, self.feed_dict))
raise e
if isinstance(to_fetch, int):
return self._executed[to_fetch]
elif isinstance(to_fetch, list):
return [self.get(x) for x in to_fetch]
elif isinstance(to_fetch, tuple):
return tuple(self.get(x) for x in to_fetch)
else:
raise ValueError("Unsupported fetch type: {}".format(to_fetch))
_count = 0
def run_timeline(sess, ops, debug_name, feed_dict={}, timeline_dir=None):
if timeline_dir:
run_options = tf.RunOptions(trace_level=tf.RunOptions.FULL_TRACE)
run_metadata = tf.RunMetadata()
start = time.time()
fetches = sess.run(
ops, options=run_options, run_metadata=run_metadata,
feed_dict=feed_dict)
trace = timeline.Timeline(step_stats=run_metadata.step_stats)
global _count
outf = os.path.join(
timeline_dir,
"timeline-{}-{}-{}.json".format(debug_name, os.getpid(), _count))
_count += 1
trace_file = open(outf, "w")
print(
"Wrote tf timeline ({} s) to {}".format(
time.time() - start, os.path.abspath(outf)))
trace_file.write(trace.generate_chrome_trace_format())
else:
fetches = sess.run(ops, feed_dict=feed_dict)
return fetches
+3
View File
@@ -43,6 +43,9 @@ TrainingResult = namedtuple(
# (Optional) The number of episodes total.
"episodes_total",
# (Optional) Per-policy reward information in multi-agent RL.
"policy_reward_mean",
# (Optional) The current training accuracy if applicable.
"mean_accuracy",