[rllib] Q-Mix implementation (Q-Mix, VDN, IQN, and Ape-X variants) (#3548)

This commit is contained in:
Eric Liang
2018-12-18 10:40:01 -08:00
committed by GitHub
parent bc4aa85ea3
commit db0dee573e
35 changed files with 1339 additions and 71 deletions
+2 -2
View File
@@ -33,8 +33,8 @@ def _register_all():
for key in [
"PPO", "ES", "DQN", "APEX", "A3C", "PG", "DDPG", "APEX_DDPG",
"IMPALA", "ARS", "A2C", "__fake", "__sigmoid_fake_data",
"__parameter_tuning"
"IMPALA", "ARS", "A2C", "QMIX", "APEX_QMIX", "__fake",
"__sigmoid_fake_data", "__parameter_tuning"
]:
from ray.rllib.agents.agent import get_agent_class
register_trainable(key, get_agent_class(key))
+15
View File
@@ -10,6 +10,7 @@ import pickle
import six
import tempfile
import tensorflow as tf
import traceback
from types import FunctionType
import ray
@@ -546,6 +547,14 @@ def _register_if_needed(env_object):
def get_agent_class(alg):
"""Returns the class of a known agent given its name."""
try:
return _get_agent_class(alg)
except ImportError:
from ray.rllib.agents.mock import _agent_import_failed
return _agent_import_failed(traceback.format_exc())
def _get_agent_class(alg):
if alg == "DDPG":
from ray.rllib.agents import ddpg
return ddpg.DDPGAgent
@@ -579,6 +588,12 @@ def get_agent_class(alg):
elif alg == "IMPALA":
from ray.rllib.agents import impala
return impala.ImpalaAgent
elif alg == "QMIX":
from ray.rllib.agents import qmix
return qmix.QMixAgent
elif alg == "APEX_QMIX":
from ray.rllib.agents import qmix
return qmix.ApexQMixAgent
elif alg == "script":
from ray.tune import script_runner
return script_runner.ScriptRunner
-6
View File
@@ -8,12 +8,6 @@ from ray.rllib.agents.ddpg.ddpg_policy_graph import DDPGPolicyGraph
from ray.rllib.utils.annotations import override
from ray.rllib.utils.schedules import ConstantSchedule, LinearSchedule
OPTIMIZER_SHARED_CONFIGS = [
"buffer_size", "prioritized_replay", "prioritized_replay_alpha",
"prioritized_replay_beta", "prioritized_replay_eps", "sample_batch_size",
"train_batch_size", "learning_starts"
]
# yapf: disable
# __sphinx_doc_begin__
DEFAULT_CONFIG = with_common_config({
+3 -2
View File
@@ -117,12 +117,13 @@ class DQNAgent(Agent):
_agent_name = "DQN"
_default_config = DEFAULT_CONFIG
_policy_graph = DQNPolicyGraph
_optimizer_shared_configs = OPTIMIZER_SHARED_CONFIGS
@override(Agent)
def _init(self):
# Update effective batch size to include n-step
adjusted_batch_size = max(self.config["sample_batch_size"],
self.config["n_step"])
self.config.get("n_step", 1))
self.config["sample_batch_size"] = adjusted_batch_size
self.exploration0 = self._make_exploration_schedule(-1)
@@ -131,7 +132,7 @@ class DQNAgent(Agent):
for i in range(self.config["num_workers"])
]
for k in OPTIMIZER_SHARED_CONFIGS:
for k in self._optimizer_shared_configs:
if self._agent_name != "DQN" and k in [
"schedule_max_timesteps", "beta_annealing_fraction",
"final_prioritized_replay_beta"
+13
View File
@@ -100,3 +100,16 @@ class _ParameterTuningAgent(_MockAgent):
timesteps_this_iter=self.config["iter_timesteps"],
time_this_iter_s=self.config["iter_time"],
info={})
def _agent_import_failed(trace):
"""Returns dummy agent class for if PyTorch etc. is not installed."""
class _AgentImportFailed(Agent):
_agent_name = "AgentImportFailed"
_default_config = with_common_config({})
def _setup(self, config):
raise ImportError(trace)
return _AgentImportFailed
+1
View File
@@ -0,0 +1 @@
Code in this package is adapted from https://github.com/oxwhirl/pymarl_alpha.
+8
View File
@@ -0,0 +1,8 @@
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from ray.rllib.agents.qmix.qmix import QMixAgent, DEFAULT_CONFIG
from ray.rllib.agents.qmix.apex import ApexQMixAgent
__all__ = ["QMixAgent", "ApexQMixAgent", "DEFAULT_CONFIG"]
+55
View File
@@ -0,0 +1,55 @@
"""Experimental: scalable Ape-X variant of QMIX"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from ray.rllib.agents.qmix.qmix import QMixAgent, DEFAULT_CONFIG as QMIX_CONFIG
from ray.rllib.utils.annotations import override
from ray.rllib.utils import merge_dicts
APEX_QMIX_DEFAULT_CONFIG = merge_dicts(
QMIX_CONFIG, # see also the options in qmix.py, which are also supported
{
"optimizer_class": "AsyncReplayOptimizer",
"optimizer": merge_dicts(
QMIX_CONFIG["optimizer"],
{
"max_weight_sync_delay": 400,
"num_replay_buffer_shards": 4,
"batch_replay": True, # required for RNN. Disables prio.
"debug": False
}),
"num_gpus": 0,
"num_workers": 32,
"buffer_size": 2000000,
"learning_starts": 50000,
"train_batch_size": 512,
"sample_batch_size": 50,
"max_weight_sync_delay": 400,
"target_network_update_freq": 500000,
"timesteps_per_iteration": 25000,
"per_worker_exploration": True,
"min_iter_time_s": 30,
},
)
class ApexQMixAgent(QMixAgent):
"""QMIX variant that uses the Ape-X distributed policy optimizer.
By default, this is configured for a large single node (32 cores). For
running in a large cluster, increase the `num_workers` config var.
"""
_agent_name = "APEX_QMIX"
_default_config = APEX_QMIX_DEFAULT_CONFIG
@override(QMixAgent)
def update_target_if_needed(self):
# Ape-X updates based on num steps trained, not sampled
if self.optimizer.num_steps_trained - self.last_target_update_ts > \
self.config["target_network_update_freq"]:
self.local_evaluator.for_policy(lambda p: p.update_target())
self.last_target_update_ts = self.optimizer.num_steps_trained
self.num_target_updates += 1
+64
View File
@@ -0,0 +1,64 @@
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import torch as th
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
class VDNMixer(nn.Module):
def __init__(self):
super(VDNMixer, self).__init__()
def forward(self, agent_qs, batch):
return th.sum(agent_qs, dim=2, keepdim=True)
class QMixer(nn.Module):
def __init__(self, n_agents, state_shape, mixing_embed_dim):
super(QMixer, self).__init__()
self.n_agents = n_agents
self.embed_dim = mixing_embed_dim
self.state_dim = int(np.prod(state_shape))
self.hyper_w_1 = nn.Linear(self.state_dim,
self.embed_dim * self.n_agents)
self.hyper_w_final = nn.Linear(self.state_dim, self.embed_dim)
# State dependent bias for hidden layer
self.hyper_b_1 = nn.Linear(self.state_dim, self.embed_dim)
# V(s) instead of a bias for the last layers
self.V = nn.Sequential(
nn.Linear(self.state_dim, self.embed_dim), nn.ReLU(),
nn.Linear(self.embed_dim, 1))
def forward(self, agent_qs, states):
"""Forward pass for the mixer.
Arguments:
agent_qs: Tensor of shape [B, T, n_agents, n_actions]
states: Tensor of shape [B, T, state_dim]
"""
bs = agent_qs.size(0)
states = states.reshape(-1, self.state_dim)
agent_qs = agent_qs.view(-1, 1, self.n_agents)
# First layer
w1 = th.abs(self.hyper_w_1(states))
b1 = self.hyper_b_1(states)
w1 = w1.view(-1, self.n_agents, self.embed_dim)
b1 = b1.view(-1, 1, self.embed_dim)
hidden = F.elu(th.bmm(agent_qs, w1) + b1)
# Second layer
w_final = th.abs(self.hyper_w_final(states))
w_final = w_final.view(-1, self.embed_dim, 1)
# State-dependent bias
v = self.V(states).view(-1, 1, 1)
# Compute final output
y = th.bmm(hidden, w_final) + v
# Reshape and return
q_tot = y.view(bs, -1, 1)
return q_tot
+28
View File
@@ -0,0 +1,28 @@
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from torch import nn
import torch.nn.functional as F
# TODO(ekl) we should have common models for pytorch like we do for TF
class RNNModel(nn.Module):
def __init__(self, obs_size, rnn_hidden_dim, n_actions):
nn.Module.__init__(self)
self.rnn_hidden_dim = rnn_hidden_dim
self.n_actions = n_actions
self.fc1 = nn.Linear(obs_size, rnn_hidden_dim)
self.rnn = nn.GRUCell(rnn_hidden_dim, rnn_hidden_dim)
self.fc2 = nn.Linear(rnn_hidden_dim, n_actions)
def init_hidden(self):
# make hidden states on same device as model
return self.fc1.weight.new(1, self.rnn_hidden_dim).zero_()
def forward(self, inputs, hidden_state):
x = F.relu(self.fc1(inputs.float()))
h_in = hidden_state.reshape(-1, self.rnn_hidden_dim)
h = self.rnn(x, h_in)
q = self.fc2(h)
return q, h
+92
View File
@@ -0,0 +1,92 @@
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from ray.rllib.agents.agent import with_common_config
from ray.rllib.agents.dqn.dqn import DQNAgent
from ray.rllib.agents.qmix.qmix_policy_graph import QMixPolicyGraph
# yapf: disable
# __sphinx_doc_begin__
DEFAULT_CONFIG = with_common_config({
# === QMix ===
# Mixing network. Either "qmix", "vdn", or None
"mixer": "qmix",
# Size of the mixing network embedding
"mixing_embed_dim": 32,
# Whether to use Double_Q learning
"double_q": True,
# Optimize over complete episodes by default.
"batch_mode": "complete_episodes",
# === Exploration ===
# Max num timesteps for annealing schedules. Exploration is annealed from
# 1.0 to exploration_fraction over this number of timesteps scaled by
# exploration_fraction
"schedule_max_timesteps": 100000,
# Number of env steps to optimize for before returning
"timesteps_per_iteration": 1000,
# Fraction of entire training period over which the exploration rate is
# annealed
"exploration_fraction": 0.1,
# Final value of random action probability
"exploration_final_eps": 0.02,
# Update the target network every `target_network_update_freq` steps.
"target_network_update_freq": 500,
# === Replay buffer ===
# Size of the replay buffer in steps.
"buffer_size": 10000,
# === Optimization ===
# Learning rate for adam optimizer
"lr": 0.0005,
# RMSProp alpha
"optim_alpha": 0.99,
# RMSProp epsilon
"optim_eps": 0.00001,
# If not None, clip gradients during optimization at this value
"grad_norm_clipping": 10,
# How many steps of the model to sample before learning starts.
"learning_starts": 1000,
# Update the replay buffer with this many samples at once. Note that
# this setting applies per-worker if num_workers > 1.
"sample_batch_size": 4,
# Size of a batched sampled from replay buffer for training. Note that
# if async_updates is set, then each worker returns gradients for a
# batch of this size.
"train_batch_size": 32,
# === Parallelism ===
# Number of workers for collecting samples with. This only makes sense
# to increase if your environment is particularly slow to sample, or if
# you"re using the Async or Ape-X optimizers.
"num_workers": 0,
# Optimizer class to use.
"optimizer_class": "SyncBatchReplayOptimizer",
# Whether to use a distribution of epsilons across workers for exploration.
"per_worker_exploration": False,
# Whether to compute priorities on workers.
"worker_side_prioritization": False,
# Prevent iterations from going lower than this time span
"min_iter_time_s": 1,
# === Model ===
"model": {
"lstm_cell_size": 64,
"max_seq_len": 999999,
},
})
# __sphinx_doc_end__
# yapf: enable
class QMixAgent(DQNAgent):
"""QMix implementation in PyTorch."""
_agent_name = "QMIX"
_default_config = DEFAULT_CONFIG
_policy_graph = QMixPolicyGraph
_optimizer_shared_configs = [
"learning_starts", "buffer_size", "train_batch_size"
]
@@ -0,0 +1,411 @@
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from gym.spaces import Tuple, Discrete, Dict
import logging
import numpy as np
import torch as th
import torch.nn as nn
from torch.optim import RMSprop
from torch.distributions import Categorical
import ray
from ray.rllib.agents.qmix.mixers import VDNMixer, QMixer
from ray.rllib.agents.qmix.model import RNNModel
from ray.rllib.evaluation.policy_graph import PolicyGraph
from ray.rllib.models.action_dist import TupleActions
from ray.rllib.models.pytorch.misc import var_to_np
from ray.rllib.models.lstm import chop_into_sequences
from ray.rllib.models.model import _unpack_obs
from ray.rllib.models.preprocessors import get_preprocessor
from ray.rllib.env.constants import GROUP_REWARDS
from ray.rllib.utils.annotations import override
logger = logging.getLogger(__name__)
class QMixLoss(nn.Module):
def __init__(self,
model,
target_model,
mixer,
target_mixer,
n_agents,
n_actions,
double_q=True,
gamma=0.99):
nn.Module.__init__(self)
self.model = model
self.target_model = target_model
self.mixer = mixer
self.target_mixer = target_mixer
self.n_agents = n_agents
self.n_actions = n_actions
self.double_q = double_q
self.gamma = gamma
def forward(self, rewards, actions, terminated, mask, obs, action_mask):
"""Forward pass of the loss.
Arguments:
rewards: Tensor of shape [B, T-1, n_agents]
actions: Tensor of shape [B, T-1, n_agents]
terminated: Tensor of shape [B, T-1, n_agents]
mask: Tensor of shape [B, T-1, n_agents]
obs: Tensor of shape [B, T, n_agents, obs_size]
action_mask: Tensor of shape [B, T, n_agents, n_actions]
"""
B, T = obs.size(0), obs.size(1)
# Calculate estimated Q-Values
mac_out = []
h = self.model.init_hidden().expand([B, self.n_agents, -1])
for t in range(T):
q, h = _mac(self.model, obs[:, t], h)
mac_out.append(q)
mac_out = th.stack(mac_out, dim=1) # Concat over time
# Pick the Q-Values for the actions taken -> [B * n_agents, T-1]
chosen_action_qvals = th.gather(
mac_out[:, :-1], dim=3, index=actions.unsqueeze(3)).squeeze(3)
# Calculate the Q-Values necessary for the target
target_mac_out = []
target_h = self.target_model.init_hidden().expand(
[B, self.n_agents, -1])
for t in range(T):
target_q, target_h = _mac(self.target_model, obs[:, t], target_h)
target_mac_out.append(target_q)
# We don't need the first timesteps Q-Value estimate for targets
target_mac_out = th.stack(
target_mac_out[1:], dim=1) # Concat across time
# Mask out unavailable actions
target_mac_out[action_mask[:, 1:] == 0] = -9999999
# Max over target Q-Values
if self.double_q:
# Get actions that maximise live Q (for double q-learning)
mac_out[action_mask == 0] = -9999999
cur_max_actions = mac_out[:, 1:].max(dim=3, keepdim=True)[1]
target_max_qvals = th.gather(target_mac_out, 3,
cur_max_actions).squeeze(3)
else:
target_max_qvals = target_mac_out.max(dim=3)[0]
# Mix
if self.mixer is not None:
# TODO(ekl) add support for handling global state? This is just
# treating the stacked agent obs as the state.
chosen_action_qvals = self.mixer(chosen_action_qvals, obs[:, :-1])
target_max_qvals = self.target_mixer(target_max_qvals, obs[:, 1:])
# Calculate 1-step Q-Learning targets
targets = rewards + self.gamma * (1 - terminated) * target_max_qvals
# Td-error
td_error = (chosen_action_qvals - targets.detach())
mask = mask.expand_as(td_error)
# 0-out the targets that came from padded data
masked_td_error = td_error * mask
# Normal L2 loss, take mean over actual data
loss = (masked_td_error**2).sum() / mask.sum()
return loss, mask, masked_td_error, chosen_action_qvals, targets
class QMixPolicyGraph(PolicyGraph):
"""QMix impl. Assumes homogeneous agents for now.
You must use MultiAgentEnv.with_agent_groups() to group agents
together for QMix. This creates the proper Tuple obs/action spaces and
populates the '_group_rewards' info field.
Action masking: to specify an action mask for individual agents, use a
dict space with an action_mask key, e.g. {"obs": ob, "action_mask": mask}.
The mask space must be `Box(0, 1, (n_actions,))`.
"""
def __init__(self, obs_space, action_space, config):
_validate(obs_space, action_space)
config = dict(ray.rllib.agents.qmix.qmix.DEFAULT_CONFIG, **config)
self.config = config
self.observation_space = obs_space
self.action_space = action_space
self.n_agents = len(obs_space.original_space.spaces)
self.n_actions = action_space.spaces[0].n
self.h_size = config["model"]["lstm_cell_size"]
agent_obs_space = obs_space.original_space.spaces[0]
if isinstance(agent_obs_space, Dict):
space_keys = set(agent_obs_space.spaces.keys())
if space_keys != {"obs", "action_mask"}:
raise ValueError(
"Dict obs space for agent must have keyset "
"['obs', 'action_mask'], got {}".format(space_keys))
mask_shape = tuple(agent_obs_space.spaces["action_mask"].shape)
if mask_shape != (self.n_actions, ):
raise ValueError("Action mask shape must be {}, got {}".format(
(self.n_actions, ), mask_shape))
self.has_action_mask = True
self.obs_size = _get_size(agent_obs_space.spaces["obs"])
else:
self.has_action_mask = False
self.obs_size = _get_size(agent_obs_space)
self.model = RNNModel(self.obs_size, self.h_size, self.n_actions)
self.target_model = RNNModel(self.obs_size, self.h_size,
self.n_actions)
# Setup the mixer network.
# The global state is just the stacked agent observations for now.
self.state_shape = [self.obs_size, self.n_agents]
if config["mixer"] is None:
self.mixer = None
self.target_mixer = None
elif config["mixer"] == "qmix":
self.mixer = QMixer(self.n_agents, self.state_shape,
config["mixing_embed_dim"])
self.target_mixer = QMixer(self.n_agents, self.state_shape,
config["mixing_embed_dim"])
elif config["mixer"] == "vdn":
self.mixer = VDNMixer()
self.target_mixer = VDNMixer()
else:
raise ValueError("Unknown mixer type {}".format(config["mixer"]))
self.cur_epsilon = 1.0
self.update_target() # initial sync
# Setup optimizer
self.params = list(self.model.parameters())
self.loss = QMixLoss(self.model, self.target_model, self.mixer,
self.target_mixer, self.n_agents, self.n_actions,
self.config["double_q"], self.config["gamma"])
self.optimiser = RMSprop(
params=self.params,
lr=config["lr"],
alpha=config["optim_alpha"],
eps=config["optim_eps"])
@override(PolicyGraph)
def compute_actions(self,
obs_batch,
state_batches=None,
prev_action_batch=None,
prev_reward_batch=None,
info_batch=None,
episodes=None,
**kwargs):
obs_batch, action_mask = self._unpack_observation(obs_batch)
assert len(state_batches) == self.n_agents, state_batches
state_batches = np.stack(state_batches, axis=1)
# Compute actions
with th.no_grad():
q_values, hiddens = _mac(self.model, th.from_numpy(obs_batch),
th.from_numpy(state_batches))
avail = th.from_numpy(action_mask).float()
masked_q_values = q_values.clone()
masked_q_values[avail == 0.0] = -float("inf")
# epsilon-greedy action selector
random_numbers = th.rand_like(q_values[:, :, 0])
pick_random = (random_numbers < self.cur_epsilon).long()
random_actions = Categorical(avail).sample().long()
actions = (pick_random * random_actions +
(1 - pick_random) * masked_q_values.max(dim=2)[1])
actions = var_to_np(actions)
hiddens = var_to_np(hiddens)
return (TupleActions(list(actions.transpose([1, 0]))),
hiddens.transpose([1, 0, 2]), {})
@override(PolicyGraph)
def compute_apply(self, samples):
obs_batch, action_mask = self._unpack_observation(samples["obs"])
group_rewards = self._get_group_rewards(samples["infos"])
# These will be padded to shape [B * T, ...]
[rew, action_mask, act, dones, obs], initial_states, seq_lens = \
chop_into_sequences(
samples["eps_id"],
samples["agent_index"], [
group_rewards, action_mask, samples["actions"],
samples["dones"], obs_batch
],
[samples["state_in_{}".format(k)]
for k in range(self.n_agents)],
max_seq_len=self.config["model"]["max_seq_len"],
dynamic_max=True,
_extra_padding=1)
# TODO(ekl) adding 1 extra unit of padding here, since otherwise we
# lose the terminating reward and the Q-values will be unanchored!
B, T = len(seq_lens), max(seq_lens) + 1
def to_batches(arr):
new_shape = [B, T] + list(arr.shape[1:])
return th.from_numpy(np.reshape(arr, new_shape))
rewards = to_batches(rew)[:, :-1].float()
actions = to_batches(act)[:, :-1].long()
obs = to_batches(obs).reshape([B, T, self.n_agents,
self.obs_size]).float()
action_mask = to_batches(action_mask)
# TODO(ekl) this treats group termination as individual termination
terminated = to_batches(dones.astype(np.float32)).unsqueeze(2).expand(
B, T, self.n_agents)[:, :-1]
filled = (np.reshape(np.tile(np.arange(T), B), [B, T]) <
np.expand_dims(seq_lens, 1)).astype(np.float32)
mask = th.from_numpy(filled).unsqueeze(2).expand(B, T,
self.n_agents)[:, :-1]
mask[:, 1:] = mask[:, 1:] * (1 - terminated[:, :-1])
# Compute loss
loss_out, mask, masked_td_error, chosen_action_qvals, targets = \
self.loss(rewards, actions, terminated, mask, obs, action_mask)
# Optimise
self.optimiser.zero_grad()
loss_out.backward()
grad_norm = th.nn.utils.clip_grad_norm_(
self.params, self.config["grad_norm_clipping"])
self.optimiser.step()
mask_elems = mask.sum().item()
stats = {
"loss": loss_out.item(),
"grad_norm": grad_norm
if isinstance(grad_norm, float) else grad_norm.item(),
"td_error_abs": masked_td_error.abs().sum().item() / mask_elems,
"q_taken_mean": (chosen_action_qvals * mask).sum().item() /
mask_elems,
"target_mean": (targets * mask).sum().item() / mask_elems,
}
return {"stats": stats}, {}
@override(PolicyGraph)
def get_initial_state(self):
return [
self.model.init_hidden().numpy().squeeze()
for _ in range(self.n_agents)
]
@override(PolicyGraph)
def get_weights(self):
return {"model": self.model.state_dict()}
@override(PolicyGraph)
def set_weights(self, weights):
self.model.load_state_dict(weights["model"])
@override(PolicyGraph)
def get_state(self):
return {
"model": self.model.state_dict(),
"target_model": self.target_model.state_dict(),
"mixer": self.mixer.state_dict() if self.mixer else None,
"target_mixer": self.target_mixer.state_dict()
if self.mixer else None,
"cur_epsilon": self.cur_epsilon,
}
@override(PolicyGraph)
def set_state(self, state):
self.model.load_state_dict(state["model"])
self.target_model.load_state_dict(state["target_model"])
if state["mixer"] is not None:
self.mixer.load_state_dict(state["mixer"])
self.target_mixer.load_state_dict(state["target_mixer"])
self.set_epsilon(state["cur_epsilon"])
self.update_target()
def update_target(self):
self.target_model.load_state_dict(self.model.state_dict())
if self.mixer is not None:
self.target_mixer.load_state_dict(self.mixer.state_dict())
logger.debug("Updated target networks")
def set_epsilon(self, epsilon):
self.cur_epsilon = epsilon
def _get_group_rewards(self, info_batch):
group_rewards = np.array([
info.get(GROUP_REWARDS, [0.0] * self.n_agents)
for info in info_batch
])
return group_rewards
def _unpack_observation(self, obs_batch):
unpacked = _unpack_obs(
np.array(obs_batch),
self.observation_space.original_space,
tensorlib=np)
if self.has_action_mask:
obs = np.concatenate(
[o["obs"] for o in unpacked],
axis=1).reshape([len(obs_batch), self.n_agents, self.obs_size])
action_mask = np.concatenate(
[o["action_mask"] for o in unpacked], axis=1).reshape(
[len(obs_batch), self.n_agents, self.n_actions])
else:
obs = np.concatenate(
unpacked,
axis=1).reshape([len(obs_batch), self.n_agents, self.obs_size])
action_mask = np.ones(
[len(obs_batch), self.n_agents, self.n_actions])
return obs, action_mask
def _validate(obs_space, action_space):
if not hasattr(obs_space, "original_space") or \
not isinstance(obs_space.original_space, Tuple):
raise ValueError("Obs space must be a Tuple, got {}. Use ".format(
obs_space) + "MultiAgentEnv.with_agent_groups() to group related "
"agents for QMix.")
if not isinstance(action_space, Tuple):
raise ValueError(
"Action space must be a Tuple, got {}. ".format(action_space) +
"Use MultiAgentEnv.with_agent_groups() to group related "
"agents for QMix.")
if not isinstance(action_space.spaces[0], Discrete):
raise ValueError(
"QMix requires a discrete action space, got {}".format(
action_space.spaces[0]))
if len({str(x) for x in obs_space.original_space.spaces}) > 1:
raise ValueError(
"Implementation limitation: observations of grouped agents "
"must be homogeneous, got {}".format(
obs_space.original_space.spaces))
if len({str(x) for x in action_space.spaces}) > 1:
raise ValueError(
"Implementation limitation: action space of grouped agents "
"must be homogeneous, got {}".format(action_space.spaces))
def _get_size(obs_space):
return get_preprocessor(obs_space)(obs_space).size
def _mac(model, obs, h):
"""Forward pass of the multi-agent controller.
Arguments:
model: Model that produces q-values for a 1d agent batch
obs: Tensor of shape [B, n_agents, obs_size]
h: Tensor of shape [B, n_agents, h_size]
Returns:
q_vals: Tensor of shape [B, n_agents, n_actions]
h: Tensor of shape [B, n_agents, h_size]
"""
B, n_agents = obs.size(0), obs.size(1)
obs_flat = obs.reshape([B * n_agents, -1])
h_flat = h.reshape([B * n_agents, -1])
q_flat, h_flat = model.forward(obs_flat, h_flat)
return q_flat.reshape([B, n_agents, -1]), h_flat.reshape([B, n_agents, -1])
+3 -3
View File
@@ -304,9 +304,9 @@ class _MultiAgentEnvToAsync(AsyncVectorEnv):
raise ValueError(
"Key set for obs and rewards must be the same: "
"{} vs {}".format(obs.keys(), rewards.keys()))
if set(obs.keys()) != set(infos.keys()):
raise ValueError("Key set for obs and infos must be the same: "
"{} vs {}".format(obs.keys(), infos.keys()))
if set(infos).difference(set(obs)):
raise ValueError("Key set for infos must be a subset of obs: "
"{} vs {}".format(infos.keys(), obs.keys()))
if dones["__all__"]:
self.dones.add(env_id)
self.env_states[env_id].observe(obs, rewards, dones, infos)
+19
View File
@@ -0,0 +1,19 @@
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
# info key for the individual rewards of an agent, for example:
# info: {
# group_1: {
# _group_rewards: [5, -1, 1], # 3 agents in this group
# }
# }
GROUP_REWARDS = "_group_rewards"
# info key for the individual infos of an agent, for example:
# info: {
# group_1: {
# _group_infos: [{"foo": ...}, {}], # 2 agents in this group
# }
# }
GROUP_INFO = "_group_info"
+107
View File
@@ -0,0 +1,107 @@
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from collections import OrderedDict
from ray.rllib.env.constants import GROUP_REWARDS, GROUP_INFO
from ray.rllib.env.multi_agent_env import MultiAgentEnv
# TODO(ekl) we should add some unit tests for this
class _GroupAgentsWrapper(MultiAgentEnv):
"""Wraps a MultiAgentEnv environment with agents grouped as specified.
See multi_agent_env.py for the specification of groups.
This API is experimental.
"""
def __init__(self, env, groups, obs_space=None, act_space=None):
"""Wrap an existing multi-agent env to group agents together.
See MultiAgentEnv.with_agent_groups() for usage info.
Arguments:
env (MultiAgentEnv): env to wrap
groups (dict): Grouping spec as documented in MultiAgentEnv
obs_space (Space): Optional observation space for the grouped
env. Must be a tuple space.
act_space (Space): Optional action space for the grouped env.
Must be a tuple space.
"""
self.env = env
self.groups = groups
self.agent_id_to_group = {}
for group_id, agent_ids in groups.items():
for agent_id in agent_ids:
if agent_id in self.agent_id_to_group:
raise ValueError(
"Agent id {} is in multiple groups".format(
agent_id, groups))
self.agent_id_to_group[agent_id] = group_id
if obs_space is not None:
self.observation_space = obs_space
if act_space is not None:
self.action_space = act_space
def reset(self):
obs = self.env.reset()
return self._group_items(obs)
def step(self, action_dict):
# Ungroup and send actions
action_dict = self._ungroup_items(action_dict)
obs, rewards, dones, infos = self.env.step(action_dict)
# Apply grouping transforms to the env outputs
obs = self._group_items(obs)
rewards = self._group_items(
rewards, agg_fn=lambda gvals: list(gvals.values()))
dones = self._group_items(
dones, agg_fn=lambda gvals: all(gvals.values()))
infos = self._group_items(
infos, agg_fn=lambda gvals: {GROUP_INFO: list(gvals.values())})
# Aggregate rewards, but preserve the original values in infos
for agent_id, rew in rewards.items():
if isinstance(rew, list):
rewards[agent_id] = sum(rew)
if agent_id not in infos:
infos[agent_id] = {}
infos[agent_id][GROUP_REWARDS] = rew
return obs, rewards, dones, infos
def _ungroup_items(self, items):
out = {}
for agent_id, value in items.items():
if agent_id in self.groups:
assert len(value) == len(self.groups[agent_id]), \
(agent_id, value, self.groups)
for a, v in zip(self.groups[agent_id], value):
out[a] = v
else:
out[agent_id] = value
return out
def _group_items(self, items, agg_fn=lambda gvals: list(gvals.values())):
grouped_items = {}
for agent_id, item in items.items():
if agent_id in self.agent_id_to_group:
group_id = self.agent_id_to_group[agent_id]
if group_id in grouped_items:
continue # already added
group_out = OrderedDict()
for a in self.groups[group_id]:
if a in items:
group_out[a] = items[a]
else:
raise ValueError(
"Missing member of group {}: {}: {}".format(
group_id, a, items))
grouped_items[group_id] = agg_fn(group_out)
else:
grouped_items[agent_id] = item
return grouped_items
+50 -4
View File
@@ -30,9 +30,14 @@ class MultiAgentEnv(object):
}
>>> print(dones)
{
"car_0": False,
"car_1": True,
"__all__": False,
"car_0": False, # car_0 is still running
"car_1": True, # car_1 is done
"__all__": False, # the env is not done
}
>>> print(infos)
{
"car_0": {}, # info for car_0
"car_1": {}, # info for car_1
}
"""
@@ -57,6 +62,47 @@ class MultiAgentEnv(object):
episode is just started, the value will be None.
dones (dict): Done values for each ready agent. The special key
"__all__" (required) is used to indicate env termination.
infos (dict): Info values for each ready agent.
infos (dict): Optional info values for each agent id.
"""
raise NotImplementedError
# yapf: disable
# __grouping_doc_begin__
def with_agent_groups(self, groups, obs_space=None, act_space=None):
"""Convenience method for grouping together agents in this env.
An agent group is a list of agent ids that are mapped to a single
logical agent. All agents of the group must act at the same time in the
environment. The grouped agent exposes Tuple action and observation
spaces that are the concatenated action and obs spaces of the
individual agents.
The rewards of all the agents in a group are summed. The individual
agent rewards are available under the "individual_rewards" key of the
group info return.
Agent grouping is required to leverage algorithms such as Q-Mix.
This API is experimental.
Arguments:
groups (dict): Mapping from group id to a list of the agent ids
of group members. If an agent id is not present in any group
value, it will be left ungrouped.
obs_space (Space): Optional observation space for the grouped
env. Must be a tuple space.
act_space (Space): Optional action space for the grouped env.
Must be a tuple space.
Examples:
>>> env = YourMultiAgentEnv(...)
>>> grouped_env = env.with_agent_groups(env, {
... "group1": ["agent1", "agent2", "agent3"],
... "group2": ["agent4", "agent5"],
... })
"""
from ray.rllib.env.group_agents_wrapper import _GroupAgentsWrapper
return _GroupAgentsWrapper(self, groups, obs_space, act_space)
# __grouping_doc_end__
# yapf: enable
+10 -2
View File
@@ -42,7 +42,9 @@ class PolicyGraph(object):
state_batches,
prev_action_batch=None,
prev_reward_batch=None,
episodes=None):
info_batch=None,
episodes=None,
**kwargs):
"""Compute actions for the current policy.
Arguments:
@@ -50,9 +52,11 @@ class PolicyGraph(object):
state_batches (list): list of RNN state input batches, if any
prev_action_batch (np.ndarray): batch of previous action values
prev_reward_batch (np.ndarray): batch of previous rewards
info_batch (info): batch of info objects
episodes (list): MultiAgentEpisode for each obs in obs_batch.
This provides access to all of the internal episode state,
which may be useful for model-based or multiagent algorithms.
kwargs: forward compatibility placeholder
Returns:
actions (np.ndarray): batch of output actions, with shape like
@@ -69,7 +73,9 @@ class PolicyGraph(object):
state,
prev_action_batch=None,
prev_reward_batch=None,
episode=None):
info_batch=None,
episode=None,
**kwargs):
"""Unbatched version of compute_actions.
Arguments:
@@ -77,9 +83,11 @@ class PolicyGraph(object):
state_batches (list): list of RNN state inputs, if any
prev_action_batch (np.ndarray): batch of previous action values
prev_reward_batch (np.ndarray): batch of previous rewards
info_batch (list): batch of info objects
episode (MultiAgentEpisode): this provides access to all of the
internal episode state, which may be useful for model-based or
multi-agent algorithms.
kwargs: forward compatibility placeholder
Returns:
actions (obj): single action
+12 -20
View File
@@ -24,20 +24,13 @@ RolloutMetrics = namedtuple(
"RolloutMetrics",
["episode_length", "episode_reward", "agent_rewards", "custom_metrics"])
PolicyEvalData = namedtuple(
"PolicyEvalData",
["env_id", "agent_id", "obs", "rnn_state", "prev_action", "prev_reward"])
PolicyEvalData = namedtuple("PolicyEvalData", [
"env_id", "agent_id", "obs", "info", "rnn_state", "prev_action",
"prev_reward"
])
class SyncSampler(object):
"""This class interacts with the environment and tells it what to do.
Note that batch_size is only a unit of measure here. Batches can
accumulate and the gradient can be calculated on up to 5 batches.
This class provides data on invocation, rather than on a separate
thread."""
def __init__(self,
env,
policies,
@@ -94,11 +87,6 @@ class SyncSampler(object):
class AsyncSampler(threading.Thread):
"""This class interacts with the environment and tells it what to do.
Note that batch_size is only a unit of measure here. Batches can
accumulate and the gradient can be calculated on up to 5 batches."""
def __init__(self,
env,
policies,
@@ -361,17 +349,18 @@ def _process_observations(async_vector_env, policies, batch_builder_pool,
if not agent_done:
to_eval[policy_id].append(
PolicyEvalData(env_id, agent_id, filtered_obs,
infos[env_id].get(agent_id, {}),
episode.rnn_state_for(agent_id),
episode.last_action_for(agent_id),
rewards[env_id][agent_id] or 0.0))
last_observation = episode.last_observation_for(agent_id)
episode._set_last_observation(agent_id, filtered_obs)
episode._set_last_info(agent_id, infos[env_id][agent_id])
episode._set_last_info(agent_id, infos[env_id].get(agent_id, {}))
# Record transition info if applicable
if last_observation is not None and \
infos[env_id][agent_id].get("training_enabled", True):
if (last_observation is not None and infos[env_id].get(
agent_id, {}).get("training_enabled", True)):
episode.batch_builder.add_values(
agent_id,
policy_id,
@@ -384,7 +373,7 @@ def _process_observations(async_vector_env, policies, batch_builder_pool,
prev_actions=episode.prev_action_for(agent_id),
prev_rewards=episode.prev_reward_for(agent_id),
dones=agent_done,
infos=infos[env_id][agent_id],
infos=infos[env_id].get(agent_id, {}),
new_obs=filtered_obs,
**episode.last_pi_info_for(agent_id))
@@ -435,6 +424,7 @@ def _process_observations(async_vector_env, policies, batch_builder_pool,
to_eval[policy_id].append(
PolicyEvalData(
env_id, agent_id, filtered_obs,
episode.last_info_for(agent_id) or {},
episode.rnn_state_for(agent_id),
np.zeros_like(
_flatten_action(policy.action_space.sample())),
@@ -462,6 +452,7 @@ def _do_policy_eval(tf_sess, to_eval, policies, active_episodes):
policy = _get_or_raise(policies, policy_id)
if builder and (policy.compute_actions.__code__ is
TFPolicyGraph.compute_actions.__code__):
# TODO(ekl): how can we make info batch available to TF code?
pending_fetches[policy_id] = policy._build_compute_actions(
builder, [t.obs for t in eval_data],
rnn_in_cols,
@@ -473,6 +464,7 @@ def _do_policy_eval(tf_sess, to_eval, policies, active_episodes):
rnn_in_cols,
prev_action_batch=[t.prev_action for t in eval_data],
prev_reward_batch=[t.prev_reward for t in eval_data],
info_batch=[t.info for t in eval_data],
episodes=[active_episodes[t.env_id] for t in eval_data])
if builder:
for k, v in pending_fetches.items():
@@ -153,7 +153,9 @@ class TFPolicyGraph(PolicyGraph):
state_batches=None,
prev_action_batch=None,
prev_reward_batch=None,
episodes=None):
info_batch=None,
episodes=None,
**kwargs):
builder = TFRunBuilder(self._sess, "compute_actions")
fetches = self._build_compute_actions(builder, obs_batch,
state_batches, prev_action_batch,
@@ -63,7 +63,9 @@ class TorchPolicyGraph(PolicyGraph):
state_batches=None,
prev_action_batch=None,
prev_reward_batch=None,
episodes=None):
info_batch=None,
episodes=None,
**kwargs):
if state_batches:
raise NotImplementedError("Torch RNN support")
with self.lock:
+117
View File
@@ -0,0 +1,117 @@
"""The two-step game from QMIX: https://arxiv.org/pdf/1803.11485.pdf"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import argparse
from gym.spaces import Tuple, Discrete
import ray
from ray.tune import register_env, run_experiments, grid_search
from ray.rllib.env.multi_agent_env import MultiAgentEnv
parser = argparse.ArgumentParser()
parser.add_argument("--stop", type=int, default=50000)
parser.add_argument("--run", type=str, default="QMIX")
class TwoStepGame(MultiAgentEnv):
action_space = Discrete(2)
# Each agent gets a separate [3] obs space, to ensure that they can
# learn meaningfully different Q values even with a shared Q model.
observation_space = Discrete(6)
def __init__(self, env_config):
self.state = None
def reset(self):
self.state = 0
return {"agent_1": self.state, "agent_2": self.state + 3}
def step(self, action_dict):
if self.state == 0:
action = action_dict["agent_1"]
assert action in [0, 1], action
if action == 0:
self.state = 1
else:
self.state = 2
global_rew = 0
done = False
elif self.state == 1:
global_rew = 7
done = True
else:
if action_dict["agent_1"] == 0 and action_dict["agent_2"] == 0:
global_rew = 0
elif action_dict["agent_1"] == 1 and action_dict["agent_2"] == 1:
global_rew = 8
else:
global_rew = 1
done = True
rewards = {"agent_1": global_rew / 2.0, "agent_2": global_rew / 2.0}
obs = {"agent_1": self.state, "agent_2": self.state + 3}
dones = {"__all__": done}
infos = {}
return obs, rewards, dones, infos
if __name__ == "__main__":
args = parser.parse_args()
grouping = {
"group_1": ["agent_1", "agent_2"],
}
obs_space = Tuple([
TwoStepGame.observation_space,
TwoStepGame.observation_space,
])
act_space = Tuple([
TwoStepGame.action_space,
TwoStepGame.action_space,
])
register_env(
"grouped_twostep",
lambda config: TwoStepGame(config).with_agent_groups(
grouping, obs_space=obs_space, act_space=act_space))
if args.run == "QMIX":
config = {
"sample_batch_size": 4,
"train_batch_size": 32,
"exploration_final_eps": 0.0,
"num_workers": 0,
"mixer": grid_search([None, "qmix", "vdn"]),
}
elif args.run == "APEX_QMIX":
config = {
"num_gpus": 0,
"num_workers": 2,
"optimizer": {
"num_replay_buffer_shards": 1,
},
"min_iter_time_s": 3,
"buffer_size": 1000,
"learning_starts": 1000,
"train_batch_size": 128,
"sample_batch_size": 32,
"target_network_update_freq": 500,
"timesteps_per_iteration": 1000,
}
else:
config = {}
ray.init()
run_experiments({
"two_step": {
"run": args.run,
"env": "grouped_twostep",
"stop": {
"timesteps_total": args.stop,
},
"config": config,
},
})
+4 -2
View File
@@ -123,7 +123,8 @@ def chop_into_sequences(episode_ids,
feature_columns,
state_columns,
max_seq_len,
dynamic_max=True):
dynamic_max=True,
_extra_padding=0):
"""Truncate and pad experiences into fixed-length sequences.
Arguments:
@@ -136,6 +137,7 @@ def chop_into_sequences(episode_ids,
dynamic_max (bool): Whether to dynamically shrink the max seq len.
For example, if max len is 20 and the actual max seq len in the
data is 7, it will be shrunk to 7.
_extra_padding (int): Add extra padding to the end of sequences.
Returns:
f_pad (list): Padded feature columns. These will be of shape
@@ -177,7 +179,7 @@ def chop_into_sequences(episode_ids,
# Dynamically shrink max len as needed to optimize memory usage
if dynamic_max:
max_seq_len = max(seq_lens)
max_seq_len = max(seq_lens) + _extra_padding
feature_sequences = []
for f in feature_columns:
+15 -3
View File
@@ -168,7 +168,15 @@ def _restore_original_dimensions(input_dict, obs_space):
return input_dict
def _unpack_obs(obs, space):
def _unpack_obs(obs, space, tensorlib=tf):
"""Unpack a flattened Dict or Tuple observation array/tensor.
Arguments:
obs: The flattened observation tensor
space: The original space prior to flattening
tensorlib: The library used to unflatten (reshape) the array/tensor
"""
if (isinstance(space, gym.spaces.Dict)
or isinstance(space, gym.spaces.Tuple)):
prep = get_preprocessor(space)(space)
@@ -186,14 +194,18 @@ def _unpack_obs(obs, space):
offset += p.size
u.append(
_unpack_obs(
tf.reshape(obs_slice, [-1] + list(p.shape)), v))
tensorlib.reshape(obs_slice, [-1] + list(p.shape)),
v,
tensorlib=tensorlib))
else:
u = OrderedDict()
for p, (k, v) in zip(prep.preprocessors, space.spaces.items()):
obs_slice = obs[:, offset:offset + p.size]
offset += p.size
u[k] = _unpack_obs(
tf.reshape(obs_slice, [-1] + list(p.shape)), v)
tensorlib.reshape(obs_slice, [-1] + list(p.shape)),
v,
tensorlib=tensorlib)
return u
else:
return obs
+10 -3
View File
@@ -5,10 +5,17 @@ from ray.rllib.optimizers.async_gradients_optimizer import \
AsyncGradientsOptimizer
from ray.rllib.optimizers.sync_samples_optimizer import SyncSamplesOptimizer
from ray.rllib.optimizers.sync_replay_optimizer import SyncReplayOptimizer
from ray.rllib.optimizers.sync_batch_replay_optimizer import \
SyncBatchReplayOptimizer
from ray.rllib.optimizers.multi_gpu_optimizer import LocalMultiGPUOptimizer
__all__ = [
"PolicyOptimizer", "AsyncReplayOptimizer", "AsyncSamplesOptimizer",
"AsyncGradientsOptimizer", "SyncSamplesOptimizer", "SyncReplayOptimizer",
"LocalMultiGPUOptimizer"
"PolicyOptimizer",
"AsyncReplayOptimizer",
"AsyncSamplesOptimizer",
"AsyncGradientsOptimizer",
"SyncSamplesOptimizer",
"SyncReplayOptimizer",
"LocalMultiGPUOptimizer",
"SyncBatchReplayOptimizer",
]
@@ -36,6 +36,11 @@ class AsyncReplayOptimizer(PolicyOptimizer):
This class coordinates the data transfers between the learner thread,
remote evaluators (Ape-X actors), and replay buffer actors.
This has two modes of operation:
- normal replay: replays independent samples.
- batch replay: simplified mode where entire sample batches are
replayed. This supports RNNs, but not prioritization.
This optimizer requires that policy evaluators return an additional
"td_error" array in the info return of compute_gradients(). This error
term will be used for sample prioritization."""
@@ -52,9 +57,11 @@ class AsyncReplayOptimizer(PolicyOptimizer):
sample_batch_size=50,
num_replay_buffer_shards=1,
max_weight_sync_delay=400,
debug=False):
debug=False,
batch_replay=False):
self.debug = debug
self.batch_replay = batch_replay
self.replay_starts = learning_starts
self.prioritized_replay_beta = prioritized_replay_beta
self.prioritized_replay_eps = prioritized_replay_eps
@@ -63,7 +70,11 @@ class AsyncReplayOptimizer(PolicyOptimizer):
self.learner = LearnerThread(self.local_evaluator)
self.learner.start()
self.replay_actors = create_colocated(ReplayActor, [
if self.batch_replay:
replay_cls = BatchReplayActor
else:
replay_cls = ReplayActor
self.replay_actors = create_colocated(replay_cls, [
num_replay_buffer_shards,
learning_starts,
buffer_size,
@@ -101,6 +112,7 @@ class AsyncReplayOptimizer(PolicyOptimizer):
@override(PolicyOptimizer)
def step(self):
assert self.learner.is_alive()
assert len(self.remote_evaluators) > 0
start = time.time()
sample_timesteps, train_timesteps = self._step()
@@ -299,6 +311,54 @@ class ReplayActor(object):
return stat
@ray.remote(num_cpus=0)
class BatchReplayActor(object):
"""The batch replay version of the replay actor.
This allows for RNN models, but ignores prioritization params.
"""
def __init__(self, num_shards, learning_starts, buffer_size,
train_batch_size, prioritized_replay_alpha,
prioritized_replay_beta, prioritized_replay_eps):
self.replay_starts = learning_starts // num_shards
self.buffer_size = buffer_size // num_shards
self.train_batch_size = train_batch_size
self.buffer = []
# Metrics
self.num_added = 0
self.cur_size = 0
def get_host(self):
return os.uname()[1]
def add_batch(self, batch):
# Handle everything as if multiagent
if isinstance(batch, SampleBatch):
batch = MultiAgentBatch({DEFAULT_POLICY_ID: batch}, batch.count)
self.buffer.append(batch)
self.cur_size += batch.count
self.num_added += batch.count
while self.cur_size > self.buffer_size:
self.cur_size -= self.buffer.pop(0).count
def replay(self):
if self.num_added < self.replay_starts:
return None
return random.choice(self.buffer)
def update_priorities(self, prio_dict):
pass
def stats(self, debug=False):
stat = {
"cur_size": self.cur_size,
"num_added": self.num_added,
}
return stat
class LearnerThread(threading.Thread):
"""Background thread that updates the local model from replay data.
@@ -334,8 +394,8 @@ class LearnerThread(threading.Thread):
grad_out = self.local_evaluator.compute_apply(replay)
for pid, info in grad_out.items():
prio_dict[pid] = (
replay.policy_batches[pid]["batch_indexes"],
info["td_error"])
replay.policy_batches[pid].data.get("batch_indexes"),
info.get("td_error"))
if "stats" in info:
self.stats[pid] = info["stats"]
self.outqueue.put((ra, prio_dict, replay.count))
@@ -0,0 +1,101 @@
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import random
import ray
from ray.rllib.optimizers.policy_optimizer import PolicyOptimizer
from ray.rllib.evaluation.sample_batch import SampleBatch, DEFAULT_POLICY_ID, \
MultiAgentBatch
from ray.rllib.utils.annotations import override
from ray.rllib.utils.timer import TimerStat
class SyncBatchReplayOptimizer(PolicyOptimizer):
"""Variant of the sync replay optimizer that replays entire batches.
This enables RNN support. Does not currently support prioritization."""
@override(PolicyOptimizer)
def _init(self,
learning_starts=1000,
buffer_size=10000,
train_batch_size=32):
self.replay_starts = learning_starts
self.max_buffer_size = buffer_size
self.train_batch_size = train_batch_size
assert self.max_buffer_size >= self.replay_starts
# List of buffered sample batches
self.replay_buffer = []
self.buffer_size = 0
# Stats
self.update_weights_timer = TimerStat()
self.sample_timer = TimerStat()
self.grad_timer = TimerStat()
self.learner_stats = {}
@override(PolicyOptimizer)
def step(self):
with self.update_weights_timer:
if self.remote_evaluators:
weights = ray.put(self.local_evaluator.get_weights())
for e in self.remote_evaluators:
e.set_weights.remote(weights)
with self.sample_timer:
if self.remote_evaluators:
batches = ray.get(
[e.sample.remote() for e in self.remote_evaluators])
else:
batches = [self.local_evaluator.sample()]
# Handle everything as if multiagent
tmp = []
for batch in batches:
if isinstance(batch, SampleBatch):
batch = MultiAgentBatch({
DEFAULT_POLICY_ID: batch
}, batch.count)
tmp.append(batch)
batches = tmp
for batch in batches:
self.replay_buffer.append(batch)
self.num_steps_sampled += batch.count
self.buffer_size += batch.count
while self.buffer_size > self.max_buffer_size:
evicted = self.replay_buffer.pop(0)
self.buffer_size -= evicted.count
if self.num_steps_sampled >= self.replay_starts:
self._optimize()
@override(PolicyOptimizer)
def stats(self):
return dict(
PolicyOptimizer.stats(self), **{
"sample_time_ms": round(1000 * self.sample_timer.mean, 3),
"grad_time_ms": round(1000 * self.grad_timer.mean, 3),
"update_time_ms": round(1000 * self.update_weights_timer.mean,
3),
"opt_peak_throughput": round(self.grad_timer.mean_throughput,
3),
"opt_samples": round(self.grad_timer.mean_units_processed, 3),
"learner": self.learner_stats,
})
def _optimize(self):
samples = [random.choice(self.replay_buffer)]
while sum(s.count for s in samples) < self.train_batch_size:
samples.append(random.choice(self.replay_buffer))
samples = SampleBatch.concat_samples(samples)
with self.grad_timer:
info_dict = self.local_evaluator.compute_apply(samples)
for policy_id, info in info_dict.items():
if "stats" in info:
self.learner_stats[policy_id] = info["stats"]
self.grad_timer.push_units_processed(samples.count)
self.num_steps_trained += samples.count
@@ -13,7 +13,6 @@ from ray.rllib.evaluation.sample_batch import SampleBatch, DEFAULT_POLICY_ID, \
MultiAgentBatch
from ray.rllib.utils.annotations import override
from ray.rllib.utils.compression import pack_if_needed
from ray.rllib.utils.filter import RunningStat
from ray.rllib.utils.timer import TimerStat
from ray.rllib.utils.schedules import LinearSchedule
@@ -54,7 +53,6 @@ class SyncReplayOptimizer(PolicyOptimizer):
self.sample_timer = TimerStat()
self.replay_timer = TimerStat()
self.grad_timer = TimerStat()
self.throughput = RunningStat()
self.learner_stats = {}
# Set up replay buffer
@@ -159,13 +157,13 @@ class SyncReplayOptimizer(PolicyOptimizer):
dones) = replay_buffer.sample(self.train_batch_size)
weights = np.ones_like(rewards)
batch_indexes = -np.ones_like(rewards)
samples[policy_id] = SampleBatch({
"obs": obses_t,
"actions": actions,
"rewards": rewards,
"new_obs": obses_tp1,
"dones": dones,
"weights": weights,
"batch_indexes": batch_indexes
})
samples[policy_id] = SampleBatch({
"obs": obses_t,
"actions": actions,
"rewards": rewards,
"new_obs": obses_tp1,
"dones": dones,
"weights": weights,
"batch_indexes": batch_indexes
})
return MultiAgentBatch(samples, self.train_batch_size)
@@ -0,0 +1,68 @@
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from gym.spaces import Tuple, Discrete, Dict, Box
import ray
from ray.tune import register_env
from ray.rllib.env.multi_agent_env import MultiAgentEnv
from ray.rllib.agents.qmix import QMixAgent
class AvailActionsTestEnv(MultiAgentEnv):
action_space = Discrete(10)
observation_space = Dict({
"obs": Discrete(3),
"action_mask": Box(0, 1, (10, )),
})
def __init__(self, env_config):
self.state = None
self.avail = env_config["avail_action"]
self.action_mask = [0] * 10
self.action_mask[env_config["avail_action"]] = 1
def reset(self):
self.state = 0
return {
"agent_1": {
"obs": self.state,
"action_mask": self.action_mask
}
}
def step(self, action_dict):
if self.state > 0:
assert action_dict["agent_1"] == self.avail, \
"Failed to obey available actions mask!"
self.state += 1
rewards = {"agent_1": 1}
obs = {"agent_1": {"obs": 0, "action_mask": self.action_mask}}
dones = {"__all__": self.state > 20}
return obs, rewards, dones, {}
if __name__ == "__main__":
grouping = {
"group_1": ["agent_1"], # trivial grouping for testing
}
obs_space = Tuple([AvailActionsTestEnv.observation_space])
act_space = Tuple([AvailActionsTestEnv.action_space])
register_env(
"action_mask_test",
lambda config: AvailActionsTestEnv(config).with_agent_groups(
grouping, obs_space=obs_space, act_space=act_space))
ray.init()
agent = QMixAgent(
env="action_mask_test",
config={
"num_envs_per_worker": 5, # test with vectorization on
"env_config": {
"avail_action": 3,
},
})
for _ in range(5):
agent.train() # OK if it doesn't trip the action assertion error
assert agent.train()["episode_reward_mean"] == 21.0
@@ -339,7 +339,8 @@ class TestMultiAgentEnv(unittest.TestCase):
state_batches,
prev_action_batch=None,
prev_reward_batch=None,
episodes=None):
episodes=None,
**kwargs):
return [0] * len(obs_batch), [[h] * len(obs_batch)], {}
def get_initial_state(self):
@@ -363,7 +364,8 @@ class TestMultiAgentEnv(unittest.TestCase):
state_batches,
prev_action_batch=None,
prev_reward_batch=None,
episodes=None):
episodes=None,
**kwargs):
# Pretend we did a model-based rollout and want to return
# the extra trajectory.
builder = episodes[0].new_batch_builder()
@@ -25,7 +25,8 @@ class MockPolicyGraph(PolicyGraph):
state_batches,
prev_action_batch=None,
prev_reward_batch=None,
episodes=None):
episodes=None,
**kwargs):
return [0] * len(obs_batch), [], {}
def postprocess_trajectory(self,
@@ -42,7 +43,8 @@ class BadPolicyGraph(PolicyGraph):
state_batches,
prev_action_batch=None,
prev_reward_batch=None,
episodes=None):
episodes=None,
**kwargs):
raise Exception("intentional error")
def postprocess_trajectory(self,