diff --git a/python/ray/rllib/examples/starcraft/README.rst b/python/ray/rllib/examples/starcraft/README.rst new file mode 100644 index 000000000..b10cc5d6f --- /dev/null +++ b/python/ray/rllib/examples/starcraft/README.rst @@ -0,0 +1,18 @@ +StarCraft on RLlib +================== + +This builds off the StarCraft env in https://github.com/oxwhirl/pymarl_alpha. + +Temporary instructions +---------------------- + +To install, run + +``` +git clone https://github.com/oxwhirl/pymarl_alpha +mv pymarl_alpha ~/pymarl +cd ~/pymarl +install_sc1.sh +install_sc2.sh +export PYMARL_PATH="~/pymarl" +``` diff --git a/python/ray/rllib/examples/starcraft/sc2.yaml b/python/ray/rllib/examples/starcraft/sc2.yaml new file mode 100644 index 000000000..db108e226 --- /dev/null +++ b/python/ray/rllib/examples/starcraft/sc2.yaml @@ -0,0 +1,32 @@ +## Adapted from `https://github.com/oxwhirl/pymarl_alpha`. + +env: sc2 + +env_args: + map_name: "3m_3m" # SC2 map name + difficulty: "7" # Very hard + move_amount: 2 # How much units are ordered to move per step + step_mul: 8 # How many frames are skiped per step + reward_sparse: False # Only +1/-1 reward for win/defeat (the rest of reward configs are ignored if True) + reward_only_positive: True # Reward is always positive + reward_negative_scale: 0.5 # How much to scale negative rewards, ignored if reward_only_positive=True + reward_death_value: 10 # Reward for killing an enemy unit and penalty for having an allied unit killed (if reward_only_poitive=False) + reward_scale: True # Whether or not to scale rewards before returning to agents + reward_scale_rate: 20 # If reward_scale=True, the agents receive the reward of (max_reward / reward_scale_rate), where max_reward is the maximum possible reward per episode + reward_win: 200 # Reward for win + reward_defeat: 0 # Reward for defeat (should be nonpositive) + state_last_action: True # Whether the last actions of units is a part of the state + obs_instead_of_state: False # Use combination of all agnets' observations as state + obs_own_health: True # Whether agents receive their own health as a part of observation + obs_all_health: True # Whether agents receive the health of all units (in the sight range) as a part of observataion + continuing_episode: False # Stop/continue episode after its termination + game_version: "4.1.2" # Ignored for Mac/Windows + save_replay_prefix: "" # Prefix of the replay to be saved + heuristic: False # Whether or not use a simple nonlearning hearistic as a controller + +test_nepisode: 32 +test_interval: 10000 +log_interval: 2000 +runner_log_interval: 2000 +learner_log_interval: 2000 +t_max: 2000000 diff --git a/python/ray/rllib/examples/starcraft/starcraft_env.py b/python/ray/rllib/examples/starcraft/starcraft_env.py new file mode 100644 index 000000000..7cfd3f266 --- /dev/null +++ b/python/ray/rllib/examples/starcraft/starcraft_env.py @@ -0,0 +1,153 @@ +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import argparse +from gym.spaces import Discrete, Box, Dict, Tuple +import os +import sys +import tensorflow as tf +import tensorflow.contrib.slim as slim +import yaml + +import ray +from ray.rllib.env.multi_agent_env import MultiAgentEnv +from ray.tune.registry import register_env +from ray.rllib.models import Model, ModelCatalog +from ray.rllib.models.misc import normc_initializer +from ray.rllib.agents.qmix import QMixAgent +from ray.rllib.agents.pg import PGAgent +from ray.rllib.agents.ppo import PPOAgent +from ray.tune.logger import pretty_print + + +class MaskedActionsModel(Model): + def _build_layers_v2(self, input_dict, num_outputs, options): + action_mask = input_dict["obs"]["action_mask"] + if num_outputs != action_mask.shape[1].value: + raise ValueError( + "This model assumes num outputs is equal to max avail actions", + num_outputs, action_mask) + + # Standard FC net component. + last_layer = input_dict["obs"]["obs"] + hiddens = [256, 256] + for i, size in enumerate(hiddens): + label = "fc{}".format(i) + last_layer = slim.fully_connected( + last_layer, + size, + weights_initializer=normc_initializer(1.0), + activation_fn=tf.nn.tanh, + scope=label) + action_logits = slim.fully_connected( + last_layer, + num_outputs, + weights_initializer=normc_initializer(0.01), + activation_fn=None, + scope="fc_out") + + # Mask out invalid actions (use tf.float32.min for stability) + inf_mask = tf.maximum(tf.log(action_mask), tf.float32.min) + masked_logits = inf_mask + action_logits + + return masked_logits, last_layer + + +class SC2MultiAgentEnv(MultiAgentEnv): + """RLlib Wrapper around StarCraft2.""" + + def __init__(self, override_cfg): + PYMARL_PATH = override_cfg.pop("pymarl_path") + os.environ["SC2PATH"] = os.path.join(PYMARL_PATH, + "3rdparty/StarCraftII") + sys.path.append(os.path.join(PYMARL_PATH, "src")) + from envs.starcraft2 import StarCraft2Env + curpath = os.path.dirname(os.path.abspath(__file__)) + with open(os.path.join(curpath, "sc2.yaml")) as f: + pymarl_args = yaml.load(f) + pymarl_args.update(override_cfg) + pymarl_args["env_args"].setdefault("seed", 0) + + self._starcraft_env = StarCraft2Env(**pymarl_args) + obs_size = self._starcraft_env.get_obs_size() + num_actions = self._starcraft_env.get_total_actions() + self.observation_space = Dict({ + "action_mask": Box(0, 1, shape=(num_actions, )), + "obs": Box(-1, 1, shape=(obs_size, )) + }) + self.action_space = Discrete(self._starcraft_env.get_total_actions()) + + def reset(self): + obs_list, state_list = self._starcraft_env.reset() + return_obs = {} + for i, obs in enumerate(obs_list): + return_obs[i] = { + "action_mask": self._starcraft_env.get_avail_agent_actions(i), + "obs": obs + } + return return_obs + + def step(self, action_dict): + # TODO(rliaw): Check to handle missing agents, if any + actions = [action_dict[k] for k in sorted(action_dict)] + rew, done, info = self._starcraft_env.step(actions) + obs_list = self._starcraft_env.get_obs() + return_obs = {} + for i, obs in enumerate(obs_list): + return_obs[i] = { + "action_mask": self._starcraft_env.get_avail_agent_actions(i), + "obs": obs + } + rews = {i: rew / len(obs_list) for i in range(len(obs_list))} + dones = {i: done for i in range(len(obs_list))} + dones["__all__"] = done + infos = {i: info for i in range(len(obs_list))} + return return_obs, rews, dones, infos + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--num-iters", type=int, default=100) + parser.add_argument("--run", type=str, default="qmix") + args = parser.parse_args() + + path_to_pymarl = os.environ.get("PYMARL_PATH", + os.path.expanduser("~/pymarl/")) + + ray.init() + ModelCatalog.register_custom_model("mask_model", MaskedActionsModel) + + register_env("starcraft", lambda cfg: SC2MultiAgentEnv(cfg)) + agent_cfg = { + "observation_filter": "NoFilter", + "num_workers": 4, + "model": { + "custom_model": "mask_model", + }, + "env_config": { + "pymarl_path": path_to_pymarl + } + } + if args.run.lower() == "qmix": + + def grouped_sc2(cfg): + env = SC2MultiAgentEnv(cfg) + agent_list = list(range(env._starcraft_env.n_agents)) + grouping = { + "group_1": agent_list, + } + obs_space = Tuple([env.observation_space for i in agent_list]) + act_space = Tuple([env.action_space for i in agent_list]) + return env.with_agent_groups( + grouping, obs_space=obs_space, act_space=act_space) + + register_env("grouped_starcraft", grouped_sc2) + agent = QMixAgent(env="grouped_starcraft", config=agent_cfg) + elif args.run.lower() == "pg": + agent = PGAgent(env="starcraft", config=agent_cfg) + elif args.run.lower() == "ppo": + agent_cfg.update({"vf_share_layers": True}) + agent = PPOAgent(env="starcraft", config=agent_cfg) + for i in range(args.num_iters): + print(pretty_print(agent.train()))