From 603accf1c28fc08f8d0ab5351777bba288ab3df0 Mon Sep 17 00:00:00 2001 From: Kai Fricke Date: Thu, 5 Nov 2020 17:55:38 +0100 Subject: [PATCH] [tune] logger refactor part 3: Add ExperimentLogger class (#11749) --- doc/source/tune/user-guide.rst | 11 +- python/ray/tune/config_parser.py | 1 - python/ray/tune/experiment.py | 11 ++ python/ray/tune/logger.py | 100 +++++++++++++++++- python/ray/tune/ray_trial_executor.py | 13 +-- python/ray/tune/schedulers/pbt.py | 2 +- python/ray/tune/tests/test_api.py | 26 ++--- python/ray/tune/tests/test_run_experiment.py | 81 +++++++++++--- python/ray/tune/tests/test_trial_runner.py | 2 +- .../tune/tests/test_trial_runner_callbacks.py | 63 ++++++++++- python/ray/tune/tests/test_trial_scheduler.py | 4 +- python/ray/tune/trial.py | 45 ++------ python/ray/tune/trial_executor.py | 5 +- python/ray/tune/trial_runner.py | 6 +- python/ray/tune/tune.py | 6 +- python/ray/tune/utils/callback.py | 83 ++++++++++++++- python/ray/tune/utils/util.py | 25 +++++ 17 files changed, 389 insertions(+), 95 deletions(-) diff --git a/doc/source/tune/user-guide.rst b/doc/source/tune/user-guide.rst index 7400cc327..e3648b31f 100644 --- a/doc/source/tune/user-guide.rst +++ b/doc/source/tune/user-guide.rst @@ -635,9 +635,18 @@ These are the environment variables Ray Tune currently considers: * **TUNE_CLUSTER_SSH_KEY**: SSH key used by the Tune driver process to connect to remote cluster machines for checkpoint syncing. If this is not set, ``~/ray_bootstrap_key.pem`` will be used. +* **TUNE_DISABLE_AUTO_CALLBACK_LOGGERS**: Ray Tune automatically adds a CSV and + JSON logger callback if they haven't been passed. Setting this variable to + `1` disables this automatic creation. Please note that this will most likely + affect analyzing your results after the tuning run. +* **TUNE_DISABLE_AUTO_CALLBACK_SYNCER**: Ray Tune automatically adds a + Syncer callback to sync logs and checkpoints between different nodes if none + has been passed. Setting this variable to `1` disables this automatic creation. + Please note that this will most likely affect advanced scheduling algorithms + like PopulationBasedTraining. * **TUNE_DISABLE_AUTO_INIT**: Disable automatically calling ``ray.init()`` if not attached to a Ray session. -* **TUNE_DISABLE_DATED_SUBDIR**: Tune automatically adds a date string to experiment +* **TUNE_DISABLE_DATED_SUBDIR**: Ray Tune automatically adds a date string to experiment directories when the name is not specified explicitly or the trainable isn't passed as a string. Setting this environment variable to ``1`` disables adding these date strings. * **TUNE_DISABLE_STRICT_METRIC_CHECKING**: When you report metrics to Tune via diff --git a/python/ray/tune/config_parser.py b/python/ray/tune/config_parser.py index c3c501974..0f75bd992 100644 --- a/python/ray/tune/config_parser.py +++ b/python/ray/tune/config_parser.py @@ -190,7 +190,6 @@ def create_trial_from_spec(spec, output_path, parser, **trial_kwargs): restore_path=spec.get("restore"), trial_name_creator=spec.get("trial_name_creator"), trial_dirname_creator=spec.get("trial_dirname_creator"), - loggers=spec.get("loggers"), log_to_file=spec.get("log_to_file"), # str(None) doesn't create None max_failures=args.max_failures, diff --git a/python/ray/tune/experiment.py b/python/ray/tune/experiment.py index e7af1cb07..76c4b9c03 100644 --- a/python/ray/tune/experiment.py +++ b/python/ray/tune/experiment.py @@ -123,6 +123,17 @@ class Experiment: max_failures=0, restore=None): + if loggers is not None: + # Most users won't run into this as `tune.run()` does not pass + # the argument anymore. However, we will want to inform users + # if they instantiate their `Experiment` objects themselves. + raise ValueError( + "Passing `loggers` to an `Experiment` is deprecated. Use " + "an `ExperimentLogger` callback instead, e.g. by passing the " + "`Logger` classes to `tune.logger.LegacyExperimentLogger` and " + "passing this as part of the `callback` parameter to " + "`tune.run()`.") + config = config or {} if callable(run) and detect_checkpoint_function(run): if checkpoint_at_end: diff --git a/python/ray/tune/logger.py b/python/ray/tune/logger.py index 867fae618..6f363baca 100644 --- a/python/ray/tune/logger.py +++ b/python/ray/tune/logger.py @@ -5,10 +5,11 @@ import numpy as np import os import yaml -from typing import TYPE_CHECKING, Dict, List, Optional, Type +from typing import Iterable, TYPE_CHECKING, Dict, List, Optional, Type import ray.cloudpickle as cloudpickle +from ray.tune.callback import Callback from ray.tune.utils.util import SafeFallbackEncoder from ray.util.debug import log_once from ray.tune.result import (TRAINING_ITERATION, TIME_TOTAL_S, TIMESTEPS_TOTAL, @@ -129,7 +130,8 @@ class JsonLogger(Logger): self.local_out.write(b) def flush(self): - self.local_out.flush() + if not self.local_out.closed: + self.local_out.flush() def close(self): self.local_out.close() @@ -181,7 +183,8 @@ class CSVLogger(Logger): self._file.flush() def flush(self): - self._file.flush() + if not self._file.closed: + self._file.flush() def close(self): self._file.close() @@ -361,6 +364,97 @@ class UnifiedLogger(Logger): _logger.flush() +class ExperimentLogger(Callback): + def log_trial_start(self, trial: "Trial"): + raise NotImplementedError + + def log_trial_restore(self, trial: "Trial"): + raise NotImplementedError + + def log_trial_save(self, trial: "Trial"): + raise NotImplementedError + + def log_trial_result(self, iteration: int, trial: "Trial", result: Dict): + raise NotImplementedError + + def log_trial_end(self, trial: "Trial", failed: bool = False): + raise NotImplementedError + + def on_trial_result(self, iteration: int, trials: List["Trial"], + trial: "Trial", result: Dict, **info): + self.log_trial_result(iteration, trial, result) + + def on_trial_start(self, iteration: int, trials: List["Trial"], + trial: "Trial", **info): + self.log_trial_start(trial) + + def on_trial_restore(self, iteration: int, trials: List["Trial"], + trial: "Trial", **info): + self.log_trial_restore(trial) + + def on_trial_save(self, iteration: int, trials: List["Trial"], + trial: "Trial", **info): + self.log_trial_save(trial) + + def on_trial_complete(self, iteration: int, trials: List["Trial"], + trial: "Trial", **info): + self.log_trial_end(trial, failed=False) + + def on_trial_error(self, iteration: int, trials: List["Trial"], + trial: "Trial", **info): + self.log_trial_end(trial, failed=True) + + +class LegacyExperimentLogger(ExperimentLogger): + """Supports logging to trial-specific `Logger` classes. + + Previously, Ray Tune logging was handled via `Logger` classes that have + been instantiated per-trial. This callback is a fallback to these + `Logger`-classes, instantiating each `Logger` class for each trial + and logging to them. + + Args: + logger_classes (Iterable[Type[Logger]]): Logger classes that should + be instantiated for each trial. + + """ + + def __init__(self, logger_classes: Iterable[Type[Logger]]): + self.logger_classes = list(logger_classes) + self._class_trial_loggers: Dict[Type[Logger], Dict["Trial", + Logger]] = {} + + def log_trial_start(self, trial: "Trial"): + trial.init_logdir() + + for logger_class in self.logger_classes: + trial_loggers = self._class_trial_loggers.get(logger_class, {}) + if trial not in trial_loggers: + logger = logger_class(trial.config, trial.logdir, trial) + trial_loggers[trial] = logger + self._class_trial_loggers[logger_class] = trial_loggers + + def log_trial_restore(self, trial: "Trial"): + for logger_class, trial_loggers in self._class_trial_loggers.items(): + if trial in trial_loggers: + trial_loggers[trial].flush() + + def log_trial_save(self, trial: "Trial"): + for logger_class, trial_loggers in self._class_trial_loggers.items(): + if trial in trial_loggers: + trial_loggers[trial].flush() + + def log_trial_result(self, iteration: int, trial: "Trial", result: Dict): + for logger_class, trial_loggers in self._class_trial_loggers.items(): + if trial in trial_loggers: + trial_loggers[trial].on_result(result) + + def log_trial_end(self, trial: "Trial", failed: bool = False): + for logger_class, trial_loggers in self._class_trial_loggers.items(): + if trial in trial_loggers: + trial_loggers[trial].close() + + def pretty_print(result): result = result.copy() result.update(config=None) # drop config from pretty print diff --git a/python/ray/tune/ray_trial_executor.py b/python/ray/tune/ray_trial_executor.py index 5ad050d25..e46ac6586 100644 --- a/python/ray/tune/ray_trial_executor.py +++ b/python/ray/tune/ray_trial_executor.py @@ -177,7 +177,7 @@ class RayTrialExecutor(TrialExecutor): self._update_avail_resources() def _setup_remote_runner(self, trial, reuse_allowed): - trial.init_logger() + trial.init_logdir() # We checkpoint metadata here to try mitigating logdir duplication self.try_checkpoint_metadata(trial) logger_creator = partial(noop_logger_creator, logdir=trial.logdir) @@ -297,8 +297,7 @@ class RayTrialExecutor(TrialExecutor): elif train and not trial.is_restoring: self._train(trial) - def _stop_trial(self, trial, error=False, error_msg=None, - stop_logger=True): + def _stop_trial(self, trial, error=False, error_msg=None): """Stops this trial. Stops this trial, releasing all allocating resources. If stopping the @@ -308,7 +307,6 @@ class RayTrialExecutor(TrialExecutor): Args: error (bool): Whether to mark this trial as terminated in error. error_msg (str): Optional error message. - stop_logger (bool): Whether to shut down the trial logger. """ self.set_status(trial, Trial.ERROR if error else Trial.TERMINATED) trial.set_location(Location()) @@ -329,8 +327,6 @@ class RayTrialExecutor(TrialExecutor): self.set_status(trial, Trial.ERROR) finally: trial.set_runner(None) - if stop_logger: - trial.close_logger() def start_trial(self, trial, checkpoint=None, train=True): """Starts the trial. @@ -365,11 +361,10 @@ class RayTrialExecutor(TrialExecutor): out = [rid for rid, t in dictionary.items() if t is item] return out - def stop_trial(self, trial, error=False, error_msg=None, stop_logger=True): + def stop_trial(self, trial, error=False, error_msg=None): """Only returns resources if resources allocated.""" prior_status = trial.status - self._stop_trial( - trial, error=error, error_msg=error_msg, stop_logger=stop_logger) + self._stop_trial(trial, error=error, error_msg=error_msg) if prior_status == Trial.RUNNING: logger.debug("Trial %s: Returning resources.", trial) self._return_resources(trial.resources) diff --git a/python/ray/tune/schedulers/pbt.py b/python/ray/tune/schedulers/pbt.py index 24a0d97cc..aead77cd7 100644 --- a/python/ray/tune/schedulers/pbt.py +++ b/python/ray/tune/schedulers/pbt.py @@ -574,7 +574,7 @@ class PopulationBasedTraining(FIFOScheduler): trial_executor.restore( trial, new_state.last_checkpoint, block=True) else: - trial_executor.stop_trial(trial, stop_logger=False) + trial_executor.stop_trial(trial) trial.config = new_config trial.experiment_tag = new_tag trial_executor.start_trial( diff --git a/python/ray/tune/tests/test_api.py b/python/ray/tune/tests/test_api.py index 121fb8adf..de2977210 100644 --- a/python/ray/tune/tests/test_api.py +++ b/python/ray/tune/tests/test_api.py @@ -12,7 +12,7 @@ from ray.rllib import _register_all from ray import tune from ray.tune import (DurableTrainable, Trainable, TuneError, Stopper, - EarlyStopping) + EarlyStopping, run) from ray.tune import register_env, register_trainable, run_experiments from ray.tune.schedulers import (TrialScheduler, FIFOScheduler, AsyncHyperBandScheduler) @@ -90,19 +90,19 @@ class TrainableFunctionApiTest(unittest.TestCase): class_trainable_name = "class_trainable" register_trainable(class_trainable_name, _WrappedTrainable) - trials = run_experiments( - { - "function_api": { - "run": _function_trainable, - "loggers": [FunctionAPILogger], - }, - "class_api": { - "run": class_trainable_name, - "loggers": [ClassAPILogger], - }, - }, + [trial1] = run( + _function_trainable, + loggers=[FunctionAPILogger], raise_on_failed_trial=False, - scheduler=MockScheduler()) + scheduler=MockScheduler()).trials + + [trial2] = run( + class_trainable_name, + loggers=[ClassAPILogger], + raise_on_failed_trial=False, + scheduler=MockScheduler()).trials + + trials = [trial1, trial2] # Ignore these fields NO_COMPARE_FIELDS = { diff --git a/python/ray/tune/tests/test_run_experiment.py b/python/ray/tune/tests/test_run_experiment.py index bf2f4cbcb..e95a0d0c9 100644 --- a/python/ray/tune/tests/test_run_experiment.py +++ b/python/ray/tune/tests/test_run_experiment.py @@ -7,7 +7,7 @@ from ray.rllib import _register_all from ray.tune.result import TIMESTEPS_TOTAL from ray.tune import Trainable, TuneError from ray.tune import register_trainable, run_experiments -from ray.tune.logger import Logger +from ray.tune.logger import LegacyExperimentLogger, Logger from ray.tune.experiment import Experiment from ray.tune.trial import Trial, ExportFormat @@ -173,21 +173,25 @@ class RunExperimentTest(unittest.TestCase): for trial in trials: self.assertEqual(trial.status, Trial.TERMINATED) - def testCustomLogger(self): + def testCustomLoggerNoAutoLogging(self): + """Does not create CSV/JSON logger callbacks automatically""" + os.environ["TUNE_DISABLE_AUTO_CALLBACK_LOGGERS"] = "1" + class CustomLogger(Logger): def on_result(self, result): with open(os.path.join(self.logdir, "test.log"), "w") as f: f.write("hi") - [trial] = run_experiments({ - "foo": { - "run": "__fake", - "stop": { - "training_iteration": 1 - }, - "loggers": [CustomLogger] - } - }) + [trial] = run_experiments( + { + "foo": { + "run": "__fake", + "stop": { + "training_iteration": 1 + } + } + }, + callbacks=[LegacyExperimentLogger(logger_classes=[CustomLogger])]) self.assertTrue(os.path.exists(os.path.join(trial.logdir, "test.log"))) self.assertFalse( os.path.exists(os.path.join(trial.logdir, "params.json"))) @@ -203,16 +207,65 @@ class RunExperimentTest(unittest.TestCase): self.assertTrue( os.path.exists(os.path.join(trial.logdir, "params.json"))) + [trial] = run_experiments( + { + "foo": { + "run": "__fake", + "stop": { + "training_iteration": 1 + } + } + }, + callbacks=[LegacyExperimentLogger(logger_classes=[])]) + self.assertFalse( + os.path.exists(os.path.join(trial.logdir, "params.json"))) + + def testCustomLoggerWithAutoLogging(self): + """Creates CSV/JSON logger callbacks automatically""" + if "TUNE_DISABLE_AUTO_CALLBACK_LOGGERS" in os.environ: + del os.environ["TUNE_DISABLE_AUTO_CALLBACK_LOGGERS"] + + class CustomLogger(Logger): + def on_result(self, result): + with open(os.path.join(self.logdir, "test.log"), "w") as f: + f.write("hi") + + [trial] = run_experiments( + { + "foo": { + "run": "__fake", + "stop": { + "training_iteration": 1 + } + } + }, + callbacks=[LegacyExperimentLogger(logger_classes=[CustomLogger])]) + self.assertTrue(os.path.exists(os.path.join(trial.logdir, "test.log"))) + self.assertTrue( + os.path.exists(os.path.join(trial.logdir, "params.json"))) + [trial] = run_experiments({ "foo": { "run": "__fake", "stop": { "training_iteration": 1 - }, - "loggers": [] + } } }) - self.assertFalse( + self.assertTrue( + os.path.exists(os.path.join(trial.logdir, "params.json"))) + + [trial] = run_experiments( + { + "foo": { + "run": "__fake", + "stop": { + "training_iteration": 1 + } + } + }, + callbacks=[LegacyExperimentLogger(logger_classes=[])]) + self.assertTrue( os.path.exists(os.path.join(trial.logdir, "params.json"))) def testCustomTrialString(self): diff --git a/python/ray/tune/tests/test_trial_runner.py b/python/ray/tune/tests/test_trial_runner.py index cf1232f7b..399802564 100644 --- a/python/ray/tune/tests/test_trial_runner.py +++ b/python/ray/tune/tests/test_trial_runner.py @@ -263,7 +263,7 @@ class TrialRunnerTest(unittest.TestCase): def on_trial_result(self, trial_runner, trial, result): if result["training_iteration"] == 1: executor = trial_runner.trial_executor - executor.stop_trial(trial, stop_logger=False) + executor.stop_trial(trial) trial.update_resources(2, 0) executor.start_trial(trial) return TrialScheduler.CONTINUE diff --git a/python/ray/tune/tests/test_trial_runner_callbacks.py b/python/ray/tune/tests/test_trial_runner_callbacks.py index 45551f801..f774f9e50 100644 --- a/python/ray/tune/tests/test_trial_runner_callbacks.py +++ b/python/ray/tune/tests/test_trial_runner_callbacks.py @@ -9,12 +9,16 @@ import ray from ray import tune from ray.rllib import _register_all from ray.tune.checkpoint_manager import Checkpoint +from ray.tune.logger import DEFAULT_LOGGERS, ExperimentLogger, \ + LegacyExperimentLogger from ray.tune.ray_trial_executor import RayTrialExecutor from ray.tune.result import TRAINING_ITERATION +from ray.tune.syncer import SyncConfig, SyncerCallback from ray.tune.trial import Trial -from ray.tune.callback import Callback from ray.tune.trial_runner import TrialRunner +from ray.tune import Callback +from ray.tune.utils.callback import create_default_callbacks class TestCallback(Callback): @@ -200,6 +204,63 @@ class TrialRunnerCallbacks(unittest.TestCase): self.callback.state["trial_complete"]["trial"].config["do"], "delay") + def testCallbackReordering(self): + """SyncerCallback should come after ExperimentLogger callbacks""" + + def get_positions(callbacks): + first_logger_pos = None + last_logger_pos = None + syncer_pos = None + for i, callback in enumerate(callbacks): + if isinstance(callback, ExperimentLogger): + if first_logger_pos is None: + first_logger_pos = i + last_logger_pos = i + elif isinstance(callback, SyncerCallback): + syncer_pos = i + return first_logger_pos, last_logger_pos, syncer_pos + + # Auto creation of loggers, no callbacks, no syncer + callbacks = create_default_callbacks(None, SyncConfig(), None) + first_logger_pos, last_logger_pos, syncer_pos = get_positions( + callbacks) + self.assertLess(last_logger_pos, syncer_pos) + + # Auto creation of loggers with callbacks + callbacks = create_default_callbacks([Callback()], SyncConfig(), None) + first_logger_pos, last_logger_pos, syncer_pos = get_positions( + callbacks) + self.assertLess(last_logger_pos, syncer_pos) + + # Auto creation of loggers with existing logger (but no CSV/JSON) + callbacks = create_default_callbacks([ExperimentLogger()], + SyncConfig(), None) + first_logger_pos, last_logger_pos, syncer_pos = get_positions( + callbacks) + self.assertLess(last_logger_pos, syncer_pos) + + # This should throw an error as the syncer comes before the logger + with self.assertRaises(ValueError): + callbacks = create_default_callbacks( + [SyncerCallback(None), + ExperimentLogger()], SyncConfig(), None) + + # This should be reordered but preserve the regular callback order + [mc1, mc2, mc3] = [Callback(), Callback(), Callback()] + # Has to be legacy logger to avoid logger callback creation + lc = LegacyExperimentLogger(logger_classes=DEFAULT_LOGGERS) + callbacks = create_default_callbacks([mc1, mc2, lc, mc3], SyncConfig(), + None) + print(callbacks) + first_logger_pos, last_logger_pos, syncer_pos = get_positions( + callbacks) + self.assertLess(last_logger_pos, syncer_pos) + self.assertLess(callbacks.index(mc1), callbacks.index(mc2)) + self.assertLess(callbacks.index(mc2), callbacks.index(mc3)) + self.assertLess(callbacks.index(lc), callbacks.index(mc3)) + # Syncer callback is appended + self.assertLess(callbacks.index(mc3), syncer_pos) + if __name__ == "__main__": import pytest diff --git a/python/ray/tune/tests/test_trial_scheduler.py b/python/ray/tune/tests/test_trial_scheduler.py index 2588a2e64..9f54e2937 100644 --- a/python/ray/tune/tests/test_trial_scheduler.py +++ b/python/ray/tune/tests/test_trial_scheduler.py @@ -217,10 +217,8 @@ class _MockTrialExecutor(TrialExecutor): trial.restored_checkpoint = checkpoint_obj.value trial.status = Trial.RUNNING - def stop_trial(self, trial, error=False, error_msg=None, stop_logger=True): + def stop_trial(self, trial, error=False, error_msg=None): trial.status = Trial.ERROR if error else Trial.TERMINATED - if stop_logger: - trial.logger_running = False def restore(self, trial, checkpoint=None, block=False): pass diff --git a/python/ray/tune/trial.py b/python/ray/tune/trial.py index 686526d6e..938706234 100644 --- a/python/ray/tune/trial.py +++ b/python/ray/tune/trial.py @@ -12,10 +12,10 @@ import os from numbers import Number from ray.tune import TuneError from ray.tune.checkpoint_manager import Checkpoint, CheckpointManager -from ray.tune.logger import pretty_print, UnifiedLogger # NOTE(rkn): We import ray.tune.registry here instead of importing the names we # need because there are cyclic imports that may cause specific names to not # have been defined yet. See https://github.com/ray-project/ray/issues/1716. +from ray.tune.logger import pretty_print from ray.tune.registry import get_trainable_cls, validate_trainable from ray.tune.result import DEFAULT_RESULTS_DIR, DONE, TRAINING_ITERATION from ray.tune.resources import Resources, json_to_resources, resources_to_json @@ -189,7 +189,6 @@ class Trial: restore_path=None, trial_name_creator=None, trial_dirname_creator=None, - loggers=None, log_to_file=None, max_failures=0): """Initialize a new trial. @@ -222,7 +221,6 @@ class Trial: self.location = Location() self.resources = resources or Resources(cpu=1, gpu=0) self.stopping_criterion = stopping_criterion or {} - self.loggers = loggers self.log_to_file = log_to_file # Make sure `stdout_file, stderr_file = Trial.log_to_file` works @@ -250,7 +248,6 @@ class Trial: self.start_time = None self.logdir = None self.runner = None - self.result_logger = None self.last_debug = 0 self.error_file = None self.error_msg = None @@ -285,7 +282,6 @@ class Trial: self.extra_arg = None self._nonjson_fields = [ - "loggers", "results", "best_result", "param_config", @@ -350,22 +346,17 @@ class Trial: export_formats=self.export_formats, restore_path=self.restore_path, trial_name_creator=self.trial_name_creator, - loggers=self.loggers, log_to_file=self.log_to_file, max_failures=self.max_failures, ) - def init_logger(self): - """Init logger.""" - if not self.result_logger: - if not self.logdir: - self.logdir = create_logdir(self._generate_dirname(), - self.local_dir) - else: - os.makedirs(self.logdir, exist_ok=True) - - self.result_logger = UnifiedLogger( - self.config, self.logdir, trial=self, loggers=self.loggers) + def init_logdir(self): + """Init logdir.""" + if not self.logdir: + self.logdir = create_logdir(self._generate_dirname(), + self.local_dir) + else: + os.makedirs(self.logdir, exist_ok=True) def update_resources(self, cpu, gpu, **kwargs): """EXPERIMENTAL: Updates the resource requirements. @@ -395,12 +386,6 @@ class Trial: if self.start_time is None: self.start_time = time.time() - def close_logger(self): - """Closes logger.""" - if self.result_logger: - self.result_logger.close() - self.result_logger = None - def write_error_log(self, error_msg): if error_msg and self.logdir: self.num_failures += 1 @@ -479,7 +464,6 @@ class Trial: self.set_location(Location(result.get("node_ip"), result.get("pid"))) self.last_result = result self.last_update_time = time.time() - self.result_logger.on_result(self.last_result) for metric, value in flatten_dict(result).items(): if isinstance(value, Number): @@ -569,7 +553,7 @@ class Trial: def __getstate__(self): """Memento generator for Trial. - Sets RUNNING trials to PENDING, and flushes the result logger. + Sets RUNNING trials to PENDING. Note this can only occur if the trial holds a PERSISTENT checkpoint. """ assert self.checkpoint.storage == Checkpoint.PERSISTENT, ( @@ -582,19 +566,13 @@ class Trial: state["runner"] = None state["location"] = Location() - state["result_logger"] = None # Avoid waiting for events that will never occur on resume. state["resuming_from"] = None state["saving_to"] = None - if self.result_logger: - self.result_logger.flush() - state["__logger_started__"] = True - else: - state["__logger_started__"] = False + return copy.deepcopy(state) def __setstate__(self, state): - logger_started = state.pop("__logger_started__") state["resources"] = json_to_resources(state["resources"]) if state["status"] == Trial.RUNNING: @@ -604,5 +582,4 @@ class Trial: self.__dict__.update(state) validate_trainable(self.trainable_name) - if logger_started: - self.init_logger() + self.init_logdir() # Create logdir if it does not exist diff --git a/python/ray/tune/trial_executor.py b/python/ray/tune/trial_executor.py index 99bd62341..54ea5b71d 100644 --- a/python/ray/tune/trial_executor.py +++ b/python/ray/tune/trial_executor.py @@ -85,7 +85,7 @@ class TrialExecutor: raise NotImplementedError("Subclasses of TrialExecutor must provide " "start_trial() method") - def stop_trial(self, trial, error=False, error_msg=None, stop_logger=True): + def stop_trial(self, trial, error=False, error_msg=None): """Stops the trial. Stops this trial, releasing all allocating resources. @@ -95,7 +95,6 @@ class TrialExecutor: Args: error (bool): Whether to mark this trial as terminated in error. error_msg (str): Optional error message. - stop_logger (bool): Whether to shut down the trial logger. """ raise NotImplementedError("Subclasses of TrialExecutor must provide " "stop_trial() method") @@ -113,7 +112,7 @@ class TrialExecutor: assert trial.status == Trial.RUNNING, trial.status try: self.save(trial, Checkpoint.MEMORY) - self.stop_trial(trial, stop_logger=False) + self.stop_trial(trial) self.set_status(trial, Trial.PAUSED) except Exception: logger.exception("Error pausing runner.") diff --git a/python/ray/tune/trial_runner.py b/python/ray/tune/trial_runner.py index b06da19bd..400f42d8f 100644 --- a/python/ray/tune/trial_runner.py +++ b/python/ray/tune/trial_runner.py @@ -793,11 +793,7 @@ class TrialRunner: # Restore was unsuccessful, try again without checkpoint. trial.clear_checkpoint() self.trial_executor.stop_trial( - trial, - error=error_msg is not None, - error_msg=error_msg, - stop_logger=False) - trial.result_logger.flush() + trial, error=error_msg is not None, error_msg=error_msg) if self.trial_executor.has_resources(trial.resources): logger.info( "Trial %s: Attempting to restore " diff --git a/python/ray/tune/tune.py b/python/ray/tune/tune.py index 9c3aefd14..0594d4d7e 100644 --- a/python/ray/tune/tune.py +++ b/python/ray/tune/tune.py @@ -311,7 +311,6 @@ def run( sync_to_driver=sync_config.sync_to_driver, trial_name_creator=trial_name_creator, trial_dirname_creator=trial_dirname_creator, - loggers=loggers, log_to_file=log_to_file, checkpoint_freq=checkpoint_freq, checkpoint_at_end=checkpoint_at_end, @@ -355,8 +354,9 @@ def run( "own `metric` and `mode` parameters. Either remove the arguments " "from your scheduler or from your call to `tune.run()`") - # Create syncer callbacks - callbacks = create_default_callbacks(callbacks, sync_config) + # Create logger and syncer callbacks + callbacks = create_default_callbacks( + callbacks, sync_config, loggers=loggers) runner = TrialRunner( search_alg=search_alg, diff --git a/python/ray/tune/utils/callback.py b/python/ray/tune/utils/callback.py index e54d9d842..c4b7275ba 100644 --- a/python/ray/tune/utils/callback.py +++ b/python/ray/tune/utils/callback.py @@ -3,16 +3,73 @@ from typing import List, Optional from ray.tune.callback import Callback from ray.tune.syncer import SyncConfig +from ray.tune.logger import CSVLogger, DEFAULT_LOGGERS, ExperimentLogger, \ + JsonLogger, LegacyExperimentLogger, Logger from ray.tune.syncer import SyncerCallback def create_default_callbacks(callbacks: Optional[List[Callback]], - sync_config: SyncConfig): + sync_config: SyncConfig, + loggers: Optional[List[Logger]]): callbacks = callbacks or [] + has_syncer_callback = False + has_csv_logger = False + has_json_logger = False - # Check if there is a SyncerCallback - has_syncer_callback = any(isinstance(c, SyncerCallback) for c in callbacks) + # Track syncer obj/index to move callback after loggers + last_logger_index = None + syncer_index = None + + if not loggers: + # If no logger callback and no `loggers` have been provided, + # add DEFAULT_LOGGERS. + if not any( + isinstance(callback, ExperimentLogger) + for callback in callbacks): + loggers = DEFAULT_LOGGERS + + # Create LegacyExperimentLogger for passed Logger classes + if loggers: + # Todo(krfricke): Deprecate `loggers` argument, print warning here. + add_loggers = [] + for trial_logger in loggers: + if isinstance(trial_logger, ExperimentLogger): + callbacks.append(trial_logger) + elif isinstance(trial_logger, type) and issubclass( + trial_logger, Logger): + add_loggers.append(trial_logger) + else: + raise ValueError( + f"Invalid value passed to `loggers` argument of " + f"`tune.run()`: {trial_logger}") + if add_loggers: + callbacks.append(LegacyExperimentLogger(add_loggers)) + + # Check if we have a CSV and JSON logger + for i, callback in enumerate(callbacks): + if isinstance(callback, LegacyExperimentLogger): + last_logger_index = i + if CSVLogger in callback.logger_classes: + has_csv_logger = True + if JsonLogger in callback.logger_classes: + has_json_logger = True + # Todo(krfricke): add checks for new ExperimentLogger classes + elif isinstance(callback, SyncerCallback): + syncer_index = i + has_syncer_callback = True + + # If CSV or JSON logger is missing, add + if os.environ.get("TUNE_DISABLE_AUTO_CALLBACK_LOGGERS", "0") != "1": + # Todo(krfricke): Switch to new ExperimentLogger classes + add_loggers = [] + if not has_csv_logger: + add_loggers.append(CSVLogger) + if not has_json_logger: + add_loggers.append(JsonLogger) + if add_loggers: + callbacks.append(LegacyExperimentLogger(add_loggers)) + last_logger_index = len(callbacks) - 1 # If no SyncerCallback was found, add if not has_syncer_callback and os.environ.get( @@ -20,5 +77,25 @@ def create_default_callbacks(callbacks: Optional[List[Callback]], syncer_callback = SyncerCallback( sync_function=sync_config.sync_to_driver) callbacks.append(syncer_callback) + syncer_index = len(callbacks) - 1 + + # Todo(krfricke): Maybe check if syncer comes after all loggers + if syncer_index is not None and last_logger_index is not None and \ + syncer_index < last_logger_index: + if (not has_csv_logger or not has_json_logger) and not loggers: + # Only raise the warning if the loggers were passed by the user. + # (I.e. don't warn if this was automatic behavior and they only + # passed a customer SyncerCallback). + raise ValueError( + "The `SyncerCallback` you passed to `tune.run()` came before " + "at least one `ExperimentLogger`. Syncing should be done " + "after writing logs. Please re-order the callbacks so that " + "the `SyncerCallback` comes after any `ExperimentLogger`.") + else: + # If these loggers were automatically created. just re-order + # the callbacks + syncer_obj = callbacks[syncer_index] + callbacks.pop(syncer_index) + callbacks.insert(last_logger_index, syncer_obj) return callbacks diff --git a/python/ray/tune/utils/util.py b/python/ray/tune/utils/util.py index c6328adfe..e7b22f956 100644 --- a/python/ray/tune/utils/util.py +++ b/python/ray/tune/utils/util.py @@ -6,6 +6,7 @@ import os import inspect import threading import time +import uuid from collections import defaultdict, deque, Mapping, Sequence from datetime import datetime from threading import Thread @@ -540,6 +541,30 @@ def detect_config_single(func): return use_config_single +def create_logdir(dirname: str, local_dir: str): + """Create an empty logdir with name `dirname` in `local_dir`. + + If `local_dir`/`dirname` already exists, a unique string is appended + to the dirname. + + Args: + dirname (str): Dirname to create in `local_dir` + local_dir (str): Root directory for the log dir + + Returns: full path to the newly created logdir. + """ + local_dir = os.path.expanduser(local_dir) + logdir = os.path.join(local_dir, dirname) + if os.path.exists(logdir): + old_dirname = dirname + dirname += "_" + uuid.uuid4().hex[:4] + logger.info(f"Creating a new dirname {dirname} because " + f"trial dirname '{old_dirname}' already exists.") + logdir = os.path.join(local_dir, dirname) + os.makedirs(logdir, exist_ok=True) + return logdir + + class SafeFallbackEncoder(json.JSONEncoder): def __init__(self, nan_str="null", **kwargs): super(SafeFallbackEncoder, self).__init__(**kwargs)