mirror of
https://github.com/wassname/ray.git
synced 2026-07-03 20:39:52 +08:00
[tune] Prevent leak of magic keys in trial config (#9903)
Co-authored-by: Kai Fricke <kai@anyscale.com> Co-authored-by: Richard Liaw <rliaw@berkeley.edu>
This commit is contained in:
@@ -190,6 +190,7 @@ def create_trial_from_spec(spec, output_path, parser, **trial_kwargs):
|
||||
restore_path=spec.get("restore"),
|
||||
trial_name_creator=spec.get("trial_name_creator"),
|
||||
loggers=spec.get("loggers"),
|
||||
log_to_file=spec.get("log_to_file"),
|
||||
# str(None) doesn't create None
|
||||
sync_to_driver_fn=spec.get("sync_to_driver"),
|
||||
max_failures=args.max_failures,
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
import copy
|
||||
import logging
|
||||
import os
|
||||
from typing import Sequence
|
||||
|
||||
from ray.tune.error import TuneError
|
||||
from ray.tune.function_runner import detect_checkpoint_function
|
||||
@@ -45,6 +46,31 @@ def _raise_on_durable(trainable_name, sync_to_driver, upload_dir):
|
||||
"`upload_dir` must be provided.")
|
||||
|
||||
|
||||
def _validate_log_to_file(log_to_file):
|
||||
"""Validate ``tune.run``'s ``log_to_file`` parameter. Return
|
||||
validated relative stdout and stderr filenames."""
|
||||
if not log_to_file:
|
||||
stdout_file = stderr_file = None
|
||||
elif isinstance(log_to_file, bool) and log_to_file:
|
||||
stdout_file = "stdout"
|
||||
stderr_file = "stderr"
|
||||
elif isinstance(log_to_file, str):
|
||||
stdout_file = stderr_file = log_to_file
|
||||
elif isinstance(log_to_file, Sequence):
|
||||
if len(log_to_file) != 2:
|
||||
raise ValueError(
|
||||
"If you pass a Sequence to `log_to_file` it has to have "
|
||||
"a length of 2 (for stdout and stderr, respectively). The "
|
||||
"Sequence you passed has length {}.".format(len(log_to_file)))
|
||||
stdout_file, stderr_file = log_to_file
|
||||
else:
|
||||
raise ValueError(
|
||||
"You can pass a boolean, a string, or a Sequence of length 2 to "
|
||||
"`log_to_file`, but you passed something else ({}).".format(
|
||||
type(log_to_file)))
|
||||
return stdout_file, stderr_file
|
||||
|
||||
|
||||
class Experiment:
|
||||
"""Tracks experiment specifications.
|
||||
|
||||
@@ -82,6 +108,7 @@ class Experiment:
|
||||
upload_dir=None,
|
||||
trial_name_creator=None,
|
||||
loggers=None,
|
||||
log_to_file=False,
|
||||
sync_to_driver=None,
|
||||
checkpoint_freq=0,
|
||||
checkpoint_at_end=False,
|
||||
@@ -133,6 +160,8 @@ class Experiment:
|
||||
|
||||
_raise_on_durable(self._run_identifier, sync_to_driver, upload_dir)
|
||||
|
||||
stdout_file, stderr_file = _validate_log_to_file(log_to_file)
|
||||
|
||||
spec = {
|
||||
"run": self._run_identifier,
|
||||
"stop": stopping_criteria,
|
||||
@@ -145,6 +174,7 @@ class Experiment:
|
||||
"remote_checkpoint_dir": self.remote_checkpoint_dir,
|
||||
"trial_name_creator": trial_name_creator,
|
||||
"loggers": loggers,
|
||||
"log_to_file": (stdout_file, stderr_file),
|
||||
"sync_to_driver": sync_to_driver,
|
||||
"checkpoint_freq": checkpoint_freq,
|
||||
"checkpoint_at_end": checkpoint_at_end,
|
||||
|
||||
@@ -14,7 +14,7 @@ from ray.resource_spec import ResourceSpec
|
||||
from ray.tune.durable_trainable import DurableTrainable
|
||||
from ray.tune.error import AbortTrialExecution, TuneError
|
||||
from ray.tune.logger import NoopLogger
|
||||
from ray.tune.result import TRIAL_INFO, LOGDIR_PATH
|
||||
from ray.tune.result import TRIAL_INFO, LOGDIR_PATH, STDOUT_FILE, STDERR_FILE
|
||||
from ray.tune.resources import Resources
|
||||
from ray.tune.trainable import TrainableUtil
|
||||
from ray.tune.trial import Trial, Checkpoint, Location, TrialInfo
|
||||
@@ -175,6 +175,10 @@ class RayTrialExecutor(TrialExecutor):
|
||||
# configure the remote runner to use a noop-logger.
|
||||
trial_config = copy.deepcopy(trial.config)
|
||||
trial_config[TRIAL_INFO] = TrialInfo(trial)
|
||||
|
||||
stdout_file, stderr_file = trial.log_to_file
|
||||
trial_config[STDOUT_FILE] = stdout_file
|
||||
trial_config[STDERR_FILE] = stderr_file
|
||||
kwargs = {
|
||||
"config": trial_config,
|
||||
"logger_creator": logger_creator,
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
from typing import Sequence
|
||||
|
||||
import ray.cloudpickle as cloudpickle
|
||||
from collections import deque
|
||||
import copy
|
||||
@@ -175,6 +177,7 @@ class Trial:
|
||||
restore_path=None,
|
||||
trial_name_creator=None,
|
||||
loggers=None,
|
||||
log_to_file=None,
|
||||
sync_to_driver_fn=None,
|
||||
max_failures=0):
|
||||
"""Initialize a new trial.
|
||||
@@ -208,6 +211,13 @@ class Trial:
|
||||
self.resources = resources or Resources(cpu=1, gpu=0)
|
||||
self.stopping_criterion = stopping_criterion or {}
|
||||
self.loggers = loggers
|
||||
|
||||
self.log_to_file = log_to_file
|
||||
# Make sure `stdout_file, stderr_file = Trial.log_to_file` works
|
||||
if not self.log_to_file or not isinstance(self.log_to_file, Sequence) \
|
||||
or not len(self.log_to_file) == 2:
|
||||
self.log_to_file = (None, None)
|
||||
|
||||
self.sync_to_driver_fn = sync_to_driver_fn
|
||||
self.verbose = True
|
||||
self.max_failures = max_failures
|
||||
|
||||
+1
-31
@@ -1,10 +1,8 @@
|
||||
import logging
|
||||
from typing import Sequence
|
||||
|
||||
from ray.tune.error import TuneError
|
||||
from ray.tune.experiment import convert_to_experiment_list, Experiment
|
||||
from ray.tune.analysis import ExperimentAnalysis
|
||||
from ray.tune.result import STDOUT_FILE, STDERR_FILE
|
||||
from ray.tune.suggest import BasicVariantGenerator
|
||||
from ray.tune.suggest.suggestion import Searcher, SearchGenerator
|
||||
from ray.tune.trial import Trial
|
||||
@@ -67,31 +65,6 @@ def _report_progress(runner, reporter, done=False):
|
||||
reporter.report(trials, done, sched_debug_str, executor_debug_str)
|
||||
|
||||
|
||||
def _validate_log_to_file(log_to_file):
|
||||
"""Validate ``tune.run``'s ``log_to_file`` parameter. Return
|
||||
validated relative stdout and stderr filenames."""
|
||||
if not log_to_file:
|
||||
stdout_file = stderr_file = None
|
||||
elif isinstance(log_to_file, bool) and log_to_file:
|
||||
stdout_file = "stdout"
|
||||
stderr_file = "stderr"
|
||||
elif isinstance(log_to_file, str):
|
||||
stdout_file = stderr_file = log_to_file
|
||||
elif isinstance(log_to_file, Sequence):
|
||||
if len(log_to_file) != 2:
|
||||
raise ValueError(
|
||||
"If you pass a Sequence to `log_to_file` it has to have "
|
||||
"a length of 2 (for stdout and stderr, respectively). The "
|
||||
"Sequence you passed has length {}.".format(len(log_to_file)))
|
||||
stdout_file, stderr_file = log_to_file
|
||||
else:
|
||||
raise ValueError(
|
||||
"You can pass a boolean, a string, or a Sequence of length 2 to "
|
||||
"`log_to_file`, but you passed something else ({}).".format(
|
||||
type(log_to_file)))
|
||||
return stdout_file, stderr_file
|
||||
|
||||
|
||||
def run(run_or_experiment,
|
||||
name=None,
|
||||
stop=None,
|
||||
@@ -289,10 +262,6 @@ def run(run_or_experiment,
|
||||
else:
|
||||
experiments = [run_or_experiment]
|
||||
|
||||
stdout_file, stderr_file = _validate_log_to_file(log_to_file)
|
||||
config[STDOUT_FILE] = stdout_file
|
||||
config[STDERR_FILE] = stderr_file
|
||||
|
||||
for i, exp in enumerate(experiments):
|
||||
if not isinstance(exp, Experiment):
|
||||
run_identifier = Experiment.register_if_needed(exp)
|
||||
@@ -308,6 +277,7 @@ def run(run_or_experiment,
|
||||
sync_to_driver=sync_to_driver,
|
||||
trial_name_creator=trial_name_creator,
|
||||
loggers=loggers,
|
||||
log_to_file=log_to_file,
|
||||
checkpoint_freq=checkpoint_freq,
|
||||
checkpoint_at_end=checkpoint_at_end,
|
||||
sync_on_checkpoint=sync_on_checkpoint,
|
||||
|
||||
Reference in New Issue
Block a user