diff --git a/python/ray/tune/analysis/experiment_analysis.py b/python/ray/tune/analysis/experiment_analysis.py index 5d3c53895..463d34190 100644 --- a/python/ray/tune/analysis/experiment_analysis.py +++ b/python/ray/tune/analysis/experiment_analysis.py @@ -5,6 +5,7 @@ from numbers import Number from typing import Any, Dict, List, Optional, Tuple from ray.tune.utils import flatten_dict +from ray.tune.utils.serialization import TuneFunctionDecoder from ray.tune.utils.util import is_nan_or_inf try: @@ -338,12 +339,16 @@ class ExperimentAnalysis(Analysis): raise ValueError( "{} is not a valid file.".format(experiment_checkpoint_path)) with open(experiment_checkpoint_path) as f: - _experiment_state = json.load(f) + _experiment_state = json.load(f, cls=TuneFunctionDecoder) self._experiment_state = _experiment_state if "checkpoints" not in _experiment_state: raise TuneError("Experiment state invalid; no checkpoints found.") - self._checkpoints = _experiment_state["checkpoints"] + self._checkpoints = [ + json.loads(cp, cls=TuneFunctionDecoder) + if isinstance(cp, str) else cp + for cp in _experiment_state["checkpoints"] + ] self.trials = trials super(ExperimentAnalysis, self).__init__( diff --git a/python/ray/tune/automl/search_policy.py b/python/ray/tune/automl/search_policy.py index 519d3f571..07a41f66c 100644 --- a/python/ray/tune/automl/search_policy.py +++ b/python/ray/tune/automl/search_policy.py @@ -114,6 +114,8 @@ class AutoMLSearcher(SearchAlgorithm): trial.param_config = param_config trial.extra_arg = extra_arg + trial.invalidate_json_state() + trials.append(trial) self._running_trials[trial.trial_id] = trial @@ -142,6 +144,7 @@ class AutoMLSearcher(SearchAlgorithm): or result[self.reward_attr] \ > trial.best_result[self.reward_attr]: trial.best_result = result + trial.invalidate_json_state() # Update job's best trial if self.best_trial is None \ diff --git a/python/ray/tune/ray_trial_executor.py b/python/ray/tune/ray_trial_executor.py index e46ac6586..2174abfdc 100644 --- a/python/ray/tune/ray_trial_executor.py +++ b/python/ray/tune/ray_trial_executor.py @@ -404,8 +404,8 @@ class RayTrialExecutor(TrialExecutor): Returns: True if `reset_config` is successful else False. """ - trial.experiment_tag = new_experiment_tag - trial.config = new_config + trial.set_experiment_tag(new_experiment_tag) + trial.set_config(new_config) trainable = trial.runner with self._change_working_directory(trial): with warn_if_slow("reset"): diff --git a/python/ray/tune/schedulers/pbt.py b/python/ray/tune/schedulers/pbt.py index aead77cd7..20d12819a 100644 --- a/python/ray/tune/schedulers/pbt.py +++ b/python/ray/tune/schedulers/pbt.py @@ -557,8 +557,8 @@ class PopulationBasedTraining(FIFOScheduler): raise TuneError("Trials should be paused here only if in " "synchronous mode. If you encounter this error" " please raise an issue on Ray Github.") - trial.config = new_config - trial.experiment_tag = new_tag + trial.set_experiment_tag(new_tag) + trial.set_config(new_config) trial.on_checkpoint(new_state.last_checkpoint) else: # If trial is running, we first try to reset it. @@ -575,8 +575,8 @@ class PopulationBasedTraining(FIFOScheduler): trial, new_state.last_checkpoint, block=True) else: trial_executor.stop_trial(trial) - trial.config = new_config - trial.experiment_tag = new_tag + trial.set_experiment_tag(new_tag) + trial.set_config(new_config) trial_executor.start_trial( trial, new_state.last_checkpoint, train=False) @@ -761,7 +761,7 @@ class PopulationBasedTrainingReplay(FIFOScheduler): "No replay policy found and trial initialized without a " "valid config. Either pass a `config` argument to `tune.run()`" "or consider not using PBT replay for this run.") - self._trial.config = self.config + self._trial.set_config(self.config) def on_trial_result(self, trial_runner: "trial_runner.TrialRunner", trial: Trial, result: Dict) -> str: @@ -800,8 +800,8 @@ class PopulationBasedTrainingReplay(FIFOScheduler): trial_executor.restore(trial, checkpoint, block=True) else: trial_executor.stop_trial(trial, stop_logger=False) - trial.config = new_config - trial.experiment_tag = new_tag + trial.set_experiment_tag(new_tag) + trial.set_config(new_config) trial_executor.start_trial(trial, checkpoint, train=False) self.current_config = new_config diff --git a/python/ray/tune/tests/test_experiment_analysis_mem.py b/python/ray/tune/tests/test_experiment_analysis_mem.py index db95254fe..0ac4463e0 100644 --- a/python/ray/tune/tests/test_experiment_analysis_mem.py +++ b/python/ray/tune/tests/test_experiment_analysis_mem.py @@ -12,6 +12,7 @@ import ray from ray.tune import (run, Trainable, sample_from, Analysis, ExperimentAnalysis, grid_search) from ray.tune.utils.mock import MyTrainableClass +from ray.tune.utils.serialization import TuneFunctionEncoder class ExperimentAnalysisInMemorySuite(unittest.TestCase): @@ -54,7 +55,8 @@ class ExperimentAnalysisInMemorySuite(unittest.TestCase): def tearDown(self): shutil.rmtree(self.test_dir, ignore_errors=True) - def testInit(self): + def testInitLegacy(self): + """Should still work if checkpoints are not json strings""" experiment_checkpoint_path = os.path.join(self.test_dir, "experiment_state.json") checkpoint_data = { @@ -71,6 +73,27 @@ class ExperimentAnalysisInMemorySuite(unittest.TestCase): self.assertEqual(len(experiment_analysis._checkpoints), 1) self.assertTrue(experiment_analysis.trials is None) + def testInit(self): + experiment_checkpoint_path = os.path.join(self.test_dir, + "experiment_state.json") + checkpoint_data = { + "checkpoints": [ + json.dumps( + { + "trainable_name": "MockTrainable", + "logdir": "/mock/test/MockTrainable_0_id=3_2020-07-12" + }, + cls=TuneFunctionEncoder) + ] + } + + with open(experiment_checkpoint_path, "w") as f: + f.write(json.dumps(checkpoint_data)) + + experiment_analysis = ExperimentAnalysis(experiment_checkpoint_path) + self.assertEqual(len(experiment_analysis._checkpoints), 1) + self.assertTrue(experiment_analysis.trials is None) + def testInitException(self): experiment_checkpoint_path = os.path.join(self.test_dir, "mock.json") with pytest.raises(ValueError): diff --git a/python/ray/tune/trial.py b/python/ray/tune/trial.py index af4e7bc50..2e9465c83 100644 --- a/python/ray/tune/trial.py +++ b/python/ray/tune/trial.py @@ -1,3 +1,4 @@ +import json from typing import Sequence import ray.cloudpickle as cloudpickle @@ -18,6 +19,7 @@ from ray.tune.checkpoint_manager import Checkpoint, CheckpointManager from ray.tune.registry import get_trainable_cls, validate_trainable from ray.tune.result import DEFAULT_RESULTS_DIR, DONE, TRAINING_ITERATION from ray.tune.resources import Resources, json_to_resources, resources_to_json +from ray.tune.utils.serialization import TuneFunctionEncoder from ray.tune.utils.trainable import TrainableUtil from ray.tune.utils import date_str, flatten_dict from ray.utils import binary_to_hex, hex_to_binary @@ -294,6 +296,9 @@ class Trial: raise ValueError(f"Trial dirname must not contain '/'. " "Got {self.custom_dirname}") + self._state_json = None + self._state_valid = False + @property def node_ip(self): return self.location.hostname @@ -355,6 +360,7 @@ class Trial: self.local_dir) else: os.makedirs(self.logdir, exist_ok=True) + self.invalidate_json_state() def update_resources(self, cpu, gpu, **kwargs): """EXPERIMENTAL: Updates the resource requirements. @@ -367,15 +373,20 @@ class Trial: if self.status is Trial.RUNNING: raise ValueError("Cannot update resources while Trial is running.") self.resources = Resources(cpu, gpu, **kwargs) + self.invalidate_json_state() def set_runner(self, runner): self.runner = runner self.checkpoint_manager.delete = checkpoint_deleter( self._trainable_name(), runner) + # No need to invalidate state cache: runner is not stored in json + # self.invalidate_json_state() def set_location(self, location): """Sets the location of the trial.""" self.location = location + # No need to invalidate state cache: location is not stored in json + # self.invalidate_json_state() def set_status(self, status): """Sets the status of the trial.""" @@ -383,6 +394,15 @@ class Trial: if status == Trial.RUNNING: if self.start_time is None: self.start_time = time.time() + self.invalidate_json_state() + + def set_config(self, config): + self.config = config + self.invalidate_json_state() + + def set_experiment_tag(self, experiment_tag): + self.experiment_tag = experiment_tag + self.invalidate_json_state() def write_error_log(self, error_msg): if error_msg and self.logdir: @@ -393,6 +413,7 @@ class Trial: self.num_failures, date_str())) f.write(error_msg + "\n") self.error_msg = error_msg + self.invalidate_json_state() def should_stop(self, result): """Whether the given result meets this trial's stopping criteria.""" @@ -426,6 +447,7 @@ class Trial: def clear_checkpoint(self): self.checkpoint.value = None self.restoring_from = None + self.invalidate_json_state() def on_checkpoint(self, checkpoint): """Hook for handling checkpoints taken by the Trainable. @@ -434,12 +456,14 @@ class Trial: checkpoint (Checkpoint): Checkpoint taken. """ self.checkpoint_manager.on_checkpoint(checkpoint) + self.invalidate_json_state() def on_restore(self): """Handles restoration completion.""" assert self.is_restoring self.last_result = self.restoring_from.result self.restoring_from = None + self.invalidate_json_state() def should_recover(self): """Returns whether the trial qualifies for retrying. @@ -492,6 +516,7 @@ class Trial: self.metric_analysis[metric][key] = sum( self.metric_n_steps[metric][str(n)]) / len( self.metric_n_steps[metric][str(n)]) + self.invalidate_json_state() def get_trainable_cls(self): return get_trainable_cls(self.trainable_name) @@ -541,6 +566,17 @@ class Trial: generated_dirname += f"_{date_str()}" return generated_dirname.replace("/", "_") + def invalidate_json_state(self): + self._state_valid = False + + def get_json_state(self) -> str: + if not self._state_json or not self._state_valid: + json_state = json.dumps( + self.__getstate__(), indent=2, cls=TuneFunctionEncoder) + self._state_json = json_state + self._state_valid = True + return self._state_json + def __getstate__(self): """Memento generator for Trial. @@ -558,9 +594,12 @@ class Trial: state["runner"] = None state["location"] = Location() # Avoid waiting for events that will never occur on resume. - state["resuming_from"] = None + state["restoring_from"] = None state["saving_to"] = None + state["_state_json"] = None + state["_state_valid"] = False + return copy.deepcopy(state) def __setstate__(self, state): diff --git a/python/ray/tune/trial_executor.py b/python/ray/tune/trial_executor.py index 54ea5b71d..02258e17f 100644 --- a/python/ray/tune/trial_executor.py +++ b/python/ray/tune/trial_executor.py @@ -26,6 +26,7 @@ class TrialExecutor: """ self._queue_trials = queue_trials self._cached_trial_state = {} + self._trials_to_cache = set() def set_status(self, trial, status): """Sets status and checkpoints metadata if needed. @@ -59,14 +60,18 @@ class TrialExecutor: return try: logger.debug("Trial %s: Saving trial metadata.", trial) - self._cached_trial_state[trial.trial_id] = trial.__getstate__() + # Lazy cache trials + self._trials_to_cache.add(trial) except Exception: logger.exception("Trial %s: Error checkpointing trial metadata.", trial) def get_checkpoints(self): """Returns a copy of mapping of the trial ID to pickled metadata.""" - return self._cached_trial_state.copy() + for trial in self._trials_to_cache: + self._cached_trial_state[trial.trial_id] = trial.get_json_state() + self._trials_to_cache.clear() + return self._cached_trial_state def has_resources(self, resources): """Returns whether this runner has at least the specified resources.""" diff --git a/python/ray/tune/trial_runner.py b/python/ray/tune/trial_runner.py index b5340b2ec..0fee16613 100644 --- a/python/ray/tune/trial_runner.py +++ b/python/ray/tune/trial_runner.py @@ -5,10 +5,8 @@ import logging import os import time import traceback -import types import warnings -import ray.cloudpickle as cloudpickle from ray.services import get_node_ip_address from ray.tune import TuneError from ray.tune.callback import CallbackList @@ -22,8 +20,9 @@ from ray.tune.schedulers import FIFOScheduler, TrialScheduler from ray.tune.suggest import BasicVariantGenerator from ray.tune.utils import warn_if_slow, flatten_dict, env_integer from ray.tune.utils.log import Verbosity, has_verbosity +from ray.tune.utils.serialization import TuneFunctionDecoder, \ + TuneFunctionEncoder from ray.tune.web_server import TuneServer -from ray.utils import binary_to_hex, hex_to_binary from ray.util.debug import log_once MAX_DEBUG_TRIALS = 20 @@ -40,37 +39,6 @@ def _find_newest_ckpt(ckpt_dir): return max(full_paths) -class _TuneFunctionEncoder(json.JSONEncoder): - def default(self, obj): - if isinstance(obj, types.FunctionType): - return self._to_cloudpickle(obj) - try: - return super(_TuneFunctionEncoder, self).default(obj) - except Exception: - logger.debug("Unable to encode. Falling back to cloudpickle.") - return self._to_cloudpickle(obj) - - def _to_cloudpickle(self, obj): - return { - "_type": "CLOUDPICKLE_FALLBACK", - "value": binary_to_hex(cloudpickle.dumps(obj)) - } - - -class _TuneFunctionDecoder(json.JSONDecoder): - def __init__(self, *args, **kwargs): - json.JSONDecoder.__init__( - self, object_hook=self.object_hook, *args, **kwargs) - - def object_hook(self, obj): - if obj.get("_type") == "CLOUDPICKLE_FALLBACK": - return self._from_cloudpickle(obj) - return obj - - def _from_cloudpickle(self, obj): - return cloudpickle.loads(hex_to_binary(obj["value"])) - - class TrialRunner: """A TrialRunner implements the event loop for scheduling trials on Ray. @@ -310,7 +278,7 @@ class TrialRunner: tmp_file_name = os.path.join(self._local_checkpoint_dir, ".tmp_checkpoint") with open(tmp_file_name, "w") as f: - json.dump(runner_state, f, indent=2, cls=_TuneFunctionEncoder) + json.dump(runner_state, f, indent=2, cls=TuneFunctionEncoder) os.replace(tmp_file_name, self.checkpoint_file) self._search_alg.save_to_dir( @@ -330,7 +298,7 @@ class TrialRunner: """ newest_ckpt_path = _find_newest_ckpt(self._local_checkpoint_dir) with open(newest_ckpt_path, "r") as f: - runner_state = json.load(f, cls=_TuneFunctionDecoder) + runner_state = json.load(f, cls=TuneFunctionDecoder) self.checkpoint_file = newest_ckpt_path logger.warning("".join([ @@ -343,8 +311,14 @@ class TrialRunner: if self._search_alg.has_checkpoint(self._local_checkpoint_dir): self._search_alg.restore_from_dir(self._local_checkpoint_dir) + checkpoints = [ + json.loads(cp, cls=TuneFunctionDecoder) + if isinstance(cp, str) else cp + for cp in runner_state["checkpoints"] + ] + trials = [] - for trial_cp in runner_state["checkpoints"]: + for trial_cp in checkpoints: new_trial = Trial(trial_cp["trainable_name"]) new_trial.__setstate__(trial_cp) trials += [new_trial] diff --git a/python/ray/tune/tune.py b/python/ray/tune/tune.py index e809fa5ce..b7ce7eb79 100644 --- a/python/ray/tune/tune.py +++ b/python/ray/tune/tune.py @@ -248,7 +248,7 @@ def run( trial (of ERROR state) when the experiments complete. callbacks (list): List of callbacks that will be called at different times in the training loop. Must be instances of the - ``ray.tune.trial_runner.Callback`` class. If not passed, + ``ray.tune.callback.Callback`` class. If not passed, `LoggerCallback` and `SyncerCallback` callbacks are automatically added. diff --git a/python/ray/tune/utils/serialization.py b/python/ray/tune/utils/serialization.py new file mode 100644 index 000000000..de6061647 --- /dev/null +++ b/python/ray/tune/utils/serialization.py @@ -0,0 +1,39 @@ +import json +import logging +import types + +from ray import cloudpickle as cloudpickle +from ray.utils import binary_to_hex, hex_to_binary + +logger = logging.getLogger(__name__) + + +class TuneFunctionEncoder(json.JSONEncoder): + def default(self, obj): + if isinstance(obj, types.FunctionType): + return self._to_cloudpickle(obj) + try: + return super(TuneFunctionEncoder, self).default(obj) + except Exception: + logger.debug("Unable to encode. Falling back to cloudpickle.") + return self._to_cloudpickle(obj) + + def _to_cloudpickle(self, obj): + return { + "_type": "CLOUDPICKLE_FALLBACK", + "value": binary_to_hex(cloudpickle.dumps(obj)) + } + + +class TuneFunctionDecoder(json.JSONDecoder): + def __init__(self, *args, **kwargs): + json.JSONDecoder.__init__( + self, object_hook=self.object_hook, *args, **kwargs) + + def object_hook(self, obj): + if obj.get("_type") == "CLOUDPICKLE_FALLBACK": + return self._from_cloudpickle(obj) + return obj + + def _from_cloudpickle(self, obj): + return cloudpickle.loads(hex_to_binary(obj["value"]))