[rllib] Q-Mix implementation (Q-Mix, VDN, IQN, and Ape-X variants) (#3548)

This commit is contained in:
Eric Liang
2018-12-18 10:40:01 -08:00
committed by GitHub
parent bc4aa85ea3
commit db0dee573e
35 changed files with 1339 additions and 71 deletions
+15
View File
@@ -10,6 +10,7 @@ import pickle
import six
import tempfile
import tensorflow as tf
import traceback
from types import FunctionType
import ray
@@ -546,6 +547,14 @@ def _register_if_needed(env_object):
def get_agent_class(alg):
"""Returns the class of a known agent given its name."""
try:
return _get_agent_class(alg)
except ImportError:
from ray.rllib.agents.mock import _agent_import_failed
return _agent_import_failed(traceback.format_exc())
def _get_agent_class(alg):
if alg == "DDPG":
from ray.rllib.agents import ddpg
return ddpg.DDPGAgent
@@ -579,6 +588,12 @@ def get_agent_class(alg):
elif alg == "IMPALA":
from ray.rllib.agents import impala
return impala.ImpalaAgent
elif alg == "QMIX":
from ray.rllib.agents import qmix
return qmix.QMixAgent
elif alg == "APEX_QMIX":
from ray.rllib.agents import qmix
return qmix.ApexQMixAgent
elif alg == "script":
from ray.tune import script_runner
return script_runner.ScriptRunner
-6
View File
@@ -8,12 +8,6 @@ from ray.rllib.agents.ddpg.ddpg_policy_graph import DDPGPolicyGraph
from ray.rllib.utils.annotations import override
from ray.rllib.utils.schedules import ConstantSchedule, LinearSchedule
OPTIMIZER_SHARED_CONFIGS = [
"buffer_size", "prioritized_replay", "prioritized_replay_alpha",
"prioritized_replay_beta", "prioritized_replay_eps", "sample_batch_size",
"train_batch_size", "learning_starts"
]
# yapf: disable
# __sphinx_doc_begin__
DEFAULT_CONFIG = with_common_config({
+3 -2
View File
@@ -117,12 +117,13 @@ class DQNAgent(Agent):
_agent_name = "DQN"
_default_config = DEFAULT_CONFIG
_policy_graph = DQNPolicyGraph
_optimizer_shared_configs = OPTIMIZER_SHARED_CONFIGS
@override(Agent)
def _init(self):
# Update effective batch size to include n-step
adjusted_batch_size = max(self.config["sample_batch_size"],
self.config["n_step"])
self.config.get("n_step", 1))
self.config["sample_batch_size"] = adjusted_batch_size
self.exploration0 = self._make_exploration_schedule(-1)
@@ -131,7 +132,7 @@ class DQNAgent(Agent):
for i in range(self.config["num_workers"])
]
for k in OPTIMIZER_SHARED_CONFIGS:
for k in self._optimizer_shared_configs:
if self._agent_name != "DQN" and k in [
"schedule_max_timesteps", "beta_annealing_fraction",
"final_prioritized_replay_beta"
+13
View File
@@ -100,3 +100,16 @@ class _ParameterTuningAgent(_MockAgent):
timesteps_this_iter=self.config["iter_timesteps"],
time_this_iter_s=self.config["iter_time"],
info={})
def _agent_import_failed(trace):
"""Returns dummy agent class for if PyTorch etc. is not installed."""
class _AgentImportFailed(Agent):
_agent_name = "AgentImportFailed"
_default_config = with_common_config({})
def _setup(self, config):
raise ImportError(trace)
return _AgentImportFailed
+1
View File
@@ -0,0 +1 @@
Code in this package is adapted from https://github.com/oxwhirl/pymarl_alpha.
+8
View File
@@ -0,0 +1,8 @@
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from ray.rllib.agents.qmix.qmix import QMixAgent, DEFAULT_CONFIG
from ray.rllib.agents.qmix.apex import ApexQMixAgent
__all__ = ["QMixAgent", "ApexQMixAgent", "DEFAULT_CONFIG"]
+55
View File
@@ -0,0 +1,55 @@
"""Experimental: scalable Ape-X variant of QMIX"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from ray.rllib.agents.qmix.qmix import QMixAgent, DEFAULT_CONFIG as QMIX_CONFIG
from ray.rllib.utils.annotations import override
from ray.rllib.utils import merge_dicts
APEX_QMIX_DEFAULT_CONFIG = merge_dicts(
QMIX_CONFIG, # see also the options in qmix.py, which are also supported
{
"optimizer_class": "AsyncReplayOptimizer",
"optimizer": merge_dicts(
QMIX_CONFIG["optimizer"],
{
"max_weight_sync_delay": 400,
"num_replay_buffer_shards": 4,
"batch_replay": True, # required for RNN. Disables prio.
"debug": False
}),
"num_gpus": 0,
"num_workers": 32,
"buffer_size": 2000000,
"learning_starts": 50000,
"train_batch_size": 512,
"sample_batch_size": 50,
"max_weight_sync_delay": 400,
"target_network_update_freq": 500000,
"timesteps_per_iteration": 25000,
"per_worker_exploration": True,
"min_iter_time_s": 30,
},
)
class ApexQMixAgent(QMixAgent):
"""QMIX variant that uses the Ape-X distributed policy optimizer.
By default, this is configured for a large single node (32 cores). For
running in a large cluster, increase the `num_workers` config var.
"""
_agent_name = "APEX_QMIX"
_default_config = APEX_QMIX_DEFAULT_CONFIG
@override(QMixAgent)
def update_target_if_needed(self):
# Ape-X updates based on num steps trained, not sampled
if self.optimizer.num_steps_trained - self.last_target_update_ts > \
self.config["target_network_update_freq"]:
self.local_evaluator.for_policy(lambda p: p.update_target())
self.last_target_update_ts = self.optimizer.num_steps_trained
self.num_target_updates += 1
+64
View File
@@ -0,0 +1,64 @@
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import torch as th
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
class VDNMixer(nn.Module):
def __init__(self):
super(VDNMixer, self).__init__()
def forward(self, agent_qs, batch):
return th.sum(agent_qs, dim=2, keepdim=True)
class QMixer(nn.Module):
def __init__(self, n_agents, state_shape, mixing_embed_dim):
super(QMixer, self).__init__()
self.n_agents = n_agents
self.embed_dim = mixing_embed_dim
self.state_dim = int(np.prod(state_shape))
self.hyper_w_1 = nn.Linear(self.state_dim,
self.embed_dim * self.n_agents)
self.hyper_w_final = nn.Linear(self.state_dim, self.embed_dim)
# State dependent bias for hidden layer
self.hyper_b_1 = nn.Linear(self.state_dim, self.embed_dim)
# V(s) instead of a bias for the last layers
self.V = nn.Sequential(
nn.Linear(self.state_dim, self.embed_dim), nn.ReLU(),
nn.Linear(self.embed_dim, 1))
def forward(self, agent_qs, states):
"""Forward pass for the mixer.
Arguments:
agent_qs: Tensor of shape [B, T, n_agents, n_actions]
states: Tensor of shape [B, T, state_dim]
"""
bs = agent_qs.size(0)
states = states.reshape(-1, self.state_dim)
agent_qs = agent_qs.view(-1, 1, self.n_agents)
# First layer
w1 = th.abs(self.hyper_w_1(states))
b1 = self.hyper_b_1(states)
w1 = w1.view(-1, self.n_agents, self.embed_dim)
b1 = b1.view(-1, 1, self.embed_dim)
hidden = F.elu(th.bmm(agent_qs, w1) + b1)
# Second layer
w_final = th.abs(self.hyper_w_final(states))
w_final = w_final.view(-1, self.embed_dim, 1)
# State-dependent bias
v = self.V(states).view(-1, 1, 1)
# Compute final output
y = th.bmm(hidden, w_final) + v
# Reshape and return
q_tot = y.view(bs, -1, 1)
return q_tot
+28
View File
@@ -0,0 +1,28 @@
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from torch import nn
import torch.nn.functional as F
# TODO(ekl) we should have common models for pytorch like we do for TF
class RNNModel(nn.Module):
def __init__(self, obs_size, rnn_hidden_dim, n_actions):
nn.Module.__init__(self)
self.rnn_hidden_dim = rnn_hidden_dim
self.n_actions = n_actions
self.fc1 = nn.Linear(obs_size, rnn_hidden_dim)
self.rnn = nn.GRUCell(rnn_hidden_dim, rnn_hidden_dim)
self.fc2 = nn.Linear(rnn_hidden_dim, n_actions)
def init_hidden(self):
# make hidden states on same device as model
return self.fc1.weight.new(1, self.rnn_hidden_dim).zero_()
def forward(self, inputs, hidden_state):
x = F.relu(self.fc1(inputs.float()))
h_in = hidden_state.reshape(-1, self.rnn_hidden_dim)
h = self.rnn(x, h_in)
q = self.fc2(h)
return q, h
+92
View File
@@ -0,0 +1,92 @@
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from ray.rllib.agents.agent import with_common_config
from ray.rllib.agents.dqn.dqn import DQNAgent
from ray.rllib.agents.qmix.qmix_policy_graph import QMixPolicyGraph
# yapf: disable
# __sphinx_doc_begin__
DEFAULT_CONFIG = with_common_config({
# === QMix ===
# Mixing network. Either "qmix", "vdn", or None
"mixer": "qmix",
# Size of the mixing network embedding
"mixing_embed_dim": 32,
# Whether to use Double_Q learning
"double_q": True,
# Optimize over complete episodes by default.
"batch_mode": "complete_episodes",
# === Exploration ===
# Max num timesteps for annealing schedules. Exploration is annealed from
# 1.0 to exploration_fraction over this number of timesteps scaled by
# exploration_fraction
"schedule_max_timesteps": 100000,
# Number of env steps to optimize for before returning
"timesteps_per_iteration": 1000,
# Fraction of entire training period over which the exploration rate is
# annealed
"exploration_fraction": 0.1,
# Final value of random action probability
"exploration_final_eps": 0.02,
# Update the target network every `target_network_update_freq` steps.
"target_network_update_freq": 500,
# === Replay buffer ===
# Size of the replay buffer in steps.
"buffer_size": 10000,
# === Optimization ===
# Learning rate for adam optimizer
"lr": 0.0005,
# RMSProp alpha
"optim_alpha": 0.99,
# RMSProp epsilon
"optim_eps": 0.00001,
# If not None, clip gradients during optimization at this value
"grad_norm_clipping": 10,
# How many steps of the model to sample before learning starts.
"learning_starts": 1000,
# Update the replay buffer with this many samples at once. Note that
# this setting applies per-worker if num_workers > 1.
"sample_batch_size": 4,
# Size of a batched sampled from replay buffer for training. Note that
# if async_updates is set, then each worker returns gradients for a
# batch of this size.
"train_batch_size": 32,
# === Parallelism ===
# Number of workers for collecting samples with. This only makes sense
# to increase if your environment is particularly slow to sample, or if
# you"re using the Async or Ape-X optimizers.
"num_workers": 0,
# Optimizer class to use.
"optimizer_class": "SyncBatchReplayOptimizer",
# Whether to use a distribution of epsilons across workers for exploration.
"per_worker_exploration": False,
# Whether to compute priorities on workers.
"worker_side_prioritization": False,
# Prevent iterations from going lower than this time span
"min_iter_time_s": 1,
# === Model ===
"model": {
"lstm_cell_size": 64,
"max_seq_len": 999999,
},
})
# __sphinx_doc_end__
# yapf: enable
class QMixAgent(DQNAgent):
"""QMix implementation in PyTorch."""
_agent_name = "QMIX"
_default_config = DEFAULT_CONFIG
_policy_graph = QMixPolicyGraph
_optimizer_shared_configs = [
"learning_starts", "buffer_size", "train_batch_size"
]
@@ -0,0 +1,411 @@
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from gym.spaces import Tuple, Discrete, Dict
import logging
import numpy as np
import torch as th
import torch.nn as nn
from torch.optim import RMSprop
from torch.distributions import Categorical
import ray
from ray.rllib.agents.qmix.mixers import VDNMixer, QMixer
from ray.rllib.agents.qmix.model import RNNModel
from ray.rllib.evaluation.policy_graph import PolicyGraph
from ray.rllib.models.action_dist import TupleActions
from ray.rllib.models.pytorch.misc import var_to_np
from ray.rllib.models.lstm import chop_into_sequences
from ray.rllib.models.model import _unpack_obs
from ray.rllib.models.preprocessors import get_preprocessor
from ray.rllib.env.constants import GROUP_REWARDS
from ray.rllib.utils.annotations import override
logger = logging.getLogger(__name__)
class QMixLoss(nn.Module):
def __init__(self,
model,
target_model,
mixer,
target_mixer,
n_agents,
n_actions,
double_q=True,
gamma=0.99):
nn.Module.__init__(self)
self.model = model
self.target_model = target_model
self.mixer = mixer
self.target_mixer = target_mixer
self.n_agents = n_agents
self.n_actions = n_actions
self.double_q = double_q
self.gamma = gamma
def forward(self, rewards, actions, terminated, mask, obs, action_mask):
"""Forward pass of the loss.
Arguments:
rewards: Tensor of shape [B, T-1, n_agents]
actions: Tensor of shape [B, T-1, n_agents]
terminated: Tensor of shape [B, T-1, n_agents]
mask: Tensor of shape [B, T-1, n_agents]
obs: Tensor of shape [B, T, n_agents, obs_size]
action_mask: Tensor of shape [B, T, n_agents, n_actions]
"""
B, T = obs.size(0), obs.size(1)
# Calculate estimated Q-Values
mac_out = []
h = self.model.init_hidden().expand([B, self.n_agents, -1])
for t in range(T):
q, h = _mac(self.model, obs[:, t], h)
mac_out.append(q)
mac_out = th.stack(mac_out, dim=1) # Concat over time
# Pick the Q-Values for the actions taken -> [B * n_agents, T-1]
chosen_action_qvals = th.gather(
mac_out[:, :-1], dim=3, index=actions.unsqueeze(3)).squeeze(3)
# Calculate the Q-Values necessary for the target
target_mac_out = []
target_h = self.target_model.init_hidden().expand(
[B, self.n_agents, -1])
for t in range(T):
target_q, target_h = _mac(self.target_model, obs[:, t], target_h)
target_mac_out.append(target_q)
# We don't need the first timesteps Q-Value estimate for targets
target_mac_out = th.stack(
target_mac_out[1:], dim=1) # Concat across time
# Mask out unavailable actions
target_mac_out[action_mask[:, 1:] == 0] = -9999999
# Max over target Q-Values
if self.double_q:
# Get actions that maximise live Q (for double q-learning)
mac_out[action_mask == 0] = -9999999
cur_max_actions = mac_out[:, 1:].max(dim=3, keepdim=True)[1]
target_max_qvals = th.gather(target_mac_out, 3,
cur_max_actions).squeeze(3)
else:
target_max_qvals = target_mac_out.max(dim=3)[0]
# Mix
if self.mixer is not None:
# TODO(ekl) add support for handling global state? This is just
# treating the stacked agent obs as the state.
chosen_action_qvals = self.mixer(chosen_action_qvals, obs[:, :-1])
target_max_qvals = self.target_mixer(target_max_qvals, obs[:, 1:])
# Calculate 1-step Q-Learning targets
targets = rewards + self.gamma * (1 - terminated) * target_max_qvals
# Td-error
td_error = (chosen_action_qvals - targets.detach())
mask = mask.expand_as(td_error)
# 0-out the targets that came from padded data
masked_td_error = td_error * mask
# Normal L2 loss, take mean over actual data
loss = (masked_td_error**2).sum() / mask.sum()
return loss, mask, masked_td_error, chosen_action_qvals, targets
class QMixPolicyGraph(PolicyGraph):
"""QMix impl. Assumes homogeneous agents for now.
You must use MultiAgentEnv.with_agent_groups() to group agents
together for QMix. This creates the proper Tuple obs/action spaces and
populates the '_group_rewards' info field.
Action masking: to specify an action mask for individual agents, use a
dict space with an action_mask key, e.g. {"obs": ob, "action_mask": mask}.
The mask space must be `Box(0, 1, (n_actions,))`.
"""
def __init__(self, obs_space, action_space, config):
_validate(obs_space, action_space)
config = dict(ray.rllib.agents.qmix.qmix.DEFAULT_CONFIG, **config)
self.config = config
self.observation_space = obs_space
self.action_space = action_space
self.n_agents = len(obs_space.original_space.spaces)
self.n_actions = action_space.spaces[0].n
self.h_size = config["model"]["lstm_cell_size"]
agent_obs_space = obs_space.original_space.spaces[0]
if isinstance(agent_obs_space, Dict):
space_keys = set(agent_obs_space.spaces.keys())
if space_keys != {"obs", "action_mask"}:
raise ValueError(
"Dict obs space for agent must have keyset "
"['obs', 'action_mask'], got {}".format(space_keys))
mask_shape = tuple(agent_obs_space.spaces["action_mask"].shape)
if mask_shape != (self.n_actions, ):
raise ValueError("Action mask shape must be {}, got {}".format(
(self.n_actions, ), mask_shape))
self.has_action_mask = True
self.obs_size = _get_size(agent_obs_space.spaces["obs"])
else:
self.has_action_mask = False
self.obs_size = _get_size(agent_obs_space)
self.model = RNNModel(self.obs_size, self.h_size, self.n_actions)
self.target_model = RNNModel(self.obs_size, self.h_size,
self.n_actions)
# Setup the mixer network.
# The global state is just the stacked agent observations for now.
self.state_shape = [self.obs_size, self.n_agents]
if config["mixer"] is None:
self.mixer = None
self.target_mixer = None
elif config["mixer"] == "qmix":
self.mixer = QMixer(self.n_agents, self.state_shape,
config["mixing_embed_dim"])
self.target_mixer = QMixer(self.n_agents, self.state_shape,
config["mixing_embed_dim"])
elif config["mixer"] == "vdn":
self.mixer = VDNMixer()
self.target_mixer = VDNMixer()
else:
raise ValueError("Unknown mixer type {}".format(config["mixer"]))
self.cur_epsilon = 1.0
self.update_target() # initial sync
# Setup optimizer
self.params = list(self.model.parameters())
self.loss = QMixLoss(self.model, self.target_model, self.mixer,
self.target_mixer, self.n_agents, self.n_actions,
self.config["double_q"], self.config["gamma"])
self.optimiser = RMSprop(
params=self.params,
lr=config["lr"],
alpha=config["optim_alpha"],
eps=config["optim_eps"])
@override(PolicyGraph)
def compute_actions(self,
obs_batch,
state_batches=None,
prev_action_batch=None,
prev_reward_batch=None,
info_batch=None,
episodes=None,
**kwargs):
obs_batch, action_mask = self._unpack_observation(obs_batch)
assert len(state_batches) == self.n_agents, state_batches
state_batches = np.stack(state_batches, axis=1)
# Compute actions
with th.no_grad():
q_values, hiddens = _mac(self.model, th.from_numpy(obs_batch),
th.from_numpy(state_batches))
avail = th.from_numpy(action_mask).float()
masked_q_values = q_values.clone()
masked_q_values[avail == 0.0] = -float("inf")
# epsilon-greedy action selector
random_numbers = th.rand_like(q_values[:, :, 0])
pick_random = (random_numbers < self.cur_epsilon).long()
random_actions = Categorical(avail).sample().long()
actions = (pick_random * random_actions +
(1 - pick_random) * masked_q_values.max(dim=2)[1])
actions = var_to_np(actions)
hiddens = var_to_np(hiddens)
return (TupleActions(list(actions.transpose([1, 0]))),
hiddens.transpose([1, 0, 2]), {})
@override(PolicyGraph)
def compute_apply(self, samples):
obs_batch, action_mask = self._unpack_observation(samples["obs"])
group_rewards = self._get_group_rewards(samples["infos"])
# These will be padded to shape [B * T, ...]
[rew, action_mask, act, dones, obs], initial_states, seq_lens = \
chop_into_sequences(
samples["eps_id"],
samples["agent_index"], [
group_rewards, action_mask, samples["actions"],
samples["dones"], obs_batch
],
[samples["state_in_{}".format(k)]
for k in range(self.n_agents)],
max_seq_len=self.config["model"]["max_seq_len"],
dynamic_max=True,
_extra_padding=1)
# TODO(ekl) adding 1 extra unit of padding here, since otherwise we
# lose the terminating reward and the Q-values will be unanchored!
B, T = len(seq_lens), max(seq_lens) + 1
def to_batches(arr):
new_shape = [B, T] + list(arr.shape[1:])
return th.from_numpy(np.reshape(arr, new_shape))
rewards = to_batches(rew)[:, :-1].float()
actions = to_batches(act)[:, :-1].long()
obs = to_batches(obs).reshape([B, T, self.n_agents,
self.obs_size]).float()
action_mask = to_batches(action_mask)
# TODO(ekl) this treats group termination as individual termination
terminated = to_batches(dones.astype(np.float32)).unsqueeze(2).expand(
B, T, self.n_agents)[:, :-1]
filled = (np.reshape(np.tile(np.arange(T), B), [B, T]) <
np.expand_dims(seq_lens, 1)).astype(np.float32)
mask = th.from_numpy(filled).unsqueeze(2).expand(B, T,
self.n_agents)[:, :-1]
mask[:, 1:] = mask[:, 1:] * (1 - terminated[:, :-1])
# Compute loss
loss_out, mask, masked_td_error, chosen_action_qvals, targets = \
self.loss(rewards, actions, terminated, mask, obs, action_mask)
# Optimise
self.optimiser.zero_grad()
loss_out.backward()
grad_norm = th.nn.utils.clip_grad_norm_(
self.params, self.config["grad_norm_clipping"])
self.optimiser.step()
mask_elems = mask.sum().item()
stats = {
"loss": loss_out.item(),
"grad_norm": grad_norm
if isinstance(grad_norm, float) else grad_norm.item(),
"td_error_abs": masked_td_error.abs().sum().item() / mask_elems,
"q_taken_mean": (chosen_action_qvals * mask).sum().item() /
mask_elems,
"target_mean": (targets * mask).sum().item() / mask_elems,
}
return {"stats": stats}, {}
@override(PolicyGraph)
def get_initial_state(self):
return [
self.model.init_hidden().numpy().squeeze()
for _ in range(self.n_agents)
]
@override(PolicyGraph)
def get_weights(self):
return {"model": self.model.state_dict()}
@override(PolicyGraph)
def set_weights(self, weights):
self.model.load_state_dict(weights["model"])
@override(PolicyGraph)
def get_state(self):
return {
"model": self.model.state_dict(),
"target_model": self.target_model.state_dict(),
"mixer": self.mixer.state_dict() if self.mixer else None,
"target_mixer": self.target_mixer.state_dict()
if self.mixer else None,
"cur_epsilon": self.cur_epsilon,
}
@override(PolicyGraph)
def set_state(self, state):
self.model.load_state_dict(state["model"])
self.target_model.load_state_dict(state["target_model"])
if state["mixer"] is not None:
self.mixer.load_state_dict(state["mixer"])
self.target_mixer.load_state_dict(state["target_mixer"])
self.set_epsilon(state["cur_epsilon"])
self.update_target()
def update_target(self):
self.target_model.load_state_dict(self.model.state_dict())
if self.mixer is not None:
self.target_mixer.load_state_dict(self.mixer.state_dict())
logger.debug("Updated target networks")
def set_epsilon(self, epsilon):
self.cur_epsilon = epsilon
def _get_group_rewards(self, info_batch):
group_rewards = np.array([
info.get(GROUP_REWARDS, [0.0] * self.n_agents)
for info in info_batch
])
return group_rewards
def _unpack_observation(self, obs_batch):
unpacked = _unpack_obs(
np.array(obs_batch),
self.observation_space.original_space,
tensorlib=np)
if self.has_action_mask:
obs = np.concatenate(
[o["obs"] for o in unpacked],
axis=1).reshape([len(obs_batch), self.n_agents, self.obs_size])
action_mask = np.concatenate(
[o["action_mask"] for o in unpacked], axis=1).reshape(
[len(obs_batch), self.n_agents, self.n_actions])
else:
obs = np.concatenate(
unpacked,
axis=1).reshape([len(obs_batch), self.n_agents, self.obs_size])
action_mask = np.ones(
[len(obs_batch), self.n_agents, self.n_actions])
return obs, action_mask
def _validate(obs_space, action_space):
if not hasattr(obs_space, "original_space") or \
not isinstance(obs_space.original_space, Tuple):
raise ValueError("Obs space must be a Tuple, got {}. Use ".format(
obs_space) + "MultiAgentEnv.with_agent_groups() to group related "
"agents for QMix.")
if not isinstance(action_space, Tuple):
raise ValueError(
"Action space must be a Tuple, got {}. ".format(action_space) +
"Use MultiAgentEnv.with_agent_groups() to group related "
"agents for QMix.")
if not isinstance(action_space.spaces[0], Discrete):
raise ValueError(
"QMix requires a discrete action space, got {}".format(
action_space.spaces[0]))
if len({str(x) for x in obs_space.original_space.spaces}) > 1:
raise ValueError(
"Implementation limitation: observations of grouped agents "
"must be homogeneous, got {}".format(
obs_space.original_space.spaces))
if len({str(x) for x in action_space.spaces}) > 1:
raise ValueError(
"Implementation limitation: action space of grouped agents "
"must be homogeneous, got {}".format(action_space.spaces))
def _get_size(obs_space):
return get_preprocessor(obs_space)(obs_space).size
def _mac(model, obs, h):
"""Forward pass of the multi-agent controller.
Arguments:
model: Model that produces q-values for a 1d agent batch
obs: Tensor of shape [B, n_agents, obs_size]
h: Tensor of shape [B, n_agents, h_size]
Returns:
q_vals: Tensor of shape [B, n_agents, n_actions]
h: Tensor of shape [B, n_agents, h_size]
"""
B, n_agents = obs.size(0), obs.size(1)
obs_flat = obs.reshape([B * n_agents, -1])
h_flat = h.reshape([B * n_agents, -1])
q_flat, h_flat = model.forward(obs_flat, h_flat)
return q_flat.reshape([B, n_agents, -1]), h_flat.reshape([B, n_agents, -1])