mirror of
https://github.com/wassname/ray.git
synced 2026-07-03 21:13:54 +08:00
[rllib] [rfc] add contrib module and guideline for merging (#3565)
This adds guidelines for merging code into `rllib/contrib` vs `rllib/agents`. Also, clean up the agent import code to make registration easier.
This commit is contained in:
@@ -31,12 +31,11 @@ def _setup_logger():
|
||||
|
||||
def _register_all():
|
||||
|
||||
for key in [
|
||||
"PPO", "ES", "DQN", "APEX", "A3C", "PG", "DDPG", "APEX_DDPG",
|
||||
"IMPALA", "ARS", "A2C", "QMIX", "APEX_QMIX", "__fake",
|
||||
"__sigmoid_fake_data", "__parameter_tuning"
|
||||
]:
|
||||
from ray.rllib.agents.agent import get_agent_class
|
||||
from ray.rllib.agents.registry import ALGORITHMS
|
||||
from ray.rllib.contrib.registry import CONTRIBUTED_ALGORITHMS
|
||||
for key in list(ALGORITHMS.keys()) + list(CONTRIBUTED_ALGORITHMS.keys(
|
||||
)) + ["__fake", "__sigmoid_fake_data", "__parameter_tuning"]:
|
||||
from ray.rllib.agents.registry import get_agent_class
|
||||
register_trainable(key, get_agent_class(key))
|
||||
|
||||
|
||||
|
||||
@@ -10,7 +10,6 @@ import pickle
|
||||
import six
|
||||
import tempfile
|
||||
import tensorflow as tf
|
||||
import traceback
|
||||
from types import FunctionType
|
||||
|
||||
import ray
|
||||
@@ -542,69 +541,3 @@ def _register_if_needed(env_object):
|
||||
name = env_object.__name__
|
||||
register_env(name, lambda config: env_object(config))
|
||||
return name
|
||||
|
||||
|
||||
def get_agent_class(alg):
|
||||
"""Returns the class of a known agent given its name."""
|
||||
|
||||
try:
|
||||
return _get_agent_class(alg)
|
||||
except ImportError:
|
||||
from ray.rllib.agents.mock import _agent_import_failed
|
||||
return _agent_import_failed(traceback.format_exc())
|
||||
|
||||
|
||||
def _get_agent_class(alg):
|
||||
if alg == "DDPG":
|
||||
from ray.rllib.agents import ddpg
|
||||
return ddpg.DDPGAgent
|
||||
elif alg == "APEX_DDPG":
|
||||
from ray.rllib.agents import ddpg
|
||||
return ddpg.ApexDDPGAgent
|
||||
elif alg == "PPO":
|
||||
from ray.rllib.agents import ppo
|
||||
return ppo.PPOAgent
|
||||
elif alg == "ES":
|
||||
from ray.rllib.agents import es
|
||||
return es.ESAgent
|
||||
elif alg == "ARS":
|
||||
from ray.rllib.agents import ars
|
||||
return ars.ARSAgent
|
||||
elif alg == "DQN":
|
||||
from ray.rllib.agents import dqn
|
||||
return dqn.DQNAgent
|
||||
elif alg == "APEX":
|
||||
from ray.rllib.agents import dqn
|
||||
return dqn.ApexAgent
|
||||
elif alg == "A3C":
|
||||
from ray.rllib.agents import a3c
|
||||
return a3c.A3CAgent
|
||||
elif alg == "A2C":
|
||||
from ray.rllib.agents import a3c
|
||||
return a3c.A2CAgent
|
||||
elif alg == "PG":
|
||||
from ray.rllib.agents import pg
|
||||
return pg.PGAgent
|
||||
elif alg == "IMPALA":
|
||||
from ray.rllib.agents import impala
|
||||
return impala.ImpalaAgent
|
||||
elif alg == "QMIX":
|
||||
from ray.rllib.agents import qmix
|
||||
return qmix.QMixAgent
|
||||
elif alg == "APEX_QMIX":
|
||||
from ray.rllib.agents import qmix
|
||||
return qmix.ApexQMixAgent
|
||||
elif alg == "script":
|
||||
from ray.tune import script_runner
|
||||
return script_runner.ScriptRunner
|
||||
elif alg == "__fake":
|
||||
from ray.rllib.agents.mock import _MockAgent
|
||||
return _MockAgent
|
||||
elif alg == "__sigmoid_fake_data":
|
||||
from ray.rllib.agents.mock import _SigmoidFakeData
|
||||
return _SigmoidFakeData
|
||||
elif alg == "__parameter_tuning":
|
||||
from ray.rllib.agents.mock import _ParameterTuningAgent
|
||||
return _ParameterTuningAgent
|
||||
else:
|
||||
raise Exception(("Unknown algorithm {}.").format(alg))
|
||||
|
||||
@@ -0,0 +1,122 @@
|
||||
"""Registry of algorithm names for `rllib train --run=<alg_name>`"""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import traceback
|
||||
|
||||
from ray.rllib.contrib.registry import CONTRIBUTED_ALGORITHMS
|
||||
|
||||
|
||||
def _import_qmix():
|
||||
from ray.rllib.agents import qmix
|
||||
return qmix.QMixAgent
|
||||
|
||||
|
||||
def _import_apex_qmix():
|
||||
from ray.rllib.agents import qmix
|
||||
return qmix.ApexQMixAgent
|
||||
|
||||
|
||||
def _import_ddpg():
|
||||
from ray.rllib.agents import ddpg
|
||||
return ddpg.DDPGAgent
|
||||
|
||||
|
||||
def _import_apex_ddpg():
|
||||
from ray.rllib.agents import ddpg
|
||||
return ddpg.ApexDDPGAgent
|
||||
|
||||
|
||||
def _import_ppo():
|
||||
from ray.rllib.agents import ppo
|
||||
return ppo.PPOAgent
|
||||
|
||||
|
||||
def _import_es():
|
||||
from ray.rllib.agents import es
|
||||
return es.ESAgent
|
||||
|
||||
|
||||
def _import_ars():
|
||||
from ray.rllib.agents import ars
|
||||
return ars.ARSAgent
|
||||
|
||||
|
||||
def _import_dqn():
|
||||
from ray.rllib.agents import dqn
|
||||
return dqn.DQNAgent
|
||||
|
||||
|
||||
def _import_apex():
|
||||
from ray.rllib.agents import dqn
|
||||
return dqn.ApexAgent
|
||||
|
||||
|
||||
def _import_a3c():
|
||||
from ray.rllib.agents import a3c
|
||||
return a3c.A3CAgent
|
||||
|
||||
|
||||
def _import_a2c():
|
||||
from ray.rllib.agents import a3c
|
||||
return a3c.A2CAgent
|
||||
|
||||
|
||||
def _import_pg():
|
||||
from ray.rllib.agents import pg
|
||||
return pg.PGAgent
|
||||
|
||||
|
||||
def _import_impala():
|
||||
from ray.rllib.agents import impala
|
||||
return impala.ImpalaAgent
|
||||
|
||||
|
||||
ALGORITHMS = {
|
||||
"DDPG": _import_ddpg,
|
||||
"APEX_DDPG": _import_apex_ddpg,
|
||||
"PPO": _import_ppo,
|
||||
"ES": _import_es,
|
||||
"ARS": _import_ars,
|
||||
"DQN": _import_dqn,
|
||||
"APEX": _import_apex,
|
||||
"A3C": _import_a3c,
|
||||
"A2C": _import_a2c,
|
||||
"PG": _import_pg,
|
||||
"IMPALA": _import_impala,
|
||||
"QMIX": _import_qmix,
|
||||
"APEX_QMIX": _import_apex_qmix,
|
||||
}
|
||||
|
||||
|
||||
def get_agent_class(alg):
|
||||
"""Returns the class of a known agent given its name."""
|
||||
|
||||
try:
|
||||
return _get_agent_class(alg)
|
||||
except ImportError:
|
||||
from ray.rllib.agents.mock import _agent_import_failed
|
||||
return _agent_import_failed(traceback.format_exc())
|
||||
|
||||
|
||||
def _get_agent_class(alg):
|
||||
if alg in ALGORITHMS:
|
||||
return ALGORITHMS[alg]()
|
||||
elif alg in CONTRIBUTED_ALGORITHMS:
|
||||
return CONTRIBUTED_ALGORITHMS[alg]()
|
||||
elif alg == "script":
|
||||
from ray.tune import script_runner
|
||||
return script_runner.ScriptRunner
|
||||
elif alg == "__fake":
|
||||
from ray.rllib.agents.mock import _MockAgent
|
||||
return _MockAgent
|
||||
elif alg == "__sigmoid_fake_data":
|
||||
from ray.rllib.agents.mock import _SigmoidFakeData
|
||||
return _SigmoidFakeData
|
||||
elif alg == "__parameter_tuning":
|
||||
from ray.rllib.agents.mock import _ParameterTuningAgent
|
||||
return _ParameterTuningAgent
|
||||
else:
|
||||
raise Exception(("Unknown algorithm {}.").format(alg))
|
||||
@@ -0,0 +1,3 @@
|
||||
Contributed algorithms, which can be run via `rllib train --run=contrib/<alg_name>`
|
||||
|
||||
See https://ray.readthedocs.io/en/latest/rllib-dev.html for guidelines.
|
||||
@@ -0,0 +1,52 @@
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import numpy as np
|
||||
|
||||
from ray.rllib.agents.agent import Agent, with_common_config
|
||||
from ray.rllib.utils.annotations import override
|
||||
|
||||
|
||||
# yapf: disable
|
||||
# __sphinx_doc_begin__
|
||||
class RandomAgent(Agent):
|
||||
"""Agent that takes random actions and never learns."""
|
||||
|
||||
_agent_name = "RandomAgent"
|
||||
_default_config = with_common_config({
|
||||
"rollouts_per_iteration": 10,
|
||||
})
|
||||
|
||||
@override(Agent)
|
||||
def _init(self):
|
||||
self.env = self.env_creator(self.config["env_config"])
|
||||
|
||||
@override(Agent)
|
||||
def _train(self):
|
||||
rewards = []
|
||||
steps = 0
|
||||
for _ in range(self.config["rollouts_per_iteration"]):
|
||||
obs = self.env.reset()
|
||||
done = False
|
||||
reward = 0.0
|
||||
while not done:
|
||||
action = self.env.action_space.sample()
|
||||
obs, r, done, info = self.env.step(action)
|
||||
reward += r
|
||||
steps += 1
|
||||
rewards.append(reward)
|
||||
return {
|
||||
"episode_reward_mean": np.mean(rewards),
|
||||
"timesteps_this_iter": steps,
|
||||
}
|
||||
# __sphinx_doc_end__
|
||||
# don't enable yapf after, it's buggy here
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
agent = RandomAgent(
|
||||
env="CartPole-v0", config={"rollouts_per_iteration": 10})
|
||||
result = agent.train()
|
||||
assert result["episode_reward_mean"] > 10, result
|
||||
print("Test: OK")
|
||||
@@ -0,0 +1,15 @@
|
||||
"""Registry of algorithm names for `rllib train --run=<alg_name>`"""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
|
||||
def _import_random_agent():
|
||||
from ray.rllib.contrib.random_agent.random_agent import RandomAgent
|
||||
return RandomAgent
|
||||
|
||||
|
||||
CONTRIBUTED_ALGORITHMS = {
|
||||
"contrib/RandomAgent": _import_random_agent,
|
||||
}
|
||||
@@ -11,7 +11,7 @@ import pickle
|
||||
|
||||
import gym
|
||||
import ray
|
||||
from ray.rllib.agents.agent import get_agent_class
|
||||
from ray.rllib.agents.registry import get_agent_class
|
||||
|
||||
EXAMPLE_USAGE = """
|
||||
Example Usage via RLlib CLI:
|
||||
|
||||
@@ -7,7 +7,7 @@ from __future__ import print_function
|
||||
import numpy as np
|
||||
import ray
|
||||
|
||||
from ray.rllib.agents.agent import get_agent_class
|
||||
from ray.rllib.agents.registry import get_agent_class
|
||||
|
||||
|
||||
def get_mean_action(alg, obs):
|
||||
|
||||
@@ -8,7 +8,7 @@ import numpy as np
|
||||
import sys
|
||||
|
||||
import ray
|
||||
from ray.rllib.agents.agent import get_agent_class
|
||||
from ray.rllib.agents.registry import get_agent_class
|
||||
from ray.rllib.test.test_multi_agent_env import MultiCartpole, MultiMountainCar
|
||||
from ray.rllib.utils.error import UnsupportedSpaceException
|
||||
from ray.tune.registry import register_env
|
||||
|
||||
@@ -339,4 +339,4 @@ class Trial(object):
|
||||
identifier = self.trainable_name
|
||||
if self.experiment_tag:
|
||||
identifier += "_" + self.experiment_tag
|
||||
return identifier
|
||||
return identifier.replace("/", "_")
|
||||
|
||||
Reference in New Issue
Block a user