[tune] Prevent leak of magic keys in trial config (#9903)

Co-authored-by: Kai Fricke <kai@anyscale.com>
Co-authored-by: Richard Liaw <rliaw@berkeley.edu>
This commit is contained in:
krfricke
2020-08-04 20:24:01 +02:00
committed by GitHub
parent bdc42f8dab
commit ef717ecda6
6 changed files with 48 additions and 33 deletions
+1
View File
@@ -190,6 +190,7 @@ def create_trial_from_spec(spec, output_path, parser, **trial_kwargs):
restore_path=spec.get("restore"),
trial_name_creator=spec.get("trial_name_creator"),
loggers=spec.get("loggers"),
log_to_file=spec.get("log_to_file"),
# str(None) doesn't create None
sync_to_driver_fn=spec.get("sync_to_driver"),
max_failures=args.max_failures,
+30
View File
@@ -1,6 +1,7 @@
import copy
import logging
import os
from typing import Sequence
from ray.tune.error import TuneError
from ray.tune.function_runner import detect_checkpoint_function
@@ -45,6 +46,31 @@ def _raise_on_durable(trainable_name, sync_to_driver, upload_dir):
"`upload_dir` must be provided.")
def _validate_log_to_file(log_to_file):
"""Validate ``tune.run``'s ``log_to_file`` parameter. Return
validated relative stdout and stderr filenames."""
if not log_to_file:
stdout_file = stderr_file = None
elif isinstance(log_to_file, bool) and log_to_file:
stdout_file = "stdout"
stderr_file = "stderr"
elif isinstance(log_to_file, str):
stdout_file = stderr_file = log_to_file
elif isinstance(log_to_file, Sequence):
if len(log_to_file) != 2:
raise ValueError(
"If you pass a Sequence to `log_to_file` it has to have "
"a length of 2 (for stdout and stderr, respectively). The "
"Sequence you passed has length {}.".format(len(log_to_file)))
stdout_file, stderr_file = log_to_file
else:
raise ValueError(
"You can pass a boolean, a string, or a Sequence of length 2 to "
"`log_to_file`, but you passed something else ({}).".format(
type(log_to_file)))
return stdout_file, stderr_file
class Experiment:
"""Tracks experiment specifications.
@@ -82,6 +108,7 @@ class Experiment:
upload_dir=None,
trial_name_creator=None,
loggers=None,
log_to_file=False,
sync_to_driver=None,
checkpoint_freq=0,
checkpoint_at_end=False,
@@ -133,6 +160,8 @@ class Experiment:
_raise_on_durable(self._run_identifier, sync_to_driver, upload_dir)
stdout_file, stderr_file = _validate_log_to_file(log_to_file)
spec = {
"run": self._run_identifier,
"stop": stopping_criteria,
@@ -145,6 +174,7 @@ class Experiment:
"remote_checkpoint_dir": self.remote_checkpoint_dir,
"trial_name_creator": trial_name_creator,
"loggers": loggers,
"log_to_file": (stdout_file, stderr_file),
"sync_to_driver": sync_to_driver,
"checkpoint_freq": checkpoint_freq,
"checkpoint_at_end": checkpoint_at_end,
+5 -1
View File
@@ -14,7 +14,7 @@ from ray.resource_spec import ResourceSpec
from ray.tune.durable_trainable import DurableTrainable
from ray.tune.error import AbortTrialExecution, TuneError
from ray.tune.logger import NoopLogger
from ray.tune.result import TRIAL_INFO, LOGDIR_PATH
from ray.tune.result import TRIAL_INFO, LOGDIR_PATH, STDOUT_FILE, STDERR_FILE
from ray.tune.resources import Resources
from ray.tune.trainable import TrainableUtil
from ray.tune.trial import Trial, Checkpoint, Location, TrialInfo
@@ -175,6 +175,10 @@ class RayTrialExecutor(TrialExecutor):
# configure the remote runner to use a noop-logger.
trial_config = copy.deepcopy(trial.config)
trial_config[TRIAL_INFO] = TrialInfo(trial)
stdout_file, stderr_file = trial.log_to_file
trial_config[STDOUT_FILE] = stdout_file
trial_config[STDERR_FILE] = stderr_file
kwargs = {
"config": trial_config,
"logger_creator": logger_creator,
+10
View File
@@ -1,3 +1,5 @@
from typing import Sequence
import ray.cloudpickle as cloudpickle
from collections import deque
import copy
@@ -175,6 +177,7 @@ class Trial:
restore_path=None,
trial_name_creator=None,
loggers=None,
log_to_file=None,
sync_to_driver_fn=None,
max_failures=0):
"""Initialize a new trial.
@@ -208,6 +211,13 @@ class Trial:
self.resources = resources or Resources(cpu=1, gpu=0)
self.stopping_criterion = stopping_criterion or {}
self.loggers = loggers
self.log_to_file = log_to_file
# Make sure `stdout_file, stderr_file = Trial.log_to_file` works
if not self.log_to_file or not isinstance(self.log_to_file, Sequence) \
or not len(self.log_to_file) == 2:
self.log_to_file = (None, None)
self.sync_to_driver_fn = sync_to_driver_fn
self.verbose = True
self.max_failures = max_failures
+1 -31
View File
@@ -1,10 +1,8 @@
import logging
from typing import Sequence
from ray.tune.error import TuneError
from ray.tune.experiment import convert_to_experiment_list, Experiment
from ray.tune.analysis import ExperimentAnalysis
from ray.tune.result import STDOUT_FILE, STDERR_FILE
from ray.tune.suggest import BasicVariantGenerator
from ray.tune.suggest.suggestion import Searcher, SearchGenerator
from ray.tune.trial import Trial
@@ -67,31 +65,6 @@ def _report_progress(runner, reporter, done=False):
reporter.report(trials, done, sched_debug_str, executor_debug_str)
def _validate_log_to_file(log_to_file):
"""Validate ``tune.run``'s ``log_to_file`` parameter. Return
validated relative stdout and stderr filenames."""
if not log_to_file:
stdout_file = stderr_file = None
elif isinstance(log_to_file, bool) and log_to_file:
stdout_file = "stdout"
stderr_file = "stderr"
elif isinstance(log_to_file, str):
stdout_file = stderr_file = log_to_file
elif isinstance(log_to_file, Sequence):
if len(log_to_file) != 2:
raise ValueError(
"If you pass a Sequence to `log_to_file` it has to have "
"a length of 2 (for stdout and stderr, respectively). The "
"Sequence you passed has length {}.".format(len(log_to_file)))
stdout_file, stderr_file = log_to_file
else:
raise ValueError(
"You can pass a boolean, a string, or a Sequence of length 2 to "
"`log_to_file`, but you passed something else ({}).".format(
type(log_to_file)))
return stdout_file, stderr_file
def run(run_or_experiment,
name=None,
stop=None,
@@ -289,10 +262,6 @@ def run(run_or_experiment,
else:
experiments = [run_or_experiment]
stdout_file, stderr_file = _validate_log_to_file(log_to_file)
config[STDOUT_FILE] = stdout_file
config[STDERR_FILE] = stderr_file
for i, exp in enumerate(experiments):
if not isinstance(exp, Experiment):
run_identifier = Experiment.register_if_needed(exp)
@@ -308,6 +277,7 @@ def run(run_or_experiment,
sync_to_driver=sync_to_driver,
trial_name_creator=trial_name_creator,
loggers=loggers,
log_to_file=log_to_file,
checkpoint_freq=checkpoint_freq,
checkpoint_at_end=checkpoint_at_end,
sync_on_checkpoint=sync_on_checkpoint,