mirror of
https://github.com/wassname/ray.git
synced 2026-07-03 22:59:13 +08:00
[rllib] Q-Mix implementation (Q-Mix, VDN, IQN, and Ape-X variants) (#3548)
This commit is contained in:
@@ -10,6 +10,7 @@ import pickle
|
||||
import six
|
||||
import tempfile
|
||||
import tensorflow as tf
|
||||
import traceback
|
||||
from types import FunctionType
|
||||
|
||||
import ray
|
||||
@@ -546,6 +547,14 @@ def _register_if_needed(env_object):
|
||||
def get_agent_class(alg):
|
||||
"""Returns the class of a known agent given its name."""
|
||||
|
||||
try:
|
||||
return _get_agent_class(alg)
|
||||
except ImportError:
|
||||
from ray.rllib.agents.mock import _agent_import_failed
|
||||
return _agent_import_failed(traceback.format_exc())
|
||||
|
||||
|
||||
def _get_agent_class(alg):
|
||||
if alg == "DDPG":
|
||||
from ray.rllib.agents import ddpg
|
||||
return ddpg.DDPGAgent
|
||||
@@ -579,6 +588,12 @@ def get_agent_class(alg):
|
||||
elif alg == "IMPALA":
|
||||
from ray.rllib.agents import impala
|
||||
return impala.ImpalaAgent
|
||||
elif alg == "QMIX":
|
||||
from ray.rllib.agents import qmix
|
||||
return qmix.QMixAgent
|
||||
elif alg == "APEX_QMIX":
|
||||
from ray.rllib.agents import qmix
|
||||
return qmix.ApexQMixAgent
|
||||
elif alg == "script":
|
||||
from ray.tune import script_runner
|
||||
return script_runner.ScriptRunner
|
||||
|
||||
@@ -8,12 +8,6 @@ from ray.rllib.agents.ddpg.ddpg_policy_graph import DDPGPolicyGraph
|
||||
from ray.rllib.utils.annotations import override
|
||||
from ray.rllib.utils.schedules import ConstantSchedule, LinearSchedule
|
||||
|
||||
OPTIMIZER_SHARED_CONFIGS = [
|
||||
"buffer_size", "prioritized_replay", "prioritized_replay_alpha",
|
||||
"prioritized_replay_beta", "prioritized_replay_eps", "sample_batch_size",
|
||||
"train_batch_size", "learning_starts"
|
||||
]
|
||||
|
||||
# yapf: disable
|
||||
# __sphinx_doc_begin__
|
||||
DEFAULT_CONFIG = with_common_config({
|
||||
|
||||
@@ -117,12 +117,13 @@ class DQNAgent(Agent):
|
||||
_agent_name = "DQN"
|
||||
_default_config = DEFAULT_CONFIG
|
||||
_policy_graph = DQNPolicyGraph
|
||||
_optimizer_shared_configs = OPTIMIZER_SHARED_CONFIGS
|
||||
|
||||
@override(Agent)
|
||||
def _init(self):
|
||||
# Update effective batch size to include n-step
|
||||
adjusted_batch_size = max(self.config["sample_batch_size"],
|
||||
self.config["n_step"])
|
||||
self.config.get("n_step", 1))
|
||||
self.config["sample_batch_size"] = adjusted_batch_size
|
||||
|
||||
self.exploration0 = self._make_exploration_schedule(-1)
|
||||
@@ -131,7 +132,7 @@ class DQNAgent(Agent):
|
||||
for i in range(self.config["num_workers"])
|
||||
]
|
||||
|
||||
for k in OPTIMIZER_SHARED_CONFIGS:
|
||||
for k in self._optimizer_shared_configs:
|
||||
if self._agent_name != "DQN" and k in [
|
||||
"schedule_max_timesteps", "beta_annealing_fraction",
|
||||
"final_prioritized_replay_beta"
|
||||
|
||||
@@ -100,3 +100,16 @@ class _ParameterTuningAgent(_MockAgent):
|
||||
timesteps_this_iter=self.config["iter_timesteps"],
|
||||
time_this_iter_s=self.config["iter_time"],
|
||||
info={})
|
||||
|
||||
|
||||
def _agent_import_failed(trace):
|
||||
"""Returns dummy agent class for if PyTorch etc. is not installed."""
|
||||
|
||||
class _AgentImportFailed(Agent):
|
||||
_agent_name = "AgentImportFailed"
|
||||
_default_config = with_common_config({})
|
||||
|
||||
def _setup(self, config):
|
||||
raise ImportError(trace)
|
||||
|
||||
return _AgentImportFailed
|
||||
|
||||
@@ -0,0 +1 @@
|
||||
Code in this package is adapted from https://github.com/oxwhirl/pymarl_alpha.
|
||||
@@ -0,0 +1,8 @@
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from ray.rllib.agents.qmix.qmix import QMixAgent, DEFAULT_CONFIG
|
||||
from ray.rllib.agents.qmix.apex import ApexQMixAgent
|
||||
|
||||
__all__ = ["QMixAgent", "ApexQMixAgent", "DEFAULT_CONFIG"]
|
||||
@@ -0,0 +1,55 @@
|
||||
"""Experimental: scalable Ape-X variant of QMIX"""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from ray.rllib.agents.qmix.qmix import QMixAgent, DEFAULT_CONFIG as QMIX_CONFIG
|
||||
from ray.rllib.utils.annotations import override
|
||||
from ray.rllib.utils import merge_dicts
|
||||
|
||||
APEX_QMIX_DEFAULT_CONFIG = merge_dicts(
|
||||
QMIX_CONFIG, # see also the options in qmix.py, which are also supported
|
||||
{
|
||||
"optimizer_class": "AsyncReplayOptimizer",
|
||||
"optimizer": merge_dicts(
|
||||
QMIX_CONFIG["optimizer"],
|
||||
{
|
||||
"max_weight_sync_delay": 400,
|
||||
"num_replay_buffer_shards": 4,
|
||||
"batch_replay": True, # required for RNN. Disables prio.
|
||||
"debug": False
|
||||
}),
|
||||
"num_gpus": 0,
|
||||
"num_workers": 32,
|
||||
"buffer_size": 2000000,
|
||||
"learning_starts": 50000,
|
||||
"train_batch_size": 512,
|
||||
"sample_batch_size": 50,
|
||||
"max_weight_sync_delay": 400,
|
||||
"target_network_update_freq": 500000,
|
||||
"timesteps_per_iteration": 25000,
|
||||
"per_worker_exploration": True,
|
||||
"min_iter_time_s": 30,
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
class ApexQMixAgent(QMixAgent):
|
||||
"""QMIX variant that uses the Ape-X distributed policy optimizer.
|
||||
|
||||
By default, this is configured for a large single node (32 cores). For
|
||||
running in a large cluster, increase the `num_workers` config var.
|
||||
"""
|
||||
|
||||
_agent_name = "APEX_QMIX"
|
||||
_default_config = APEX_QMIX_DEFAULT_CONFIG
|
||||
|
||||
@override(QMixAgent)
|
||||
def update_target_if_needed(self):
|
||||
# Ape-X updates based on num steps trained, not sampled
|
||||
if self.optimizer.num_steps_trained - self.last_target_update_ts > \
|
||||
self.config["target_network_update_freq"]:
|
||||
self.local_evaluator.for_policy(lambda p: p.update_target())
|
||||
self.last_target_update_ts = self.optimizer.num_steps_trained
|
||||
self.num_target_updates += 1
|
||||
@@ -0,0 +1,64 @@
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import torch as th
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import numpy as np
|
||||
|
||||
|
||||
class VDNMixer(nn.Module):
|
||||
def __init__(self):
|
||||
super(VDNMixer, self).__init__()
|
||||
|
||||
def forward(self, agent_qs, batch):
|
||||
return th.sum(agent_qs, dim=2, keepdim=True)
|
||||
|
||||
|
||||
class QMixer(nn.Module):
|
||||
def __init__(self, n_agents, state_shape, mixing_embed_dim):
|
||||
super(QMixer, self).__init__()
|
||||
|
||||
self.n_agents = n_agents
|
||||
self.embed_dim = mixing_embed_dim
|
||||
self.state_dim = int(np.prod(state_shape))
|
||||
|
||||
self.hyper_w_1 = nn.Linear(self.state_dim,
|
||||
self.embed_dim * self.n_agents)
|
||||
self.hyper_w_final = nn.Linear(self.state_dim, self.embed_dim)
|
||||
|
||||
# State dependent bias for hidden layer
|
||||
self.hyper_b_1 = nn.Linear(self.state_dim, self.embed_dim)
|
||||
|
||||
# V(s) instead of a bias for the last layers
|
||||
self.V = nn.Sequential(
|
||||
nn.Linear(self.state_dim, self.embed_dim), nn.ReLU(),
|
||||
nn.Linear(self.embed_dim, 1))
|
||||
|
||||
def forward(self, agent_qs, states):
|
||||
"""Forward pass for the mixer.
|
||||
|
||||
Arguments:
|
||||
agent_qs: Tensor of shape [B, T, n_agents, n_actions]
|
||||
states: Tensor of shape [B, T, state_dim]
|
||||
"""
|
||||
bs = agent_qs.size(0)
|
||||
states = states.reshape(-1, self.state_dim)
|
||||
agent_qs = agent_qs.view(-1, 1, self.n_agents)
|
||||
# First layer
|
||||
w1 = th.abs(self.hyper_w_1(states))
|
||||
b1 = self.hyper_b_1(states)
|
||||
w1 = w1.view(-1, self.n_agents, self.embed_dim)
|
||||
b1 = b1.view(-1, 1, self.embed_dim)
|
||||
hidden = F.elu(th.bmm(agent_qs, w1) + b1)
|
||||
# Second layer
|
||||
w_final = th.abs(self.hyper_w_final(states))
|
||||
w_final = w_final.view(-1, self.embed_dim, 1)
|
||||
# State-dependent bias
|
||||
v = self.V(states).view(-1, 1, 1)
|
||||
# Compute final output
|
||||
y = th.bmm(hidden, w_final) + v
|
||||
# Reshape and return
|
||||
q_tot = y.view(bs, -1, 1)
|
||||
return q_tot
|
||||
@@ -0,0 +1,28 @@
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from torch import nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
# TODO(ekl) we should have common models for pytorch like we do for TF
|
||||
class RNNModel(nn.Module):
|
||||
def __init__(self, obs_size, rnn_hidden_dim, n_actions):
|
||||
nn.Module.__init__(self)
|
||||
self.rnn_hidden_dim = rnn_hidden_dim
|
||||
self.n_actions = n_actions
|
||||
self.fc1 = nn.Linear(obs_size, rnn_hidden_dim)
|
||||
self.rnn = nn.GRUCell(rnn_hidden_dim, rnn_hidden_dim)
|
||||
self.fc2 = nn.Linear(rnn_hidden_dim, n_actions)
|
||||
|
||||
def init_hidden(self):
|
||||
# make hidden states on same device as model
|
||||
return self.fc1.weight.new(1, self.rnn_hidden_dim).zero_()
|
||||
|
||||
def forward(self, inputs, hidden_state):
|
||||
x = F.relu(self.fc1(inputs.float()))
|
||||
h_in = hidden_state.reshape(-1, self.rnn_hidden_dim)
|
||||
h = self.rnn(x, h_in)
|
||||
q = self.fc2(h)
|
||||
return q, h
|
||||
@@ -0,0 +1,92 @@
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from ray.rllib.agents.agent import with_common_config
|
||||
from ray.rllib.agents.dqn.dqn import DQNAgent
|
||||
from ray.rllib.agents.qmix.qmix_policy_graph import QMixPolicyGraph
|
||||
|
||||
# yapf: disable
|
||||
# __sphinx_doc_begin__
|
||||
DEFAULT_CONFIG = with_common_config({
|
||||
# === QMix ===
|
||||
# Mixing network. Either "qmix", "vdn", or None
|
||||
"mixer": "qmix",
|
||||
# Size of the mixing network embedding
|
||||
"mixing_embed_dim": 32,
|
||||
# Whether to use Double_Q learning
|
||||
"double_q": True,
|
||||
# Optimize over complete episodes by default.
|
||||
"batch_mode": "complete_episodes",
|
||||
|
||||
# === Exploration ===
|
||||
# Max num timesteps for annealing schedules. Exploration is annealed from
|
||||
# 1.0 to exploration_fraction over this number of timesteps scaled by
|
||||
# exploration_fraction
|
||||
"schedule_max_timesteps": 100000,
|
||||
# Number of env steps to optimize for before returning
|
||||
"timesteps_per_iteration": 1000,
|
||||
# Fraction of entire training period over which the exploration rate is
|
||||
# annealed
|
||||
"exploration_fraction": 0.1,
|
||||
# Final value of random action probability
|
||||
"exploration_final_eps": 0.02,
|
||||
# Update the target network every `target_network_update_freq` steps.
|
||||
"target_network_update_freq": 500,
|
||||
|
||||
# === Replay buffer ===
|
||||
# Size of the replay buffer in steps.
|
||||
"buffer_size": 10000,
|
||||
|
||||
# === Optimization ===
|
||||
# Learning rate for adam optimizer
|
||||
"lr": 0.0005,
|
||||
# RMSProp alpha
|
||||
"optim_alpha": 0.99,
|
||||
# RMSProp epsilon
|
||||
"optim_eps": 0.00001,
|
||||
# If not None, clip gradients during optimization at this value
|
||||
"grad_norm_clipping": 10,
|
||||
# How many steps of the model to sample before learning starts.
|
||||
"learning_starts": 1000,
|
||||
# Update the replay buffer with this many samples at once. Note that
|
||||
# this setting applies per-worker if num_workers > 1.
|
||||
"sample_batch_size": 4,
|
||||
# Size of a batched sampled from replay buffer for training. Note that
|
||||
# if async_updates is set, then each worker returns gradients for a
|
||||
# batch of this size.
|
||||
"train_batch_size": 32,
|
||||
|
||||
# === Parallelism ===
|
||||
# Number of workers for collecting samples with. This only makes sense
|
||||
# to increase if your environment is particularly slow to sample, or if
|
||||
# you"re using the Async or Ape-X optimizers.
|
||||
"num_workers": 0,
|
||||
# Optimizer class to use.
|
||||
"optimizer_class": "SyncBatchReplayOptimizer",
|
||||
# Whether to use a distribution of epsilons across workers for exploration.
|
||||
"per_worker_exploration": False,
|
||||
# Whether to compute priorities on workers.
|
||||
"worker_side_prioritization": False,
|
||||
# Prevent iterations from going lower than this time span
|
||||
"min_iter_time_s": 1,
|
||||
|
||||
# === Model ===
|
||||
"model": {
|
||||
"lstm_cell_size": 64,
|
||||
"max_seq_len": 999999,
|
||||
},
|
||||
})
|
||||
# __sphinx_doc_end__
|
||||
# yapf: enable
|
||||
|
||||
|
||||
class QMixAgent(DQNAgent):
|
||||
"""QMix implementation in PyTorch."""
|
||||
|
||||
_agent_name = "QMIX"
|
||||
_default_config = DEFAULT_CONFIG
|
||||
_policy_graph = QMixPolicyGraph
|
||||
_optimizer_shared_configs = [
|
||||
"learning_starts", "buffer_size", "train_batch_size"
|
||||
]
|
||||
@@ -0,0 +1,411 @@
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from gym.spaces import Tuple, Discrete, Dict
|
||||
import logging
|
||||
import numpy as np
|
||||
import torch as th
|
||||
import torch.nn as nn
|
||||
from torch.optim import RMSprop
|
||||
from torch.distributions import Categorical
|
||||
|
||||
import ray
|
||||
from ray.rllib.agents.qmix.mixers import VDNMixer, QMixer
|
||||
from ray.rllib.agents.qmix.model import RNNModel
|
||||
from ray.rllib.evaluation.policy_graph import PolicyGraph
|
||||
from ray.rllib.models.action_dist import TupleActions
|
||||
from ray.rllib.models.pytorch.misc import var_to_np
|
||||
from ray.rllib.models.lstm import chop_into_sequences
|
||||
from ray.rllib.models.model import _unpack_obs
|
||||
from ray.rllib.models.preprocessors import get_preprocessor
|
||||
from ray.rllib.env.constants import GROUP_REWARDS
|
||||
from ray.rllib.utils.annotations import override
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class QMixLoss(nn.Module):
|
||||
def __init__(self,
|
||||
model,
|
||||
target_model,
|
||||
mixer,
|
||||
target_mixer,
|
||||
n_agents,
|
||||
n_actions,
|
||||
double_q=True,
|
||||
gamma=0.99):
|
||||
nn.Module.__init__(self)
|
||||
self.model = model
|
||||
self.target_model = target_model
|
||||
self.mixer = mixer
|
||||
self.target_mixer = target_mixer
|
||||
self.n_agents = n_agents
|
||||
self.n_actions = n_actions
|
||||
self.double_q = double_q
|
||||
self.gamma = gamma
|
||||
|
||||
def forward(self, rewards, actions, terminated, mask, obs, action_mask):
|
||||
"""Forward pass of the loss.
|
||||
|
||||
Arguments:
|
||||
rewards: Tensor of shape [B, T-1, n_agents]
|
||||
actions: Tensor of shape [B, T-1, n_agents]
|
||||
terminated: Tensor of shape [B, T-1, n_agents]
|
||||
mask: Tensor of shape [B, T-1, n_agents]
|
||||
obs: Tensor of shape [B, T, n_agents, obs_size]
|
||||
action_mask: Tensor of shape [B, T, n_agents, n_actions]
|
||||
"""
|
||||
|
||||
B, T = obs.size(0), obs.size(1)
|
||||
|
||||
# Calculate estimated Q-Values
|
||||
mac_out = []
|
||||
h = self.model.init_hidden().expand([B, self.n_agents, -1])
|
||||
for t in range(T):
|
||||
q, h = _mac(self.model, obs[:, t], h)
|
||||
mac_out.append(q)
|
||||
mac_out = th.stack(mac_out, dim=1) # Concat over time
|
||||
|
||||
# Pick the Q-Values for the actions taken -> [B * n_agents, T-1]
|
||||
chosen_action_qvals = th.gather(
|
||||
mac_out[:, :-1], dim=3, index=actions.unsqueeze(3)).squeeze(3)
|
||||
|
||||
# Calculate the Q-Values necessary for the target
|
||||
target_mac_out = []
|
||||
target_h = self.target_model.init_hidden().expand(
|
||||
[B, self.n_agents, -1])
|
||||
for t in range(T):
|
||||
target_q, target_h = _mac(self.target_model, obs[:, t], target_h)
|
||||
target_mac_out.append(target_q)
|
||||
|
||||
# We don't need the first timesteps Q-Value estimate for targets
|
||||
target_mac_out = th.stack(
|
||||
target_mac_out[1:], dim=1) # Concat across time
|
||||
|
||||
# Mask out unavailable actions
|
||||
target_mac_out[action_mask[:, 1:] == 0] = -9999999
|
||||
|
||||
# Max over target Q-Values
|
||||
if self.double_q:
|
||||
# Get actions that maximise live Q (for double q-learning)
|
||||
mac_out[action_mask == 0] = -9999999
|
||||
cur_max_actions = mac_out[:, 1:].max(dim=3, keepdim=True)[1]
|
||||
target_max_qvals = th.gather(target_mac_out, 3,
|
||||
cur_max_actions).squeeze(3)
|
||||
else:
|
||||
target_max_qvals = target_mac_out.max(dim=3)[0]
|
||||
|
||||
# Mix
|
||||
if self.mixer is not None:
|
||||
# TODO(ekl) add support for handling global state? This is just
|
||||
# treating the stacked agent obs as the state.
|
||||
chosen_action_qvals = self.mixer(chosen_action_qvals, obs[:, :-1])
|
||||
target_max_qvals = self.target_mixer(target_max_qvals, obs[:, 1:])
|
||||
|
||||
# Calculate 1-step Q-Learning targets
|
||||
targets = rewards + self.gamma * (1 - terminated) * target_max_qvals
|
||||
|
||||
# Td-error
|
||||
td_error = (chosen_action_qvals - targets.detach())
|
||||
|
||||
mask = mask.expand_as(td_error)
|
||||
|
||||
# 0-out the targets that came from padded data
|
||||
masked_td_error = td_error * mask
|
||||
|
||||
# Normal L2 loss, take mean over actual data
|
||||
loss = (masked_td_error**2).sum() / mask.sum()
|
||||
return loss, mask, masked_td_error, chosen_action_qvals, targets
|
||||
|
||||
|
||||
class QMixPolicyGraph(PolicyGraph):
|
||||
"""QMix impl. Assumes homogeneous agents for now.
|
||||
|
||||
You must use MultiAgentEnv.with_agent_groups() to group agents
|
||||
together for QMix. This creates the proper Tuple obs/action spaces and
|
||||
populates the '_group_rewards' info field.
|
||||
|
||||
Action masking: to specify an action mask for individual agents, use a
|
||||
dict space with an action_mask key, e.g. {"obs": ob, "action_mask": mask}.
|
||||
The mask space must be `Box(0, 1, (n_actions,))`.
|
||||
"""
|
||||
|
||||
def __init__(self, obs_space, action_space, config):
|
||||
_validate(obs_space, action_space)
|
||||
config = dict(ray.rllib.agents.qmix.qmix.DEFAULT_CONFIG, **config)
|
||||
self.config = config
|
||||
self.observation_space = obs_space
|
||||
self.action_space = action_space
|
||||
self.n_agents = len(obs_space.original_space.spaces)
|
||||
self.n_actions = action_space.spaces[0].n
|
||||
self.h_size = config["model"]["lstm_cell_size"]
|
||||
|
||||
agent_obs_space = obs_space.original_space.spaces[0]
|
||||
if isinstance(agent_obs_space, Dict):
|
||||
space_keys = set(agent_obs_space.spaces.keys())
|
||||
if space_keys != {"obs", "action_mask"}:
|
||||
raise ValueError(
|
||||
"Dict obs space for agent must have keyset "
|
||||
"['obs', 'action_mask'], got {}".format(space_keys))
|
||||
mask_shape = tuple(agent_obs_space.spaces["action_mask"].shape)
|
||||
if mask_shape != (self.n_actions, ):
|
||||
raise ValueError("Action mask shape must be {}, got {}".format(
|
||||
(self.n_actions, ), mask_shape))
|
||||
self.has_action_mask = True
|
||||
self.obs_size = _get_size(agent_obs_space.spaces["obs"])
|
||||
else:
|
||||
self.has_action_mask = False
|
||||
self.obs_size = _get_size(agent_obs_space)
|
||||
|
||||
self.model = RNNModel(self.obs_size, self.h_size, self.n_actions)
|
||||
self.target_model = RNNModel(self.obs_size, self.h_size,
|
||||
self.n_actions)
|
||||
|
||||
# Setup the mixer network.
|
||||
# The global state is just the stacked agent observations for now.
|
||||
self.state_shape = [self.obs_size, self.n_agents]
|
||||
if config["mixer"] is None:
|
||||
self.mixer = None
|
||||
self.target_mixer = None
|
||||
elif config["mixer"] == "qmix":
|
||||
self.mixer = QMixer(self.n_agents, self.state_shape,
|
||||
config["mixing_embed_dim"])
|
||||
self.target_mixer = QMixer(self.n_agents, self.state_shape,
|
||||
config["mixing_embed_dim"])
|
||||
elif config["mixer"] == "vdn":
|
||||
self.mixer = VDNMixer()
|
||||
self.target_mixer = VDNMixer()
|
||||
else:
|
||||
raise ValueError("Unknown mixer type {}".format(config["mixer"]))
|
||||
|
||||
self.cur_epsilon = 1.0
|
||||
self.update_target() # initial sync
|
||||
|
||||
# Setup optimizer
|
||||
self.params = list(self.model.parameters())
|
||||
self.loss = QMixLoss(self.model, self.target_model, self.mixer,
|
||||
self.target_mixer, self.n_agents, self.n_actions,
|
||||
self.config["double_q"], self.config["gamma"])
|
||||
self.optimiser = RMSprop(
|
||||
params=self.params,
|
||||
lr=config["lr"],
|
||||
alpha=config["optim_alpha"],
|
||||
eps=config["optim_eps"])
|
||||
|
||||
@override(PolicyGraph)
|
||||
def compute_actions(self,
|
||||
obs_batch,
|
||||
state_batches=None,
|
||||
prev_action_batch=None,
|
||||
prev_reward_batch=None,
|
||||
info_batch=None,
|
||||
episodes=None,
|
||||
**kwargs):
|
||||
obs_batch, action_mask = self._unpack_observation(obs_batch)
|
||||
assert len(state_batches) == self.n_agents, state_batches
|
||||
state_batches = np.stack(state_batches, axis=1)
|
||||
|
||||
# Compute actions
|
||||
with th.no_grad():
|
||||
q_values, hiddens = _mac(self.model, th.from_numpy(obs_batch),
|
||||
th.from_numpy(state_batches))
|
||||
avail = th.from_numpy(action_mask).float()
|
||||
masked_q_values = q_values.clone()
|
||||
masked_q_values[avail == 0.0] = -float("inf")
|
||||
# epsilon-greedy action selector
|
||||
random_numbers = th.rand_like(q_values[:, :, 0])
|
||||
pick_random = (random_numbers < self.cur_epsilon).long()
|
||||
random_actions = Categorical(avail).sample().long()
|
||||
actions = (pick_random * random_actions +
|
||||
(1 - pick_random) * masked_q_values.max(dim=2)[1])
|
||||
actions = var_to_np(actions)
|
||||
hiddens = var_to_np(hiddens)
|
||||
|
||||
return (TupleActions(list(actions.transpose([1, 0]))),
|
||||
hiddens.transpose([1, 0, 2]), {})
|
||||
|
||||
@override(PolicyGraph)
|
||||
def compute_apply(self, samples):
|
||||
obs_batch, action_mask = self._unpack_observation(samples["obs"])
|
||||
group_rewards = self._get_group_rewards(samples["infos"])
|
||||
|
||||
# These will be padded to shape [B * T, ...]
|
||||
[rew, action_mask, act, dones, obs], initial_states, seq_lens = \
|
||||
chop_into_sequences(
|
||||
samples["eps_id"],
|
||||
samples["agent_index"], [
|
||||
group_rewards, action_mask, samples["actions"],
|
||||
samples["dones"], obs_batch
|
||||
],
|
||||
[samples["state_in_{}".format(k)]
|
||||
for k in range(self.n_agents)],
|
||||
max_seq_len=self.config["model"]["max_seq_len"],
|
||||
dynamic_max=True,
|
||||
_extra_padding=1)
|
||||
# TODO(ekl) adding 1 extra unit of padding here, since otherwise we
|
||||
# lose the terminating reward and the Q-values will be unanchored!
|
||||
B, T = len(seq_lens), max(seq_lens) + 1
|
||||
|
||||
def to_batches(arr):
|
||||
new_shape = [B, T] + list(arr.shape[1:])
|
||||
return th.from_numpy(np.reshape(arr, new_shape))
|
||||
|
||||
rewards = to_batches(rew)[:, :-1].float()
|
||||
actions = to_batches(act)[:, :-1].long()
|
||||
obs = to_batches(obs).reshape([B, T, self.n_agents,
|
||||
self.obs_size]).float()
|
||||
action_mask = to_batches(action_mask)
|
||||
|
||||
# TODO(ekl) this treats group termination as individual termination
|
||||
terminated = to_batches(dones.astype(np.float32)).unsqueeze(2).expand(
|
||||
B, T, self.n_agents)[:, :-1]
|
||||
filled = (np.reshape(np.tile(np.arange(T), B), [B, T]) <
|
||||
np.expand_dims(seq_lens, 1)).astype(np.float32)
|
||||
mask = th.from_numpy(filled).unsqueeze(2).expand(B, T,
|
||||
self.n_agents)[:, :-1]
|
||||
mask[:, 1:] = mask[:, 1:] * (1 - terminated[:, :-1])
|
||||
|
||||
# Compute loss
|
||||
loss_out, mask, masked_td_error, chosen_action_qvals, targets = \
|
||||
self.loss(rewards, actions, terminated, mask, obs, action_mask)
|
||||
|
||||
# Optimise
|
||||
self.optimiser.zero_grad()
|
||||
loss_out.backward()
|
||||
grad_norm = th.nn.utils.clip_grad_norm_(
|
||||
self.params, self.config["grad_norm_clipping"])
|
||||
self.optimiser.step()
|
||||
|
||||
mask_elems = mask.sum().item()
|
||||
stats = {
|
||||
"loss": loss_out.item(),
|
||||
"grad_norm": grad_norm
|
||||
if isinstance(grad_norm, float) else grad_norm.item(),
|
||||
"td_error_abs": masked_td_error.abs().sum().item() / mask_elems,
|
||||
"q_taken_mean": (chosen_action_qvals * mask).sum().item() /
|
||||
mask_elems,
|
||||
"target_mean": (targets * mask).sum().item() / mask_elems,
|
||||
}
|
||||
return {"stats": stats}, {}
|
||||
|
||||
@override(PolicyGraph)
|
||||
def get_initial_state(self):
|
||||
return [
|
||||
self.model.init_hidden().numpy().squeeze()
|
||||
for _ in range(self.n_agents)
|
||||
]
|
||||
|
||||
@override(PolicyGraph)
|
||||
def get_weights(self):
|
||||
return {"model": self.model.state_dict()}
|
||||
|
||||
@override(PolicyGraph)
|
||||
def set_weights(self, weights):
|
||||
self.model.load_state_dict(weights["model"])
|
||||
|
||||
@override(PolicyGraph)
|
||||
def get_state(self):
|
||||
return {
|
||||
"model": self.model.state_dict(),
|
||||
"target_model": self.target_model.state_dict(),
|
||||
"mixer": self.mixer.state_dict() if self.mixer else None,
|
||||
"target_mixer": self.target_mixer.state_dict()
|
||||
if self.mixer else None,
|
||||
"cur_epsilon": self.cur_epsilon,
|
||||
}
|
||||
|
||||
@override(PolicyGraph)
|
||||
def set_state(self, state):
|
||||
self.model.load_state_dict(state["model"])
|
||||
self.target_model.load_state_dict(state["target_model"])
|
||||
if state["mixer"] is not None:
|
||||
self.mixer.load_state_dict(state["mixer"])
|
||||
self.target_mixer.load_state_dict(state["target_mixer"])
|
||||
self.set_epsilon(state["cur_epsilon"])
|
||||
self.update_target()
|
||||
|
||||
def update_target(self):
|
||||
self.target_model.load_state_dict(self.model.state_dict())
|
||||
if self.mixer is not None:
|
||||
self.target_mixer.load_state_dict(self.mixer.state_dict())
|
||||
logger.debug("Updated target networks")
|
||||
|
||||
def set_epsilon(self, epsilon):
|
||||
self.cur_epsilon = epsilon
|
||||
|
||||
def _get_group_rewards(self, info_batch):
|
||||
group_rewards = np.array([
|
||||
info.get(GROUP_REWARDS, [0.0] * self.n_agents)
|
||||
for info in info_batch
|
||||
])
|
||||
return group_rewards
|
||||
|
||||
def _unpack_observation(self, obs_batch):
|
||||
unpacked = _unpack_obs(
|
||||
np.array(obs_batch),
|
||||
self.observation_space.original_space,
|
||||
tensorlib=np)
|
||||
if self.has_action_mask:
|
||||
obs = np.concatenate(
|
||||
[o["obs"] for o in unpacked],
|
||||
axis=1).reshape([len(obs_batch), self.n_agents, self.obs_size])
|
||||
action_mask = np.concatenate(
|
||||
[o["action_mask"] for o in unpacked], axis=1).reshape(
|
||||
[len(obs_batch), self.n_agents, self.n_actions])
|
||||
else:
|
||||
obs = np.concatenate(
|
||||
unpacked,
|
||||
axis=1).reshape([len(obs_batch), self.n_agents, self.obs_size])
|
||||
action_mask = np.ones(
|
||||
[len(obs_batch), self.n_agents, self.n_actions])
|
||||
return obs, action_mask
|
||||
|
||||
|
||||
def _validate(obs_space, action_space):
|
||||
if not hasattr(obs_space, "original_space") or \
|
||||
not isinstance(obs_space.original_space, Tuple):
|
||||
raise ValueError("Obs space must be a Tuple, got {}. Use ".format(
|
||||
obs_space) + "MultiAgentEnv.with_agent_groups() to group related "
|
||||
"agents for QMix.")
|
||||
if not isinstance(action_space, Tuple):
|
||||
raise ValueError(
|
||||
"Action space must be a Tuple, got {}. ".format(action_space) +
|
||||
"Use MultiAgentEnv.with_agent_groups() to group related "
|
||||
"agents for QMix.")
|
||||
if not isinstance(action_space.spaces[0], Discrete):
|
||||
raise ValueError(
|
||||
"QMix requires a discrete action space, got {}".format(
|
||||
action_space.spaces[0]))
|
||||
if len({str(x) for x in obs_space.original_space.spaces}) > 1:
|
||||
raise ValueError(
|
||||
"Implementation limitation: observations of grouped agents "
|
||||
"must be homogeneous, got {}".format(
|
||||
obs_space.original_space.spaces))
|
||||
if len({str(x) for x in action_space.spaces}) > 1:
|
||||
raise ValueError(
|
||||
"Implementation limitation: action space of grouped agents "
|
||||
"must be homogeneous, got {}".format(action_space.spaces))
|
||||
|
||||
|
||||
def _get_size(obs_space):
|
||||
return get_preprocessor(obs_space)(obs_space).size
|
||||
|
||||
|
||||
def _mac(model, obs, h):
|
||||
"""Forward pass of the multi-agent controller.
|
||||
|
||||
Arguments:
|
||||
model: Model that produces q-values for a 1d agent batch
|
||||
obs: Tensor of shape [B, n_agents, obs_size]
|
||||
h: Tensor of shape [B, n_agents, h_size]
|
||||
|
||||
Returns:
|
||||
q_vals: Tensor of shape [B, n_agents, n_actions]
|
||||
h: Tensor of shape [B, n_agents, h_size]
|
||||
"""
|
||||
B, n_agents = obs.size(0), obs.size(1)
|
||||
obs_flat = obs.reshape([B * n_agents, -1])
|
||||
h_flat = h.reshape([B * n_agents, -1])
|
||||
q_flat, h_flat = model.forward(obs_flat, h_flat)
|
||||
return q_flat.reshape([B, n_agents, -1]), h_flat.reshape([B, n_agents, -1])
|
||||
Reference in New Issue
Block a user