mirror of
https://github.com/wassname/ray.git
synced 2026-06-27 23:08:32 +08:00
[rllib] Q-Mix implementation (Q-Mix, VDN, IQN, and Ape-X variants) (#3548)
This commit is contained in:
@@ -33,8 +33,8 @@ def _register_all():
|
||||
|
||||
for key in [
|
||||
"PPO", "ES", "DQN", "APEX", "A3C", "PG", "DDPG", "APEX_DDPG",
|
||||
"IMPALA", "ARS", "A2C", "__fake", "__sigmoid_fake_data",
|
||||
"__parameter_tuning"
|
||||
"IMPALA", "ARS", "A2C", "QMIX", "APEX_QMIX", "__fake",
|
||||
"__sigmoid_fake_data", "__parameter_tuning"
|
||||
]:
|
||||
from ray.rllib.agents.agent import get_agent_class
|
||||
register_trainable(key, get_agent_class(key))
|
||||
|
||||
@@ -10,6 +10,7 @@ import pickle
|
||||
import six
|
||||
import tempfile
|
||||
import tensorflow as tf
|
||||
import traceback
|
||||
from types import FunctionType
|
||||
|
||||
import ray
|
||||
@@ -546,6 +547,14 @@ def _register_if_needed(env_object):
|
||||
def get_agent_class(alg):
|
||||
"""Returns the class of a known agent given its name."""
|
||||
|
||||
try:
|
||||
return _get_agent_class(alg)
|
||||
except ImportError:
|
||||
from ray.rllib.agents.mock import _agent_import_failed
|
||||
return _agent_import_failed(traceback.format_exc())
|
||||
|
||||
|
||||
def _get_agent_class(alg):
|
||||
if alg == "DDPG":
|
||||
from ray.rllib.agents import ddpg
|
||||
return ddpg.DDPGAgent
|
||||
@@ -579,6 +588,12 @@ def get_agent_class(alg):
|
||||
elif alg == "IMPALA":
|
||||
from ray.rllib.agents import impala
|
||||
return impala.ImpalaAgent
|
||||
elif alg == "QMIX":
|
||||
from ray.rllib.agents import qmix
|
||||
return qmix.QMixAgent
|
||||
elif alg == "APEX_QMIX":
|
||||
from ray.rllib.agents import qmix
|
||||
return qmix.ApexQMixAgent
|
||||
elif alg == "script":
|
||||
from ray.tune import script_runner
|
||||
return script_runner.ScriptRunner
|
||||
|
||||
@@ -8,12 +8,6 @@ from ray.rllib.agents.ddpg.ddpg_policy_graph import DDPGPolicyGraph
|
||||
from ray.rllib.utils.annotations import override
|
||||
from ray.rllib.utils.schedules import ConstantSchedule, LinearSchedule
|
||||
|
||||
OPTIMIZER_SHARED_CONFIGS = [
|
||||
"buffer_size", "prioritized_replay", "prioritized_replay_alpha",
|
||||
"prioritized_replay_beta", "prioritized_replay_eps", "sample_batch_size",
|
||||
"train_batch_size", "learning_starts"
|
||||
]
|
||||
|
||||
# yapf: disable
|
||||
# __sphinx_doc_begin__
|
||||
DEFAULT_CONFIG = with_common_config({
|
||||
|
||||
@@ -117,12 +117,13 @@ class DQNAgent(Agent):
|
||||
_agent_name = "DQN"
|
||||
_default_config = DEFAULT_CONFIG
|
||||
_policy_graph = DQNPolicyGraph
|
||||
_optimizer_shared_configs = OPTIMIZER_SHARED_CONFIGS
|
||||
|
||||
@override(Agent)
|
||||
def _init(self):
|
||||
# Update effective batch size to include n-step
|
||||
adjusted_batch_size = max(self.config["sample_batch_size"],
|
||||
self.config["n_step"])
|
||||
self.config.get("n_step", 1))
|
||||
self.config["sample_batch_size"] = adjusted_batch_size
|
||||
|
||||
self.exploration0 = self._make_exploration_schedule(-1)
|
||||
@@ -131,7 +132,7 @@ class DQNAgent(Agent):
|
||||
for i in range(self.config["num_workers"])
|
||||
]
|
||||
|
||||
for k in OPTIMIZER_SHARED_CONFIGS:
|
||||
for k in self._optimizer_shared_configs:
|
||||
if self._agent_name != "DQN" and k in [
|
||||
"schedule_max_timesteps", "beta_annealing_fraction",
|
||||
"final_prioritized_replay_beta"
|
||||
|
||||
@@ -100,3 +100,16 @@ class _ParameterTuningAgent(_MockAgent):
|
||||
timesteps_this_iter=self.config["iter_timesteps"],
|
||||
time_this_iter_s=self.config["iter_time"],
|
||||
info={})
|
||||
|
||||
|
||||
def _agent_import_failed(trace):
|
||||
"""Returns dummy agent class for if PyTorch etc. is not installed."""
|
||||
|
||||
class _AgentImportFailed(Agent):
|
||||
_agent_name = "AgentImportFailed"
|
||||
_default_config = with_common_config({})
|
||||
|
||||
def _setup(self, config):
|
||||
raise ImportError(trace)
|
||||
|
||||
return _AgentImportFailed
|
||||
|
||||
@@ -0,0 +1 @@
|
||||
Code in this package is adapted from https://github.com/oxwhirl/pymarl_alpha.
|
||||
@@ -0,0 +1,8 @@
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from ray.rllib.agents.qmix.qmix import QMixAgent, DEFAULT_CONFIG
|
||||
from ray.rllib.agents.qmix.apex import ApexQMixAgent
|
||||
|
||||
__all__ = ["QMixAgent", "ApexQMixAgent", "DEFAULT_CONFIG"]
|
||||
@@ -0,0 +1,55 @@
|
||||
"""Experimental: scalable Ape-X variant of QMIX"""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from ray.rllib.agents.qmix.qmix import QMixAgent, DEFAULT_CONFIG as QMIX_CONFIG
|
||||
from ray.rllib.utils.annotations import override
|
||||
from ray.rllib.utils import merge_dicts
|
||||
|
||||
APEX_QMIX_DEFAULT_CONFIG = merge_dicts(
|
||||
QMIX_CONFIG, # see also the options in qmix.py, which are also supported
|
||||
{
|
||||
"optimizer_class": "AsyncReplayOptimizer",
|
||||
"optimizer": merge_dicts(
|
||||
QMIX_CONFIG["optimizer"],
|
||||
{
|
||||
"max_weight_sync_delay": 400,
|
||||
"num_replay_buffer_shards": 4,
|
||||
"batch_replay": True, # required for RNN. Disables prio.
|
||||
"debug": False
|
||||
}),
|
||||
"num_gpus": 0,
|
||||
"num_workers": 32,
|
||||
"buffer_size": 2000000,
|
||||
"learning_starts": 50000,
|
||||
"train_batch_size": 512,
|
||||
"sample_batch_size": 50,
|
||||
"max_weight_sync_delay": 400,
|
||||
"target_network_update_freq": 500000,
|
||||
"timesteps_per_iteration": 25000,
|
||||
"per_worker_exploration": True,
|
||||
"min_iter_time_s": 30,
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
class ApexQMixAgent(QMixAgent):
|
||||
"""QMIX variant that uses the Ape-X distributed policy optimizer.
|
||||
|
||||
By default, this is configured for a large single node (32 cores). For
|
||||
running in a large cluster, increase the `num_workers` config var.
|
||||
"""
|
||||
|
||||
_agent_name = "APEX_QMIX"
|
||||
_default_config = APEX_QMIX_DEFAULT_CONFIG
|
||||
|
||||
@override(QMixAgent)
|
||||
def update_target_if_needed(self):
|
||||
# Ape-X updates based on num steps trained, not sampled
|
||||
if self.optimizer.num_steps_trained - self.last_target_update_ts > \
|
||||
self.config["target_network_update_freq"]:
|
||||
self.local_evaluator.for_policy(lambda p: p.update_target())
|
||||
self.last_target_update_ts = self.optimizer.num_steps_trained
|
||||
self.num_target_updates += 1
|
||||
@@ -0,0 +1,64 @@
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import torch as th
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import numpy as np
|
||||
|
||||
|
||||
class VDNMixer(nn.Module):
|
||||
def __init__(self):
|
||||
super(VDNMixer, self).__init__()
|
||||
|
||||
def forward(self, agent_qs, batch):
|
||||
return th.sum(agent_qs, dim=2, keepdim=True)
|
||||
|
||||
|
||||
class QMixer(nn.Module):
|
||||
def __init__(self, n_agents, state_shape, mixing_embed_dim):
|
||||
super(QMixer, self).__init__()
|
||||
|
||||
self.n_agents = n_agents
|
||||
self.embed_dim = mixing_embed_dim
|
||||
self.state_dim = int(np.prod(state_shape))
|
||||
|
||||
self.hyper_w_1 = nn.Linear(self.state_dim,
|
||||
self.embed_dim * self.n_agents)
|
||||
self.hyper_w_final = nn.Linear(self.state_dim, self.embed_dim)
|
||||
|
||||
# State dependent bias for hidden layer
|
||||
self.hyper_b_1 = nn.Linear(self.state_dim, self.embed_dim)
|
||||
|
||||
# V(s) instead of a bias for the last layers
|
||||
self.V = nn.Sequential(
|
||||
nn.Linear(self.state_dim, self.embed_dim), nn.ReLU(),
|
||||
nn.Linear(self.embed_dim, 1))
|
||||
|
||||
def forward(self, agent_qs, states):
|
||||
"""Forward pass for the mixer.
|
||||
|
||||
Arguments:
|
||||
agent_qs: Tensor of shape [B, T, n_agents, n_actions]
|
||||
states: Tensor of shape [B, T, state_dim]
|
||||
"""
|
||||
bs = agent_qs.size(0)
|
||||
states = states.reshape(-1, self.state_dim)
|
||||
agent_qs = agent_qs.view(-1, 1, self.n_agents)
|
||||
# First layer
|
||||
w1 = th.abs(self.hyper_w_1(states))
|
||||
b1 = self.hyper_b_1(states)
|
||||
w1 = w1.view(-1, self.n_agents, self.embed_dim)
|
||||
b1 = b1.view(-1, 1, self.embed_dim)
|
||||
hidden = F.elu(th.bmm(agent_qs, w1) + b1)
|
||||
# Second layer
|
||||
w_final = th.abs(self.hyper_w_final(states))
|
||||
w_final = w_final.view(-1, self.embed_dim, 1)
|
||||
# State-dependent bias
|
||||
v = self.V(states).view(-1, 1, 1)
|
||||
# Compute final output
|
||||
y = th.bmm(hidden, w_final) + v
|
||||
# Reshape and return
|
||||
q_tot = y.view(bs, -1, 1)
|
||||
return q_tot
|
||||
@@ -0,0 +1,28 @@
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from torch import nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
# TODO(ekl) we should have common models for pytorch like we do for TF
|
||||
class RNNModel(nn.Module):
|
||||
def __init__(self, obs_size, rnn_hidden_dim, n_actions):
|
||||
nn.Module.__init__(self)
|
||||
self.rnn_hidden_dim = rnn_hidden_dim
|
||||
self.n_actions = n_actions
|
||||
self.fc1 = nn.Linear(obs_size, rnn_hidden_dim)
|
||||
self.rnn = nn.GRUCell(rnn_hidden_dim, rnn_hidden_dim)
|
||||
self.fc2 = nn.Linear(rnn_hidden_dim, n_actions)
|
||||
|
||||
def init_hidden(self):
|
||||
# make hidden states on same device as model
|
||||
return self.fc1.weight.new(1, self.rnn_hidden_dim).zero_()
|
||||
|
||||
def forward(self, inputs, hidden_state):
|
||||
x = F.relu(self.fc1(inputs.float()))
|
||||
h_in = hidden_state.reshape(-1, self.rnn_hidden_dim)
|
||||
h = self.rnn(x, h_in)
|
||||
q = self.fc2(h)
|
||||
return q, h
|
||||
@@ -0,0 +1,92 @@
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from ray.rllib.agents.agent import with_common_config
|
||||
from ray.rllib.agents.dqn.dqn import DQNAgent
|
||||
from ray.rllib.agents.qmix.qmix_policy_graph import QMixPolicyGraph
|
||||
|
||||
# yapf: disable
|
||||
# __sphinx_doc_begin__
|
||||
DEFAULT_CONFIG = with_common_config({
|
||||
# === QMix ===
|
||||
# Mixing network. Either "qmix", "vdn", or None
|
||||
"mixer": "qmix",
|
||||
# Size of the mixing network embedding
|
||||
"mixing_embed_dim": 32,
|
||||
# Whether to use Double_Q learning
|
||||
"double_q": True,
|
||||
# Optimize over complete episodes by default.
|
||||
"batch_mode": "complete_episodes",
|
||||
|
||||
# === Exploration ===
|
||||
# Max num timesteps for annealing schedules. Exploration is annealed from
|
||||
# 1.0 to exploration_fraction over this number of timesteps scaled by
|
||||
# exploration_fraction
|
||||
"schedule_max_timesteps": 100000,
|
||||
# Number of env steps to optimize for before returning
|
||||
"timesteps_per_iteration": 1000,
|
||||
# Fraction of entire training period over which the exploration rate is
|
||||
# annealed
|
||||
"exploration_fraction": 0.1,
|
||||
# Final value of random action probability
|
||||
"exploration_final_eps": 0.02,
|
||||
# Update the target network every `target_network_update_freq` steps.
|
||||
"target_network_update_freq": 500,
|
||||
|
||||
# === Replay buffer ===
|
||||
# Size of the replay buffer in steps.
|
||||
"buffer_size": 10000,
|
||||
|
||||
# === Optimization ===
|
||||
# Learning rate for adam optimizer
|
||||
"lr": 0.0005,
|
||||
# RMSProp alpha
|
||||
"optim_alpha": 0.99,
|
||||
# RMSProp epsilon
|
||||
"optim_eps": 0.00001,
|
||||
# If not None, clip gradients during optimization at this value
|
||||
"grad_norm_clipping": 10,
|
||||
# How many steps of the model to sample before learning starts.
|
||||
"learning_starts": 1000,
|
||||
# Update the replay buffer with this many samples at once. Note that
|
||||
# this setting applies per-worker if num_workers > 1.
|
||||
"sample_batch_size": 4,
|
||||
# Size of a batched sampled from replay buffer for training. Note that
|
||||
# if async_updates is set, then each worker returns gradients for a
|
||||
# batch of this size.
|
||||
"train_batch_size": 32,
|
||||
|
||||
# === Parallelism ===
|
||||
# Number of workers for collecting samples with. This only makes sense
|
||||
# to increase if your environment is particularly slow to sample, or if
|
||||
# you"re using the Async or Ape-X optimizers.
|
||||
"num_workers": 0,
|
||||
# Optimizer class to use.
|
||||
"optimizer_class": "SyncBatchReplayOptimizer",
|
||||
# Whether to use a distribution of epsilons across workers for exploration.
|
||||
"per_worker_exploration": False,
|
||||
# Whether to compute priorities on workers.
|
||||
"worker_side_prioritization": False,
|
||||
# Prevent iterations from going lower than this time span
|
||||
"min_iter_time_s": 1,
|
||||
|
||||
# === Model ===
|
||||
"model": {
|
||||
"lstm_cell_size": 64,
|
||||
"max_seq_len": 999999,
|
||||
},
|
||||
})
|
||||
# __sphinx_doc_end__
|
||||
# yapf: enable
|
||||
|
||||
|
||||
class QMixAgent(DQNAgent):
|
||||
"""QMix implementation in PyTorch."""
|
||||
|
||||
_agent_name = "QMIX"
|
||||
_default_config = DEFAULT_CONFIG
|
||||
_policy_graph = QMixPolicyGraph
|
||||
_optimizer_shared_configs = [
|
||||
"learning_starts", "buffer_size", "train_batch_size"
|
||||
]
|
||||
@@ -0,0 +1,411 @@
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from gym.spaces import Tuple, Discrete, Dict
|
||||
import logging
|
||||
import numpy as np
|
||||
import torch as th
|
||||
import torch.nn as nn
|
||||
from torch.optim import RMSprop
|
||||
from torch.distributions import Categorical
|
||||
|
||||
import ray
|
||||
from ray.rllib.agents.qmix.mixers import VDNMixer, QMixer
|
||||
from ray.rllib.agents.qmix.model import RNNModel
|
||||
from ray.rllib.evaluation.policy_graph import PolicyGraph
|
||||
from ray.rllib.models.action_dist import TupleActions
|
||||
from ray.rllib.models.pytorch.misc import var_to_np
|
||||
from ray.rllib.models.lstm import chop_into_sequences
|
||||
from ray.rllib.models.model import _unpack_obs
|
||||
from ray.rllib.models.preprocessors import get_preprocessor
|
||||
from ray.rllib.env.constants import GROUP_REWARDS
|
||||
from ray.rllib.utils.annotations import override
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class QMixLoss(nn.Module):
|
||||
def __init__(self,
|
||||
model,
|
||||
target_model,
|
||||
mixer,
|
||||
target_mixer,
|
||||
n_agents,
|
||||
n_actions,
|
||||
double_q=True,
|
||||
gamma=0.99):
|
||||
nn.Module.__init__(self)
|
||||
self.model = model
|
||||
self.target_model = target_model
|
||||
self.mixer = mixer
|
||||
self.target_mixer = target_mixer
|
||||
self.n_agents = n_agents
|
||||
self.n_actions = n_actions
|
||||
self.double_q = double_q
|
||||
self.gamma = gamma
|
||||
|
||||
def forward(self, rewards, actions, terminated, mask, obs, action_mask):
|
||||
"""Forward pass of the loss.
|
||||
|
||||
Arguments:
|
||||
rewards: Tensor of shape [B, T-1, n_agents]
|
||||
actions: Tensor of shape [B, T-1, n_agents]
|
||||
terminated: Tensor of shape [B, T-1, n_agents]
|
||||
mask: Tensor of shape [B, T-1, n_agents]
|
||||
obs: Tensor of shape [B, T, n_agents, obs_size]
|
||||
action_mask: Tensor of shape [B, T, n_agents, n_actions]
|
||||
"""
|
||||
|
||||
B, T = obs.size(0), obs.size(1)
|
||||
|
||||
# Calculate estimated Q-Values
|
||||
mac_out = []
|
||||
h = self.model.init_hidden().expand([B, self.n_agents, -1])
|
||||
for t in range(T):
|
||||
q, h = _mac(self.model, obs[:, t], h)
|
||||
mac_out.append(q)
|
||||
mac_out = th.stack(mac_out, dim=1) # Concat over time
|
||||
|
||||
# Pick the Q-Values for the actions taken -> [B * n_agents, T-1]
|
||||
chosen_action_qvals = th.gather(
|
||||
mac_out[:, :-1], dim=3, index=actions.unsqueeze(3)).squeeze(3)
|
||||
|
||||
# Calculate the Q-Values necessary for the target
|
||||
target_mac_out = []
|
||||
target_h = self.target_model.init_hidden().expand(
|
||||
[B, self.n_agents, -1])
|
||||
for t in range(T):
|
||||
target_q, target_h = _mac(self.target_model, obs[:, t], target_h)
|
||||
target_mac_out.append(target_q)
|
||||
|
||||
# We don't need the first timesteps Q-Value estimate for targets
|
||||
target_mac_out = th.stack(
|
||||
target_mac_out[1:], dim=1) # Concat across time
|
||||
|
||||
# Mask out unavailable actions
|
||||
target_mac_out[action_mask[:, 1:] == 0] = -9999999
|
||||
|
||||
# Max over target Q-Values
|
||||
if self.double_q:
|
||||
# Get actions that maximise live Q (for double q-learning)
|
||||
mac_out[action_mask == 0] = -9999999
|
||||
cur_max_actions = mac_out[:, 1:].max(dim=3, keepdim=True)[1]
|
||||
target_max_qvals = th.gather(target_mac_out, 3,
|
||||
cur_max_actions).squeeze(3)
|
||||
else:
|
||||
target_max_qvals = target_mac_out.max(dim=3)[0]
|
||||
|
||||
# Mix
|
||||
if self.mixer is not None:
|
||||
# TODO(ekl) add support for handling global state? This is just
|
||||
# treating the stacked agent obs as the state.
|
||||
chosen_action_qvals = self.mixer(chosen_action_qvals, obs[:, :-1])
|
||||
target_max_qvals = self.target_mixer(target_max_qvals, obs[:, 1:])
|
||||
|
||||
# Calculate 1-step Q-Learning targets
|
||||
targets = rewards + self.gamma * (1 - terminated) * target_max_qvals
|
||||
|
||||
# Td-error
|
||||
td_error = (chosen_action_qvals - targets.detach())
|
||||
|
||||
mask = mask.expand_as(td_error)
|
||||
|
||||
# 0-out the targets that came from padded data
|
||||
masked_td_error = td_error * mask
|
||||
|
||||
# Normal L2 loss, take mean over actual data
|
||||
loss = (masked_td_error**2).sum() / mask.sum()
|
||||
return loss, mask, masked_td_error, chosen_action_qvals, targets
|
||||
|
||||
|
||||
class QMixPolicyGraph(PolicyGraph):
|
||||
"""QMix impl. Assumes homogeneous agents for now.
|
||||
|
||||
You must use MultiAgentEnv.with_agent_groups() to group agents
|
||||
together for QMix. This creates the proper Tuple obs/action spaces and
|
||||
populates the '_group_rewards' info field.
|
||||
|
||||
Action masking: to specify an action mask for individual agents, use a
|
||||
dict space with an action_mask key, e.g. {"obs": ob, "action_mask": mask}.
|
||||
The mask space must be `Box(0, 1, (n_actions,))`.
|
||||
"""
|
||||
|
||||
def __init__(self, obs_space, action_space, config):
|
||||
_validate(obs_space, action_space)
|
||||
config = dict(ray.rllib.agents.qmix.qmix.DEFAULT_CONFIG, **config)
|
||||
self.config = config
|
||||
self.observation_space = obs_space
|
||||
self.action_space = action_space
|
||||
self.n_agents = len(obs_space.original_space.spaces)
|
||||
self.n_actions = action_space.spaces[0].n
|
||||
self.h_size = config["model"]["lstm_cell_size"]
|
||||
|
||||
agent_obs_space = obs_space.original_space.spaces[0]
|
||||
if isinstance(agent_obs_space, Dict):
|
||||
space_keys = set(agent_obs_space.spaces.keys())
|
||||
if space_keys != {"obs", "action_mask"}:
|
||||
raise ValueError(
|
||||
"Dict obs space for agent must have keyset "
|
||||
"['obs', 'action_mask'], got {}".format(space_keys))
|
||||
mask_shape = tuple(agent_obs_space.spaces["action_mask"].shape)
|
||||
if mask_shape != (self.n_actions, ):
|
||||
raise ValueError("Action mask shape must be {}, got {}".format(
|
||||
(self.n_actions, ), mask_shape))
|
||||
self.has_action_mask = True
|
||||
self.obs_size = _get_size(agent_obs_space.spaces["obs"])
|
||||
else:
|
||||
self.has_action_mask = False
|
||||
self.obs_size = _get_size(agent_obs_space)
|
||||
|
||||
self.model = RNNModel(self.obs_size, self.h_size, self.n_actions)
|
||||
self.target_model = RNNModel(self.obs_size, self.h_size,
|
||||
self.n_actions)
|
||||
|
||||
# Setup the mixer network.
|
||||
# The global state is just the stacked agent observations for now.
|
||||
self.state_shape = [self.obs_size, self.n_agents]
|
||||
if config["mixer"] is None:
|
||||
self.mixer = None
|
||||
self.target_mixer = None
|
||||
elif config["mixer"] == "qmix":
|
||||
self.mixer = QMixer(self.n_agents, self.state_shape,
|
||||
config["mixing_embed_dim"])
|
||||
self.target_mixer = QMixer(self.n_agents, self.state_shape,
|
||||
config["mixing_embed_dim"])
|
||||
elif config["mixer"] == "vdn":
|
||||
self.mixer = VDNMixer()
|
||||
self.target_mixer = VDNMixer()
|
||||
else:
|
||||
raise ValueError("Unknown mixer type {}".format(config["mixer"]))
|
||||
|
||||
self.cur_epsilon = 1.0
|
||||
self.update_target() # initial sync
|
||||
|
||||
# Setup optimizer
|
||||
self.params = list(self.model.parameters())
|
||||
self.loss = QMixLoss(self.model, self.target_model, self.mixer,
|
||||
self.target_mixer, self.n_agents, self.n_actions,
|
||||
self.config["double_q"], self.config["gamma"])
|
||||
self.optimiser = RMSprop(
|
||||
params=self.params,
|
||||
lr=config["lr"],
|
||||
alpha=config["optim_alpha"],
|
||||
eps=config["optim_eps"])
|
||||
|
||||
@override(PolicyGraph)
|
||||
def compute_actions(self,
|
||||
obs_batch,
|
||||
state_batches=None,
|
||||
prev_action_batch=None,
|
||||
prev_reward_batch=None,
|
||||
info_batch=None,
|
||||
episodes=None,
|
||||
**kwargs):
|
||||
obs_batch, action_mask = self._unpack_observation(obs_batch)
|
||||
assert len(state_batches) == self.n_agents, state_batches
|
||||
state_batches = np.stack(state_batches, axis=1)
|
||||
|
||||
# Compute actions
|
||||
with th.no_grad():
|
||||
q_values, hiddens = _mac(self.model, th.from_numpy(obs_batch),
|
||||
th.from_numpy(state_batches))
|
||||
avail = th.from_numpy(action_mask).float()
|
||||
masked_q_values = q_values.clone()
|
||||
masked_q_values[avail == 0.0] = -float("inf")
|
||||
# epsilon-greedy action selector
|
||||
random_numbers = th.rand_like(q_values[:, :, 0])
|
||||
pick_random = (random_numbers < self.cur_epsilon).long()
|
||||
random_actions = Categorical(avail).sample().long()
|
||||
actions = (pick_random * random_actions +
|
||||
(1 - pick_random) * masked_q_values.max(dim=2)[1])
|
||||
actions = var_to_np(actions)
|
||||
hiddens = var_to_np(hiddens)
|
||||
|
||||
return (TupleActions(list(actions.transpose([1, 0]))),
|
||||
hiddens.transpose([1, 0, 2]), {})
|
||||
|
||||
@override(PolicyGraph)
|
||||
def compute_apply(self, samples):
|
||||
obs_batch, action_mask = self._unpack_observation(samples["obs"])
|
||||
group_rewards = self._get_group_rewards(samples["infos"])
|
||||
|
||||
# These will be padded to shape [B * T, ...]
|
||||
[rew, action_mask, act, dones, obs], initial_states, seq_lens = \
|
||||
chop_into_sequences(
|
||||
samples["eps_id"],
|
||||
samples["agent_index"], [
|
||||
group_rewards, action_mask, samples["actions"],
|
||||
samples["dones"], obs_batch
|
||||
],
|
||||
[samples["state_in_{}".format(k)]
|
||||
for k in range(self.n_agents)],
|
||||
max_seq_len=self.config["model"]["max_seq_len"],
|
||||
dynamic_max=True,
|
||||
_extra_padding=1)
|
||||
# TODO(ekl) adding 1 extra unit of padding here, since otherwise we
|
||||
# lose the terminating reward and the Q-values will be unanchored!
|
||||
B, T = len(seq_lens), max(seq_lens) + 1
|
||||
|
||||
def to_batches(arr):
|
||||
new_shape = [B, T] + list(arr.shape[1:])
|
||||
return th.from_numpy(np.reshape(arr, new_shape))
|
||||
|
||||
rewards = to_batches(rew)[:, :-1].float()
|
||||
actions = to_batches(act)[:, :-1].long()
|
||||
obs = to_batches(obs).reshape([B, T, self.n_agents,
|
||||
self.obs_size]).float()
|
||||
action_mask = to_batches(action_mask)
|
||||
|
||||
# TODO(ekl) this treats group termination as individual termination
|
||||
terminated = to_batches(dones.astype(np.float32)).unsqueeze(2).expand(
|
||||
B, T, self.n_agents)[:, :-1]
|
||||
filled = (np.reshape(np.tile(np.arange(T), B), [B, T]) <
|
||||
np.expand_dims(seq_lens, 1)).astype(np.float32)
|
||||
mask = th.from_numpy(filled).unsqueeze(2).expand(B, T,
|
||||
self.n_agents)[:, :-1]
|
||||
mask[:, 1:] = mask[:, 1:] * (1 - terminated[:, :-1])
|
||||
|
||||
# Compute loss
|
||||
loss_out, mask, masked_td_error, chosen_action_qvals, targets = \
|
||||
self.loss(rewards, actions, terminated, mask, obs, action_mask)
|
||||
|
||||
# Optimise
|
||||
self.optimiser.zero_grad()
|
||||
loss_out.backward()
|
||||
grad_norm = th.nn.utils.clip_grad_norm_(
|
||||
self.params, self.config["grad_norm_clipping"])
|
||||
self.optimiser.step()
|
||||
|
||||
mask_elems = mask.sum().item()
|
||||
stats = {
|
||||
"loss": loss_out.item(),
|
||||
"grad_norm": grad_norm
|
||||
if isinstance(grad_norm, float) else grad_norm.item(),
|
||||
"td_error_abs": masked_td_error.abs().sum().item() / mask_elems,
|
||||
"q_taken_mean": (chosen_action_qvals * mask).sum().item() /
|
||||
mask_elems,
|
||||
"target_mean": (targets * mask).sum().item() / mask_elems,
|
||||
}
|
||||
return {"stats": stats}, {}
|
||||
|
||||
@override(PolicyGraph)
|
||||
def get_initial_state(self):
|
||||
return [
|
||||
self.model.init_hidden().numpy().squeeze()
|
||||
for _ in range(self.n_agents)
|
||||
]
|
||||
|
||||
@override(PolicyGraph)
|
||||
def get_weights(self):
|
||||
return {"model": self.model.state_dict()}
|
||||
|
||||
@override(PolicyGraph)
|
||||
def set_weights(self, weights):
|
||||
self.model.load_state_dict(weights["model"])
|
||||
|
||||
@override(PolicyGraph)
|
||||
def get_state(self):
|
||||
return {
|
||||
"model": self.model.state_dict(),
|
||||
"target_model": self.target_model.state_dict(),
|
||||
"mixer": self.mixer.state_dict() if self.mixer else None,
|
||||
"target_mixer": self.target_mixer.state_dict()
|
||||
if self.mixer else None,
|
||||
"cur_epsilon": self.cur_epsilon,
|
||||
}
|
||||
|
||||
@override(PolicyGraph)
|
||||
def set_state(self, state):
|
||||
self.model.load_state_dict(state["model"])
|
||||
self.target_model.load_state_dict(state["target_model"])
|
||||
if state["mixer"] is not None:
|
||||
self.mixer.load_state_dict(state["mixer"])
|
||||
self.target_mixer.load_state_dict(state["target_mixer"])
|
||||
self.set_epsilon(state["cur_epsilon"])
|
||||
self.update_target()
|
||||
|
||||
def update_target(self):
|
||||
self.target_model.load_state_dict(self.model.state_dict())
|
||||
if self.mixer is not None:
|
||||
self.target_mixer.load_state_dict(self.mixer.state_dict())
|
||||
logger.debug("Updated target networks")
|
||||
|
||||
def set_epsilon(self, epsilon):
|
||||
self.cur_epsilon = epsilon
|
||||
|
||||
def _get_group_rewards(self, info_batch):
|
||||
group_rewards = np.array([
|
||||
info.get(GROUP_REWARDS, [0.0] * self.n_agents)
|
||||
for info in info_batch
|
||||
])
|
||||
return group_rewards
|
||||
|
||||
def _unpack_observation(self, obs_batch):
|
||||
unpacked = _unpack_obs(
|
||||
np.array(obs_batch),
|
||||
self.observation_space.original_space,
|
||||
tensorlib=np)
|
||||
if self.has_action_mask:
|
||||
obs = np.concatenate(
|
||||
[o["obs"] for o in unpacked],
|
||||
axis=1).reshape([len(obs_batch), self.n_agents, self.obs_size])
|
||||
action_mask = np.concatenate(
|
||||
[o["action_mask"] for o in unpacked], axis=1).reshape(
|
||||
[len(obs_batch), self.n_agents, self.n_actions])
|
||||
else:
|
||||
obs = np.concatenate(
|
||||
unpacked,
|
||||
axis=1).reshape([len(obs_batch), self.n_agents, self.obs_size])
|
||||
action_mask = np.ones(
|
||||
[len(obs_batch), self.n_agents, self.n_actions])
|
||||
return obs, action_mask
|
||||
|
||||
|
||||
def _validate(obs_space, action_space):
|
||||
if not hasattr(obs_space, "original_space") or \
|
||||
not isinstance(obs_space.original_space, Tuple):
|
||||
raise ValueError("Obs space must be a Tuple, got {}. Use ".format(
|
||||
obs_space) + "MultiAgentEnv.with_agent_groups() to group related "
|
||||
"agents for QMix.")
|
||||
if not isinstance(action_space, Tuple):
|
||||
raise ValueError(
|
||||
"Action space must be a Tuple, got {}. ".format(action_space) +
|
||||
"Use MultiAgentEnv.with_agent_groups() to group related "
|
||||
"agents for QMix.")
|
||||
if not isinstance(action_space.spaces[0], Discrete):
|
||||
raise ValueError(
|
||||
"QMix requires a discrete action space, got {}".format(
|
||||
action_space.spaces[0]))
|
||||
if len({str(x) for x in obs_space.original_space.spaces}) > 1:
|
||||
raise ValueError(
|
||||
"Implementation limitation: observations of grouped agents "
|
||||
"must be homogeneous, got {}".format(
|
||||
obs_space.original_space.spaces))
|
||||
if len({str(x) for x in action_space.spaces}) > 1:
|
||||
raise ValueError(
|
||||
"Implementation limitation: action space of grouped agents "
|
||||
"must be homogeneous, got {}".format(action_space.spaces))
|
||||
|
||||
|
||||
def _get_size(obs_space):
|
||||
return get_preprocessor(obs_space)(obs_space).size
|
||||
|
||||
|
||||
def _mac(model, obs, h):
|
||||
"""Forward pass of the multi-agent controller.
|
||||
|
||||
Arguments:
|
||||
model: Model that produces q-values for a 1d agent batch
|
||||
obs: Tensor of shape [B, n_agents, obs_size]
|
||||
h: Tensor of shape [B, n_agents, h_size]
|
||||
|
||||
Returns:
|
||||
q_vals: Tensor of shape [B, n_agents, n_actions]
|
||||
h: Tensor of shape [B, n_agents, h_size]
|
||||
"""
|
||||
B, n_agents = obs.size(0), obs.size(1)
|
||||
obs_flat = obs.reshape([B * n_agents, -1])
|
||||
h_flat = h.reshape([B * n_agents, -1])
|
||||
q_flat, h_flat = model.forward(obs_flat, h_flat)
|
||||
return q_flat.reshape([B, n_agents, -1]), h_flat.reshape([B, n_agents, -1])
|
||||
+3
-3
@@ -304,9 +304,9 @@ class _MultiAgentEnvToAsync(AsyncVectorEnv):
|
||||
raise ValueError(
|
||||
"Key set for obs and rewards must be the same: "
|
||||
"{} vs {}".format(obs.keys(), rewards.keys()))
|
||||
if set(obs.keys()) != set(infos.keys()):
|
||||
raise ValueError("Key set for obs and infos must be the same: "
|
||||
"{} vs {}".format(obs.keys(), infos.keys()))
|
||||
if set(infos).difference(set(obs)):
|
||||
raise ValueError("Key set for infos must be a subset of obs: "
|
||||
"{} vs {}".format(infos.keys(), obs.keys()))
|
||||
if dones["__all__"]:
|
||||
self.dones.add(env_id)
|
||||
self.env_states[env_id].observe(obs, rewards, dones, infos)
|
||||
|
||||
Vendored
+19
@@ -0,0 +1,19 @@
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
# info key for the individual rewards of an agent, for example:
|
||||
# info: {
|
||||
# group_1: {
|
||||
# _group_rewards: [5, -1, 1], # 3 agents in this group
|
||||
# }
|
||||
# }
|
||||
GROUP_REWARDS = "_group_rewards"
|
||||
|
||||
# info key for the individual infos of an agent, for example:
|
||||
# info: {
|
||||
# group_1: {
|
||||
# _group_infos: [{"foo": ...}, {}], # 2 agents in this group
|
||||
# }
|
||||
# }
|
||||
GROUP_INFO = "_group_info"
|
||||
+107
@@ -0,0 +1,107 @@
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from collections import OrderedDict
|
||||
|
||||
from ray.rllib.env.constants import GROUP_REWARDS, GROUP_INFO
|
||||
from ray.rllib.env.multi_agent_env import MultiAgentEnv
|
||||
|
||||
|
||||
# TODO(ekl) we should add some unit tests for this
|
||||
class _GroupAgentsWrapper(MultiAgentEnv):
|
||||
"""Wraps a MultiAgentEnv environment with agents grouped as specified.
|
||||
|
||||
See multi_agent_env.py for the specification of groups.
|
||||
|
||||
This API is experimental.
|
||||
"""
|
||||
|
||||
def __init__(self, env, groups, obs_space=None, act_space=None):
|
||||
"""Wrap an existing multi-agent env to group agents together.
|
||||
|
||||
See MultiAgentEnv.with_agent_groups() for usage info.
|
||||
|
||||
Arguments:
|
||||
env (MultiAgentEnv): env to wrap
|
||||
groups (dict): Grouping spec as documented in MultiAgentEnv
|
||||
obs_space (Space): Optional observation space for the grouped
|
||||
env. Must be a tuple space.
|
||||
act_space (Space): Optional action space for the grouped env.
|
||||
Must be a tuple space.
|
||||
"""
|
||||
|
||||
self.env = env
|
||||
self.groups = groups
|
||||
self.agent_id_to_group = {}
|
||||
for group_id, agent_ids in groups.items():
|
||||
for agent_id in agent_ids:
|
||||
if agent_id in self.agent_id_to_group:
|
||||
raise ValueError(
|
||||
"Agent id {} is in multiple groups".format(
|
||||
agent_id, groups))
|
||||
self.agent_id_to_group[agent_id] = group_id
|
||||
if obs_space is not None:
|
||||
self.observation_space = obs_space
|
||||
if act_space is not None:
|
||||
self.action_space = act_space
|
||||
|
||||
def reset(self):
|
||||
obs = self.env.reset()
|
||||
return self._group_items(obs)
|
||||
|
||||
def step(self, action_dict):
|
||||
# Ungroup and send actions
|
||||
action_dict = self._ungroup_items(action_dict)
|
||||
obs, rewards, dones, infos = self.env.step(action_dict)
|
||||
|
||||
# Apply grouping transforms to the env outputs
|
||||
obs = self._group_items(obs)
|
||||
rewards = self._group_items(
|
||||
rewards, agg_fn=lambda gvals: list(gvals.values()))
|
||||
dones = self._group_items(
|
||||
dones, agg_fn=lambda gvals: all(gvals.values()))
|
||||
infos = self._group_items(
|
||||
infos, agg_fn=lambda gvals: {GROUP_INFO: list(gvals.values())})
|
||||
|
||||
# Aggregate rewards, but preserve the original values in infos
|
||||
for agent_id, rew in rewards.items():
|
||||
if isinstance(rew, list):
|
||||
rewards[agent_id] = sum(rew)
|
||||
if agent_id not in infos:
|
||||
infos[agent_id] = {}
|
||||
infos[agent_id][GROUP_REWARDS] = rew
|
||||
|
||||
return obs, rewards, dones, infos
|
||||
|
||||
def _ungroup_items(self, items):
|
||||
out = {}
|
||||
for agent_id, value in items.items():
|
||||
if agent_id in self.groups:
|
||||
assert len(value) == len(self.groups[agent_id]), \
|
||||
(agent_id, value, self.groups)
|
||||
for a, v in zip(self.groups[agent_id], value):
|
||||
out[a] = v
|
||||
else:
|
||||
out[agent_id] = value
|
||||
return out
|
||||
|
||||
def _group_items(self, items, agg_fn=lambda gvals: list(gvals.values())):
|
||||
grouped_items = {}
|
||||
for agent_id, item in items.items():
|
||||
if agent_id in self.agent_id_to_group:
|
||||
group_id = self.agent_id_to_group[agent_id]
|
||||
if group_id in grouped_items:
|
||||
continue # already added
|
||||
group_out = OrderedDict()
|
||||
for a in self.groups[group_id]:
|
||||
if a in items:
|
||||
group_out[a] = items[a]
|
||||
else:
|
||||
raise ValueError(
|
||||
"Missing member of group {}: {}: {}".format(
|
||||
group_id, a, items))
|
||||
grouped_items[group_id] = agg_fn(group_out)
|
||||
else:
|
||||
grouped_items[agent_id] = item
|
||||
return grouped_items
|
||||
+50
-4
@@ -30,9 +30,14 @@ class MultiAgentEnv(object):
|
||||
}
|
||||
>>> print(dones)
|
||||
{
|
||||
"car_0": False,
|
||||
"car_1": True,
|
||||
"__all__": False,
|
||||
"car_0": False, # car_0 is still running
|
||||
"car_1": True, # car_1 is done
|
||||
"__all__": False, # the env is not done
|
||||
}
|
||||
>>> print(infos)
|
||||
{
|
||||
"car_0": {}, # info for car_0
|
||||
"car_1": {}, # info for car_1
|
||||
}
|
||||
"""
|
||||
|
||||
@@ -57,6 +62,47 @@ class MultiAgentEnv(object):
|
||||
episode is just started, the value will be None.
|
||||
dones (dict): Done values for each ready agent. The special key
|
||||
"__all__" (required) is used to indicate env termination.
|
||||
infos (dict): Info values for each ready agent.
|
||||
infos (dict): Optional info values for each agent id.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
# yapf: disable
|
||||
# __grouping_doc_begin__
|
||||
def with_agent_groups(self, groups, obs_space=None, act_space=None):
|
||||
"""Convenience method for grouping together agents in this env.
|
||||
|
||||
An agent group is a list of agent ids that are mapped to a single
|
||||
logical agent. All agents of the group must act at the same time in the
|
||||
environment. The grouped agent exposes Tuple action and observation
|
||||
spaces that are the concatenated action and obs spaces of the
|
||||
individual agents.
|
||||
|
||||
The rewards of all the agents in a group are summed. The individual
|
||||
agent rewards are available under the "individual_rewards" key of the
|
||||
group info return.
|
||||
|
||||
Agent grouping is required to leverage algorithms such as Q-Mix.
|
||||
|
||||
This API is experimental.
|
||||
|
||||
Arguments:
|
||||
groups (dict): Mapping from group id to a list of the agent ids
|
||||
of group members. If an agent id is not present in any group
|
||||
value, it will be left ungrouped.
|
||||
obs_space (Space): Optional observation space for the grouped
|
||||
env. Must be a tuple space.
|
||||
act_space (Space): Optional action space for the grouped env.
|
||||
Must be a tuple space.
|
||||
|
||||
Examples:
|
||||
>>> env = YourMultiAgentEnv(...)
|
||||
>>> grouped_env = env.with_agent_groups(env, {
|
||||
... "group1": ["agent1", "agent2", "agent3"],
|
||||
... "group2": ["agent4", "agent5"],
|
||||
... })
|
||||
"""
|
||||
|
||||
from ray.rllib.env.group_agents_wrapper import _GroupAgentsWrapper
|
||||
return _GroupAgentsWrapper(self, groups, obs_space, act_space)
|
||||
# __grouping_doc_end__
|
||||
# yapf: enable
|
||||
|
||||
@@ -42,7 +42,9 @@ class PolicyGraph(object):
|
||||
state_batches,
|
||||
prev_action_batch=None,
|
||||
prev_reward_batch=None,
|
||||
episodes=None):
|
||||
info_batch=None,
|
||||
episodes=None,
|
||||
**kwargs):
|
||||
"""Compute actions for the current policy.
|
||||
|
||||
Arguments:
|
||||
@@ -50,9 +52,11 @@ class PolicyGraph(object):
|
||||
state_batches (list): list of RNN state input batches, if any
|
||||
prev_action_batch (np.ndarray): batch of previous action values
|
||||
prev_reward_batch (np.ndarray): batch of previous rewards
|
||||
info_batch (info): batch of info objects
|
||||
episodes (list): MultiAgentEpisode for each obs in obs_batch.
|
||||
This provides access to all of the internal episode state,
|
||||
which may be useful for model-based or multiagent algorithms.
|
||||
kwargs: forward compatibility placeholder
|
||||
|
||||
Returns:
|
||||
actions (np.ndarray): batch of output actions, with shape like
|
||||
@@ -69,7 +73,9 @@ class PolicyGraph(object):
|
||||
state,
|
||||
prev_action_batch=None,
|
||||
prev_reward_batch=None,
|
||||
episode=None):
|
||||
info_batch=None,
|
||||
episode=None,
|
||||
**kwargs):
|
||||
"""Unbatched version of compute_actions.
|
||||
|
||||
Arguments:
|
||||
@@ -77,9 +83,11 @@ class PolicyGraph(object):
|
||||
state_batches (list): list of RNN state inputs, if any
|
||||
prev_action_batch (np.ndarray): batch of previous action values
|
||||
prev_reward_batch (np.ndarray): batch of previous rewards
|
||||
info_batch (list): batch of info objects
|
||||
episode (MultiAgentEpisode): this provides access to all of the
|
||||
internal episode state, which may be useful for model-based or
|
||||
multi-agent algorithms.
|
||||
kwargs: forward compatibility placeholder
|
||||
|
||||
Returns:
|
||||
actions (obj): single action
|
||||
|
||||
@@ -24,20 +24,13 @@ RolloutMetrics = namedtuple(
|
||||
"RolloutMetrics",
|
||||
["episode_length", "episode_reward", "agent_rewards", "custom_metrics"])
|
||||
|
||||
PolicyEvalData = namedtuple(
|
||||
"PolicyEvalData",
|
||||
["env_id", "agent_id", "obs", "rnn_state", "prev_action", "prev_reward"])
|
||||
PolicyEvalData = namedtuple("PolicyEvalData", [
|
||||
"env_id", "agent_id", "obs", "info", "rnn_state", "prev_action",
|
||||
"prev_reward"
|
||||
])
|
||||
|
||||
|
||||
class SyncSampler(object):
|
||||
"""This class interacts with the environment and tells it what to do.
|
||||
|
||||
Note that batch_size is only a unit of measure here. Batches can
|
||||
accumulate and the gradient can be calculated on up to 5 batches.
|
||||
|
||||
This class provides data on invocation, rather than on a separate
|
||||
thread."""
|
||||
|
||||
def __init__(self,
|
||||
env,
|
||||
policies,
|
||||
@@ -94,11 +87,6 @@ class SyncSampler(object):
|
||||
|
||||
|
||||
class AsyncSampler(threading.Thread):
|
||||
"""This class interacts with the environment and tells it what to do.
|
||||
|
||||
Note that batch_size is only a unit of measure here. Batches can
|
||||
accumulate and the gradient can be calculated on up to 5 batches."""
|
||||
|
||||
def __init__(self,
|
||||
env,
|
||||
policies,
|
||||
@@ -361,17 +349,18 @@ def _process_observations(async_vector_env, policies, batch_builder_pool,
|
||||
if not agent_done:
|
||||
to_eval[policy_id].append(
|
||||
PolicyEvalData(env_id, agent_id, filtered_obs,
|
||||
infos[env_id].get(agent_id, {}),
|
||||
episode.rnn_state_for(agent_id),
|
||||
episode.last_action_for(agent_id),
|
||||
rewards[env_id][agent_id] or 0.0))
|
||||
|
||||
last_observation = episode.last_observation_for(agent_id)
|
||||
episode._set_last_observation(agent_id, filtered_obs)
|
||||
episode._set_last_info(agent_id, infos[env_id][agent_id])
|
||||
episode._set_last_info(agent_id, infos[env_id].get(agent_id, {}))
|
||||
|
||||
# Record transition info if applicable
|
||||
if last_observation is not None and \
|
||||
infos[env_id][agent_id].get("training_enabled", True):
|
||||
if (last_observation is not None and infos[env_id].get(
|
||||
agent_id, {}).get("training_enabled", True)):
|
||||
episode.batch_builder.add_values(
|
||||
agent_id,
|
||||
policy_id,
|
||||
@@ -384,7 +373,7 @@ def _process_observations(async_vector_env, policies, batch_builder_pool,
|
||||
prev_actions=episode.prev_action_for(agent_id),
|
||||
prev_rewards=episode.prev_reward_for(agent_id),
|
||||
dones=agent_done,
|
||||
infos=infos[env_id][agent_id],
|
||||
infos=infos[env_id].get(agent_id, {}),
|
||||
new_obs=filtered_obs,
|
||||
**episode.last_pi_info_for(agent_id))
|
||||
|
||||
@@ -435,6 +424,7 @@ def _process_observations(async_vector_env, policies, batch_builder_pool,
|
||||
to_eval[policy_id].append(
|
||||
PolicyEvalData(
|
||||
env_id, agent_id, filtered_obs,
|
||||
episode.last_info_for(agent_id) or {},
|
||||
episode.rnn_state_for(agent_id),
|
||||
np.zeros_like(
|
||||
_flatten_action(policy.action_space.sample())),
|
||||
@@ -462,6 +452,7 @@ def _do_policy_eval(tf_sess, to_eval, policies, active_episodes):
|
||||
policy = _get_or_raise(policies, policy_id)
|
||||
if builder and (policy.compute_actions.__code__ is
|
||||
TFPolicyGraph.compute_actions.__code__):
|
||||
# TODO(ekl): how can we make info batch available to TF code?
|
||||
pending_fetches[policy_id] = policy._build_compute_actions(
|
||||
builder, [t.obs for t in eval_data],
|
||||
rnn_in_cols,
|
||||
@@ -473,6 +464,7 @@ def _do_policy_eval(tf_sess, to_eval, policies, active_episodes):
|
||||
rnn_in_cols,
|
||||
prev_action_batch=[t.prev_action for t in eval_data],
|
||||
prev_reward_batch=[t.prev_reward for t in eval_data],
|
||||
info_batch=[t.info for t in eval_data],
|
||||
episodes=[active_episodes[t.env_id] for t in eval_data])
|
||||
if builder:
|
||||
for k, v in pending_fetches.items():
|
||||
|
||||
@@ -153,7 +153,9 @@ class TFPolicyGraph(PolicyGraph):
|
||||
state_batches=None,
|
||||
prev_action_batch=None,
|
||||
prev_reward_batch=None,
|
||||
episodes=None):
|
||||
info_batch=None,
|
||||
episodes=None,
|
||||
**kwargs):
|
||||
builder = TFRunBuilder(self._sess, "compute_actions")
|
||||
fetches = self._build_compute_actions(builder, obs_batch,
|
||||
state_batches, prev_action_batch,
|
||||
|
||||
@@ -63,7 +63,9 @@ class TorchPolicyGraph(PolicyGraph):
|
||||
state_batches=None,
|
||||
prev_action_batch=None,
|
||||
prev_reward_batch=None,
|
||||
episodes=None):
|
||||
info_batch=None,
|
||||
episodes=None,
|
||||
**kwargs):
|
||||
if state_batches:
|
||||
raise NotImplementedError("Torch RNN support")
|
||||
with self.lock:
|
||||
|
||||
@@ -0,0 +1,117 @@
|
||||
"""The two-step game from QMIX: https://arxiv.org/pdf/1803.11485.pdf"""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import argparse
|
||||
from gym.spaces import Tuple, Discrete
|
||||
|
||||
import ray
|
||||
from ray.tune import register_env, run_experiments, grid_search
|
||||
from ray.rllib.env.multi_agent_env import MultiAgentEnv
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--stop", type=int, default=50000)
|
||||
parser.add_argument("--run", type=str, default="QMIX")
|
||||
|
||||
|
||||
class TwoStepGame(MultiAgentEnv):
|
||||
action_space = Discrete(2)
|
||||
|
||||
# Each agent gets a separate [3] obs space, to ensure that they can
|
||||
# learn meaningfully different Q values even with a shared Q model.
|
||||
observation_space = Discrete(6)
|
||||
|
||||
def __init__(self, env_config):
|
||||
self.state = None
|
||||
|
||||
def reset(self):
|
||||
self.state = 0
|
||||
return {"agent_1": self.state, "agent_2": self.state + 3}
|
||||
|
||||
def step(self, action_dict):
|
||||
if self.state == 0:
|
||||
action = action_dict["agent_1"]
|
||||
assert action in [0, 1], action
|
||||
if action == 0:
|
||||
self.state = 1
|
||||
else:
|
||||
self.state = 2
|
||||
global_rew = 0
|
||||
done = False
|
||||
elif self.state == 1:
|
||||
global_rew = 7
|
||||
done = True
|
||||
else:
|
||||
if action_dict["agent_1"] == 0 and action_dict["agent_2"] == 0:
|
||||
global_rew = 0
|
||||
elif action_dict["agent_1"] == 1 and action_dict["agent_2"] == 1:
|
||||
global_rew = 8
|
||||
else:
|
||||
global_rew = 1
|
||||
done = True
|
||||
|
||||
rewards = {"agent_1": global_rew / 2.0, "agent_2": global_rew / 2.0}
|
||||
obs = {"agent_1": self.state, "agent_2": self.state + 3}
|
||||
dones = {"__all__": done}
|
||||
infos = {}
|
||||
return obs, rewards, dones, infos
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = parser.parse_args()
|
||||
|
||||
grouping = {
|
||||
"group_1": ["agent_1", "agent_2"],
|
||||
}
|
||||
obs_space = Tuple([
|
||||
TwoStepGame.observation_space,
|
||||
TwoStepGame.observation_space,
|
||||
])
|
||||
act_space = Tuple([
|
||||
TwoStepGame.action_space,
|
||||
TwoStepGame.action_space,
|
||||
])
|
||||
register_env(
|
||||
"grouped_twostep",
|
||||
lambda config: TwoStepGame(config).with_agent_groups(
|
||||
grouping, obs_space=obs_space, act_space=act_space))
|
||||
|
||||
if args.run == "QMIX":
|
||||
config = {
|
||||
"sample_batch_size": 4,
|
||||
"train_batch_size": 32,
|
||||
"exploration_final_eps": 0.0,
|
||||
"num_workers": 0,
|
||||
"mixer": grid_search([None, "qmix", "vdn"]),
|
||||
}
|
||||
elif args.run == "APEX_QMIX":
|
||||
config = {
|
||||
"num_gpus": 0,
|
||||
"num_workers": 2,
|
||||
"optimizer": {
|
||||
"num_replay_buffer_shards": 1,
|
||||
},
|
||||
"min_iter_time_s": 3,
|
||||
"buffer_size": 1000,
|
||||
"learning_starts": 1000,
|
||||
"train_batch_size": 128,
|
||||
"sample_batch_size": 32,
|
||||
"target_network_update_freq": 500,
|
||||
"timesteps_per_iteration": 1000,
|
||||
}
|
||||
else:
|
||||
config = {}
|
||||
|
||||
ray.init()
|
||||
run_experiments({
|
||||
"two_step": {
|
||||
"run": args.run,
|
||||
"env": "grouped_twostep",
|
||||
"stop": {
|
||||
"timesteps_total": args.stop,
|
||||
},
|
||||
"config": config,
|
||||
},
|
||||
})
|
||||
@@ -123,7 +123,8 @@ def chop_into_sequences(episode_ids,
|
||||
feature_columns,
|
||||
state_columns,
|
||||
max_seq_len,
|
||||
dynamic_max=True):
|
||||
dynamic_max=True,
|
||||
_extra_padding=0):
|
||||
"""Truncate and pad experiences into fixed-length sequences.
|
||||
|
||||
Arguments:
|
||||
@@ -136,6 +137,7 @@ def chop_into_sequences(episode_ids,
|
||||
dynamic_max (bool): Whether to dynamically shrink the max seq len.
|
||||
For example, if max len is 20 and the actual max seq len in the
|
||||
data is 7, it will be shrunk to 7.
|
||||
_extra_padding (int): Add extra padding to the end of sequences.
|
||||
|
||||
Returns:
|
||||
f_pad (list): Padded feature columns. These will be of shape
|
||||
@@ -177,7 +179,7 @@ def chop_into_sequences(episode_ids,
|
||||
|
||||
# Dynamically shrink max len as needed to optimize memory usage
|
||||
if dynamic_max:
|
||||
max_seq_len = max(seq_lens)
|
||||
max_seq_len = max(seq_lens) + _extra_padding
|
||||
|
||||
feature_sequences = []
|
||||
for f in feature_columns:
|
||||
|
||||
@@ -168,7 +168,15 @@ def _restore_original_dimensions(input_dict, obs_space):
|
||||
return input_dict
|
||||
|
||||
|
||||
def _unpack_obs(obs, space):
|
||||
def _unpack_obs(obs, space, tensorlib=tf):
|
||||
"""Unpack a flattened Dict or Tuple observation array/tensor.
|
||||
|
||||
Arguments:
|
||||
obs: The flattened observation tensor
|
||||
space: The original space prior to flattening
|
||||
tensorlib: The library used to unflatten (reshape) the array/tensor
|
||||
"""
|
||||
|
||||
if (isinstance(space, gym.spaces.Dict)
|
||||
or isinstance(space, gym.spaces.Tuple)):
|
||||
prep = get_preprocessor(space)(space)
|
||||
@@ -186,14 +194,18 @@ def _unpack_obs(obs, space):
|
||||
offset += p.size
|
||||
u.append(
|
||||
_unpack_obs(
|
||||
tf.reshape(obs_slice, [-1] + list(p.shape)), v))
|
||||
tensorlib.reshape(obs_slice, [-1] + list(p.shape)),
|
||||
v,
|
||||
tensorlib=tensorlib))
|
||||
else:
|
||||
u = OrderedDict()
|
||||
for p, (k, v) in zip(prep.preprocessors, space.spaces.items()):
|
||||
obs_slice = obs[:, offset:offset + p.size]
|
||||
offset += p.size
|
||||
u[k] = _unpack_obs(
|
||||
tf.reshape(obs_slice, [-1] + list(p.shape)), v)
|
||||
tensorlib.reshape(obs_slice, [-1] + list(p.shape)),
|
||||
v,
|
||||
tensorlib=tensorlib)
|
||||
return u
|
||||
else:
|
||||
return obs
|
||||
|
||||
@@ -5,10 +5,17 @@ from ray.rllib.optimizers.async_gradients_optimizer import \
|
||||
AsyncGradientsOptimizer
|
||||
from ray.rllib.optimizers.sync_samples_optimizer import SyncSamplesOptimizer
|
||||
from ray.rllib.optimizers.sync_replay_optimizer import SyncReplayOptimizer
|
||||
from ray.rllib.optimizers.sync_batch_replay_optimizer import \
|
||||
SyncBatchReplayOptimizer
|
||||
from ray.rllib.optimizers.multi_gpu_optimizer import LocalMultiGPUOptimizer
|
||||
|
||||
__all__ = [
|
||||
"PolicyOptimizer", "AsyncReplayOptimizer", "AsyncSamplesOptimizer",
|
||||
"AsyncGradientsOptimizer", "SyncSamplesOptimizer", "SyncReplayOptimizer",
|
||||
"LocalMultiGPUOptimizer"
|
||||
"PolicyOptimizer",
|
||||
"AsyncReplayOptimizer",
|
||||
"AsyncSamplesOptimizer",
|
||||
"AsyncGradientsOptimizer",
|
||||
"SyncSamplesOptimizer",
|
||||
"SyncReplayOptimizer",
|
||||
"LocalMultiGPUOptimizer",
|
||||
"SyncBatchReplayOptimizer",
|
||||
]
|
||||
|
||||
@@ -36,6 +36,11 @@ class AsyncReplayOptimizer(PolicyOptimizer):
|
||||
This class coordinates the data transfers between the learner thread,
|
||||
remote evaluators (Ape-X actors), and replay buffer actors.
|
||||
|
||||
This has two modes of operation:
|
||||
- normal replay: replays independent samples.
|
||||
- batch replay: simplified mode where entire sample batches are
|
||||
replayed. This supports RNNs, but not prioritization.
|
||||
|
||||
This optimizer requires that policy evaluators return an additional
|
||||
"td_error" array in the info return of compute_gradients(). This error
|
||||
term will be used for sample prioritization."""
|
||||
@@ -52,9 +57,11 @@ class AsyncReplayOptimizer(PolicyOptimizer):
|
||||
sample_batch_size=50,
|
||||
num_replay_buffer_shards=1,
|
||||
max_weight_sync_delay=400,
|
||||
debug=False):
|
||||
debug=False,
|
||||
batch_replay=False):
|
||||
|
||||
self.debug = debug
|
||||
self.batch_replay = batch_replay
|
||||
self.replay_starts = learning_starts
|
||||
self.prioritized_replay_beta = prioritized_replay_beta
|
||||
self.prioritized_replay_eps = prioritized_replay_eps
|
||||
@@ -63,7 +70,11 @@ class AsyncReplayOptimizer(PolicyOptimizer):
|
||||
self.learner = LearnerThread(self.local_evaluator)
|
||||
self.learner.start()
|
||||
|
||||
self.replay_actors = create_colocated(ReplayActor, [
|
||||
if self.batch_replay:
|
||||
replay_cls = BatchReplayActor
|
||||
else:
|
||||
replay_cls = ReplayActor
|
||||
self.replay_actors = create_colocated(replay_cls, [
|
||||
num_replay_buffer_shards,
|
||||
learning_starts,
|
||||
buffer_size,
|
||||
@@ -101,6 +112,7 @@ class AsyncReplayOptimizer(PolicyOptimizer):
|
||||
|
||||
@override(PolicyOptimizer)
|
||||
def step(self):
|
||||
assert self.learner.is_alive()
|
||||
assert len(self.remote_evaluators) > 0
|
||||
start = time.time()
|
||||
sample_timesteps, train_timesteps = self._step()
|
||||
@@ -299,6 +311,54 @@ class ReplayActor(object):
|
||||
return stat
|
||||
|
||||
|
||||
@ray.remote(num_cpus=0)
|
||||
class BatchReplayActor(object):
|
||||
"""The batch replay version of the replay actor.
|
||||
|
||||
This allows for RNN models, but ignores prioritization params.
|
||||
"""
|
||||
|
||||
def __init__(self, num_shards, learning_starts, buffer_size,
|
||||
train_batch_size, prioritized_replay_alpha,
|
||||
prioritized_replay_beta, prioritized_replay_eps):
|
||||
self.replay_starts = learning_starts // num_shards
|
||||
self.buffer_size = buffer_size // num_shards
|
||||
self.train_batch_size = train_batch_size
|
||||
self.buffer = []
|
||||
|
||||
# Metrics
|
||||
self.num_added = 0
|
||||
self.cur_size = 0
|
||||
|
||||
def get_host(self):
|
||||
return os.uname()[1]
|
||||
|
||||
def add_batch(self, batch):
|
||||
# Handle everything as if multiagent
|
||||
if isinstance(batch, SampleBatch):
|
||||
batch = MultiAgentBatch({DEFAULT_POLICY_ID: batch}, batch.count)
|
||||
self.buffer.append(batch)
|
||||
self.cur_size += batch.count
|
||||
self.num_added += batch.count
|
||||
while self.cur_size > self.buffer_size:
|
||||
self.cur_size -= self.buffer.pop(0).count
|
||||
|
||||
def replay(self):
|
||||
if self.num_added < self.replay_starts:
|
||||
return None
|
||||
return random.choice(self.buffer)
|
||||
|
||||
def update_priorities(self, prio_dict):
|
||||
pass
|
||||
|
||||
def stats(self, debug=False):
|
||||
stat = {
|
||||
"cur_size": self.cur_size,
|
||||
"num_added": self.num_added,
|
||||
}
|
||||
return stat
|
||||
|
||||
|
||||
class LearnerThread(threading.Thread):
|
||||
"""Background thread that updates the local model from replay data.
|
||||
|
||||
@@ -334,8 +394,8 @@ class LearnerThread(threading.Thread):
|
||||
grad_out = self.local_evaluator.compute_apply(replay)
|
||||
for pid, info in grad_out.items():
|
||||
prio_dict[pid] = (
|
||||
replay.policy_batches[pid]["batch_indexes"],
|
||||
info["td_error"])
|
||||
replay.policy_batches[pid].data.get("batch_indexes"),
|
||||
info.get("td_error"))
|
||||
if "stats" in info:
|
||||
self.stats[pid] = info["stats"]
|
||||
self.outqueue.put((ra, prio_dict, replay.count))
|
||||
|
||||
@@ -0,0 +1,101 @@
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import random
|
||||
|
||||
import ray
|
||||
from ray.rllib.optimizers.policy_optimizer import PolicyOptimizer
|
||||
from ray.rllib.evaluation.sample_batch import SampleBatch, DEFAULT_POLICY_ID, \
|
||||
MultiAgentBatch
|
||||
from ray.rllib.utils.annotations import override
|
||||
from ray.rllib.utils.timer import TimerStat
|
||||
|
||||
|
||||
class SyncBatchReplayOptimizer(PolicyOptimizer):
|
||||
"""Variant of the sync replay optimizer that replays entire batches.
|
||||
|
||||
This enables RNN support. Does not currently support prioritization."""
|
||||
|
||||
@override(PolicyOptimizer)
|
||||
def _init(self,
|
||||
learning_starts=1000,
|
||||
buffer_size=10000,
|
||||
train_batch_size=32):
|
||||
self.replay_starts = learning_starts
|
||||
self.max_buffer_size = buffer_size
|
||||
self.train_batch_size = train_batch_size
|
||||
assert self.max_buffer_size >= self.replay_starts
|
||||
|
||||
# List of buffered sample batches
|
||||
self.replay_buffer = []
|
||||
self.buffer_size = 0
|
||||
|
||||
# Stats
|
||||
self.update_weights_timer = TimerStat()
|
||||
self.sample_timer = TimerStat()
|
||||
self.grad_timer = TimerStat()
|
||||
self.learner_stats = {}
|
||||
|
||||
@override(PolicyOptimizer)
|
||||
def step(self):
|
||||
with self.update_weights_timer:
|
||||
if self.remote_evaluators:
|
||||
weights = ray.put(self.local_evaluator.get_weights())
|
||||
for e in self.remote_evaluators:
|
||||
e.set_weights.remote(weights)
|
||||
|
||||
with self.sample_timer:
|
||||
if self.remote_evaluators:
|
||||
batches = ray.get(
|
||||
[e.sample.remote() for e in self.remote_evaluators])
|
||||
else:
|
||||
batches = [self.local_evaluator.sample()]
|
||||
|
||||
# Handle everything as if multiagent
|
||||
tmp = []
|
||||
for batch in batches:
|
||||
if isinstance(batch, SampleBatch):
|
||||
batch = MultiAgentBatch({
|
||||
DEFAULT_POLICY_ID: batch
|
||||
}, batch.count)
|
||||
tmp.append(batch)
|
||||
batches = tmp
|
||||
|
||||
for batch in batches:
|
||||
self.replay_buffer.append(batch)
|
||||
self.num_steps_sampled += batch.count
|
||||
self.buffer_size += batch.count
|
||||
while self.buffer_size > self.max_buffer_size:
|
||||
evicted = self.replay_buffer.pop(0)
|
||||
self.buffer_size -= evicted.count
|
||||
|
||||
if self.num_steps_sampled >= self.replay_starts:
|
||||
self._optimize()
|
||||
|
||||
@override(PolicyOptimizer)
|
||||
def stats(self):
|
||||
return dict(
|
||||
PolicyOptimizer.stats(self), **{
|
||||
"sample_time_ms": round(1000 * self.sample_timer.mean, 3),
|
||||
"grad_time_ms": round(1000 * self.grad_timer.mean, 3),
|
||||
"update_time_ms": round(1000 * self.update_weights_timer.mean,
|
||||
3),
|
||||
"opt_peak_throughput": round(self.grad_timer.mean_throughput,
|
||||
3),
|
||||
"opt_samples": round(self.grad_timer.mean_units_processed, 3),
|
||||
"learner": self.learner_stats,
|
||||
})
|
||||
|
||||
def _optimize(self):
|
||||
samples = [random.choice(self.replay_buffer)]
|
||||
while sum(s.count for s in samples) < self.train_batch_size:
|
||||
samples.append(random.choice(self.replay_buffer))
|
||||
samples = SampleBatch.concat_samples(samples)
|
||||
with self.grad_timer:
|
||||
info_dict = self.local_evaluator.compute_apply(samples)
|
||||
for policy_id, info in info_dict.items():
|
||||
if "stats" in info:
|
||||
self.learner_stats[policy_id] = info["stats"]
|
||||
self.grad_timer.push_units_processed(samples.count)
|
||||
self.num_steps_trained += samples.count
|
||||
@@ -13,7 +13,6 @@ from ray.rllib.evaluation.sample_batch import SampleBatch, DEFAULT_POLICY_ID, \
|
||||
MultiAgentBatch
|
||||
from ray.rllib.utils.annotations import override
|
||||
from ray.rllib.utils.compression import pack_if_needed
|
||||
from ray.rllib.utils.filter import RunningStat
|
||||
from ray.rllib.utils.timer import TimerStat
|
||||
from ray.rllib.utils.schedules import LinearSchedule
|
||||
|
||||
@@ -54,7 +53,6 @@ class SyncReplayOptimizer(PolicyOptimizer):
|
||||
self.sample_timer = TimerStat()
|
||||
self.replay_timer = TimerStat()
|
||||
self.grad_timer = TimerStat()
|
||||
self.throughput = RunningStat()
|
||||
self.learner_stats = {}
|
||||
|
||||
# Set up replay buffer
|
||||
@@ -159,13 +157,13 @@ class SyncReplayOptimizer(PolicyOptimizer):
|
||||
dones) = replay_buffer.sample(self.train_batch_size)
|
||||
weights = np.ones_like(rewards)
|
||||
batch_indexes = -np.ones_like(rewards)
|
||||
samples[policy_id] = SampleBatch({
|
||||
"obs": obses_t,
|
||||
"actions": actions,
|
||||
"rewards": rewards,
|
||||
"new_obs": obses_tp1,
|
||||
"dones": dones,
|
||||
"weights": weights,
|
||||
"batch_indexes": batch_indexes
|
||||
})
|
||||
samples[policy_id] = SampleBatch({
|
||||
"obs": obses_t,
|
||||
"actions": actions,
|
||||
"rewards": rewards,
|
||||
"new_obs": obses_tp1,
|
||||
"dones": dones,
|
||||
"weights": weights,
|
||||
"batch_indexes": batch_indexes
|
||||
})
|
||||
return MultiAgentBatch(samples, self.train_batch_size)
|
||||
|
||||
@@ -0,0 +1,68 @@
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from gym.spaces import Tuple, Discrete, Dict, Box
|
||||
|
||||
import ray
|
||||
from ray.tune import register_env
|
||||
from ray.rllib.env.multi_agent_env import MultiAgentEnv
|
||||
from ray.rllib.agents.qmix import QMixAgent
|
||||
|
||||
|
||||
class AvailActionsTestEnv(MultiAgentEnv):
|
||||
action_space = Discrete(10)
|
||||
observation_space = Dict({
|
||||
"obs": Discrete(3),
|
||||
"action_mask": Box(0, 1, (10, )),
|
||||
})
|
||||
|
||||
def __init__(self, env_config):
|
||||
self.state = None
|
||||
self.avail = env_config["avail_action"]
|
||||
self.action_mask = [0] * 10
|
||||
self.action_mask[env_config["avail_action"]] = 1
|
||||
|
||||
def reset(self):
|
||||
self.state = 0
|
||||
return {
|
||||
"agent_1": {
|
||||
"obs": self.state,
|
||||
"action_mask": self.action_mask
|
||||
}
|
||||
}
|
||||
|
||||
def step(self, action_dict):
|
||||
if self.state > 0:
|
||||
assert action_dict["agent_1"] == self.avail, \
|
||||
"Failed to obey available actions mask!"
|
||||
self.state += 1
|
||||
rewards = {"agent_1": 1}
|
||||
obs = {"agent_1": {"obs": 0, "action_mask": self.action_mask}}
|
||||
dones = {"__all__": self.state > 20}
|
||||
return obs, rewards, dones, {}
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
grouping = {
|
||||
"group_1": ["agent_1"], # trivial grouping for testing
|
||||
}
|
||||
obs_space = Tuple([AvailActionsTestEnv.observation_space])
|
||||
act_space = Tuple([AvailActionsTestEnv.action_space])
|
||||
register_env(
|
||||
"action_mask_test",
|
||||
lambda config: AvailActionsTestEnv(config).with_agent_groups(
|
||||
grouping, obs_space=obs_space, act_space=act_space))
|
||||
|
||||
ray.init()
|
||||
agent = QMixAgent(
|
||||
env="action_mask_test",
|
||||
config={
|
||||
"num_envs_per_worker": 5, # test with vectorization on
|
||||
"env_config": {
|
||||
"avail_action": 3,
|
||||
},
|
||||
})
|
||||
for _ in range(5):
|
||||
agent.train() # OK if it doesn't trip the action assertion error
|
||||
assert agent.train()["episode_reward_mean"] == 21.0
|
||||
@@ -339,7 +339,8 @@ class TestMultiAgentEnv(unittest.TestCase):
|
||||
state_batches,
|
||||
prev_action_batch=None,
|
||||
prev_reward_batch=None,
|
||||
episodes=None):
|
||||
episodes=None,
|
||||
**kwargs):
|
||||
return [0] * len(obs_batch), [[h] * len(obs_batch)], {}
|
||||
|
||||
def get_initial_state(self):
|
||||
@@ -363,7 +364,8 @@ class TestMultiAgentEnv(unittest.TestCase):
|
||||
state_batches,
|
||||
prev_action_batch=None,
|
||||
prev_reward_batch=None,
|
||||
episodes=None):
|
||||
episodes=None,
|
||||
**kwargs):
|
||||
# Pretend we did a model-based rollout and want to return
|
||||
# the extra trajectory.
|
||||
builder = episodes[0].new_batch_builder()
|
||||
|
||||
@@ -25,7 +25,8 @@ class MockPolicyGraph(PolicyGraph):
|
||||
state_batches,
|
||||
prev_action_batch=None,
|
||||
prev_reward_batch=None,
|
||||
episodes=None):
|
||||
episodes=None,
|
||||
**kwargs):
|
||||
return [0] * len(obs_batch), [], {}
|
||||
|
||||
def postprocess_trajectory(self,
|
||||
@@ -42,7 +43,8 @@ class BadPolicyGraph(PolicyGraph):
|
||||
state_batches,
|
||||
prev_action_batch=None,
|
||||
prev_reward_batch=None,
|
||||
episodes=None):
|
||||
episodes=None,
|
||||
**kwargs):
|
||||
raise Exception("intentional error")
|
||||
|
||||
def postprocess_trajectory(self,
|
||||
|
||||
Reference in New Issue
Block a user