mirror of
https://github.com/wassname/ray.git
synced 2026-06-28 10:33:16 +08:00
[tune] cache checkpoint serialization (#12064)
This commit is contained in:
@@ -5,6 +5,7 @@ from numbers import Number
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
from ray.tune.utils import flatten_dict
|
||||
from ray.tune.utils.serialization import TuneFunctionDecoder
|
||||
from ray.tune.utils.util import is_nan_or_inf
|
||||
|
||||
try:
|
||||
@@ -338,12 +339,16 @@ class ExperimentAnalysis(Analysis):
|
||||
raise ValueError(
|
||||
"{} is not a valid file.".format(experiment_checkpoint_path))
|
||||
with open(experiment_checkpoint_path) as f:
|
||||
_experiment_state = json.load(f)
|
||||
_experiment_state = json.load(f, cls=TuneFunctionDecoder)
|
||||
self._experiment_state = _experiment_state
|
||||
|
||||
if "checkpoints" not in _experiment_state:
|
||||
raise TuneError("Experiment state invalid; no checkpoints found.")
|
||||
self._checkpoints = _experiment_state["checkpoints"]
|
||||
self._checkpoints = [
|
||||
json.loads(cp, cls=TuneFunctionDecoder)
|
||||
if isinstance(cp, str) else cp
|
||||
for cp in _experiment_state["checkpoints"]
|
||||
]
|
||||
self.trials = trials
|
||||
|
||||
super(ExperimentAnalysis, self).__init__(
|
||||
|
||||
@@ -114,6 +114,8 @@ class AutoMLSearcher(SearchAlgorithm):
|
||||
trial.param_config = param_config
|
||||
trial.extra_arg = extra_arg
|
||||
|
||||
trial.invalidate_json_state()
|
||||
|
||||
trials.append(trial)
|
||||
self._running_trials[trial.trial_id] = trial
|
||||
|
||||
@@ -142,6 +144,7 @@ class AutoMLSearcher(SearchAlgorithm):
|
||||
or result[self.reward_attr] \
|
||||
> trial.best_result[self.reward_attr]:
|
||||
trial.best_result = result
|
||||
trial.invalidate_json_state()
|
||||
|
||||
# Update job's best trial
|
||||
if self.best_trial is None \
|
||||
|
||||
@@ -404,8 +404,8 @@ class RayTrialExecutor(TrialExecutor):
|
||||
Returns:
|
||||
True if `reset_config` is successful else False.
|
||||
"""
|
||||
trial.experiment_tag = new_experiment_tag
|
||||
trial.config = new_config
|
||||
trial.set_experiment_tag(new_experiment_tag)
|
||||
trial.set_config(new_config)
|
||||
trainable = trial.runner
|
||||
with self._change_working_directory(trial):
|
||||
with warn_if_slow("reset"):
|
||||
|
||||
@@ -557,8 +557,8 @@ class PopulationBasedTraining(FIFOScheduler):
|
||||
raise TuneError("Trials should be paused here only if in "
|
||||
"synchronous mode. If you encounter this error"
|
||||
" please raise an issue on Ray Github.")
|
||||
trial.config = new_config
|
||||
trial.experiment_tag = new_tag
|
||||
trial.set_experiment_tag(new_tag)
|
||||
trial.set_config(new_config)
|
||||
trial.on_checkpoint(new_state.last_checkpoint)
|
||||
else:
|
||||
# If trial is running, we first try to reset it.
|
||||
@@ -575,8 +575,8 @@ class PopulationBasedTraining(FIFOScheduler):
|
||||
trial, new_state.last_checkpoint, block=True)
|
||||
else:
|
||||
trial_executor.stop_trial(trial)
|
||||
trial.config = new_config
|
||||
trial.experiment_tag = new_tag
|
||||
trial.set_experiment_tag(new_tag)
|
||||
trial.set_config(new_config)
|
||||
trial_executor.start_trial(
|
||||
trial, new_state.last_checkpoint, train=False)
|
||||
|
||||
@@ -761,7 +761,7 @@ class PopulationBasedTrainingReplay(FIFOScheduler):
|
||||
"No replay policy found and trial initialized without a "
|
||||
"valid config. Either pass a `config` argument to `tune.run()`"
|
||||
"or consider not using PBT replay for this run.")
|
||||
self._trial.config = self.config
|
||||
self._trial.set_config(self.config)
|
||||
|
||||
def on_trial_result(self, trial_runner: "trial_runner.TrialRunner",
|
||||
trial: Trial, result: Dict) -> str:
|
||||
@@ -800,8 +800,8 @@ class PopulationBasedTrainingReplay(FIFOScheduler):
|
||||
trial_executor.restore(trial, checkpoint, block=True)
|
||||
else:
|
||||
trial_executor.stop_trial(trial, stop_logger=False)
|
||||
trial.config = new_config
|
||||
trial.experiment_tag = new_tag
|
||||
trial.set_experiment_tag(new_tag)
|
||||
trial.set_config(new_config)
|
||||
trial_executor.start_trial(trial, checkpoint, train=False)
|
||||
|
||||
self.current_config = new_config
|
||||
|
||||
@@ -12,6 +12,7 @@ import ray
|
||||
from ray.tune import (run, Trainable, sample_from, Analysis,
|
||||
ExperimentAnalysis, grid_search)
|
||||
from ray.tune.utils.mock import MyTrainableClass
|
||||
from ray.tune.utils.serialization import TuneFunctionEncoder
|
||||
|
||||
|
||||
class ExperimentAnalysisInMemorySuite(unittest.TestCase):
|
||||
@@ -54,7 +55,8 @@ class ExperimentAnalysisInMemorySuite(unittest.TestCase):
|
||||
def tearDown(self):
|
||||
shutil.rmtree(self.test_dir, ignore_errors=True)
|
||||
|
||||
def testInit(self):
|
||||
def testInitLegacy(self):
|
||||
"""Should still work if checkpoints are not json strings"""
|
||||
experiment_checkpoint_path = os.path.join(self.test_dir,
|
||||
"experiment_state.json")
|
||||
checkpoint_data = {
|
||||
@@ -71,6 +73,27 @@ class ExperimentAnalysisInMemorySuite(unittest.TestCase):
|
||||
self.assertEqual(len(experiment_analysis._checkpoints), 1)
|
||||
self.assertTrue(experiment_analysis.trials is None)
|
||||
|
||||
def testInit(self):
|
||||
experiment_checkpoint_path = os.path.join(self.test_dir,
|
||||
"experiment_state.json")
|
||||
checkpoint_data = {
|
||||
"checkpoints": [
|
||||
json.dumps(
|
||||
{
|
||||
"trainable_name": "MockTrainable",
|
||||
"logdir": "/mock/test/MockTrainable_0_id=3_2020-07-12"
|
||||
},
|
||||
cls=TuneFunctionEncoder)
|
||||
]
|
||||
}
|
||||
|
||||
with open(experiment_checkpoint_path, "w") as f:
|
||||
f.write(json.dumps(checkpoint_data))
|
||||
|
||||
experiment_analysis = ExperimentAnalysis(experiment_checkpoint_path)
|
||||
self.assertEqual(len(experiment_analysis._checkpoints), 1)
|
||||
self.assertTrue(experiment_analysis.trials is None)
|
||||
|
||||
def testInitException(self):
|
||||
experiment_checkpoint_path = os.path.join(self.test_dir, "mock.json")
|
||||
with pytest.raises(ValueError):
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import json
|
||||
from typing import Sequence
|
||||
|
||||
import ray.cloudpickle as cloudpickle
|
||||
@@ -18,6 +19,7 @@ from ray.tune.checkpoint_manager import Checkpoint, CheckpointManager
|
||||
from ray.tune.registry import get_trainable_cls, validate_trainable
|
||||
from ray.tune.result import DEFAULT_RESULTS_DIR, DONE, TRAINING_ITERATION
|
||||
from ray.tune.resources import Resources, json_to_resources, resources_to_json
|
||||
from ray.tune.utils.serialization import TuneFunctionEncoder
|
||||
from ray.tune.utils.trainable import TrainableUtil
|
||||
from ray.tune.utils import date_str, flatten_dict
|
||||
from ray.utils import binary_to_hex, hex_to_binary
|
||||
@@ -294,6 +296,9 @@ class Trial:
|
||||
raise ValueError(f"Trial dirname must not contain '/'. "
|
||||
"Got {self.custom_dirname}")
|
||||
|
||||
self._state_json = None
|
||||
self._state_valid = False
|
||||
|
||||
@property
|
||||
def node_ip(self):
|
||||
return self.location.hostname
|
||||
@@ -355,6 +360,7 @@ class Trial:
|
||||
self.local_dir)
|
||||
else:
|
||||
os.makedirs(self.logdir, exist_ok=True)
|
||||
self.invalidate_json_state()
|
||||
|
||||
def update_resources(self, cpu, gpu, **kwargs):
|
||||
"""EXPERIMENTAL: Updates the resource requirements.
|
||||
@@ -367,15 +373,20 @@ class Trial:
|
||||
if self.status is Trial.RUNNING:
|
||||
raise ValueError("Cannot update resources while Trial is running.")
|
||||
self.resources = Resources(cpu, gpu, **kwargs)
|
||||
self.invalidate_json_state()
|
||||
|
||||
def set_runner(self, runner):
|
||||
self.runner = runner
|
||||
self.checkpoint_manager.delete = checkpoint_deleter(
|
||||
self._trainable_name(), runner)
|
||||
# No need to invalidate state cache: runner is not stored in json
|
||||
# self.invalidate_json_state()
|
||||
|
||||
def set_location(self, location):
|
||||
"""Sets the location of the trial."""
|
||||
self.location = location
|
||||
# No need to invalidate state cache: location is not stored in json
|
||||
# self.invalidate_json_state()
|
||||
|
||||
def set_status(self, status):
|
||||
"""Sets the status of the trial."""
|
||||
@@ -383,6 +394,15 @@ class Trial:
|
||||
if status == Trial.RUNNING:
|
||||
if self.start_time is None:
|
||||
self.start_time = time.time()
|
||||
self.invalidate_json_state()
|
||||
|
||||
def set_config(self, config):
|
||||
self.config = config
|
||||
self.invalidate_json_state()
|
||||
|
||||
def set_experiment_tag(self, experiment_tag):
|
||||
self.experiment_tag = experiment_tag
|
||||
self.invalidate_json_state()
|
||||
|
||||
def write_error_log(self, error_msg):
|
||||
if error_msg and self.logdir:
|
||||
@@ -393,6 +413,7 @@ class Trial:
|
||||
self.num_failures, date_str()))
|
||||
f.write(error_msg + "\n")
|
||||
self.error_msg = error_msg
|
||||
self.invalidate_json_state()
|
||||
|
||||
def should_stop(self, result):
|
||||
"""Whether the given result meets this trial's stopping criteria."""
|
||||
@@ -426,6 +447,7 @@ class Trial:
|
||||
def clear_checkpoint(self):
|
||||
self.checkpoint.value = None
|
||||
self.restoring_from = None
|
||||
self.invalidate_json_state()
|
||||
|
||||
def on_checkpoint(self, checkpoint):
|
||||
"""Hook for handling checkpoints taken by the Trainable.
|
||||
@@ -434,12 +456,14 @@ class Trial:
|
||||
checkpoint (Checkpoint): Checkpoint taken.
|
||||
"""
|
||||
self.checkpoint_manager.on_checkpoint(checkpoint)
|
||||
self.invalidate_json_state()
|
||||
|
||||
def on_restore(self):
|
||||
"""Handles restoration completion."""
|
||||
assert self.is_restoring
|
||||
self.last_result = self.restoring_from.result
|
||||
self.restoring_from = None
|
||||
self.invalidate_json_state()
|
||||
|
||||
def should_recover(self):
|
||||
"""Returns whether the trial qualifies for retrying.
|
||||
@@ -492,6 +516,7 @@ class Trial:
|
||||
self.metric_analysis[metric][key] = sum(
|
||||
self.metric_n_steps[metric][str(n)]) / len(
|
||||
self.metric_n_steps[metric][str(n)])
|
||||
self.invalidate_json_state()
|
||||
|
||||
def get_trainable_cls(self):
|
||||
return get_trainable_cls(self.trainable_name)
|
||||
@@ -541,6 +566,17 @@ class Trial:
|
||||
generated_dirname += f"_{date_str()}"
|
||||
return generated_dirname.replace("/", "_")
|
||||
|
||||
def invalidate_json_state(self):
|
||||
self._state_valid = False
|
||||
|
||||
def get_json_state(self) -> str:
|
||||
if not self._state_json or not self._state_valid:
|
||||
json_state = json.dumps(
|
||||
self.__getstate__(), indent=2, cls=TuneFunctionEncoder)
|
||||
self._state_json = json_state
|
||||
self._state_valid = True
|
||||
return self._state_json
|
||||
|
||||
def __getstate__(self):
|
||||
"""Memento generator for Trial.
|
||||
|
||||
@@ -558,9 +594,12 @@ class Trial:
|
||||
state["runner"] = None
|
||||
state["location"] = Location()
|
||||
# Avoid waiting for events that will never occur on resume.
|
||||
state["resuming_from"] = None
|
||||
state["restoring_from"] = None
|
||||
state["saving_to"] = None
|
||||
|
||||
state["_state_json"] = None
|
||||
state["_state_valid"] = False
|
||||
|
||||
return copy.deepcopy(state)
|
||||
|
||||
def __setstate__(self, state):
|
||||
|
||||
@@ -26,6 +26,7 @@ class TrialExecutor:
|
||||
"""
|
||||
self._queue_trials = queue_trials
|
||||
self._cached_trial_state = {}
|
||||
self._trials_to_cache = set()
|
||||
|
||||
def set_status(self, trial, status):
|
||||
"""Sets status and checkpoints metadata if needed.
|
||||
@@ -59,14 +60,18 @@ class TrialExecutor:
|
||||
return
|
||||
try:
|
||||
logger.debug("Trial %s: Saving trial metadata.", trial)
|
||||
self._cached_trial_state[trial.trial_id] = trial.__getstate__()
|
||||
# Lazy cache trials
|
||||
self._trials_to_cache.add(trial)
|
||||
except Exception:
|
||||
logger.exception("Trial %s: Error checkpointing trial metadata.",
|
||||
trial)
|
||||
|
||||
def get_checkpoints(self):
|
||||
"""Returns a copy of mapping of the trial ID to pickled metadata."""
|
||||
return self._cached_trial_state.copy()
|
||||
for trial in self._trials_to_cache:
|
||||
self._cached_trial_state[trial.trial_id] = trial.get_json_state()
|
||||
self._trials_to_cache.clear()
|
||||
return self._cached_trial_state
|
||||
|
||||
def has_resources(self, resources):
|
||||
"""Returns whether this runner has at least the specified resources."""
|
||||
|
||||
@@ -5,10 +5,8 @@ import logging
|
||||
import os
|
||||
import time
|
||||
import traceback
|
||||
import types
|
||||
import warnings
|
||||
|
||||
import ray.cloudpickle as cloudpickle
|
||||
from ray.services import get_node_ip_address
|
||||
from ray.tune import TuneError
|
||||
from ray.tune.callback import CallbackList
|
||||
@@ -22,8 +20,9 @@ from ray.tune.schedulers import FIFOScheduler, TrialScheduler
|
||||
from ray.tune.suggest import BasicVariantGenerator
|
||||
from ray.tune.utils import warn_if_slow, flatten_dict, env_integer
|
||||
from ray.tune.utils.log import Verbosity, has_verbosity
|
||||
from ray.tune.utils.serialization import TuneFunctionDecoder, \
|
||||
TuneFunctionEncoder
|
||||
from ray.tune.web_server import TuneServer
|
||||
from ray.utils import binary_to_hex, hex_to_binary
|
||||
from ray.util.debug import log_once
|
||||
|
||||
MAX_DEBUG_TRIALS = 20
|
||||
@@ -40,37 +39,6 @@ def _find_newest_ckpt(ckpt_dir):
|
||||
return max(full_paths)
|
||||
|
||||
|
||||
class _TuneFunctionEncoder(json.JSONEncoder):
|
||||
def default(self, obj):
|
||||
if isinstance(obj, types.FunctionType):
|
||||
return self._to_cloudpickle(obj)
|
||||
try:
|
||||
return super(_TuneFunctionEncoder, self).default(obj)
|
||||
except Exception:
|
||||
logger.debug("Unable to encode. Falling back to cloudpickle.")
|
||||
return self._to_cloudpickle(obj)
|
||||
|
||||
def _to_cloudpickle(self, obj):
|
||||
return {
|
||||
"_type": "CLOUDPICKLE_FALLBACK",
|
||||
"value": binary_to_hex(cloudpickle.dumps(obj))
|
||||
}
|
||||
|
||||
|
||||
class _TuneFunctionDecoder(json.JSONDecoder):
|
||||
def __init__(self, *args, **kwargs):
|
||||
json.JSONDecoder.__init__(
|
||||
self, object_hook=self.object_hook, *args, **kwargs)
|
||||
|
||||
def object_hook(self, obj):
|
||||
if obj.get("_type") == "CLOUDPICKLE_FALLBACK":
|
||||
return self._from_cloudpickle(obj)
|
||||
return obj
|
||||
|
||||
def _from_cloudpickle(self, obj):
|
||||
return cloudpickle.loads(hex_to_binary(obj["value"]))
|
||||
|
||||
|
||||
class TrialRunner:
|
||||
"""A TrialRunner implements the event loop for scheduling trials on Ray.
|
||||
|
||||
@@ -310,7 +278,7 @@ class TrialRunner:
|
||||
tmp_file_name = os.path.join(self._local_checkpoint_dir,
|
||||
".tmp_checkpoint")
|
||||
with open(tmp_file_name, "w") as f:
|
||||
json.dump(runner_state, f, indent=2, cls=_TuneFunctionEncoder)
|
||||
json.dump(runner_state, f, indent=2, cls=TuneFunctionEncoder)
|
||||
|
||||
os.replace(tmp_file_name, self.checkpoint_file)
|
||||
self._search_alg.save_to_dir(
|
||||
@@ -330,7 +298,7 @@ class TrialRunner:
|
||||
"""
|
||||
newest_ckpt_path = _find_newest_ckpt(self._local_checkpoint_dir)
|
||||
with open(newest_ckpt_path, "r") as f:
|
||||
runner_state = json.load(f, cls=_TuneFunctionDecoder)
|
||||
runner_state = json.load(f, cls=TuneFunctionDecoder)
|
||||
self.checkpoint_file = newest_ckpt_path
|
||||
|
||||
logger.warning("".join([
|
||||
@@ -343,8 +311,14 @@ class TrialRunner:
|
||||
if self._search_alg.has_checkpoint(self._local_checkpoint_dir):
|
||||
self._search_alg.restore_from_dir(self._local_checkpoint_dir)
|
||||
|
||||
checkpoints = [
|
||||
json.loads(cp, cls=TuneFunctionDecoder)
|
||||
if isinstance(cp, str) else cp
|
||||
for cp in runner_state["checkpoints"]
|
||||
]
|
||||
|
||||
trials = []
|
||||
for trial_cp in runner_state["checkpoints"]:
|
||||
for trial_cp in checkpoints:
|
||||
new_trial = Trial(trial_cp["trainable_name"])
|
||||
new_trial.__setstate__(trial_cp)
|
||||
trials += [new_trial]
|
||||
|
||||
@@ -248,7 +248,7 @@ def run(
|
||||
trial (of ERROR state) when the experiments complete.
|
||||
callbacks (list): List of callbacks that will be called at different
|
||||
times in the training loop. Must be instances of the
|
||||
``ray.tune.trial_runner.Callback`` class. If not passed,
|
||||
``ray.tune.callback.Callback`` class. If not passed,
|
||||
`LoggerCallback` and `SyncerCallback` callbacks are automatically
|
||||
added.
|
||||
|
||||
|
||||
@@ -0,0 +1,39 @@
|
||||
import json
|
||||
import logging
|
||||
import types
|
||||
|
||||
from ray import cloudpickle as cloudpickle
|
||||
from ray.utils import binary_to_hex, hex_to_binary
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class TuneFunctionEncoder(json.JSONEncoder):
|
||||
def default(self, obj):
|
||||
if isinstance(obj, types.FunctionType):
|
||||
return self._to_cloudpickle(obj)
|
||||
try:
|
||||
return super(TuneFunctionEncoder, self).default(obj)
|
||||
except Exception:
|
||||
logger.debug("Unable to encode. Falling back to cloudpickle.")
|
||||
return self._to_cloudpickle(obj)
|
||||
|
||||
def _to_cloudpickle(self, obj):
|
||||
return {
|
||||
"_type": "CLOUDPICKLE_FALLBACK",
|
||||
"value": binary_to_hex(cloudpickle.dumps(obj))
|
||||
}
|
||||
|
||||
|
||||
class TuneFunctionDecoder(json.JSONDecoder):
|
||||
def __init__(self, *args, **kwargs):
|
||||
json.JSONDecoder.__init__(
|
||||
self, object_hook=self.object_hook, *args, **kwargs)
|
||||
|
||||
def object_hook(self, obj):
|
||||
if obj.get("_type") == "CLOUDPICKLE_FALLBACK":
|
||||
return self._from_cloudpickle(obj)
|
||||
return obj
|
||||
|
||||
def _from_cloudpickle(self, obj):
|
||||
return cloudpickle.loads(hex_to_binary(obj["value"]))
|
||||
Reference in New Issue
Block a user