mirror of
https://github.com/wassname/ray.git
synced 2026-07-04 22:46:48 +08:00
Qmix on gpu and with non-stacked-obs environment state support (#5751)
This commit is contained in:
committed by
Eric Liang
parent
42dd0fae96
commit
4aa06918ae
@@ -49,7 +49,7 @@ DEFAULT_CONFIG = with_common_config({
|
||||
"buffer_size": 10000,
|
||||
|
||||
# === Optimization ===
|
||||
# Learning rate for adam optimizer
|
||||
# Learning rate for RMSProp optimizer
|
||||
"lr": 0.0005,
|
||||
# RMSProp alpha
|
||||
"optim_alpha": 0.99,
|
||||
|
||||
@@ -3,6 +3,7 @@ from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from gym.spaces import Tuple, Discrete, Dict
|
||||
import os
|
||||
import logging
|
||||
import numpy as np
|
||||
import torch as th
|
||||
@@ -14,7 +15,7 @@ import ray
|
||||
from ray.rllib.agents.qmix.mixers import VDNMixer, QMixer
|
||||
from ray.rllib.agents.qmix.model import RNNModel, _get_size
|
||||
from ray.rllib.evaluation.metrics import LEARNER_STATS_KEY
|
||||
from ray.rllib.policy.policy import Policy, TupleActions
|
||||
from ray.rllib.policy.policy import TupleActions, Policy
|
||||
from ray.rllib.policy.rnn_sequencing import chop_into_sequences
|
||||
from ray.rllib.policy.sample_batch import SampleBatch
|
||||
from ray.rllib.models.catalog import ModelCatalog
|
||||
@@ -24,6 +25,9 @@ from ray.rllib.utils.annotations import override
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# if the obs space is Dict type, look for the global state under this key
|
||||
ENV_STATE = "state"
|
||||
|
||||
|
||||
class QMixLoss(nn.Module):
|
||||
def __init__(self,
|
||||
@@ -45,8 +49,17 @@ class QMixLoss(nn.Module):
|
||||
self.double_q = double_q
|
||||
self.gamma = gamma
|
||||
|
||||
def forward(self, rewards, actions, terminated, mask, obs, next_obs,
|
||||
action_mask, next_action_mask):
|
||||
def forward(self,
|
||||
rewards,
|
||||
actions,
|
||||
terminated,
|
||||
mask,
|
||||
obs,
|
||||
next_obs,
|
||||
action_mask,
|
||||
next_action_mask,
|
||||
state=None,
|
||||
next_state=None):
|
||||
"""Forward pass of the loss.
|
||||
|
||||
Arguments:
|
||||
@@ -58,8 +71,20 @@ class QMixLoss(nn.Module):
|
||||
next_obs: Tensor of shape [B, T, n_agents, obs_size]
|
||||
action_mask: Tensor of shape [B, T, n_agents, n_actions]
|
||||
next_action_mask: Tensor of shape [B, T, n_agents, n_actions]
|
||||
state: Tensor of shape [B, T, state_dim] (optional)
|
||||
next_state: Tensor of shape [B, T, state_dim] (optional)
|
||||
"""
|
||||
|
||||
# Assert either none or both of state and next_state are given
|
||||
if state is None and next_state is None:
|
||||
state = obs # default to state being all agents' observations
|
||||
next_state = next_obs
|
||||
elif (state is None) != (next_state is None):
|
||||
raise ValueError("Expected either neither or both of `state` and "
|
||||
"`next_state` to be given. Got: "
|
||||
"\n`state` = {}\n`next_state` = {}".format(
|
||||
state, next_state))
|
||||
|
||||
# Calculate estimated Q-Values
|
||||
mac_out = _unroll_mac(self.model, obs)
|
||||
|
||||
@@ -89,7 +114,7 @@ class QMixLoss(nn.Module):
|
||||
mac_out_tp1[ignore_action_tp1] = -np.inf
|
||||
|
||||
# obtain best actions at t+1 according to policy NN
|
||||
cur_max_actions = mac_out_tp1.max(dim=3, keepdim=True)[1]
|
||||
cur_max_actions = mac_out_tp1.argmax(dim=3, keepdim=True)
|
||||
|
||||
# use the target network to estimate the Q-values of policy
|
||||
# network's selected actions
|
||||
@@ -104,10 +129,8 @@ class QMixLoss(nn.Module):
|
||||
|
||||
# Mix
|
||||
if self.mixer is not None:
|
||||
# TODO(ekl) add support for handling global state? This is just
|
||||
# treating the stacked agent obs as the state.
|
||||
chosen_action_qvals = self.mixer(chosen_action_qvals, obs)
|
||||
target_max_qvals = self.target_mixer(target_max_qvals, next_obs)
|
||||
chosen_action_qvals = self.mixer(chosen_action_qvals, state)
|
||||
target_max_qvals = self.target_mixer(target_max_qvals, next_state)
|
||||
|
||||
# Calculate 1-step Q-Learning targets
|
||||
targets = rewards + self.gamma * (1 - terminated) * target_max_qvals
|
||||
@@ -146,24 +169,36 @@ class QMixTorchPolicy(Policy):
|
||||
self.n_agents = len(obs_space.original_space.spaces)
|
||||
self.n_actions = action_space.spaces[0].n
|
||||
self.h_size = config["model"]["lstm_cell_size"]
|
||||
self.has_env_global_state = False
|
||||
self.has_action_mask = False
|
||||
self.device = (th.device("cuda")
|
||||
if bool(os.environ.get("CUDA_VISIBLE_DEVICES", None))
|
||||
else th.device("cpu"))
|
||||
|
||||
agent_obs_space = obs_space.original_space.spaces[0]
|
||||
if isinstance(agent_obs_space, Dict):
|
||||
space_keys = set(agent_obs_space.spaces.keys())
|
||||
if not {"obs", "action_mask"}.issubset(space_keys):
|
||||
if "obs" not in space_keys:
|
||||
raise ValueError(
|
||||
"Dict obs space for agent must have keyset "
|
||||
"['obs', 'action_mask'], got {}".format(space_keys))
|
||||
mask_shape = tuple(agent_obs_space.spaces["action_mask"].shape)
|
||||
if mask_shape != (self.n_actions, ):
|
||||
raise ValueError("Action mask shape must be {}, got {}".format(
|
||||
(self.n_actions, ), mask_shape))
|
||||
self.has_action_mask = True
|
||||
"Dict obs space must have subspace labeled `obs`")
|
||||
self.obs_size = _get_size(agent_obs_space.spaces["obs"])
|
||||
if "action_mask" in space_keys:
|
||||
mask_shape = tuple(agent_obs_space.spaces["action_mask"].shape)
|
||||
if mask_shape != (self.n_actions, ):
|
||||
raise ValueError(
|
||||
"Action mask shape must be {}, got {}".format(
|
||||
(self.n_actions, ), mask_shape))
|
||||
self.has_action_mask = True
|
||||
if ENV_STATE in space_keys:
|
||||
self.env_global_state_shape = _get_size(
|
||||
agent_obs_space.spaces[ENV_STATE])
|
||||
self.has_env_global_state = True
|
||||
else:
|
||||
self.env_global_state_shape = (self.obs_size, self.n_agents)
|
||||
# The real agent obs space is nested inside the dict
|
||||
config["model"]["full_obs_space"] = agent_obs_space
|
||||
agent_obs_space = agent_obs_space.spaces["obs"]
|
||||
else:
|
||||
self.has_action_mask = False
|
||||
self.obs_size = _get_size(agent_obs_space)
|
||||
|
||||
self.model = ModelCatalog.get_model_v2(
|
||||
@@ -173,7 +208,7 @@ class QMixTorchPolicy(Policy):
|
||||
config["model"],
|
||||
framework="torch",
|
||||
name="model",
|
||||
default_model=RNNModel)
|
||||
default_model=RNNModel).to(self.device)
|
||||
|
||||
self.target_model = ModelCatalog.get_model_v2(
|
||||
agent_obs_space,
|
||||
@@ -182,22 +217,21 @@ class QMixTorchPolicy(Policy):
|
||||
config["model"],
|
||||
framework="torch",
|
||||
name="target_model",
|
||||
default_model=RNNModel)
|
||||
default_model=RNNModel).to(self.device)
|
||||
|
||||
# Setup the mixer network.
|
||||
# The global state is just the stacked agent observations for now.
|
||||
self.state_shape = [self.obs_size, self.n_agents]
|
||||
if config["mixer"] is None:
|
||||
self.mixer = None
|
||||
self.target_mixer = None
|
||||
elif config["mixer"] == "qmix":
|
||||
self.mixer = QMixer(self.n_agents, self.state_shape,
|
||||
config["mixing_embed_dim"])
|
||||
self.target_mixer = QMixer(self.n_agents, self.state_shape,
|
||||
config["mixing_embed_dim"])
|
||||
self.mixer = QMixer(self.n_agents, self.env_global_state_shape,
|
||||
config["mixing_embed_dim"]).to(self.device)
|
||||
self.target_mixer = QMixer(
|
||||
self.n_agents, self.env_global_state_shape,
|
||||
config["mixing_embed_dim"]).to(self.device)
|
||||
elif config["mixer"] == "vdn":
|
||||
self.mixer = VDNMixer()
|
||||
self.target_mixer = VDNMixer()
|
||||
self.mixer = VDNMixer().to(self.device)
|
||||
self.target_mixer = VDNMixer().to(self.device)
|
||||
else:
|
||||
raise ValueError("Unknown mixer type {}".format(config["mixer"]))
|
||||
|
||||
@@ -226,14 +260,21 @@ class QMixTorchPolicy(Policy):
|
||||
info_batch=None,
|
||||
episodes=None,
|
||||
**kwargs):
|
||||
obs_batch, action_mask = self._unpack_observation(obs_batch)
|
||||
obs_batch, action_mask, _ = self._unpack_observation(obs_batch)
|
||||
# We need to ensure we do not use the env global state
|
||||
# to compute actions
|
||||
|
||||
# Compute actions
|
||||
with th.no_grad():
|
||||
q_values, hiddens = _mac(
|
||||
self.model, th.from_numpy(obs_batch),
|
||||
[th.from_numpy(np.array(s)) for s in state_batches])
|
||||
avail = th.from_numpy(action_mask).float()
|
||||
self.model,
|
||||
th.as_tensor(obs_batch, dtype=th.float, device=self.device), [
|
||||
th.as_tensor(
|
||||
np.array(s), dtype=th.float, device=self.device)
|
||||
for s in state_batches
|
||||
])
|
||||
avail = th.as_tensor(
|
||||
action_mask, dtype=th.float, device=self.device)
|
||||
masked_q_values = q_values.clone()
|
||||
masked_q_values[avail == 0.0] = -float("inf")
|
||||
# epsilon-greedy action selector
|
||||
@@ -241,63 +282,81 @@ class QMixTorchPolicy(Policy):
|
||||
pick_random = (random_numbers < self.cur_epsilon).long()
|
||||
random_actions = Categorical(avail).sample().long()
|
||||
actions = (pick_random * random_actions +
|
||||
(1 - pick_random) * masked_q_values.max(dim=2)[1])
|
||||
actions = actions.numpy()
|
||||
hiddens = [s.numpy() for s in hiddens]
|
||||
(1 - pick_random) * masked_q_values.argmax(dim=2))
|
||||
actions = actions.cpu().numpy()
|
||||
hiddens = [s.cpu().numpy() for s in hiddens]
|
||||
|
||||
return TupleActions(list(actions.transpose([1, 0]))), hiddens, {}
|
||||
|
||||
@override(Policy)
|
||||
def learn_on_batch(self, samples):
|
||||
obs_batch, action_mask = self._unpack_observation(
|
||||
obs_batch, action_mask, env_global_state = self._unpack_observation(
|
||||
samples[SampleBatch.CUR_OBS])
|
||||
next_obs_batch, next_action_mask = self._unpack_observation(
|
||||
samples[SampleBatch.NEXT_OBS])
|
||||
(next_obs_batch, next_action_mask,
|
||||
next_env_global_state) = self._unpack_observation(
|
||||
samples[SampleBatch.NEXT_OBS])
|
||||
group_rewards = self._get_group_rewards(samples[SampleBatch.INFOS])
|
||||
|
||||
# These will be padded to shape [B * T, ...]
|
||||
[rew, action_mask, next_action_mask, act, dones, obs, next_obs], \
|
||||
initial_states, seq_lens = \
|
||||
input_list = [
|
||||
group_rewards, action_mask, next_action_mask,
|
||||
samples[SampleBatch.ACTIONS], samples[SampleBatch.DONES],
|
||||
obs_batch, next_obs_batch
|
||||
]
|
||||
if self.has_env_global_state:
|
||||
input_list.extend([env_global_state, next_env_global_state])
|
||||
|
||||
output_list, _, seq_lens = \
|
||||
chop_into_sequences(
|
||||
samples[SampleBatch.EPS_ID],
|
||||
samples[SampleBatch.UNROLL_ID],
|
||||
samples[SampleBatch.AGENT_INDEX], [
|
||||
group_rewards, action_mask, next_action_mask,
|
||||
samples[SampleBatch.ACTIONS], samples[SampleBatch.DONES],
|
||||
obs_batch, next_obs_batch
|
||||
],
|
||||
[samples["state_in_{}".format(k)]
|
||||
for k in range(len(self.get_initial_state()))],
|
||||
samples[SampleBatch.AGENT_INDEX],
|
||||
input_list,
|
||||
[], # RNN states not used here
|
||||
max_seq_len=self.config["model"]["max_seq_len"],
|
||||
dynamic_max=True)
|
||||
# These will be padded to shape [B * T, ...]
|
||||
if self.has_env_global_state:
|
||||
(rew, action_mask, next_action_mask, act, dones, obs, next_obs,
|
||||
env_global_state, next_env_global_state) = output_list
|
||||
else:
|
||||
(rew, action_mask, next_action_mask, act, dones, obs,
|
||||
next_obs) = output_list
|
||||
B, T = len(seq_lens), max(seq_lens)
|
||||
|
||||
def to_batches(arr):
|
||||
def to_batches(arr, dtype):
|
||||
new_shape = [B, T] + list(arr.shape[1:])
|
||||
return th.from_numpy(np.reshape(arr, new_shape))
|
||||
return th.as_tensor(
|
||||
np.reshape(arr, new_shape), dtype=dtype, device=self.device)
|
||||
|
||||
rewards = to_batches(rew).float()
|
||||
actions = to_batches(act).long()
|
||||
obs = to_batches(obs).reshape([B, T, self.n_agents,
|
||||
self.obs_size]).float()
|
||||
action_mask = to_batches(action_mask)
|
||||
next_obs = to_batches(next_obs).reshape(
|
||||
[B, T, self.n_agents, self.obs_size]).float()
|
||||
next_action_mask = to_batches(next_action_mask)
|
||||
rewards = to_batches(rew, th.float)
|
||||
actions = to_batches(act, th.long)
|
||||
obs = to_batches(obs, th.float).reshape(
|
||||
[B, T, self.n_agents, self.obs_size])
|
||||
action_mask = to_batches(action_mask, th.float)
|
||||
next_obs = to_batches(next_obs, th.float).reshape(
|
||||
[B, T, self.n_agents, self.obs_size])
|
||||
next_action_mask = to_batches(next_action_mask, th.float)
|
||||
if self.has_env_global_state:
|
||||
env_global_state = to_batches(env_global_state, th.float)
|
||||
next_env_global_state = to_batches(next_env_global_state, th.float)
|
||||
|
||||
# TODO(ekl) this treats group termination as individual termination
|
||||
terminated = to_batches(dones.astype(np.float32)).unsqueeze(2).expand(
|
||||
terminated = to_batches(dones, th.float).unsqueeze(2).expand(
|
||||
B, T, self.n_agents)
|
||||
|
||||
# Create mask for where index is < unpadded sequence length
|
||||
filled = (np.reshape(np.tile(np.arange(T), B), [B, T]) <
|
||||
np.expand_dims(seq_lens, 1)).astype(np.float32)
|
||||
mask = th.from_numpy(filled).unsqueeze(2).expand(B, T, self.n_agents)
|
||||
filled = np.reshape(
|
||||
np.tile(np.arange(T, dtype=np.float32), B),
|
||||
[B, T]) < np.expand_dims(seq_lens, 1)
|
||||
mask = th.as_tensor(
|
||||
filled, dtype=th.float, device=self.device).unsqueeze(2).expand(
|
||||
B, T, self.n_agents)
|
||||
|
||||
# Compute loss
|
||||
loss_out, mask, masked_td_error, chosen_action_qvals, targets = \
|
||||
self.loss(rewards, actions, terminated, mask, obs,
|
||||
next_obs, action_mask, next_action_mask)
|
||||
loss_out, mask, masked_td_error, chosen_action_qvals, targets = (
|
||||
self.loss(rewards, actions, terminated, mask, obs, next_obs,
|
||||
action_mask, next_action_mask, env_global_state,
|
||||
next_env_global_state))
|
||||
|
||||
# Optimise
|
||||
self.optimiser.zero_grad()
|
||||
@@ -319,40 +378,43 @@ class QMixTorchPolicy(Policy):
|
||||
return {LEARNER_STATS_KEY: stats}
|
||||
|
||||
@override(Policy)
|
||||
def get_initial_state(self):
|
||||
def get_initial_state(self): # initial RNN state
|
||||
return [
|
||||
s.expand([self.n_agents, -1]).numpy()
|
||||
s.expand([self.n_agents, -1]).cpu().numpy()
|
||||
for s in self.model.get_initial_state()
|
||||
]
|
||||
|
||||
@override(Policy)
|
||||
def get_weights(self):
|
||||
return {"model": self.model.state_dict()}
|
||||
|
||||
@override(Policy)
|
||||
def set_weights(self, weights):
|
||||
self.model.load_state_dict(weights["model"])
|
||||
|
||||
@override(Policy)
|
||||
def get_state(self):
|
||||
return {
|
||||
"model": self.model.state_dict(),
|
||||
"target_model": self.target_model.state_dict(),
|
||||
"mixer": self.mixer.state_dict() if self.mixer else None,
|
||||
"target_mixer": self.target_mixer.state_dict()
|
||||
"model": self._cpu_dict(self.model.state_dict()),
|
||||
"target_model": self._cpu_dict(self.target_model.state_dict()),
|
||||
"mixer": self._cpu_dict(self.mixer.state_dict())
|
||||
if self.mixer else None,
|
||||
"target_mixer": self._cpu_dict(self.target_mixer.state_dict())
|
||||
if self.mixer else None,
|
||||
"cur_epsilon": self.cur_epsilon,
|
||||
}
|
||||
|
||||
@override(Policy)
|
||||
def set_weights(self, weights):
|
||||
self.model.load_state_dict(self._device_dict(weights["model"]))
|
||||
self.target_model.load_state_dict(
|
||||
self._device_dict(weights["target_model"]))
|
||||
if weights["mixer"] is not None:
|
||||
self.mixer.load_state_dict(self._device_dict(weights["mixer"]))
|
||||
self.target_mixer.load_state_dict(
|
||||
self._device_dict(weights["target_mixer"]))
|
||||
|
||||
@override(Policy)
|
||||
def get_state(self):
|
||||
state = self.get_weights()
|
||||
state["cur_epsilon"] = self.cur_epsilon
|
||||
return state
|
||||
|
||||
@override(Policy)
|
||||
def set_state(self, state):
|
||||
self.model.load_state_dict(state["model"])
|
||||
self.target_model.load_state_dict(state["target_model"])
|
||||
if state["mixer"] is not None:
|
||||
self.mixer.load_state_dict(state["mixer"])
|
||||
self.target_mixer.load_state_dict(state["target_mixer"])
|
||||
self.set_weights(state)
|
||||
self.set_epsilon(state["cur_epsilon"])
|
||||
self.update_target()
|
||||
|
||||
def update_target(self):
|
||||
self.target_model.load_state_dict(self.model.state_dict())
|
||||
@@ -370,15 +432,28 @@ class QMixTorchPolicy(Policy):
|
||||
])
|
||||
return group_rewards
|
||||
|
||||
def _device_dict(self, state_dict):
|
||||
return {
|
||||
k: th.as_tensor(v, device=self.device)
|
||||
for k, v in state_dict.items()
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def _cpu_dict(state_dict):
|
||||
return {k: v.cpu().detach().numpy() for k, v in state_dict.items()}
|
||||
|
||||
def _unpack_observation(self, obs_batch):
|
||||
"""Unpacks the action mask / tuple obs from agent grouping.
|
||||
"""Unpacks the observation, action mask, and state (if present)
|
||||
from agent grouping.
|
||||
|
||||
Returns:
|
||||
obs (Tensor): flattened obs tensor of shape [B, n_agents, obs_size]
|
||||
mask (Tensor): action mask, if any
|
||||
obs (np.ndarray): obs tensor of shape [B, n_agents, obs_size]
|
||||
mask (np.ndarray): action mask, if any
|
||||
state (np.ndarray or None): state tensor of shape [B, state_size]
|
||||
or None if it is not in the batch
|
||||
"""
|
||||
unpacked = _unpack_obs(
|
||||
np.array(obs_batch),
|
||||
np.array(obs_batch, dtype=np.float32),
|
||||
self.observation_space.original_space,
|
||||
tensorlib=np)
|
||||
if self.has_action_mask:
|
||||
@@ -389,12 +464,22 @@ class QMixTorchPolicy(Policy):
|
||||
[o["action_mask"] for o in unpacked], axis=1).reshape(
|
||||
[len(obs_batch), self.n_agents, self.n_actions])
|
||||
else:
|
||||
if isinstance(unpacked[0], dict):
|
||||
unpacked_obs = [u["obs"] for u in unpacked]
|
||||
else:
|
||||
unpacked_obs = unpacked
|
||||
obs = np.concatenate(
|
||||
unpacked,
|
||||
unpacked_obs,
|
||||
axis=1).reshape([len(obs_batch), self.n_agents, self.obs_size])
|
||||
action_mask = np.ones(
|
||||
[len(obs_batch), self.n_agents, self.n_actions])
|
||||
return obs, action_mask
|
||||
[len(obs_batch), self.n_agents, self.n_actions],
|
||||
dtype=np.float32)
|
||||
|
||||
if self.has_env_global_state:
|
||||
state = unpacked[0][ENV_STATE]
|
||||
else:
|
||||
state = None
|
||||
return obs, action_mask, state
|
||||
|
||||
|
||||
def _validate(obs_space, action_space):
|
||||
@@ -436,9 +521,11 @@ def _mac(model, obs, h):
|
||||
h: Tensor of shape [B, n_agents, h_size]
|
||||
"""
|
||||
B, n_agents = obs.size(0), obs.size(1)
|
||||
obs_flat = obs.reshape([B * n_agents, -1])
|
||||
if not isinstance(obs, dict):
|
||||
obs = {"obs": obs}
|
||||
obs_agents_as_batches = {k: _drop_agent_dim(v) for k, v in obs.items()}
|
||||
h_flat = [s.reshape([B * n_agents, -1]) for s in h]
|
||||
q_flat, h_flat = model({"obs": obs_flat}, h_flat, None)
|
||||
q_flat, h_flat = model(obs_agents_as_batches, h_flat, None)
|
||||
return q_flat.reshape(
|
||||
[B, n_agents, -1]), [s.reshape([B, n_agents, -1]) for s in h_flat]
|
||||
|
||||
@@ -457,3 +544,16 @@ def _unroll_mac(model, obs_tensor):
|
||||
mac_out = th.stack(mac_out, dim=1) # Concat over time
|
||||
|
||||
return mac_out
|
||||
|
||||
|
||||
def _drop_agent_dim(T):
|
||||
shape = list(T.shape)
|
||||
B, n_agents = shape[0], shape[1]
|
||||
return T.reshape([B * n_agents] + shape[2:])
|
||||
|
||||
|
||||
def _add_agent_dim(T, n_agents):
|
||||
shape = list(T.shape)
|
||||
B = shape[0] // n_agents
|
||||
assert shape[0] % n_agents == 0
|
||||
return T.reshape([B, n_agents] + shape[1:])
|
||||
|
||||
@@ -17,6 +17,7 @@ modifies the environment.
|
||||
|
||||
import argparse
|
||||
import numpy as np
|
||||
from gym.spaces import Discrete
|
||||
|
||||
from ray import tune
|
||||
from ray.rllib.agents.ppo.ppo import PPOTrainer
|
||||
@@ -209,10 +210,8 @@ if __name__ == "__main__":
|
||||
"num_workers": 0,
|
||||
"multiagent": {
|
||||
"policies": {
|
||||
"pol1": (None, TwoStepGame.observation_space,
|
||||
TwoStepGame.action_space, {}),
|
||||
"pol2": (None, TwoStepGame.observation_space,
|
||||
TwoStepGame.action_space, {}),
|
||||
"pol1": (None, Discrete(6), TwoStepGame.action_space, {}),
|
||||
"pol2": (None, Discrete(6), TwoStepGame.action_space, {}),
|
||||
},
|
||||
"policy_mapping_fn": lambda x: "pol1" if x == 0 else "pol2",
|
||||
},
|
||||
|
||||
@@ -14,13 +14,14 @@ from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import argparse
|
||||
from gym.spaces import Tuple, Discrete
|
||||
from gym.spaces import Tuple, MultiDiscrete, Dict, Discrete
|
||||
import numpy as np
|
||||
|
||||
import ray
|
||||
from ray import tune
|
||||
from ray.tune import register_env, grid_search
|
||||
from ray.rllib.env.multi_agent_env import MultiAgentEnv
|
||||
from ray.rllib.agents.qmix.qmix_policy import ENV_STATE
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--stop", type=int, default=50000)
|
||||
@@ -30,20 +31,34 @@ parser.add_argument("--run", type=str, default="PG")
|
||||
class TwoStepGame(MultiAgentEnv):
|
||||
action_space = Discrete(2)
|
||||
|
||||
# Each agent gets a separate [3] obs space, to ensure that they can
|
||||
# learn meaningfully different Q values even with a shared Q model.
|
||||
observation_space = Discrete(6)
|
||||
|
||||
def __init__(self, env_config):
|
||||
self.state = None
|
||||
self.agent_1 = 0
|
||||
self.agent_2 = 1
|
||||
# MADDPG emits action logits instead of actual discrete actions
|
||||
self.actions_are_logits = env_config.get("actions_are_logits", False)
|
||||
self.one_hot_state_encoding = env_config.get("one_hot_state_encoding",
|
||||
False)
|
||||
self.with_state = env_config.get("separate_state_space", False)
|
||||
|
||||
if not self.one_hot_state_encoding:
|
||||
self.observation_space = Discrete(6)
|
||||
self.with_state = False
|
||||
else:
|
||||
# Each agent gets the full state (one-hot encoding of which of the
|
||||
# three states are active) as input with the receiving agent's
|
||||
# ID (1 or 2) concatenated onto the end.
|
||||
if self.with_state:
|
||||
self.observation_space = Dict({
|
||||
"obs": MultiDiscrete([2, 2, 2, 3]),
|
||||
ENV_STATE: MultiDiscrete([2, 2, 2])
|
||||
})
|
||||
else:
|
||||
self.observation_space = MultiDiscrete([2, 2, 2, 3])
|
||||
|
||||
def reset(self):
|
||||
self.state = 0
|
||||
return {self.agent_1: self.state, self.agent_2: self.state + 3}
|
||||
self.state = np.array([1, 0, 0])
|
||||
return self._obs()
|
||||
|
||||
def step(self, action_dict):
|
||||
if self.actions_are_logits:
|
||||
@@ -52,16 +67,17 @@ class TwoStepGame(MultiAgentEnv):
|
||||
for k, v in action_dict.items()
|
||||
}
|
||||
|
||||
if self.state == 0:
|
||||
state_index = np.flatnonzero(self.state)
|
||||
if state_index == 0:
|
||||
action = action_dict[self.agent_1]
|
||||
assert action in [0, 1], action
|
||||
if action == 0:
|
||||
self.state = 1
|
||||
self.state = np.array([0, 1, 0])
|
||||
else:
|
||||
self.state = 2
|
||||
self.state = np.array([0, 0, 1])
|
||||
global_rew = 0
|
||||
done = False
|
||||
elif self.state == 1:
|
||||
elif state_index == 1:
|
||||
global_rew = 7
|
||||
done = True
|
||||
else:
|
||||
@@ -79,11 +95,41 @@ class TwoStepGame(MultiAgentEnv):
|
||||
self.agent_1: global_rew / 2.0,
|
||||
self.agent_2: global_rew / 2.0
|
||||
}
|
||||
obs = {self.agent_1: self.state, self.agent_2: self.state + 3}
|
||||
obs = self._obs()
|
||||
dones = {"__all__": done}
|
||||
infos = {}
|
||||
return obs, rewards, dones, infos
|
||||
|
||||
def _obs(self):
|
||||
if self.with_state:
|
||||
return {
|
||||
self.agent_1: {
|
||||
"obs": self.agent_1_obs(),
|
||||
ENV_STATE: self.state
|
||||
},
|
||||
self.agent_2: {
|
||||
"obs": self.agent_2_obs(),
|
||||
ENV_STATE: self.state
|
||||
}
|
||||
}
|
||||
else:
|
||||
return {
|
||||
self.agent_1: self.agent_1_obs(),
|
||||
self.agent_2: self.agent_2_obs()
|
||||
}
|
||||
|
||||
def agent_1_obs(self):
|
||||
if self.one_hot_state_encoding:
|
||||
return np.concatenate([self.state, [1]])
|
||||
else:
|
||||
return np.flatnonzero(self.state)[0]
|
||||
|
||||
def agent_2_obs(self):
|
||||
if self.one_hot_state_encoding:
|
||||
return np.concatenate([self.state, [2]])
|
||||
else:
|
||||
return np.flatnonzero(self.state)[0] + 3
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = parser.parse_args()
|
||||
@@ -92,8 +138,14 @@ if __name__ == "__main__":
|
||||
"group_1": [0, 1],
|
||||
}
|
||||
obs_space = Tuple([
|
||||
TwoStepGame.observation_space,
|
||||
TwoStepGame.observation_space,
|
||||
Dict({
|
||||
"obs": MultiDiscrete([2, 2, 2, 3]),
|
||||
ENV_STATE: MultiDiscrete([2, 2, 2])
|
||||
}),
|
||||
Dict({
|
||||
"obs": MultiDiscrete([2, 2, 2, 3]),
|
||||
ENV_STATE: MultiDiscrete([2, 2, 2])
|
||||
}),
|
||||
])
|
||||
act_space = Tuple([
|
||||
TwoStepGame.action_space,
|
||||
@@ -106,8 +158,8 @@ if __name__ == "__main__":
|
||||
|
||||
if args.run == "contrib/MADDPG":
|
||||
obs_space_dict = {
|
||||
"agent_1": TwoStepGame.observation_space,
|
||||
"agent_2": TwoStepGame.observation_space,
|
||||
"agent_1": Discrete(6),
|
||||
"agent_2": Discrete(6),
|
||||
}
|
||||
act_space_dict = {
|
||||
"agent_1": TwoStepGame.action_space,
|
||||
@@ -120,14 +172,12 @@ if __name__ == "__main__":
|
||||
},
|
||||
"multiagent": {
|
||||
"policies": {
|
||||
"pol1": (None, TwoStepGame.observation_space,
|
||||
TwoStepGame.action_space, {
|
||||
"agent_id": 0,
|
||||
}),
|
||||
"pol2": (None, TwoStepGame.observation_space,
|
||||
TwoStepGame.action_space, {
|
||||
"agent_id": 1,
|
||||
}),
|
||||
"pol1": (None, Discrete(6), TwoStepGame.action_space, {
|
||||
"agent_id": 0,
|
||||
}),
|
||||
"pol2": (None, Discrete(6), TwoStepGame.action_space, {
|
||||
"agent_id": 1,
|
||||
}),
|
||||
},
|
||||
"policy_mapping_fn": lambda x: "pol1" if x == 0 else "pol2",
|
||||
},
|
||||
@@ -137,9 +187,14 @@ if __name__ == "__main__":
|
||||
config = {
|
||||
"sample_batch_size": 4,
|
||||
"train_batch_size": 32,
|
||||
"exploration_fraction": .4,
|
||||
"exploration_final_eps": 0.0,
|
||||
"num_workers": 0,
|
||||
"mixer": grid_search([None, "qmix", "vdn"]),
|
||||
"env_config": {
|
||||
"separate_state_space": True,
|
||||
"one_hot_state_encoding": True
|
||||
},
|
||||
}
|
||||
group = True
|
||||
elif args.run == "APEX_QMIX":
|
||||
@@ -156,6 +211,10 @@ if __name__ == "__main__":
|
||||
"sample_batch_size": 32,
|
||||
"target_network_update_freq": 500,
|
||||
"timesteps_per_iteration": 1000,
|
||||
"env_config": {
|
||||
"separate_state_space": True,
|
||||
"one_hot_state_encoding": True
|
||||
},
|
||||
}
|
||||
group = True
|
||||
else:
|
||||
|
||||
Reference in New Issue
Block a user