[Projects] Start multiple sessions via session start (#5740)

This commit is contained in:
Philipp Moritz
2019-09-22 01:36:23 -07:00
committed by GitHub
parent 1cfadf032e
commit 5f5873b182
4 changed files with 200 additions and 70 deletions
+34 -20
View File
@@ -3,6 +3,7 @@ from __future__ import division
from __future__ import print_function
import argparse
import copy
import json
import jsonschema
import os
@@ -49,49 +50,62 @@ class ProjectDefinition:
directory = os.path.join("~", self.config["name"], "")
return directory
def get_command_to_run(self, command=None, args=tuple()):
"""Get and format a command to run.
def get_command_info(self, command_name, args, shell, wildcards=False):
"""Get the shell command, parsed arguments and config for a command.
Args:
command (str): Name of the command to run. The command definition
should be available in project.yaml.
command_name (str): Name of the command to run. The command
definition should be available in project.yaml.
args (tuple): Tuple containing arguments to format the command
with.
wildcards (bool): If True, enable wildcards as arguments.
Returns:
The raw shell command to run, formatted with the given arguments.
The raw shell command to run with placeholders for the arguments.
The parsed argument dictonary, parsed with argparse.
The config dictionary of the command.
Raises:
ValueError: This exception is raised if the given command is not
found in project.yaml.
"""
if shell or not command_name:
return command_name, {}, {}
command_to_run = None
params = None
config = None
if command is None:
command = "default"
for command_definition in self.config["commands"]:
if command_definition["name"] == command:
if command_definition["name"] == command_name:
command_to_run = command_definition["command"]
params = command_definition.get("params", [])
config = command_definition.get("config", {})
if not command_to_run:
raise ValueError(
"Cannot find the command '{}' in commmands section of the "
"project file.".format(command))
"Cannot find the command named '{}' in commmands section "
"of the project file.".format(command_name))
# Build argument parser dynamically to parse parameter arguments.
parser = argparse.ArgumentParser(prog=command)
parser = argparse.ArgumentParser(prog=command_name)
# For argparse arguments that have a 'choices' list associated
# with them, save it in the following dictionary.
choices = {}
for param in params:
parser.add_argument(
"--" + param["name"],
required=True,
help=param.get("help"),
choices=param.get("choices"))
name = param.pop("name")
if wildcards and "choices" in param:
choices[name] = copy.deepcopy(param["choices"])
param["choices"] = param["choices"] + ["*"]
parser.add_argument("--" + name, **param)
result = parser.parse_args(list(args))
for key, val in result.__dict__.items():
command_to_run = command_to_run.replace("{{" + key + "}}", val)
parsed_args = parser.parse_args(list(args)).__dict__
return command_to_run
if wildcards:
for key, val in parsed_args.items():
if val == "*":
parsed_args[key] = choices[key]
return command_to_run, parsed_args, config
def git_repo(self):
return self.config.get("repo", None)
+14
View File
@@ -76,6 +76,20 @@
}
}
}
},
"config": {
"type": "object",
"items": {
"description": "Configuration options for the command",
"type": "object",
"properties": {
"tmux": {
"description": "If true, the command will be run inside of tmux",
"type": "boolean"
}
},
"additionalProperties": false
}
}
}
}
+115 -50
View File
@@ -3,6 +3,7 @@ from __future__ import division
from __future__ import print_function
import click
import copy
import jsonschema
import logging
import os
@@ -162,7 +163,8 @@ class SessionRunner(object):
self.session_name = session_name
# Check for features we don't support right now
project_environment = self.project_definition.config["environment"]
project_environment = self.project_definition.config.get(
"environment", {})
need_docker = ("dockerfile" in project_environment
or "dockerimage" in project_environment)
if need_docker:
@@ -193,7 +195,8 @@ class SessionRunner(object):
def setup_environment(self):
"""Set up the environment of the session."""
project_environment = self.project_definition.config["environment"]
project_environment = self.project_definition.config.get(
"environment", {})
if "requirements" in project_environment:
requirements_txt = project_environment["requirements"]
@@ -217,32 +220,7 @@ class SessionRunner(object):
for cmd in project_environment["shell"]:
self.execute_command(cmd)
def format_command(self, command, args, shell):
"""Validate and format a session command.
Args:
command (str, optional): Command from the project definition's
commands section to run, if any.
args (list): Arguments for the command to run.
shell (bool): If true, command is a shell command that should be
run directly.
Returns:
The formatted shell command to run.
Raises:
click.ClickException: This exception is raised if any error occurs.
"""
if shell:
return command
else:
try:
return self.project_definition.get_command_to_run(
command=command, args=args)
except ValueError as e:
raise click.ClickException(e)
def execute_command(self, cmd):
def execute_command(self, cmd, config={}):
"""Execute a shell command in the session.
Args:
@@ -256,7 +234,7 @@ class SessionRunner(object):
cmd=cmd,
docker=False,
screen=False,
tmux=False,
tmux=config.get("tmux", False),
stop=False,
start=False,
override_cluster_name=self.session_name,
@@ -264,13 +242,79 @@ class SessionRunner(object):
)
def format_command(command, parsed_args):
"""Substitute arguments into command.
Args:
command (str): Shell comand with argument placeholders.
parsed_args (dict): Dictionary that maps from argument names
to their value.
Returns:
Shell command with parameters from parsed_args substituted.
"""
for key, val in parsed_args.items():
command = command.replace("{{" + key + "}}", str(val))
return command
def get_session_runs(name, command, parsed_args):
"""Get a list of sessions to start.
Args:
command (str): Shell command with argument placeholders.
parsed_args (dict): Dictionary that maps from argument names
to their values.
Returns:
List of sessions to start, which are dictionaries with keys:
"name": Name of the session to start,
"command": Command to run after starting the session,
"num_steps": 4 if a command should be run, 3 if not.
"""
if not command:
return [{"name": name, "command": None, "num_steps": 3}]
# Try to find a wildcard argument (i.e. one that has a list of values)
# and give an error if there is more than one (currently unsupported).
wildcard_arg = None
for key, val in parsed_args.items():
if isinstance(val, list):
if not wildcard_arg:
wildcard_arg = key
else:
raise click.ClickException(
"More than one wildcard is not supported at the moment")
if not wildcard_arg:
session_run = {
"name": name,
"command": format_command(command, parsed_args),
"num_steps": 4
}
return [session_run]
else:
session_runs = []
for val in parsed_args[wildcard_arg]:
parsed_args = copy.deepcopy(parsed_args)
parsed_args[wildcard_arg] = val
session_run = {
"name": "{}-{}-{}".format(name, wildcard_arg, val),
"command": format_command(command, parsed_args),
"num_steps": 4
}
session_runs.append(session_run)
return session_runs
@session_cli.command(help="Attach to an existing cluster")
def attach():
@click.option("--tmux", help="Attach to tmux session", is_flag=True)
def attach(tmux):
project_definition = load_project_or_throw()
attach_cluster(
project_definition.cluster_yaml(),
start=False,
use_tmux=False,
use_tmux=tmux,
override_cluster_name=None,
new=False,
)
@@ -301,25 +345,39 @@ def stop(name):
is_flag=True)
@click.option("--name", help="A name to tag the session with.", default=None)
def session_start(command, args, shell, name):
runner = SessionRunner(session_name=name)
if shell or command:
# Get the actual command to run.
cmd = runner.format_command(command, args, shell)
num_steps = 4
else:
num_steps = 3
project_definition = load_project_or_throw()
logger.info("[1/{}] Creating cluster".format(num_steps))
runner.create_cluster()
logger.info("[2/{}] Syncing the project".format(num_steps))
runner.sync_files()
logger.info("[3/{}] Setting up environment".format(num_steps))
runner.setup_environment()
if not name:
name = project_definition.config["name"]
if shell or command:
# Run the actual command.
logger.info("[4/4] Running command")
runner.execute_command(cmd)
# Get the actual command to run. This also validates the command,
# which should be done before the cluster is started.
try:
command, parsed_args, config = project_definition.get_command_info(
command, args, shell, wildcards=True)
except ValueError as e:
raise click.ClickException(e)
session_runs = get_session_runs(name, command, parsed_args)
if len(session_runs) > 1 and not config.get("tmux", False):
logging.info("Using wildcards with tmux = False would not create "
"sessions in parallel, so we are overriding it with "
"tmux = True.")
config["tmux"] = True
for run in session_runs:
runner = SessionRunner(session_name=run["name"])
logger.info("[1/{}] Creating cluster".format(run["num_steps"]))
runner.create_cluster()
logger.info("[2/{}] Syncing the project".format(run["num_steps"]))
runner.sync_files()
logger.info("[3/{}] Setting up environment".format(run["num_steps"]))
runner.setup_environment()
if run["command"]:
# Run the actual command.
logger.info("[4/4] Running command")
runner.execute_command(run["command"], config)
@session_cli.command(
@@ -337,6 +395,13 @@ def session_start(command, args, shell, name):
@click.option(
"--name", help="Name of the session to run this command on", default=None)
def session_execute(command, args, shell, name):
project_definition = load_project_or_throw()
try:
command, parsed_args, config = project_definition.get_command_info(
command, args, shell, wildcards=False)
except ValueError as e:
raise click.ClickException(e)
runner = SessionRunner(session_name=name)
cmd = runner.format_command(command, args, shell)
runner.execute_command(cmd)
command = format_command(command, parsed_args)
runner.execute_command(command)
+37
View File
@@ -165,6 +165,10 @@ def test_session_execute_default_project():
assert expected_commands == commands_executed
result, mock_calls, test_dir = run_test_project(
"session-tests/project-pass", session_execute, ["--shell", "uptime"])
assert result.exit_code == 0
def test_session_start_docker_fail():
result, _, _ = run_test_project("session-tests/with-docker-fail",
@@ -200,3 +204,36 @@ def test_session_create_command():
if "Starting ray job with 1 and 2" in kwargs["cmd"]:
found_command = True
assert found_command
def test_session_create_multiple():
for args in [{"a": "*", "b": "2"}, {"a": "1", "b": "*"}]:
result, mock_calls, test_dir = run_test_project(
"session-tests/commands-test", session_start,
["first", "--a", args["a"], "--b", args["b"]])
loaded_project = ray.projects.ProjectDefinition(test_dir)
assert result.exit_code == 0
exec_cluster_call = mock_calls["exec_cluster"]
commands_executed = []
for _, kwargs in exec_cluster_call.call_args_list:
commands_executed.append(kwargs["cmd"].replace(
"cd {}; ".format(loaded_project.working_directory()), ""))
assert commands_executed.count("echo \"Setting up\"") == 2
if args["a"] == "*":
assert commands_executed.count(
"echo \"Starting ray job with 1 and 2\"") == 1
assert commands_executed.count(
"echo \"Starting ray job with 2 and 2\"") == 1
if args["b"] == "*":
assert commands_executed.count(
"echo \"Starting ray job with 1 and 1\"") == 1
assert commands_executed.count(
"echo \"Starting ray job with 1 and 2\"") == 1
# Using multiple wildcards shouldn't work
result, mock_calls, test_dir = run_test_project(
"session-tests/commands-test", session_start,
["first", "--a", "*", "--b", "*"])
assert result.exit_code == 1