[tune] Function API checkpointing (#8471)

Co-authored-by: krfricke <krfricke@users.noreply.github.com>
This commit is contained in:
Richard Liaw
2020-06-15 10:42:54 -07:00
committed by GitHub
parent 91e57f2e53
commit 6c49c01837
21 changed files with 897 additions and 237 deletions
+131 -22
View File
@@ -1,13 +1,18 @@
import logging
import os
import io
import time
import inspect
import shutil
import threading
import traceback
from six.moves import queue
from ray.tune import TuneError, session
from ray.tune.trainable import Trainable
from ray.tune.result import TIME_THIS_ITER_S, RESULT_DUPLICATE
from ray.tune.trainable import Trainable, TrainableUtil
from ray.tune.result import (TIME_THIS_ITER_S, RESULT_DUPLICATE,
SHOULD_CHECKPOINT)
logger = logging.getLogger(__name__)
@@ -40,6 +45,8 @@ class StatusReporter:
self._trial_name = trial_name
self._trial_id = trial_id
self._logdir = logdir
self._last_checkpoint = {}
self._fresh_checkpoint = False
def __call__(self, **kwargs):
"""Report updated training status.
@@ -77,6 +84,29 @@ class StatusReporter:
# resume training.
self._continue_semaphore.acquire()
def make_checkpoint_dir(self, step=None):
checkpoint_dir = TrainableUtil.make_checkpoint_dir(
self.logdir, index=step)
return checkpoint_dir
def save_checkpoint(self, checkpoint):
if isinstance(checkpoint, str):
try:
TrainableUtil.find_checkpoint_dir(checkpoint)
except FileNotFoundError:
logger.error("Checkpoint must be created with path given from "
"make_checkpoint_dir.")
raise
self._last_checkpoint = checkpoint
self._fresh_checkpoint = True
def has_new_checkpoint(self):
return self._fresh_checkpoint
def get_checkpoint(self):
self._fresh_checkpoint = False
return self._last_checkpoint
def _start(self):
self._last_report_time = time.time()
@@ -155,21 +185,33 @@ class FunctionRunner(Trainable):
trial_id=self.trial_id,
logdir=self.logdir)
self._last_result = {}
config = config.copy()
session.init(self._status_reporter)
def entrypoint():
return self._trainable_func(config, self._status_reporter)
# the runner thread is not started until the first call to _train
self._runner = _RunnerThread(entrypoint, self._error_queue)
self._runner = None
self._restore_tmpdir = None
self.default_checkpoint_dir = None
def _trainable_func(self):
"""Subclasses can override this to set the trainable func."""
raise NotImplementedError
def _start(self):
def entrypoint():
return self._trainable_func(self.config, self._status_reporter,
self._status_reporter.get_checkpoint())
# the runner thread is not started until the first call to _train
self._runner = _RunnerThread(entrypoint, self._error_queue)
# if not alive, try to start
self._status_reporter._start()
try:
self._runner.start()
except RuntimeError:
# If this is reached, it means the thread was started and is
# now done or has raised an exception.
pass
def _train(self):
"""Implements train() for a Function API.
@@ -178,19 +220,12 @@ class FunctionRunner(Trainable):
along with a result with "done=True". The TrialRunner will handle the
result accordingly (see tune/trial_runner.py).
"""
if self._runner.is_alive():
if self._runner and self._runner.is_alive():
# if started and alive, inform the reporter to continue and
# generate the next result
self._continue_semaphore.release()
else:
# if not alive, try to start
self._status_reporter._start()
try:
self._runner.start()
except RuntimeError:
# If this is reached, it means the thread was started and is
# now done or has raised an exception.
pass
self._start()
result = None
while result is None and self._runner.is_alive():
@@ -240,8 +275,61 @@ class FunctionRunner(Trainable):
result = new_result
self._last_result = result
if self._status_reporter.has_new_checkpoint():
result[SHOULD_CHECKPOINT] = True
return result
def create_default_checkpoint_dir(self):
self.default_checkpoint_dir = TrainableUtil.make_checkpoint_dir(
self.logdir, index="default")
return self.default_checkpoint_dir
def save(self, checkpoint_path=None):
if checkpoint_path:
raise ValueError(
"Checkpoint path should not be used with function API.")
checkpoint = self._status_reporter.get_checkpoint()
state = self.get_state()
if not checkpoint:
state.update(iteration=0, timesteps_total=0, episodes_total=0)
parent_dir = self.create_default_checkpoint_dir()
elif isinstance(checkpoint, dict):
parent_dir = TrainableUtil.make_checkpoint_dir(
self.logdir, index=self.training_iteration)
else:
parent_dir = TrainableUtil.find_checkpoint_dir(checkpoint)
checkpoint_path = TrainableUtil.process_checkpoint(
checkpoint, parent_dir, state)
return checkpoint_path
def save_to_object(self):
checkpoint_path = self.save()
data_dict = TrainableUtil.pickle_checkpoint(checkpoint_path)
out = io.BytesIO()
if len(data_dict) > 10e6: # getting pretty large
logger.info("Checkpoint size is {} bytes".format(len(data_dict)))
out.write(data_dict)
return out.getvalue()
def _restore(self, checkpoint):
# This should be removed once Trainables are refactored.
if "tune_checkpoint_path" in checkpoint:
del checkpoint["tune_checkpoint_path"]
self._status_reporter.save_checkpoint(checkpoint)
def restore_from_object(self, obj):
if self.default_checkpoint_dir is not None and os.exists(
self.default_checkpoint_dir):
shutil.rmtree(self.default_checkpoint_dir)
logger.debug("Clearing default checkpoint: %s",
self.default_checkpoint_dir)
checkpoint_dir = self.create_default_checkpoint_dir()
checkpoint_path = TrainableUtil.create_from_pickle(obj, checkpoint_dir)
self.restore(checkpoint_path)
def _stop(self):
# If everything stayed in synch properly, this should never happen.
if not self._results_queue.empty():
@@ -251,7 +339,6 @@ class FunctionRunner(Trainable):
# Check for any errors that might have been missed.
self._report_thread_runner_error()
session.shutdown()
def _report_thread_runner_error(self, block=False):
@@ -264,13 +351,35 @@ class FunctionRunner(Trainable):
pass
def detect_checkpoint_function(train_func):
func_args = inspect.getfullargspec(train_func).args
use_checkpoint = "checkpoint" in func_args
return use_checkpoint
def wrap_function(train_func):
class ImplicitFunc(FunctionRunner):
def _trainable_func(self, config, reporter):
def _trainable_func(self, config, reporter, checkpoint):
func_args = inspect.getfullargspec(train_func).args
use_track = ("reporter" not in func_args and len(func_args) == 1)
if use_track:
if len(func_args) > 1: # more arguments than just the config
if "reporter" not in func_args and (
"checkpoint" not in func_args):
raise ValueError(
"Unknown argument found in the Trainable function. "
"Arguments other than the 'config' arg must be one "
"of ['reporter', 'checkpoint']. Found: {}".format(
func_args))
use_reporter = "reporter" in func_args
use_checkpoint = "checkpoint" in func_args
if not use_checkpoint and not use_reporter:
logger.warning(
"Function checkpointing is disabled. This may result in "
"unexpected behavior when using checkpointing features or "
"certain schedulers. To enable, set the train function "
"arguments to be `func(config, checkpoint)`.")
output = train_func(config)
elif use_checkpoint:
output = train_func(config, checkpoint=checkpoint)
else:
output = train_func(config, reporter)