[rllib] [rfc] add contrib module and guideline for merging (#3565)

This adds guidelines for merging code into `rllib/contrib` vs `rllib/agents`. Also, clean up the agent import code to make registration easier.
This commit is contained in:
Eric Liang
2018-12-21 03:44:34 +09:00
committed by Richard Liaw
parent cf0c4745f4
commit 303883a3b6
17 changed files with 280 additions and 79 deletions
+5 -6
View File
@@ -31,12 +31,11 @@ def _setup_logger():
def _register_all():
for key in [
"PPO", "ES", "DQN", "APEX", "A3C", "PG", "DDPG", "APEX_DDPG",
"IMPALA", "ARS", "A2C", "QMIX", "APEX_QMIX", "__fake",
"__sigmoid_fake_data", "__parameter_tuning"
]:
from ray.rllib.agents.agent import get_agent_class
from ray.rllib.agents.registry import ALGORITHMS
from ray.rllib.contrib.registry import CONTRIBUTED_ALGORITHMS
for key in list(ALGORITHMS.keys()) + list(CONTRIBUTED_ALGORITHMS.keys(
)) + ["__fake", "__sigmoid_fake_data", "__parameter_tuning"]:
from ray.rllib.agents.registry import get_agent_class
register_trainable(key, get_agent_class(key))
-67
View File
@@ -10,7 +10,6 @@ import pickle
import six
import tempfile
import tensorflow as tf
import traceback
from types import FunctionType
import ray
@@ -542,69 +541,3 @@ def _register_if_needed(env_object):
name = env_object.__name__
register_env(name, lambda config: env_object(config))
return name
def get_agent_class(alg):
"""Returns the class of a known agent given its name."""
try:
return _get_agent_class(alg)
except ImportError:
from ray.rllib.agents.mock import _agent_import_failed
return _agent_import_failed(traceback.format_exc())
def _get_agent_class(alg):
if alg == "DDPG":
from ray.rllib.agents import ddpg
return ddpg.DDPGAgent
elif alg == "APEX_DDPG":
from ray.rllib.agents import ddpg
return ddpg.ApexDDPGAgent
elif alg == "PPO":
from ray.rllib.agents import ppo
return ppo.PPOAgent
elif alg == "ES":
from ray.rllib.agents import es
return es.ESAgent
elif alg == "ARS":
from ray.rllib.agents import ars
return ars.ARSAgent
elif alg == "DQN":
from ray.rllib.agents import dqn
return dqn.DQNAgent
elif alg == "APEX":
from ray.rllib.agents import dqn
return dqn.ApexAgent
elif alg == "A3C":
from ray.rllib.agents import a3c
return a3c.A3CAgent
elif alg == "A2C":
from ray.rllib.agents import a3c
return a3c.A2CAgent
elif alg == "PG":
from ray.rllib.agents import pg
return pg.PGAgent
elif alg == "IMPALA":
from ray.rllib.agents import impala
return impala.ImpalaAgent
elif alg == "QMIX":
from ray.rllib.agents import qmix
return qmix.QMixAgent
elif alg == "APEX_QMIX":
from ray.rllib.agents import qmix
return qmix.ApexQMixAgent
elif alg == "script":
from ray.tune import script_runner
return script_runner.ScriptRunner
elif alg == "__fake":
from ray.rllib.agents.mock import _MockAgent
return _MockAgent
elif alg == "__sigmoid_fake_data":
from ray.rllib.agents.mock import _SigmoidFakeData
return _SigmoidFakeData
elif alg == "__parameter_tuning":
from ray.rllib.agents.mock import _ParameterTuningAgent
return _ParameterTuningAgent
else:
raise Exception(("Unknown algorithm {}.").format(alg))
+122
View File
@@ -0,0 +1,122 @@
"""Registry of algorithm names for `rllib train --run=<alg_name>`"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import traceback
from ray.rllib.contrib.registry import CONTRIBUTED_ALGORITHMS
def _import_qmix():
from ray.rllib.agents import qmix
return qmix.QMixAgent
def _import_apex_qmix():
from ray.rllib.agents import qmix
return qmix.ApexQMixAgent
def _import_ddpg():
from ray.rllib.agents import ddpg
return ddpg.DDPGAgent
def _import_apex_ddpg():
from ray.rllib.agents import ddpg
return ddpg.ApexDDPGAgent
def _import_ppo():
from ray.rllib.agents import ppo
return ppo.PPOAgent
def _import_es():
from ray.rllib.agents import es
return es.ESAgent
def _import_ars():
from ray.rllib.agents import ars
return ars.ARSAgent
def _import_dqn():
from ray.rllib.agents import dqn
return dqn.DQNAgent
def _import_apex():
from ray.rllib.agents import dqn
return dqn.ApexAgent
def _import_a3c():
from ray.rllib.agents import a3c
return a3c.A3CAgent
def _import_a2c():
from ray.rllib.agents import a3c
return a3c.A2CAgent
def _import_pg():
from ray.rllib.agents import pg
return pg.PGAgent
def _import_impala():
from ray.rllib.agents import impala
return impala.ImpalaAgent
ALGORITHMS = {
"DDPG": _import_ddpg,
"APEX_DDPG": _import_apex_ddpg,
"PPO": _import_ppo,
"ES": _import_es,
"ARS": _import_ars,
"DQN": _import_dqn,
"APEX": _import_apex,
"A3C": _import_a3c,
"A2C": _import_a2c,
"PG": _import_pg,
"IMPALA": _import_impala,
"QMIX": _import_qmix,
"APEX_QMIX": _import_apex_qmix,
}
def get_agent_class(alg):
"""Returns the class of a known agent given its name."""
try:
return _get_agent_class(alg)
except ImportError:
from ray.rllib.agents.mock import _agent_import_failed
return _agent_import_failed(traceback.format_exc())
def _get_agent_class(alg):
if alg in ALGORITHMS:
return ALGORITHMS[alg]()
elif alg in CONTRIBUTED_ALGORITHMS:
return CONTRIBUTED_ALGORITHMS[alg]()
elif alg == "script":
from ray.tune import script_runner
return script_runner.ScriptRunner
elif alg == "__fake":
from ray.rllib.agents.mock import _MockAgent
return _MockAgent
elif alg == "__sigmoid_fake_data":
from ray.rllib.agents.mock import _SigmoidFakeData
return _SigmoidFakeData
elif alg == "__parameter_tuning":
from ray.rllib.agents.mock import _ParameterTuningAgent
return _ParameterTuningAgent
else:
raise Exception(("Unknown algorithm {}.").format(alg))
+3
View File
@@ -0,0 +1,3 @@
Contributed algorithms, which can be run via `rllib train --run=contrib/<alg_name>`
See https://ray.readthedocs.io/en/latest/rllib-dev.html for guidelines.
@@ -0,0 +1,52 @@
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
from ray.rllib.agents.agent import Agent, with_common_config
from ray.rllib.utils.annotations import override
# yapf: disable
# __sphinx_doc_begin__
class RandomAgent(Agent):
"""Agent that takes random actions and never learns."""
_agent_name = "RandomAgent"
_default_config = with_common_config({
"rollouts_per_iteration": 10,
})
@override(Agent)
def _init(self):
self.env = self.env_creator(self.config["env_config"])
@override(Agent)
def _train(self):
rewards = []
steps = 0
for _ in range(self.config["rollouts_per_iteration"]):
obs = self.env.reset()
done = False
reward = 0.0
while not done:
action = self.env.action_space.sample()
obs, r, done, info = self.env.step(action)
reward += r
steps += 1
rewards.append(reward)
return {
"episode_reward_mean": np.mean(rewards),
"timesteps_this_iter": steps,
}
# __sphinx_doc_end__
# don't enable yapf after, it's buggy here
if __name__ == "__main__":
agent = RandomAgent(
env="CartPole-v0", config={"rollouts_per_iteration": 10})
result = agent.train()
assert result["episode_reward_mean"] > 10, result
print("Test: OK")
+15
View File
@@ -0,0 +1,15 @@
"""Registry of algorithm names for `rllib train --run=<alg_name>`"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
def _import_random_agent():
from ray.rllib.contrib.random_agent.random_agent import RandomAgent
return RandomAgent
CONTRIBUTED_ALGORITHMS = {
"contrib/RandomAgent": _import_random_agent,
}
+1 -1
View File
@@ -11,7 +11,7 @@ import pickle
import gym
import ray
from ray.rllib.agents.agent import get_agent_class
from ray.rllib.agents.registry import get_agent_class
EXAMPLE_USAGE = """
Example Usage via RLlib CLI:
@@ -7,7 +7,7 @@ from __future__ import print_function
import numpy as np
import ray
from ray.rllib.agents.agent import get_agent_class
from ray.rllib.agents.registry import get_agent_class
def get_mean_action(alg, obs):
@@ -8,7 +8,7 @@ import numpy as np
import sys
import ray
from ray.rllib.agents.agent import get_agent_class
from ray.rllib.agents.registry import get_agent_class
from ray.rllib.test.test_multi_agent_env import MultiCartpole, MultiMountainCar
from ray.rllib.utils.error import UnsupportedSpaceException
from ray.tune.registry import register_env
+1 -1
View File
@@ -339,4 +339,4 @@ class Trial(object):
identifier = self.trainable_name
if self.experiment_tag:
identifier += "_" + self.experiment_tag
return identifier
return identifier.replace("/", "_")