[tune] cache checkpoint serialization (#12064)

This commit is contained in:
Kai Fricke
2020-11-18 18:03:53 +01:00
committed by GitHub
parent 6da4342822
commit 2b60c5774b
10 changed files with 141 additions and 53 deletions
@@ -5,6 +5,7 @@ from numbers import Number
from typing import Any, Dict, List, Optional, Tuple
from ray.tune.utils import flatten_dict
from ray.tune.utils.serialization import TuneFunctionDecoder
from ray.tune.utils.util import is_nan_or_inf
try:
@@ -338,12 +339,16 @@ class ExperimentAnalysis(Analysis):
raise ValueError(
"{} is not a valid file.".format(experiment_checkpoint_path))
with open(experiment_checkpoint_path) as f:
_experiment_state = json.load(f)
_experiment_state = json.load(f, cls=TuneFunctionDecoder)
self._experiment_state = _experiment_state
if "checkpoints" not in _experiment_state:
raise TuneError("Experiment state invalid; no checkpoints found.")
self._checkpoints = _experiment_state["checkpoints"]
self._checkpoints = [
json.loads(cp, cls=TuneFunctionDecoder)
if isinstance(cp, str) else cp
for cp in _experiment_state["checkpoints"]
]
self.trials = trials
super(ExperimentAnalysis, self).__init__(
+3
View File
@@ -114,6 +114,8 @@ class AutoMLSearcher(SearchAlgorithm):
trial.param_config = param_config
trial.extra_arg = extra_arg
trial.invalidate_json_state()
trials.append(trial)
self._running_trials[trial.trial_id] = trial
@@ -142,6 +144,7 @@ class AutoMLSearcher(SearchAlgorithm):
or result[self.reward_attr] \
> trial.best_result[self.reward_attr]:
trial.best_result = result
trial.invalidate_json_state()
# Update job's best trial
if self.best_trial is None \
+2 -2
View File
@@ -404,8 +404,8 @@ class RayTrialExecutor(TrialExecutor):
Returns:
True if `reset_config` is successful else False.
"""
trial.experiment_tag = new_experiment_tag
trial.config = new_config
trial.set_experiment_tag(new_experiment_tag)
trial.set_config(new_config)
trainable = trial.runner
with self._change_working_directory(trial):
with warn_if_slow("reset"):
+7 -7
View File
@@ -557,8 +557,8 @@ class PopulationBasedTraining(FIFOScheduler):
raise TuneError("Trials should be paused here only if in "
"synchronous mode. If you encounter this error"
" please raise an issue on Ray Github.")
trial.config = new_config
trial.experiment_tag = new_tag
trial.set_experiment_tag(new_tag)
trial.set_config(new_config)
trial.on_checkpoint(new_state.last_checkpoint)
else:
# If trial is running, we first try to reset it.
@@ -575,8 +575,8 @@ class PopulationBasedTraining(FIFOScheduler):
trial, new_state.last_checkpoint, block=True)
else:
trial_executor.stop_trial(trial)
trial.config = new_config
trial.experiment_tag = new_tag
trial.set_experiment_tag(new_tag)
trial.set_config(new_config)
trial_executor.start_trial(
trial, new_state.last_checkpoint, train=False)
@@ -761,7 +761,7 @@ class PopulationBasedTrainingReplay(FIFOScheduler):
"No replay policy found and trial initialized without a "
"valid config. Either pass a `config` argument to `tune.run()`"
"or consider not using PBT replay for this run.")
self._trial.config = self.config
self._trial.set_config(self.config)
def on_trial_result(self, trial_runner: "trial_runner.TrialRunner",
trial: Trial, result: Dict) -> str:
@@ -800,8 +800,8 @@ class PopulationBasedTrainingReplay(FIFOScheduler):
trial_executor.restore(trial, checkpoint, block=True)
else:
trial_executor.stop_trial(trial, stop_logger=False)
trial.config = new_config
trial.experiment_tag = new_tag
trial.set_experiment_tag(new_tag)
trial.set_config(new_config)
trial_executor.start_trial(trial, checkpoint, train=False)
self.current_config = new_config
@@ -12,6 +12,7 @@ import ray
from ray.tune import (run, Trainable, sample_from, Analysis,
ExperimentAnalysis, grid_search)
from ray.tune.utils.mock import MyTrainableClass
from ray.tune.utils.serialization import TuneFunctionEncoder
class ExperimentAnalysisInMemorySuite(unittest.TestCase):
@@ -54,7 +55,8 @@ class ExperimentAnalysisInMemorySuite(unittest.TestCase):
def tearDown(self):
shutil.rmtree(self.test_dir, ignore_errors=True)
def testInit(self):
def testInitLegacy(self):
"""Should still work if checkpoints are not json strings"""
experiment_checkpoint_path = os.path.join(self.test_dir,
"experiment_state.json")
checkpoint_data = {
@@ -71,6 +73,27 @@ class ExperimentAnalysisInMemorySuite(unittest.TestCase):
self.assertEqual(len(experiment_analysis._checkpoints), 1)
self.assertTrue(experiment_analysis.trials is None)
def testInit(self):
experiment_checkpoint_path = os.path.join(self.test_dir,
"experiment_state.json")
checkpoint_data = {
"checkpoints": [
json.dumps(
{
"trainable_name": "MockTrainable",
"logdir": "/mock/test/MockTrainable_0_id=3_2020-07-12"
},
cls=TuneFunctionEncoder)
]
}
with open(experiment_checkpoint_path, "w") as f:
f.write(json.dumps(checkpoint_data))
experiment_analysis = ExperimentAnalysis(experiment_checkpoint_path)
self.assertEqual(len(experiment_analysis._checkpoints), 1)
self.assertTrue(experiment_analysis.trials is None)
def testInitException(self):
experiment_checkpoint_path = os.path.join(self.test_dir, "mock.json")
with pytest.raises(ValueError):
+40 -1
View File
@@ -1,3 +1,4 @@
import json
from typing import Sequence
import ray.cloudpickle as cloudpickle
@@ -18,6 +19,7 @@ from ray.tune.checkpoint_manager import Checkpoint, CheckpointManager
from ray.tune.registry import get_trainable_cls, validate_trainable
from ray.tune.result import DEFAULT_RESULTS_DIR, DONE, TRAINING_ITERATION
from ray.tune.resources import Resources, json_to_resources, resources_to_json
from ray.tune.utils.serialization import TuneFunctionEncoder
from ray.tune.utils.trainable import TrainableUtil
from ray.tune.utils import date_str, flatten_dict
from ray.utils import binary_to_hex, hex_to_binary
@@ -294,6 +296,9 @@ class Trial:
raise ValueError(f"Trial dirname must not contain '/'. "
"Got {self.custom_dirname}")
self._state_json = None
self._state_valid = False
@property
def node_ip(self):
return self.location.hostname
@@ -355,6 +360,7 @@ class Trial:
self.local_dir)
else:
os.makedirs(self.logdir, exist_ok=True)
self.invalidate_json_state()
def update_resources(self, cpu, gpu, **kwargs):
"""EXPERIMENTAL: Updates the resource requirements.
@@ -367,15 +373,20 @@ class Trial:
if self.status is Trial.RUNNING:
raise ValueError("Cannot update resources while Trial is running.")
self.resources = Resources(cpu, gpu, **kwargs)
self.invalidate_json_state()
def set_runner(self, runner):
self.runner = runner
self.checkpoint_manager.delete = checkpoint_deleter(
self._trainable_name(), runner)
# No need to invalidate state cache: runner is not stored in json
# self.invalidate_json_state()
def set_location(self, location):
"""Sets the location of the trial."""
self.location = location
# No need to invalidate state cache: location is not stored in json
# self.invalidate_json_state()
def set_status(self, status):
"""Sets the status of the trial."""
@@ -383,6 +394,15 @@ class Trial:
if status == Trial.RUNNING:
if self.start_time is None:
self.start_time = time.time()
self.invalidate_json_state()
def set_config(self, config):
self.config = config
self.invalidate_json_state()
def set_experiment_tag(self, experiment_tag):
self.experiment_tag = experiment_tag
self.invalidate_json_state()
def write_error_log(self, error_msg):
if error_msg and self.logdir:
@@ -393,6 +413,7 @@ class Trial:
self.num_failures, date_str()))
f.write(error_msg + "\n")
self.error_msg = error_msg
self.invalidate_json_state()
def should_stop(self, result):
"""Whether the given result meets this trial's stopping criteria."""
@@ -426,6 +447,7 @@ class Trial:
def clear_checkpoint(self):
self.checkpoint.value = None
self.restoring_from = None
self.invalidate_json_state()
def on_checkpoint(self, checkpoint):
"""Hook for handling checkpoints taken by the Trainable.
@@ -434,12 +456,14 @@ class Trial:
checkpoint (Checkpoint): Checkpoint taken.
"""
self.checkpoint_manager.on_checkpoint(checkpoint)
self.invalidate_json_state()
def on_restore(self):
"""Handles restoration completion."""
assert self.is_restoring
self.last_result = self.restoring_from.result
self.restoring_from = None
self.invalidate_json_state()
def should_recover(self):
"""Returns whether the trial qualifies for retrying.
@@ -492,6 +516,7 @@ class Trial:
self.metric_analysis[metric][key] = sum(
self.metric_n_steps[metric][str(n)]) / len(
self.metric_n_steps[metric][str(n)])
self.invalidate_json_state()
def get_trainable_cls(self):
return get_trainable_cls(self.trainable_name)
@@ -541,6 +566,17 @@ class Trial:
generated_dirname += f"_{date_str()}"
return generated_dirname.replace("/", "_")
def invalidate_json_state(self):
self._state_valid = False
def get_json_state(self) -> str:
if not self._state_json or not self._state_valid:
json_state = json.dumps(
self.__getstate__(), indent=2, cls=TuneFunctionEncoder)
self._state_json = json_state
self._state_valid = True
return self._state_json
def __getstate__(self):
"""Memento generator for Trial.
@@ -558,9 +594,12 @@ class Trial:
state["runner"] = None
state["location"] = Location()
# Avoid waiting for events that will never occur on resume.
state["resuming_from"] = None
state["restoring_from"] = None
state["saving_to"] = None
state["_state_json"] = None
state["_state_valid"] = False
return copy.deepcopy(state)
def __setstate__(self, state):
+7 -2
View File
@@ -26,6 +26,7 @@ class TrialExecutor:
"""
self._queue_trials = queue_trials
self._cached_trial_state = {}
self._trials_to_cache = set()
def set_status(self, trial, status):
"""Sets status and checkpoints metadata if needed.
@@ -59,14 +60,18 @@ class TrialExecutor:
return
try:
logger.debug("Trial %s: Saving trial metadata.", trial)
self._cached_trial_state[trial.trial_id] = trial.__getstate__()
# Lazy cache trials
self._trials_to_cache.add(trial)
except Exception:
logger.exception("Trial %s: Error checkpointing trial metadata.",
trial)
def get_checkpoints(self):
"""Returns a copy of mapping of the trial ID to pickled metadata."""
return self._cached_trial_state.copy()
for trial in self._trials_to_cache:
self._cached_trial_state[trial.trial_id] = trial.get_json_state()
self._trials_to_cache.clear()
return self._cached_trial_state
def has_resources(self, resources):
"""Returns whether this runner has at least the specified resources."""
+11 -37
View File
@@ -5,10 +5,8 @@ import logging
import os
import time
import traceback
import types
import warnings
import ray.cloudpickle as cloudpickle
from ray.services import get_node_ip_address
from ray.tune import TuneError
from ray.tune.callback import CallbackList
@@ -22,8 +20,9 @@ from ray.tune.schedulers import FIFOScheduler, TrialScheduler
from ray.tune.suggest import BasicVariantGenerator
from ray.tune.utils import warn_if_slow, flatten_dict, env_integer
from ray.tune.utils.log import Verbosity, has_verbosity
from ray.tune.utils.serialization import TuneFunctionDecoder, \
TuneFunctionEncoder
from ray.tune.web_server import TuneServer
from ray.utils import binary_to_hex, hex_to_binary
from ray.util.debug import log_once
MAX_DEBUG_TRIALS = 20
@@ -40,37 +39,6 @@ def _find_newest_ckpt(ckpt_dir):
return max(full_paths)
class _TuneFunctionEncoder(json.JSONEncoder):
def default(self, obj):
if isinstance(obj, types.FunctionType):
return self._to_cloudpickle(obj)
try:
return super(_TuneFunctionEncoder, self).default(obj)
except Exception:
logger.debug("Unable to encode. Falling back to cloudpickle.")
return self._to_cloudpickle(obj)
def _to_cloudpickle(self, obj):
return {
"_type": "CLOUDPICKLE_FALLBACK",
"value": binary_to_hex(cloudpickle.dumps(obj))
}
class _TuneFunctionDecoder(json.JSONDecoder):
def __init__(self, *args, **kwargs):
json.JSONDecoder.__init__(
self, object_hook=self.object_hook, *args, **kwargs)
def object_hook(self, obj):
if obj.get("_type") == "CLOUDPICKLE_FALLBACK":
return self._from_cloudpickle(obj)
return obj
def _from_cloudpickle(self, obj):
return cloudpickle.loads(hex_to_binary(obj["value"]))
class TrialRunner:
"""A TrialRunner implements the event loop for scheduling trials on Ray.
@@ -310,7 +278,7 @@ class TrialRunner:
tmp_file_name = os.path.join(self._local_checkpoint_dir,
".tmp_checkpoint")
with open(tmp_file_name, "w") as f:
json.dump(runner_state, f, indent=2, cls=_TuneFunctionEncoder)
json.dump(runner_state, f, indent=2, cls=TuneFunctionEncoder)
os.replace(tmp_file_name, self.checkpoint_file)
self._search_alg.save_to_dir(
@@ -330,7 +298,7 @@ class TrialRunner:
"""
newest_ckpt_path = _find_newest_ckpt(self._local_checkpoint_dir)
with open(newest_ckpt_path, "r") as f:
runner_state = json.load(f, cls=_TuneFunctionDecoder)
runner_state = json.load(f, cls=TuneFunctionDecoder)
self.checkpoint_file = newest_ckpt_path
logger.warning("".join([
@@ -343,8 +311,14 @@ class TrialRunner:
if self._search_alg.has_checkpoint(self._local_checkpoint_dir):
self._search_alg.restore_from_dir(self._local_checkpoint_dir)
checkpoints = [
json.loads(cp, cls=TuneFunctionDecoder)
if isinstance(cp, str) else cp
for cp in runner_state["checkpoints"]
]
trials = []
for trial_cp in runner_state["checkpoints"]:
for trial_cp in checkpoints:
new_trial = Trial(trial_cp["trainable_name"])
new_trial.__setstate__(trial_cp)
trials += [new_trial]
+1 -1
View File
@@ -248,7 +248,7 @@ def run(
trial (of ERROR state) when the experiments complete.
callbacks (list): List of callbacks that will be called at different
times in the training loop. Must be instances of the
``ray.tune.trial_runner.Callback`` class. If not passed,
``ray.tune.callback.Callback`` class. If not passed,
`LoggerCallback` and `SyncerCallback` callbacks are automatically
added.
+39
View File
@@ -0,0 +1,39 @@
import json
import logging
import types
from ray import cloudpickle as cloudpickle
from ray.utils import binary_to_hex, hex_to_binary
logger = logging.getLogger(__name__)
class TuneFunctionEncoder(json.JSONEncoder):
def default(self, obj):
if isinstance(obj, types.FunctionType):
return self._to_cloudpickle(obj)
try:
return super(TuneFunctionEncoder, self).default(obj)
except Exception:
logger.debug("Unable to encode. Falling back to cloudpickle.")
return self._to_cloudpickle(obj)
def _to_cloudpickle(self, obj):
return {
"_type": "CLOUDPICKLE_FALLBACK",
"value": binary_to_hex(cloudpickle.dumps(obj))
}
class TuneFunctionDecoder(json.JSONDecoder):
def __init__(self, *args, **kwargs):
json.JSONDecoder.__init__(
self, object_hook=self.object_hook, *args, **kwargs)
def object_hook(self, obj):
if obj.get("_type") == "CLOUDPICKLE_FALLBACK":
return self._from_cloudpickle(obj)
return obj
def _from_cloudpickle(self, obj):
return cloudpickle.loads(hex_to_binary(obj["value"]))