From 4aa06918ae607db171cdde59cb01408a388bf41b Mon Sep 17 00:00:00 2001 From: "Matthew A. Wright" Date: Tue, 8 Oct 2019 13:18:07 -0700 Subject: [PATCH] Qmix on gpu and with non-stacked-obs environment state support (#5751) --- rllib/agents/qmix/qmix.py | 2 +- rllib/agents/qmix/qmix_policy.py | 288 ++++++++++++++++++--------- rllib/examples/centralized_critic.py | 7 +- rllib/examples/twostep_game.py | 107 +++++++--- 4 files changed, 281 insertions(+), 123 deletions(-) diff --git a/rllib/agents/qmix/qmix.py b/rllib/agents/qmix/qmix.py index 6a5bff9d6..a6cb70e2c 100644 --- a/rllib/agents/qmix/qmix.py +++ b/rllib/agents/qmix/qmix.py @@ -49,7 +49,7 @@ DEFAULT_CONFIG = with_common_config({ "buffer_size": 10000, # === Optimization === - # Learning rate for adam optimizer + # Learning rate for RMSProp optimizer "lr": 0.0005, # RMSProp alpha "optim_alpha": 0.99, diff --git a/rllib/agents/qmix/qmix_policy.py b/rllib/agents/qmix/qmix_policy.py index af8f08b53..05a348c9b 100644 --- a/rllib/agents/qmix/qmix_policy.py +++ b/rllib/agents/qmix/qmix_policy.py @@ -3,6 +3,7 @@ from __future__ import division from __future__ import print_function from gym.spaces import Tuple, Discrete, Dict +import os import logging import numpy as np import torch as th @@ -14,7 +15,7 @@ import ray from ray.rllib.agents.qmix.mixers import VDNMixer, QMixer from ray.rllib.agents.qmix.model import RNNModel, _get_size from ray.rllib.evaluation.metrics import LEARNER_STATS_KEY -from ray.rllib.policy.policy import Policy, TupleActions +from ray.rllib.policy.policy import TupleActions, Policy from ray.rllib.policy.rnn_sequencing import chop_into_sequences from ray.rllib.policy.sample_batch import SampleBatch from ray.rllib.models.catalog import ModelCatalog @@ -24,6 +25,9 @@ from ray.rllib.utils.annotations import override logger = logging.getLogger(__name__) +# if the obs space is Dict type, look for the global state under this key +ENV_STATE = "state" + class QMixLoss(nn.Module): def __init__(self, @@ -45,8 +49,17 @@ class QMixLoss(nn.Module): self.double_q = double_q self.gamma = gamma - def forward(self, rewards, actions, terminated, mask, obs, next_obs, - action_mask, next_action_mask): + def forward(self, + rewards, + actions, + terminated, + mask, + obs, + next_obs, + action_mask, + next_action_mask, + state=None, + next_state=None): """Forward pass of the loss. Arguments: @@ -58,8 +71,20 @@ class QMixLoss(nn.Module): next_obs: Tensor of shape [B, T, n_agents, obs_size] action_mask: Tensor of shape [B, T, n_agents, n_actions] next_action_mask: Tensor of shape [B, T, n_agents, n_actions] + state: Tensor of shape [B, T, state_dim] (optional) + next_state: Tensor of shape [B, T, state_dim] (optional) """ + # Assert either none or both of state and next_state are given + if state is None and next_state is None: + state = obs # default to state being all agents' observations + next_state = next_obs + elif (state is None) != (next_state is None): + raise ValueError("Expected either neither or both of `state` and " + "`next_state` to be given. Got: " + "\n`state` = {}\n`next_state` = {}".format( + state, next_state)) + # Calculate estimated Q-Values mac_out = _unroll_mac(self.model, obs) @@ -89,7 +114,7 @@ class QMixLoss(nn.Module): mac_out_tp1[ignore_action_tp1] = -np.inf # obtain best actions at t+1 according to policy NN - cur_max_actions = mac_out_tp1.max(dim=3, keepdim=True)[1] + cur_max_actions = mac_out_tp1.argmax(dim=3, keepdim=True) # use the target network to estimate the Q-values of policy # network's selected actions @@ -104,10 +129,8 @@ class QMixLoss(nn.Module): # Mix if self.mixer is not None: - # TODO(ekl) add support for handling global state? This is just - # treating the stacked agent obs as the state. - chosen_action_qvals = self.mixer(chosen_action_qvals, obs) - target_max_qvals = self.target_mixer(target_max_qvals, next_obs) + chosen_action_qvals = self.mixer(chosen_action_qvals, state) + target_max_qvals = self.target_mixer(target_max_qvals, next_state) # Calculate 1-step Q-Learning targets targets = rewards + self.gamma * (1 - terminated) * target_max_qvals @@ -146,24 +169,36 @@ class QMixTorchPolicy(Policy): self.n_agents = len(obs_space.original_space.spaces) self.n_actions = action_space.spaces[0].n self.h_size = config["model"]["lstm_cell_size"] + self.has_env_global_state = False + self.has_action_mask = False + self.device = (th.device("cuda") + if bool(os.environ.get("CUDA_VISIBLE_DEVICES", None)) + else th.device("cpu")) agent_obs_space = obs_space.original_space.spaces[0] if isinstance(agent_obs_space, Dict): space_keys = set(agent_obs_space.spaces.keys()) - if not {"obs", "action_mask"}.issubset(space_keys): + if "obs" not in space_keys: raise ValueError( - "Dict obs space for agent must have keyset " - "['obs', 'action_mask'], got {}".format(space_keys)) - mask_shape = tuple(agent_obs_space.spaces["action_mask"].shape) - if mask_shape != (self.n_actions, ): - raise ValueError("Action mask shape must be {}, got {}".format( - (self.n_actions, ), mask_shape)) - self.has_action_mask = True + "Dict obs space must have subspace labeled `obs`") self.obs_size = _get_size(agent_obs_space.spaces["obs"]) + if "action_mask" in space_keys: + mask_shape = tuple(agent_obs_space.spaces["action_mask"].shape) + if mask_shape != (self.n_actions, ): + raise ValueError( + "Action mask shape must be {}, got {}".format( + (self.n_actions, ), mask_shape)) + self.has_action_mask = True + if ENV_STATE in space_keys: + self.env_global_state_shape = _get_size( + agent_obs_space.spaces[ENV_STATE]) + self.has_env_global_state = True + else: + self.env_global_state_shape = (self.obs_size, self.n_agents) # The real agent obs space is nested inside the dict + config["model"]["full_obs_space"] = agent_obs_space agent_obs_space = agent_obs_space.spaces["obs"] else: - self.has_action_mask = False self.obs_size = _get_size(agent_obs_space) self.model = ModelCatalog.get_model_v2( @@ -173,7 +208,7 @@ class QMixTorchPolicy(Policy): config["model"], framework="torch", name="model", - default_model=RNNModel) + default_model=RNNModel).to(self.device) self.target_model = ModelCatalog.get_model_v2( agent_obs_space, @@ -182,22 +217,21 @@ class QMixTorchPolicy(Policy): config["model"], framework="torch", name="target_model", - default_model=RNNModel) + default_model=RNNModel).to(self.device) # Setup the mixer network. - # The global state is just the stacked agent observations for now. - self.state_shape = [self.obs_size, self.n_agents] if config["mixer"] is None: self.mixer = None self.target_mixer = None elif config["mixer"] == "qmix": - self.mixer = QMixer(self.n_agents, self.state_shape, - config["mixing_embed_dim"]) - self.target_mixer = QMixer(self.n_agents, self.state_shape, - config["mixing_embed_dim"]) + self.mixer = QMixer(self.n_agents, self.env_global_state_shape, + config["mixing_embed_dim"]).to(self.device) + self.target_mixer = QMixer( + self.n_agents, self.env_global_state_shape, + config["mixing_embed_dim"]).to(self.device) elif config["mixer"] == "vdn": - self.mixer = VDNMixer() - self.target_mixer = VDNMixer() + self.mixer = VDNMixer().to(self.device) + self.target_mixer = VDNMixer().to(self.device) else: raise ValueError("Unknown mixer type {}".format(config["mixer"])) @@ -226,14 +260,21 @@ class QMixTorchPolicy(Policy): info_batch=None, episodes=None, **kwargs): - obs_batch, action_mask = self._unpack_observation(obs_batch) + obs_batch, action_mask, _ = self._unpack_observation(obs_batch) + # We need to ensure we do not use the env global state + # to compute actions # Compute actions with th.no_grad(): q_values, hiddens = _mac( - self.model, th.from_numpy(obs_batch), - [th.from_numpy(np.array(s)) for s in state_batches]) - avail = th.from_numpy(action_mask).float() + self.model, + th.as_tensor(obs_batch, dtype=th.float, device=self.device), [ + th.as_tensor( + np.array(s), dtype=th.float, device=self.device) + for s in state_batches + ]) + avail = th.as_tensor( + action_mask, dtype=th.float, device=self.device) masked_q_values = q_values.clone() masked_q_values[avail == 0.0] = -float("inf") # epsilon-greedy action selector @@ -241,63 +282,81 @@ class QMixTorchPolicy(Policy): pick_random = (random_numbers < self.cur_epsilon).long() random_actions = Categorical(avail).sample().long() actions = (pick_random * random_actions + - (1 - pick_random) * masked_q_values.max(dim=2)[1]) - actions = actions.numpy() - hiddens = [s.numpy() for s in hiddens] + (1 - pick_random) * masked_q_values.argmax(dim=2)) + actions = actions.cpu().numpy() + hiddens = [s.cpu().numpy() for s in hiddens] return TupleActions(list(actions.transpose([1, 0]))), hiddens, {} @override(Policy) def learn_on_batch(self, samples): - obs_batch, action_mask = self._unpack_observation( + obs_batch, action_mask, env_global_state = self._unpack_observation( samples[SampleBatch.CUR_OBS]) - next_obs_batch, next_action_mask = self._unpack_observation( - samples[SampleBatch.NEXT_OBS]) + (next_obs_batch, next_action_mask, + next_env_global_state) = self._unpack_observation( + samples[SampleBatch.NEXT_OBS]) group_rewards = self._get_group_rewards(samples[SampleBatch.INFOS]) - # These will be padded to shape [B * T, ...] - [rew, action_mask, next_action_mask, act, dones, obs, next_obs], \ - initial_states, seq_lens = \ + input_list = [ + group_rewards, action_mask, next_action_mask, + samples[SampleBatch.ACTIONS], samples[SampleBatch.DONES], + obs_batch, next_obs_batch + ] + if self.has_env_global_state: + input_list.extend([env_global_state, next_env_global_state]) + + output_list, _, seq_lens = \ chop_into_sequences( samples[SampleBatch.EPS_ID], samples[SampleBatch.UNROLL_ID], - samples[SampleBatch.AGENT_INDEX], [ - group_rewards, action_mask, next_action_mask, - samples[SampleBatch.ACTIONS], samples[SampleBatch.DONES], - obs_batch, next_obs_batch - ], - [samples["state_in_{}".format(k)] - for k in range(len(self.get_initial_state()))], + samples[SampleBatch.AGENT_INDEX], + input_list, + [], # RNN states not used here max_seq_len=self.config["model"]["max_seq_len"], dynamic_max=True) + # These will be padded to shape [B * T, ...] + if self.has_env_global_state: + (rew, action_mask, next_action_mask, act, dones, obs, next_obs, + env_global_state, next_env_global_state) = output_list + else: + (rew, action_mask, next_action_mask, act, dones, obs, + next_obs) = output_list B, T = len(seq_lens), max(seq_lens) - def to_batches(arr): + def to_batches(arr, dtype): new_shape = [B, T] + list(arr.shape[1:]) - return th.from_numpy(np.reshape(arr, new_shape)) + return th.as_tensor( + np.reshape(arr, new_shape), dtype=dtype, device=self.device) - rewards = to_batches(rew).float() - actions = to_batches(act).long() - obs = to_batches(obs).reshape([B, T, self.n_agents, - self.obs_size]).float() - action_mask = to_batches(action_mask) - next_obs = to_batches(next_obs).reshape( - [B, T, self.n_agents, self.obs_size]).float() - next_action_mask = to_batches(next_action_mask) + rewards = to_batches(rew, th.float) + actions = to_batches(act, th.long) + obs = to_batches(obs, th.float).reshape( + [B, T, self.n_agents, self.obs_size]) + action_mask = to_batches(action_mask, th.float) + next_obs = to_batches(next_obs, th.float).reshape( + [B, T, self.n_agents, self.obs_size]) + next_action_mask = to_batches(next_action_mask, th.float) + if self.has_env_global_state: + env_global_state = to_batches(env_global_state, th.float) + next_env_global_state = to_batches(next_env_global_state, th.float) # TODO(ekl) this treats group termination as individual termination - terminated = to_batches(dones.astype(np.float32)).unsqueeze(2).expand( + terminated = to_batches(dones, th.float).unsqueeze(2).expand( B, T, self.n_agents) # Create mask for where index is < unpadded sequence length - filled = (np.reshape(np.tile(np.arange(T), B), [B, T]) < - np.expand_dims(seq_lens, 1)).astype(np.float32) - mask = th.from_numpy(filled).unsqueeze(2).expand(B, T, self.n_agents) + filled = np.reshape( + np.tile(np.arange(T, dtype=np.float32), B), + [B, T]) < np.expand_dims(seq_lens, 1) + mask = th.as_tensor( + filled, dtype=th.float, device=self.device).unsqueeze(2).expand( + B, T, self.n_agents) # Compute loss - loss_out, mask, masked_td_error, chosen_action_qvals, targets = \ - self.loss(rewards, actions, terminated, mask, obs, - next_obs, action_mask, next_action_mask) + loss_out, mask, masked_td_error, chosen_action_qvals, targets = ( + self.loss(rewards, actions, terminated, mask, obs, next_obs, + action_mask, next_action_mask, env_global_state, + next_env_global_state)) # Optimise self.optimiser.zero_grad() @@ -319,40 +378,43 @@ class QMixTorchPolicy(Policy): return {LEARNER_STATS_KEY: stats} @override(Policy) - def get_initial_state(self): + def get_initial_state(self): # initial RNN state return [ - s.expand([self.n_agents, -1]).numpy() + s.expand([self.n_agents, -1]).cpu().numpy() for s in self.model.get_initial_state() ] @override(Policy) def get_weights(self): - return {"model": self.model.state_dict()} - - @override(Policy) - def set_weights(self, weights): - self.model.load_state_dict(weights["model"]) - - @override(Policy) - def get_state(self): return { - "model": self.model.state_dict(), - "target_model": self.target_model.state_dict(), - "mixer": self.mixer.state_dict() if self.mixer else None, - "target_mixer": self.target_mixer.state_dict() + "model": self._cpu_dict(self.model.state_dict()), + "target_model": self._cpu_dict(self.target_model.state_dict()), + "mixer": self._cpu_dict(self.mixer.state_dict()) + if self.mixer else None, + "target_mixer": self._cpu_dict(self.target_mixer.state_dict()) if self.mixer else None, - "cur_epsilon": self.cur_epsilon, } + @override(Policy) + def set_weights(self, weights): + self.model.load_state_dict(self._device_dict(weights["model"])) + self.target_model.load_state_dict( + self._device_dict(weights["target_model"])) + if weights["mixer"] is not None: + self.mixer.load_state_dict(self._device_dict(weights["mixer"])) + self.target_mixer.load_state_dict( + self._device_dict(weights["target_mixer"])) + + @override(Policy) + def get_state(self): + state = self.get_weights() + state["cur_epsilon"] = self.cur_epsilon + return state + @override(Policy) def set_state(self, state): - self.model.load_state_dict(state["model"]) - self.target_model.load_state_dict(state["target_model"]) - if state["mixer"] is not None: - self.mixer.load_state_dict(state["mixer"]) - self.target_mixer.load_state_dict(state["target_mixer"]) + self.set_weights(state) self.set_epsilon(state["cur_epsilon"]) - self.update_target() def update_target(self): self.target_model.load_state_dict(self.model.state_dict()) @@ -370,15 +432,28 @@ class QMixTorchPolicy(Policy): ]) return group_rewards + def _device_dict(self, state_dict): + return { + k: th.as_tensor(v, device=self.device) + for k, v in state_dict.items() + } + + @staticmethod + def _cpu_dict(state_dict): + return {k: v.cpu().detach().numpy() for k, v in state_dict.items()} + def _unpack_observation(self, obs_batch): - """Unpacks the action mask / tuple obs from agent grouping. + """Unpacks the observation, action mask, and state (if present) + from agent grouping. Returns: - obs (Tensor): flattened obs tensor of shape [B, n_agents, obs_size] - mask (Tensor): action mask, if any + obs (np.ndarray): obs tensor of shape [B, n_agents, obs_size] + mask (np.ndarray): action mask, if any + state (np.ndarray or None): state tensor of shape [B, state_size] + or None if it is not in the batch """ unpacked = _unpack_obs( - np.array(obs_batch), + np.array(obs_batch, dtype=np.float32), self.observation_space.original_space, tensorlib=np) if self.has_action_mask: @@ -389,12 +464,22 @@ class QMixTorchPolicy(Policy): [o["action_mask"] for o in unpacked], axis=1).reshape( [len(obs_batch), self.n_agents, self.n_actions]) else: + if isinstance(unpacked[0], dict): + unpacked_obs = [u["obs"] for u in unpacked] + else: + unpacked_obs = unpacked obs = np.concatenate( - unpacked, + unpacked_obs, axis=1).reshape([len(obs_batch), self.n_agents, self.obs_size]) action_mask = np.ones( - [len(obs_batch), self.n_agents, self.n_actions]) - return obs, action_mask + [len(obs_batch), self.n_agents, self.n_actions], + dtype=np.float32) + + if self.has_env_global_state: + state = unpacked[0][ENV_STATE] + else: + state = None + return obs, action_mask, state def _validate(obs_space, action_space): @@ -436,9 +521,11 @@ def _mac(model, obs, h): h: Tensor of shape [B, n_agents, h_size] """ B, n_agents = obs.size(0), obs.size(1) - obs_flat = obs.reshape([B * n_agents, -1]) + if not isinstance(obs, dict): + obs = {"obs": obs} + obs_agents_as_batches = {k: _drop_agent_dim(v) for k, v in obs.items()} h_flat = [s.reshape([B * n_agents, -1]) for s in h] - q_flat, h_flat = model({"obs": obs_flat}, h_flat, None) + q_flat, h_flat = model(obs_agents_as_batches, h_flat, None) return q_flat.reshape( [B, n_agents, -1]), [s.reshape([B, n_agents, -1]) for s in h_flat] @@ -457,3 +544,16 @@ def _unroll_mac(model, obs_tensor): mac_out = th.stack(mac_out, dim=1) # Concat over time return mac_out + + +def _drop_agent_dim(T): + shape = list(T.shape) + B, n_agents = shape[0], shape[1] + return T.reshape([B * n_agents] + shape[2:]) + + +def _add_agent_dim(T, n_agents): + shape = list(T.shape) + B = shape[0] // n_agents + assert shape[0] % n_agents == 0 + return T.reshape([B, n_agents] + shape[1:]) diff --git a/rllib/examples/centralized_critic.py b/rllib/examples/centralized_critic.py index 6c8a450c4..d793adbf1 100644 --- a/rllib/examples/centralized_critic.py +++ b/rllib/examples/centralized_critic.py @@ -17,6 +17,7 @@ modifies the environment. import argparse import numpy as np +from gym.spaces import Discrete from ray import tune from ray.rllib.agents.ppo.ppo import PPOTrainer @@ -209,10 +210,8 @@ if __name__ == "__main__": "num_workers": 0, "multiagent": { "policies": { - "pol1": (None, TwoStepGame.observation_space, - TwoStepGame.action_space, {}), - "pol2": (None, TwoStepGame.observation_space, - TwoStepGame.action_space, {}), + "pol1": (None, Discrete(6), TwoStepGame.action_space, {}), + "pol2": (None, Discrete(6), TwoStepGame.action_space, {}), }, "policy_mapping_fn": lambda x: "pol1" if x == 0 else "pol2", }, diff --git a/rllib/examples/twostep_game.py b/rllib/examples/twostep_game.py index 3ddc0fd04..37e622390 100644 --- a/rllib/examples/twostep_game.py +++ b/rllib/examples/twostep_game.py @@ -14,13 +14,14 @@ from __future__ import division from __future__ import print_function import argparse -from gym.spaces import Tuple, Discrete +from gym.spaces import Tuple, MultiDiscrete, Dict, Discrete import numpy as np import ray from ray import tune from ray.tune import register_env, grid_search from ray.rllib.env.multi_agent_env import MultiAgentEnv +from ray.rllib.agents.qmix.qmix_policy import ENV_STATE parser = argparse.ArgumentParser() parser.add_argument("--stop", type=int, default=50000) @@ -30,20 +31,34 @@ parser.add_argument("--run", type=str, default="PG") class TwoStepGame(MultiAgentEnv): action_space = Discrete(2) - # Each agent gets a separate [3] obs space, to ensure that they can - # learn meaningfully different Q values even with a shared Q model. - observation_space = Discrete(6) - def __init__(self, env_config): self.state = None self.agent_1 = 0 self.agent_2 = 1 # MADDPG emits action logits instead of actual discrete actions self.actions_are_logits = env_config.get("actions_are_logits", False) + self.one_hot_state_encoding = env_config.get("one_hot_state_encoding", + False) + self.with_state = env_config.get("separate_state_space", False) + + if not self.one_hot_state_encoding: + self.observation_space = Discrete(6) + self.with_state = False + else: + # Each agent gets the full state (one-hot encoding of which of the + # three states are active) as input with the receiving agent's + # ID (1 or 2) concatenated onto the end. + if self.with_state: + self.observation_space = Dict({ + "obs": MultiDiscrete([2, 2, 2, 3]), + ENV_STATE: MultiDiscrete([2, 2, 2]) + }) + else: + self.observation_space = MultiDiscrete([2, 2, 2, 3]) def reset(self): - self.state = 0 - return {self.agent_1: self.state, self.agent_2: self.state + 3} + self.state = np.array([1, 0, 0]) + return self._obs() def step(self, action_dict): if self.actions_are_logits: @@ -52,16 +67,17 @@ class TwoStepGame(MultiAgentEnv): for k, v in action_dict.items() } - if self.state == 0: + state_index = np.flatnonzero(self.state) + if state_index == 0: action = action_dict[self.agent_1] assert action in [0, 1], action if action == 0: - self.state = 1 + self.state = np.array([0, 1, 0]) else: - self.state = 2 + self.state = np.array([0, 0, 1]) global_rew = 0 done = False - elif self.state == 1: + elif state_index == 1: global_rew = 7 done = True else: @@ -79,11 +95,41 @@ class TwoStepGame(MultiAgentEnv): self.agent_1: global_rew / 2.0, self.agent_2: global_rew / 2.0 } - obs = {self.agent_1: self.state, self.agent_2: self.state + 3} + obs = self._obs() dones = {"__all__": done} infos = {} return obs, rewards, dones, infos + def _obs(self): + if self.with_state: + return { + self.agent_1: { + "obs": self.agent_1_obs(), + ENV_STATE: self.state + }, + self.agent_2: { + "obs": self.agent_2_obs(), + ENV_STATE: self.state + } + } + else: + return { + self.agent_1: self.agent_1_obs(), + self.agent_2: self.agent_2_obs() + } + + def agent_1_obs(self): + if self.one_hot_state_encoding: + return np.concatenate([self.state, [1]]) + else: + return np.flatnonzero(self.state)[0] + + def agent_2_obs(self): + if self.one_hot_state_encoding: + return np.concatenate([self.state, [2]]) + else: + return np.flatnonzero(self.state)[0] + 3 + if __name__ == "__main__": args = parser.parse_args() @@ -92,8 +138,14 @@ if __name__ == "__main__": "group_1": [0, 1], } obs_space = Tuple([ - TwoStepGame.observation_space, - TwoStepGame.observation_space, + Dict({ + "obs": MultiDiscrete([2, 2, 2, 3]), + ENV_STATE: MultiDiscrete([2, 2, 2]) + }), + Dict({ + "obs": MultiDiscrete([2, 2, 2, 3]), + ENV_STATE: MultiDiscrete([2, 2, 2]) + }), ]) act_space = Tuple([ TwoStepGame.action_space, @@ -106,8 +158,8 @@ if __name__ == "__main__": if args.run == "contrib/MADDPG": obs_space_dict = { - "agent_1": TwoStepGame.observation_space, - "agent_2": TwoStepGame.observation_space, + "agent_1": Discrete(6), + "agent_2": Discrete(6), } act_space_dict = { "agent_1": TwoStepGame.action_space, @@ -120,14 +172,12 @@ if __name__ == "__main__": }, "multiagent": { "policies": { - "pol1": (None, TwoStepGame.observation_space, - TwoStepGame.action_space, { - "agent_id": 0, - }), - "pol2": (None, TwoStepGame.observation_space, - TwoStepGame.action_space, { - "agent_id": 1, - }), + "pol1": (None, Discrete(6), TwoStepGame.action_space, { + "agent_id": 0, + }), + "pol2": (None, Discrete(6), TwoStepGame.action_space, { + "agent_id": 1, + }), }, "policy_mapping_fn": lambda x: "pol1" if x == 0 else "pol2", }, @@ -137,9 +187,14 @@ if __name__ == "__main__": config = { "sample_batch_size": 4, "train_batch_size": 32, + "exploration_fraction": .4, "exploration_final_eps": 0.0, "num_workers": 0, "mixer": grid_search([None, "qmix", "vdn"]), + "env_config": { + "separate_state_space": True, + "one_hot_state_encoding": True + }, } group = True elif args.run == "APEX_QMIX": @@ -156,6 +211,10 @@ if __name__ == "__main__": "sample_batch_size": 32, "target_network_update_freq": 500, "timesteps_per_iteration": 1000, + "env_config": { + "separate_state_space": True, + "one_hot_state_encoding": True + }, } group = True else: