mirror of
https://github.com/wassname/ray.git
synced 2026-06-30 07:53:50 +08:00
[rllib] RLlib CLI (#2375)
This commit is contained in:
@@ -3,7 +3,7 @@ from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from ray.rllib.models import ModelCatalog
|
||||
from ray.rllib.utils.atari_wrappers import wrap_deepmind
|
||||
from ray.rllib.env.atari_wrappers import wrap_deepmind
|
||||
|
||||
|
||||
def wrap_dqn(env, options, random_starts):
|
||||
|
||||
+45
-31
@@ -15,43 +15,51 @@ from ray.rllib.agents.agent import get_agent_class
|
||||
from ray.rllib.agents.dqn.common.wrappers import wrap_dqn
|
||||
from ray.rllib.models import ModelCatalog
|
||||
|
||||
|
||||
EXAMPLE_USAGE = """
|
||||
example usage:
|
||||
./rollout.py /tmp/ray/checkpoint_dir/checkpoint-0 --run DQN """
|
||||
"""--env CartPole-v0 --steps 1000000 --out rollouts.pkl
|
||||
Example Usage via RLlib CLI:
|
||||
rllib rollout /tmp/ray/checkpoint_dir/checkpoint-0 --run DQN
|
||||
--env CartPole-v0 --steps 1000000 --out rollouts.pkl
|
||||
|
||||
Example Usage via executable:
|
||||
./rollout.py /tmp/ray/checkpoint_dir/checkpoint-0 --run DQN
|
||||
--env CartPole-v0 --steps 1000000 --out rollouts.pkl
|
||||
"""
|
||||
|
||||
parser = argparse.ArgumentParser(
|
||||
formatter_class=argparse.RawDescriptionHelpFormatter,
|
||||
description="Roll out a reinforcement learning agent "
|
||||
"given a checkpoint.", epilog=EXAMPLE_USAGE)
|
||||
|
||||
parser.add_argument(
|
||||
"checkpoint", type=str, help="Checkpoint from which to roll out.")
|
||||
required_named = parser.add_argument_group("required named arguments")
|
||||
required_named.add_argument(
|
||||
"--run", type=str, required=True,
|
||||
help="The algorithm or model to train. This may refer to the name "
|
||||
"of a built-on algorithm (e.g. RLLib's DQN or PPO), or a "
|
||||
"user-defined trainable function or class registered in the "
|
||||
"tune registry.")
|
||||
required_named.add_argument(
|
||||
"--env", type=str, help="The gym environment to use.")
|
||||
parser.add_argument(
|
||||
"--no-render", default=False, action="store_const", const=True,
|
||||
help="Surpress rendering of the environment.")
|
||||
parser.add_argument(
|
||||
"--steps", default=None, help="Number of steps to roll out.")
|
||||
parser.add_argument(
|
||||
"--out", default=None, help="Output filename.")
|
||||
parser.add_argument(
|
||||
"--config", default="{}", type=json.loads,
|
||||
help="Algorithm-specific configuration (e.g. env, hyperparams). "
|
||||
"Surpresses loading of configuration from checkpoint.")
|
||||
def create_parser(parser_creator=None):
|
||||
parser_creator = parser_creator or argparse.ArgumentParser
|
||||
parser = parser_creator(
|
||||
formatter_class=argparse.RawDescriptionHelpFormatter,
|
||||
description="Roll out a reinforcement learning agent "
|
||||
"given a checkpoint.", epilog=EXAMPLE_USAGE)
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = parser.parse_args()
|
||||
parser.add_argument(
|
||||
"checkpoint", type=str, help="Checkpoint from which to roll out.")
|
||||
required_named = parser.add_argument_group("required named arguments")
|
||||
required_named.add_argument(
|
||||
"--run", type=str, required=True,
|
||||
help="The algorithm or model to train. This may refer to the name "
|
||||
"of a built-on algorithm (e.g. RLLib's DQN or PPO), or a "
|
||||
"user-defined trainable function or class registered in the "
|
||||
"tune registry.")
|
||||
required_named.add_argument(
|
||||
"--env", type=str, help="The gym environment to use.")
|
||||
parser.add_argument(
|
||||
"--no-render", default=False, action="store_const", const=True,
|
||||
help="Surpress rendering of the environment.")
|
||||
parser.add_argument(
|
||||
"--steps", default=None, help="Number of steps to roll out.")
|
||||
parser.add_argument(
|
||||
"--out", default=None, help="Output filename.")
|
||||
parser.add_argument(
|
||||
"--config", default="{}", type=json.loads,
|
||||
help="Algorithm-specific configuration (e.g. env, hyperparams). "
|
||||
"Surpresses loading of configuration from checkpoint.")
|
||||
return parser
|
||||
|
||||
|
||||
def run(args, parser):
|
||||
if not args.config:
|
||||
# Load configuration from file
|
||||
config_dir = os.path.dirname(args.checkpoint)
|
||||
@@ -100,3 +108,9 @@ if __name__ == "__main__":
|
||||
print("Episode reward", reward_total)
|
||||
if args.out is not None:
|
||||
pickle.dump(rollouts, open(args.out, "wb"))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = create_parser()
|
||||
args = parser.parse_args()
|
||||
run(args, parser)
|
||||
|
||||
@@ -0,0 +1,43 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import argparse
|
||||
|
||||
from ray.rllib import train
|
||||
from ray.rllib import rollout
|
||||
|
||||
|
||||
EXAMPLE_USAGE = """
|
||||
Example usage for training:
|
||||
rllib train --run DQN --env CartPole-v0
|
||||
|
||||
Example usage for rollout:
|
||||
rllib rollout /tmp/ray/checkpoint_dir/checkpoint-0 --run DQN
|
||||
"""
|
||||
|
||||
|
||||
def cli():
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Train or Run an RLlib Agent.",
|
||||
formatter_class=argparse.RawDescriptionHelpFormatter,
|
||||
epilog=EXAMPLE_USAGE)
|
||||
subcommand_group = parser.add_subparsers(
|
||||
help="Commands to train or run an RLlib agent.", dest="command")
|
||||
|
||||
# see _SubParsersAction.add_parser in
|
||||
# https://github.com/python/cpython/blob/master/Lib/argparse.py
|
||||
train_parser = train.create_parser(
|
||||
lambda **kwargs: subcommand_group.add_parser("train", **kwargs))
|
||||
rollout_parser = rollout.create_parser(
|
||||
lambda **kwargs: subcommand_group.add_parser("rollout", **kwargs))
|
||||
options = parser.parse_args()
|
||||
|
||||
if options.command == "train":
|
||||
train.run(options, train_parser)
|
||||
elif options.command == "rollout":
|
||||
rollout.run(options, rollout_parser)
|
||||
else:
|
||||
parser.print_help()
|
||||
+47
-35
@@ -5,7 +5,6 @@ from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import argparse
|
||||
import sys
|
||||
import yaml
|
||||
|
||||
import ray
|
||||
@@ -14,50 +13,57 @@ from ray.tune.tune import _make_scheduler, run_experiments
|
||||
|
||||
|
||||
EXAMPLE_USAGE = """
|
||||
Training example:
|
||||
./train.py --run DQN --env CartPole-v0
|
||||
Training example via RLlib CLI:
|
||||
rllib train --run DQN --env CartPole-v0
|
||||
|
||||
Grid search example:
|
||||
Grid search example via RLlib CLI:
|
||||
rllib train -f tuned_examples/cartpole-grid-search-example.yaml
|
||||
|
||||
Grid search example via executable:
|
||||
./train.py -f tuned_examples/cartpole-grid-search-example.yaml
|
||||
|
||||
Note that -f overrides all other trial-specific command-line options.
|
||||
"""
|
||||
|
||||
|
||||
parser = make_parser(
|
||||
formatter_class=argparse.RawDescriptionHelpFormatter,
|
||||
description="Train a reinforcement learning agent.",
|
||||
epilog=EXAMPLE_USAGE)
|
||||
def create_parser(parser_creator=None):
|
||||
parser = make_parser(
|
||||
parser_creator=parser_creator,
|
||||
formatter_class=argparse.RawDescriptionHelpFormatter,
|
||||
description="Train a reinforcement learning agent.",
|
||||
epilog=EXAMPLE_USAGE)
|
||||
|
||||
# See also the base parser definition in ray/tune/config_parser.py
|
||||
parser.add_argument(
|
||||
"--redis-address", default=None, type=str,
|
||||
help="The Redis address of the cluster.")
|
||||
parser.add_argument(
|
||||
"--ray-num-cpus", default=None, type=int,
|
||||
help="--num-cpus to pass to Ray. This only has an affect in local mode.")
|
||||
parser.add_argument(
|
||||
"--ray-num-gpus", default=None, type=int,
|
||||
help="--num-gpus to pass to Ray. This only has an affect in local mode.")
|
||||
parser.add_argument(
|
||||
"--experiment-name", default="default", type=str,
|
||||
help="Name of the subdirectory under `local_dir` to put results in.")
|
||||
parser.add_argument(
|
||||
"--env", default=None, type=str, help="The gym environment to use.")
|
||||
parser.add_argument(
|
||||
"--queue-trials", action='store_true',
|
||||
help=(
|
||||
"Whether to queue trials when the cluster does not currently have "
|
||||
"enough resources to launch one. This should be set to True when "
|
||||
"running on an autoscaling cluster to enable automatic scale-up."))
|
||||
parser.add_argument(
|
||||
"-f", "--config-file", default=None, type=str,
|
||||
help="If specified, use config options from this file. Note that this "
|
||||
"overrides any trial-specific options set via flags above.")
|
||||
# See also the base parser definition in ray/tune/config_parser.py
|
||||
parser.add_argument(
|
||||
"--redis-address", default=None, type=str,
|
||||
help="The Redis address of the cluster.")
|
||||
parser.add_argument(
|
||||
"--ray-num-cpus", default=None, type=int,
|
||||
help="--num-cpus to pass to Ray."
|
||||
" This only has an affect in local mode.")
|
||||
parser.add_argument(
|
||||
"--ray-num-gpus", default=None, type=int,
|
||||
help="--num-gpus to pass to Ray."
|
||||
" This only has an affect in local mode.")
|
||||
parser.add_argument(
|
||||
"--experiment-name", default="default", type=str,
|
||||
help="Name of the subdirectory under `local_dir` to put results in.")
|
||||
parser.add_argument(
|
||||
"--env", default=None, type=str, help="The gym environment to use.")
|
||||
parser.add_argument(
|
||||
"--queue-trials", action='store_true',
|
||||
help=(
|
||||
"Whether to queue trials when the cluster does not currently have "
|
||||
"enough resources to launch one. This should be set to True when "
|
||||
"running on an autoscaling cluster to enable automatic scale-up."))
|
||||
parser.add_argument(
|
||||
"-f", "--config-file", default=None, type=str,
|
||||
help="If specified, use config options from this file. Note that this "
|
||||
"overrides any trial-specific options set via flags above.")
|
||||
return parser
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = parser.parse_args(sys.argv[1:])
|
||||
def run(args, parser):
|
||||
if args.config_file:
|
||||
with open(args.config_file) as f:
|
||||
experiments = yaml.load(f)
|
||||
@@ -91,3 +97,9 @@ if __name__ == "__main__":
|
||||
run_experiments(
|
||||
experiments, scheduler=_make_scheduler(args),
|
||||
queue_trials=args.queue_trials)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = create_parser()
|
||||
args = parser.parse_args()
|
||||
run(args, parser)
|
||||
|
||||
@@ -44,10 +44,19 @@ def _tune_error(msg):
|
||||
raise TuneError(msg)
|
||||
|
||||
|
||||
def make_parser(**kwargs):
|
||||
"""Returns a base argument parser for the ray.tune tool."""
|
||||
def make_parser(parser_creator=None, **kwargs):
|
||||
"""Returns a base argument parser for the ray.tune tool.
|
||||
|
||||
parser = argparse.ArgumentParser(**kwargs)
|
||||
Args:
|
||||
parser_creator: A constructor for the parser class.
|
||||
kwargs: Non-positional args to be passed into the
|
||||
parser class constructor.
|
||||
"""
|
||||
|
||||
if parser_creator:
|
||||
parser = parser_creator(**kwargs)
|
||||
else:
|
||||
parser = argparse.ArgumentParser(**kwargs)
|
||||
|
||||
# Note: keep this in sync with rllib/train.py
|
||||
parser.add_argument(
|
||||
|
||||
+6
-1
@@ -145,7 +145,12 @@ setup(
|
||||
],
|
||||
setup_requires=["cython >= 0.27, < 0.28"],
|
||||
extras_require=extras,
|
||||
entry_points={"console_scripts": ["ray=ray.scripts.scripts:main"]},
|
||||
entry_points={
|
||||
"console_scripts": [
|
||||
"ray=ray.scripts.scripts:main",
|
||||
"rllib=ray.rllib.scripts:cli [rllib]"
|
||||
]
|
||||
},
|
||||
include_package_data=True,
|
||||
zip_safe=False,
|
||||
license="Apache 2.0")
|
||||
|
||||
Reference in New Issue
Block a user