mirror of
https://github.com/wassname/ray.git
synced 2026-07-06 04:12:08 +08:00
[Projects] Start multiple sessions via session start (#5740)
This commit is contained in:
@@ -3,6 +3,7 @@ from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import argparse
|
||||
import copy
|
||||
import json
|
||||
import jsonschema
|
||||
import os
|
||||
@@ -49,49 +50,62 @@ class ProjectDefinition:
|
||||
directory = os.path.join("~", self.config["name"], "")
|
||||
return directory
|
||||
|
||||
def get_command_to_run(self, command=None, args=tuple()):
|
||||
"""Get and format a command to run.
|
||||
def get_command_info(self, command_name, args, shell, wildcards=False):
|
||||
"""Get the shell command, parsed arguments and config for a command.
|
||||
|
||||
Args:
|
||||
command (str): Name of the command to run. The command definition
|
||||
should be available in project.yaml.
|
||||
command_name (str): Name of the command to run. The command
|
||||
definition should be available in project.yaml.
|
||||
args (tuple): Tuple containing arguments to format the command
|
||||
with.
|
||||
wildcards (bool): If True, enable wildcards as arguments.
|
||||
|
||||
Returns:
|
||||
The raw shell command to run, formatted with the given arguments.
|
||||
The raw shell command to run with placeholders for the arguments.
|
||||
The parsed argument dictonary, parsed with argparse.
|
||||
The config dictionary of the command.
|
||||
|
||||
Raises:
|
||||
ValueError: This exception is raised if the given command is not
|
||||
found in project.yaml.
|
||||
"""
|
||||
if shell or not command_name:
|
||||
return command_name, {}, {}
|
||||
|
||||
command_to_run = None
|
||||
params = None
|
||||
config = None
|
||||
|
||||
if command is None:
|
||||
command = "default"
|
||||
for command_definition in self.config["commands"]:
|
||||
if command_definition["name"] == command:
|
||||
if command_definition["name"] == command_name:
|
||||
command_to_run = command_definition["command"]
|
||||
params = command_definition.get("params", [])
|
||||
config = command_definition.get("config", {})
|
||||
if not command_to_run:
|
||||
raise ValueError(
|
||||
"Cannot find the command '{}' in commmands section of the "
|
||||
"project file.".format(command))
|
||||
"Cannot find the command named '{}' in commmands section "
|
||||
"of the project file.".format(command_name))
|
||||
|
||||
# Build argument parser dynamically to parse parameter arguments.
|
||||
parser = argparse.ArgumentParser(prog=command)
|
||||
parser = argparse.ArgumentParser(prog=command_name)
|
||||
# For argparse arguments that have a 'choices' list associated
|
||||
# with them, save it in the following dictionary.
|
||||
choices = {}
|
||||
for param in params:
|
||||
parser.add_argument(
|
||||
"--" + param["name"],
|
||||
required=True,
|
||||
help=param.get("help"),
|
||||
choices=param.get("choices"))
|
||||
name = param.pop("name")
|
||||
if wildcards and "choices" in param:
|
||||
choices[name] = copy.deepcopy(param["choices"])
|
||||
param["choices"] = param["choices"] + ["*"]
|
||||
parser.add_argument("--" + name, **param)
|
||||
|
||||
result = parser.parse_args(list(args))
|
||||
for key, val in result.__dict__.items():
|
||||
command_to_run = command_to_run.replace("{{" + key + "}}", val)
|
||||
parsed_args = parser.parse_args(list(args)).__dict__
|
||||
|
||||
return command_to_run
|
||||
if wildcards:
|
||||
for key, val in parsed_args.items():
|
||||
if val == "*":
|
||||
parsed_args[key] = choices[key]
|
||||
|
||||
return command_to_run, parsed_args, config
|
||||
|
||||
def git_repo(self):
|
||||
return self.config.get("repo", None)
|
||||
|
||||
@@ -76,6 +76,20 @@
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"config": {
|
||||
"type": "object",
|
||||
"items": {
|
||||
"description": "Configuration options for the command",
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"tmux": {
|
||||
"description": "If true, the command will be run inside of tmux",
|
||||
"type": "boolean"
|
||||
}
|
||||
},
|
||||
"additionalProperties": false
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
+115
-50
@@ -3,6 +3,7 @@ from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import click
|
||||
import copy
|
||||
import jsonschema
|
||||
import logging
|
||||
import os
|
||||
@@ -162,7 +163,8 @@ class SessionRunner(object):
|
||||
self.session_name = session_name
|
||||
|
||||
# Check for features we don't support right now
|
||||
project_environment = self.project_definition.config["environment"]
|
||||
project_environment = self.project_definition.config.get(
|
||||
"environment", {})
|
||||
need_docker = ("dockerfile" in project_environment
|
||||
or "dockerimage" in project_environment)
|
||||
if need_docker:
|
||||
@@ -193,7 +195,8 @@ class SessionRunner(object):
|
||||
|
||||
def setup_environment(self):
|
||||
"""Set up the environment of the session."""
|
||||
project_environment = self.project_definition.config["environment"]
|
||||
project_environment = self.project_definition.config.get(
|
||||
"environment", {})
|
||||
|
||||
if "requirements" in project_environment:
|
||||
requirements_txt = project_environment["requirements"]
|
||||
@@ -217,32 +220,7 @@ class SessionRunner(object):
|
||||
for cmd in project_environment["shell"]:
|
||||
self.execute_command(cmd)
|
||||
|
||||
def format_command(self, command, args, shell):
|
||||
"""Validate and format a session command.
|
||||
|
||||
Args:
|
||||
command (str, optional): Command from the project definition's
|
||||
commands section to run, if any.
|
||||
args (list): Arguments for the command to run.
|
||||
shell (bool): If true, command is a shell command that should be
|
||||
run directly.
|
||||
|
||||
Returns:
|
||||
The formatted shell command to run.
|
||||
|
||||
Raises:
|
||||
click.ClickException: This exception is raised if any error occurs.
|
||||
"""
|
||||
if shell:
|
||||
return command
|
||||
else:
|
||||
try:
|
||||
return self.project_definition.get_command_to_run(
|
||||
command=command, args=args)
|
||||
except ValueError as e:
|
||||
raise click.ClickException(e)
|
||||
|
||||
def execute_command(self, cmd):
|
||||
def execute_command(self, cmd, config={}):
|
||||
"""Execute a shell command in the session.
|
||||
|
||||
Args:
|
||||
@@ -256,7 +234,7 @@ class SessionRunner(object):
|
||||
cmd=cmd,
|
||||
docker=False,
|
||||
screen=False,
|
||||
tmux=False,
|
||||
tmux=config.get("tmux", False),
|
||||
stop=False,
|
||||
start=False,
|
||||
override_cluster_name=self.session_name,
|
||||
@@ -264,13 +242,79 @@ class SessionRunner(object):
|
||||
)
|
||||
|
||||
|
||||
def format_command(command, parsed_args):
|
||||
"""Substitute arguments into command.
|
||||
|
||||
Args:
|
||||
command (str): Shell comand with argument placeholders.
|
||||
parsed_args (dict): Dictionary that maps from argument names
|
||||
to their value.
|
||||
|
||||
Returns:
|
||||
Shell command with parameters from parsed_args substituted.
|
||||
"""
|
||||
for key, val in parsed_args.items():
|
||||
command = command.replace("{{" + key + "}}", str(val))
|
||||
return command
|
||||
|
||||
|
||||
def get_session_runs(name, command, parsed_args):
|
||||
"""Get a list of sessions to start.
|
||||
|
||||
Args:
|
||||
command (str): Shell command with argument placeholders.
|
||||
parsed_args (dict): Dictionary that maps from argument names
|
||||
to their values.
|
||||
|
||||
Returns:
|
||||
List of sessions to start, which are dictionaries with keys:
|
||||
"name": Name of the session to start,
|
||||
"command": Command to run after starting the session,
|
||||
"num_steps": 4 if a command should be run, 3 if not.
|
||||
"""
|
||||
if not command:
|
||||
return [{"name": name, "command": None, "num_steps": 3}]
|
||||
|
||||
# Try to find a wildcard argument (i.e. one that has a list of values)
|
||||
# and give an error if there is more than one (currently unsupported).
|
||||
wildcard_arg = None
|
||||
for key, val in parsed_args.items():
|
||||
if isinstance(val, list):
|
||||
if not wildcard_arg:
|
||||
wildcard_arg = key
|
||||
else:
|
||||
raise click.ClickException(
|
||||
"More than one wildcard is not supported at the moment")
|
||||
|
||||
if not wildcard_arg:
|
||||
session_run = {
|
||||
"name": name,
|
||||
"command": format_command(command, parsed_args),
|
||||
"num_steps": 4
|
||||
}
|
||||
return [session_run]
|
||||
else:
|
||||
session_runs = []
|
||||
for val in parsed_args[wildcard_arg]:
|
||||
parsed_args = copy.deepcopy(parsed_args)
|
||||
parsed_args[wildcard_arg] = val
|
||||
session_run = {
|
||||
"name": "{}-{}-{}".format(name, wildcard_arg, val),
|
||||
"command": format_command(command, parsed_args),
|
||||
"num_steps": 4
|
||||
}
|
||||
session_runs.append(session_run)
|
||||
return session_runs
|
||||
|
||||
|
||||
@session_cli.command(help="Attach to an existing cluster")
|
||||
def attach():
|
||||
@click.option("--tmux", help="Attach to tmux session", is_flag=True)
|
||||
def attach(tmux):
|
||||
project_definition = load_project_or_throw()
|
||||
attach_cluster(
|
||||
project_definition.cluster_yaml(),
|
||||
start=False,
|
||||
use_tmux=False,
|
||||
use_tmux=tmux,
|
||||
override_cluster_name=None,
|
||||
new=False,
|
||||
)
|
||||
@@ -301,25 +345,39 @@ def stop(name):
|
||||
is_flag=True)
|
||||
@click.option("--name", help="A name to tag the session with.", default=None)
|
||||
def session_start(command, args, shell, name):
|
||||
runner = SessionRunner(session_name=name)
|
||||
if shell or command:
|
||||
# Get the actual command to run.
|
||||
cmd = runner.format_command(command, args, shell)
|
||||
num_steps = 4
|
||||
else:
|
||||
num_steps = 3
|
||||
project_definition = load_project_or_throw()
|
||||
|
||||
logger.info("[1/{}] Creating cluster".format(num_steps))
|
||||
runner.create_cluster()
|
||||
logger.info("[2/{}] Syncing the project".format(num_steps))
|
||||
runner.sync_files()
|
||||
logger.info("[3/{}] Setting up environment".format(num_steps))
|
||||
runner.setup_environment()
|
||||
if not name:
|
||||
name = project_definition.config["name"]
|
||||
|
||||
if shell or command:
|
||||
# Run the actual command.
|
||||
logger.info("[4/4] Running command")
|
||||
runner.execute_command(cmd)
|
||||
# Get the actual command to run. This also validates the command,
|
||||
# which should be done before the cluster is started.
|
||||
try:
|
||||
command, parsed_args, config = project_definition.get_command_info(
|
||||
command, args, shell, wildcards=True)
|
||||
except ValueError as e:
|
||||
raise click.ClickException(e)
|
||||
session_runs = get_session_runs(name, command, parsed_args)
|
||||
|
||||
if len(session_runs) > 1 and not config.get("tmux", False):
|
||||
logging.info("Using wildcards with tmux = False would not create "
|
||||
"sessions in parallel, so we are overriding it with "
|
||||
"tmux = True.")
|
||||
config["tmux"] = True
|
||||
|
||||
for run in session_runs:
|
||||
runner = SessionRunner(session_name=run["name"])
|
||||
logger.info("[1/{}] Creating cluster".format(run["num_steps"]))
|
||||
runner.create_cluster()
|
||||
logger.info("[2/{}] Syncing the project".format(run["num_steps"]))
|
||||
runner.sync_files()
|
||||
logger.info("[3/{}] Setting up environment".format(run["num_steps"]))
|
||||
runner.setup_environment()
|
||||
|
||||
if run["command"]:
|
||||
# Run the actual command.
|
||||
logger.info("[4/4] Running command")
|
||||
runner.execute_command(run["command"], config)
|
||||
|
||||
|
||||
@session_cli.command(
|
||||
@@ -337,6 +395,13 @@ def session_start(command, args, shell, name):
|
||||
@click.option(
|
||||
"--name", help="Name of the session to run this command on", default=None)
|
||||
def session_execute(command, args, shell, name):
|
||||
project_definition = load_project_or_throw()
|
||||
try:
|
||||
command, parsed_args, config = project_definition.get_command_info(
|
||||
command, args, shell, wildcards=False)
|
||||
except ValueError as e:
|
||||
raise click.ClickException(e)
|
||||
|
||||
runner = SessionRunner(session_name=name)
|
||||
cmd = runner.format_command(command, args, shell)
|
||||
runner.execute_command(cmd)
|
||||
command = format_command(command, parsed_args)
|
||||
runner.execute_command(command)
|
||||
|
||||
@@ -165,6 +165,10 @@ def test_session_execute_default_project():
|
||||
|
||||
assert expected_commands == commands_executed
|
||||
|
||||
result, mock_calls, test_dir = run_test_project(
|
||||
"session-tests/project-pass", session_execute, ["--shell", "uptime"])
|
||||
assert result.exit_code == 0
|
||||
|
||||
|
||||
def test_session_start_docker_fail():
|
||||
result, _, _ = run_test_project("session-tests/with-docker-fail",
|
||||
@@ -200,3 +204,36 @@ def test_session_create_command():
|
||||
if "Starting ray job with 1 and 2" in kwargs["cmd"]:
|
||||
found_command = True
|
||||
assert found_command
|
||||
|
||||
|
||||
def test_session_create_multiple():
|
||||
for args in [{"a": "*", "b": "2"}, {"a": "1", "b": "*"}]:
|
||||
result, mock_calls, test_dir = run_test_project(
|
||||
"session-tests/commands-test", session_start,
|
||||
["first", "--a", args["a"], "--b", args["b"]])
|
||||
|
||||
loaded_project = ray.projects.ProjectDefinition(test_dir)
|
||||
assert result.exit_code == 0
|
||||
|
||||
exec_cluster_call = mock_calls["exec_cluster"]
|
||||
commands_executed = []
|
||||
for _, kwargs in exec_cluster_call.call_args_list:
|
||||
commands_executed.append(kwargs["cmd"].replace(
|
||||
"cd {}; ".format(loaded_project.working_directory()), ""))
|
||||
assert commands_executed.count("echo \"Setting up\"") == 2
|
||||
if args["a"] == "*":
|
||||
assert commands_executed.count(
|
||||
"echo \"Starting ray job with 1 and 2\"") == 1
|
||||
assert commands_executed.count(
|
||||
"echo \"Starting ray job with 2 and 2\"") == 1
|
||||
if args["b"] == "*":
|
||||
assert commands_executed.count(
|
||||
"echo \"Starting ray job with 1 and 1\"") == 1
|
||||
assert commands_executed.count(
|
||||
"echo \"Starting ray job with 1 and 2\"") == 1
|
||||
|
||||
# Using multiple wildcards shouldn't work
|
||||
result, mock_calls, test_dir = run_test_project(
|
||||
"session-tests/commands-test", session_start,
|
||||
["first", "--a", "*", "--b", "*"])
|
||||
assert result.exit_code == 1
|
||||
|
||||
Reference in New Issue
Block a user