[rllib] Add starcraft multiagent env as example (#3542)

This commit is contained in:
Richard Liaw
2018-12-27 10:00:32 +08:00
committed by GitHub
parent b4f61dfd50
commit ac792d70c8
3 changed files with 203 additions and 0 deletions
@@ -0,0 +1,18 @@
StarCraft on RLlib
==================
This builds off the StarCraft env in https://github.com/oxwhirl/pymarl_alpha.
Temporary instructions
----------------------
To install, run
```
git clone https://github.com/oxwhirl/pymarl_alpha
mv pymarl_alpha ~/pymarl
cd ~/pymarl
install_sc1.sh
install_sc2.sh
export PYMARL_PATH="~/pymarl"
```
@@ -0,0 +1,32 @@
## Adapted from `https://github.com/oxwhirl/pymarl_alpha`.
env: sc2
env_args:
map_name: "3m_3m" # SC2 map name
difficulty: "7" # Very hard
move_amount: 2 # How much units are ordered to move per step
step_mul: 8 # How many frames are skiped per step
reward_sparse: False # Only +1/-1 reward for win/defeat (the rest of reward configs are ignored if True)
reward_only_positive: True # Reward is always positive
reward_negative_scale: 0.5 # How much to scale negative rewards, ignored if reward_only_positive=True
reward_death_value: 10 # Reward for killing an enemy unit and penalty for having an allied unit killed (if reward_only_poitive=False)
reward_scale: True # Whether or not to scale rewards before returning to agents
reward_scale_rate: 20 # If reward_scale=True, the agents receive the reward of (max_reward / reward_scale_rate), where max_reward is the maximum possible reward per episode
reward_win: 200 # Reward for win
reward_defeat: 0 # Reward for defeat (should be nonpositive)
state_last_action: True # Whether the last actions of units is a part of the state
obs_instead_of_state: False # Use combination of all agnets' observations as state
obs_own_health: True # Whether agents receive their own health as a part of observation
obs_all_health: True # Whether agents receive the health of all units (in the sight range) as a part of observataion
continuing_episode: False # Stop/continue episode after its termination
game_version: "4.1.2" # Ignored for Mac/Windows
save_replay_prefix: "" # Prefix of the replay to be saved
heuristic: False # Whether or not use a simple nonlearning hearistic as a controller
test_nepisode: 32
test_interval: 10000
log_interval: 2000
runner_log_interval: 2000
learner_log_interval: 2000
t_max: 2000000
@@ -0,0 +1,153 @@
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import argparse
from gym.spaces import Discrete, Box, Dict, Tuple
import os
import sys
import tensorflow as tf
import tensorflow.contrib.slim as slim
import yaml
import ray
from ray.rllib.env.multi_agent_env import MultiAgentEnv
from ray.tune.registry import register_env
from ray.rllib.models import Model, ModelCatalog
from ray.rllib.models.misc import normc_initializer
from ray.rllib.agents.qmix import QMixAgent
from ray.rllib.agents.pg import PGAgent
from ray.rllib.agents.ppo import PPOAgent
from ray.tune.logger import pretty_print
class MaskedActionsModel(Model):
def _build_layers_v2(self, input_dict, num_outputs, options):
action_mask = input_dict["obs"]["action_mask"]
if num_outputs != action_mask.shape[1].value:
raise ValueError(
"This model assumes num outputs is equal to max avail actions",
num_outputs, action_mask)
# Standard FC net component.
last_layer = input_dict["obs"]["obs"]
hiddens = [256, 256]
for i, size in enumerate(hiddens):
label = "fc{}".format(i)
last_layer = slim.fully_connected(
last_layer,
size,
weights_initializer=normc_initializer(1.0),
activation_fn=tf.nn.tanh,
scope=label)
action_logits = slim.fully_connected(
last_layer,
num_outputs,
weights_initializer=normc_initializer(0.01),
activation_fn=None,
scope="fc_out")
# Mask out invalid actions (use tf.float32.min for stability)
inf_mask = tf.maximum(tf.log(action_mask), tf.float32.min)
masked_logits = inf_mask + action_logits
return masked_logits, last_layer
class SC2MultiAgentEnv(MultiAgentEnv):
"""RLlib Wrapper around StarCraft2."""
def __init__(self, override_cfg):
PYMARL_PATH = override_cfg.pop("pymarl_path")
os.environ["SC2PATH"] = os.path.join(PYMARL_PATH,
"3rdparty/StarCraftII")
sys.path.append(os.path.join(PYMARL_PATH, "src"))
from envs.starcraft2 import StarCraft2Env
curpath = os.path.dirname(os.path.abspath(__file__))
with open(os.path.join(curpath, "sc2.yaml")) as f:
pymarl_args = yaml.load(f)
pymarl_args.update(override_cfg)
pymarl_args["env_args"].setdefault("seed", 0)
self._starcraft_env = StarCraft2Env(**pymarl_args)
obs_size = self._starcraft_env.get_obs_size()
num_actions = self._starcraft_env.get_total_actions()
self.observation_space = Dict({
"action_mask": Box(0, 1, shape=(num_actions, )),
"obs": Box(-1, 1, shape=(obs_size, ))
})
self.action_space = Discrete(self._starcraft_env.get_total_actions())
def reset(self):
obs_list, state_list = self._starcraft_env.reset()
return_obs = {}
for i, obs in enumerate(obs_list):
return_obs[i] = {
"action_mask": self._starcraft_env.get_avail_agent_actions(i),
"obs": obs
}
return return_obs
def step(self, action_dict):
# TODO(rliaw): Check to handle missing agents, if any
actions = [action_dict[k] for k in sorted(action_dict)]
rew, done, info = self._starcraft_env.step(actions)
obs_list = self._starcraft_env.get_obs()
return_obs = {}
for i, obs in enumerate(obs_list):
return_obs[i] = {
"action_mask": self._starcraft_env.get_avail_agent_actions(i),
"obs": obs
}
rews = {i: rew / len(obs_list) for i in range(len(obs_list))}
dones = {i: done for i in range(len(obs_list))}
dones["__all__"] = done
infos = {i: info for i in range(len(obs_list))}
return return_obs, rews, dones, infos
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--num-iters", type=int, default=100)
parser.add_argument("--run", type=str, default="qmix")
args = parser.parse_args()
path_to_pymarl = os.environ.get("PYMARL_PATH",
os.path.expanduser("~/pymarl/"))
ray.init()
ModelCatalog.register_custom_model("mask_model", MaskedActionsModel)
register_env("starcraft", lambda cfg: SC2MultiAgentEnv(cfg))
agent_cfg = {
"observation_filter": "NoFilter",
"num_workers": 4,
"model": {
"custom_model": "mask_model",
},
"env_config": {
"pymarl_path": path_to_pymarl
}
}
if args.run.lower() == "qmix":
def grouped_sc2(cfg):
env = SC2MultiAgentEnv(cfg)
agent_list = list(range(env._starcraft_env.n_agents))
grouping = {
"group_1": agent_list,
}
obs_space = Tuple([env.observation_space for i in agent_list])
act_space = Tuple([env.action_space for i in agent_list])
return env.with_agent_groups(
grouping, obs_space=obs_space, act_space=act_space)
register_env("grouped_starcraft", grouped_sc2)
agent = QMixAgent(env="grouped_starcraft", config=agent_cfg)
elif args.run.lower() == "pg":
agent = PGAgent(env="starcraft", config=agent_cfg)
elif args.run.lower() == "ppo":
agent_cfg.update({"vf_share_layers": True})
agent = PPOAgent(env="starcraft", config=agent_cfg)
for i in range(args.num_iters):
print(pretty_print(agent.train()))