Files
ray/python/ray/tune/checkpoint_manager.py
T
Ujval Misra 2fca550096 [tune] Prevent MEMORY checkpoints from breaking trial FT (#6691)
* Prevent MEMORY checkpoints from breaking FT

* Add save/pause/resume/restore test

* change checkpoint return value based on status

* Fix test_checkpoint_manager_tests.

* Fix test + checkpoint manager bug

* lint

* Add docstring

* Add docstring to checkpoint_manager constructor

* Change variable name for clarity

* Revert on_checkpoint docstring wording

* Break after success

* nit: more informative warning

* Quarantine test
2020-01-23 14:03:40 -08:00

134 lines
4.8 KiB
Python

# coding: utf-8
import heapq
import logging
logger = logging.getLogger(__name__)
class Checkpoint:
"""Describes a checkpoint of trial state.
Checkpoint may be saved in different storage.
Attributes:
storage (str): Storage type.
value (str): If storage==MEMORY, it is a Python object.
If storage==PERSISTENT, it is a path to persistent storage.
"""
MEMORY = "memory"
PERSISTENT = "persistent"
def __init__(self, storage, value, result=None):
self.storage = storage
self.value = value
self.result = result or {}
@staticmethod
def from_object(value=None):
"""Creates a checkpoint from a Python object."""
return Checkpoint(Checkpoint.MEMORY, value)
class QueueItem:
def __init__(self, priority, value):
self.priority = priority
self.value = value
def __lt__(self, other):
return self.priority < other.priority
class CheckpointManager:
"""Manages checkpoints on the driver for a trial."""
def __init__(self, keep_checkpoints_num, checkpoint_score_attr, delete_fn):
"""Initializes a new CheckpointManager.
`newest_persistent_checkpoint` and `newest_memory_checkpoint` are
initialized to Checkpoint objects with values of None.
Args:
keep_checkpoints_num (int): Keep at least this many checkpoints.
checkpoint_score_attr (str): Attribute to use to determine which
checkpoints to keep.
delete_fn (function): Function that deletes checkpoints. Must be
idempotent.
"""
self.keep_checkpoints_num = keep_checkpoints_num or float("inf")
assert self.keep_checkpoints_num > 0, (
"keep_checkpoints_num must be greater than 0.")
self._checkpoint_score_desc = checkpoint_score_attr.startswith("min-")
if self._checkpoint_score_desc:
self._checkpoint_score_attr = checkpoint_score_attr[4:]
else:
self._checkpoint_score_attr = checkpoint_score_attr
self.delete = delete_fn
self.newest_persistent_checkpoint = Checkpoint(Checkpoint.PERSISTENT,
None)
self.newest_memory_checkpoint = Checkpoint(Checkpoint.MEMORY, None)
self._best_checkpoints = []
self._membership = set()
def on_checkpoint(self, checkpoint):
"""Starts tracking checkpoint metadata on checkpoint.
Sets the newest checkpoint. For PERSISTENT checkpoints: Deletes
previous checkpoint as long as it isn't one of the best ones. Also
deletes the worst checkpoint if at capacity.
Args:
checkpoint (Checkpoint): Trial state checkpoint.
"""
if checkpoint.storage == Checkpoint.MEMORY:
self.newest_memory_checkpoint = checkpoint
return
old_checkpoint = self.newest_persistent_checkpoint
self.newest_persistent_checkpoint = checkpoint
# Remove the old checkpoint if it isn't one of the best ones.
if old_checkpoint.value and old_checkpoint not in self._membership:
self.delete(old_checkpoint)
try:
queue_item = QueueItem(self._priority(checkpoint), checkpoint)
except KeyError:
logger.error("Result dict has no key: {}. "
"checkpoint_score_attr must be set to a key in the "
"result dict.".format(self._checkpoint_score_attr))
return
if len(self._best_checkpoints) < self.keep_checkpoints_num:
heapq.heappush(self._best_checkpoints, queue_item)
self._membership.add(checkpoint)
elif queue_item.priority >= self._best_checkpoints[0].priority:
worst = heapq.heappushpop(self._best_checkpoints, queue_item).value
self._membership.add(checkpoint)
if worst in self._membership:
self._membership.remove(worst)
# Don't delete the newest checkpoint. It will be deleted on the
# next on_checkpoint() call since it isn't in self._membership.
if worst != checkpoint:
self.delete(worst)
def best_checkpoints(self):
"""Returns best checkpoints, sorted by score."""
checkpoints = sorted(self._best_checkpoints, key=lambda c: c.priority)
return [queue_item.value for queue_item in checkpoints]
def _priority(self, checkpoint):
priority = checkpoint.result[self._checkpoint_score_attr]
return -priority if self._checkpoint_score_desc else priority
def __getstate__(self):
state = self.__dict__.copy()
# Avoid serializing lambda since it may capture cyclical dependencies.
state.pop("delete")
return state
def __setstate__(self, state):
self.__dict__.update(state)
self.delete = None