mirror of
https://github.com/wassname/ray.git
synced 2026-07-02 21:56:20 +08:00
[tune] Function API checkpointing (#8471)
Co-authored-by: krfricke <krfricke@users.noreply.github.com>
This commit is contained in:
@@ -1,13 +1,18 @@
|
||||
import logging
|
||||
import os
|
||||
import io
|
||||
import time
|
||||
import inspect
|
||||
import shutil
|
||||
import threading
|
||||
import traceback
|
||||
|
||||
from six.moves import queue
|
||||
|
||||
from ray.tune import TuneError, session
|
||||
from ray.tune.trainable import Trainable
|
||||
from ray.tune.result import TIME_THIS_ITER_S, RESULT_DUPLICATE
|
||||
from ray.tune.trainable import Trainable, TrainableUtil
|
||||
from ray.tune.result import (TIME_THIS_ITER_S, RESULT_DUPLICATE,
|
||||
SHOULD_CHECKPOINT)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -40,6 +45,8 @@ class StatusReporter:
|
||||
self._trial_name = trial_name
|
||||
self._trial_id = trial_id
|
||||
self._logdir = logdir
|
||||
self._last_checkpoint = {}
|
||||
self._fresh_checkpoint = False
|
||||
|
||||
def __call__(self, **kwargs):
|
||||
"""Report updated training status.
|
||||
@@ -77,6 +84,29 @@ class StatusReporter:
|
||||
# resume training.
|
||||
self._continue_semaphore.acquire()
|
||||
|
||||
def make_checkpoint_dir(self, step=None):
|
||||
checkpoint_dir = TrainableUtil.make_checkpoint_dir(
|
||||
self.logdir, index=step)
|
||||
return checkpoint_dir
|
||||
|
||||
def save_checkpoint(self, checkpoint):
|
||||
if isinstance(checkpoint, str):
|
||||
try:
|
||||
TrainableUtil.find_checkpoint_dir(checkpoint)
|
||||
except FileNotFoundError:
|
||||
logger.error("Checkpoint must be created with path given from "
|
||||
"make_checkpoint_dir.")
|
||||
raise
|
||||
self._last_checkpoint = checkpoint
|
||||
self._fresh_checkpoint = True
|
||||
|
||||
def has_new_checkpoint(self):
|
||||
return self._fresh_checkpoint
|
||||
|
||||
def get_checkpoint(self):
|
||||
self._fresh_checkpoint = False
|
||||
return self._last_checkpoint
|
||||
|
||||
def _start(self):
|
||||
self._last_report_time = time.time()
|
||||
|
||||
@@ -155,21 +185,33 @@ class FunctionRunner(Trainable):
|
||||
trial_id=self.trial_id,
|
||||
logdir=self.logdir)
|
||||
self._last_result = {}
|
||||
config = config.copy()
|
||||
|
||||
session.init(self._status_reporter)
|
||||
|
||||
def entrypoint():
|
||||
return self._trainable_func(config, self._status_reporter)
|
||||
|
||||
# the runner thread is not started until the first call to _train
|
||||
self._runner = _RunnerThread(entrypoint, self._error_queue)
|
||||
self._runner = None
|
||||
self._restore_tmpdir = None
|
||||
self.default_checkpoint_dir = None
|
||||
|
||||
def _trainable_func(self):
|
||||
"""Subclasses can override this to set the trainable func."""
|
||||
|
||||
raise NotImplementedError
|
||||
|
||||
def _start(self):
|
||||
def entrypoint():
|
||||
return self._trainable_func(self.config, self._status_reporter,
|
||||
self._status_reporter.get_checkpoint())
|
||||
|
||||
# the runner thread is not started until the first call to _train
|
||||
self._runner = _RunnerThread(entrypoint, self._error_queue)
|
||||
# if not alive, try to start
|
||||
self._status_reporter._start()
|
||||
try:
|
||||
self._runner.start()
|
||||
except RuntimeError:
|
||||
# If this is reached, it means the thread was started and is
|
||||
# now done or has raised an exception.
|
||||
pass
|
||||
|
||||
def _train(self):
|
||||
"""Implements train() for a Function API.
|
||||
|
||||
@@ -178,19 +220,12 @@ class FunctionRunner(Trainable):
|
||||
along with a result with "done=True". The TrialRunner will handle the
|
||||
result accordingly (see tune/trial_runner.py).
|
||||
"""
|
||||
if self._runner.is_alive():
|
||||
if self._runner and self._runner.is_alive():
|
||||
# if started and alive, inform the reporter to continue and
|
||||
# generate the next result
|
||||
self._continue_semaphore.release()
|
||||
else:
|
||||
# if not alive, try to start
|
||||
self._status_reporter._start()
|
||||
try:
|
||||
self._runner.start()
|
||||
except RuntimeError:
|
||||
# If this is reached, it means the thread was started and is
|
||||
# now done or has raised an exception.
|
||||
pass
|
||||
self._start()
|
||||
|
||||
result = None
|
||||
while result is None and self._runner.is_alive():
|
||||
@@ -240,8 +275,61 @@ class FunctionRunner(Trainable):
|
||||
result = new_result
|
||||
|
||||
self._last_result = result
|
||||
if self._status_reporter.has_new_checkpoint():
|
||||
result[SHOULD_CHECKPOINT] = True
|
||||
return result
|
||||
|
||||
def create_default_checkpoint_dir(self):
|
||||
self.default_checkpoint_dir = TrainableUtil.make_checkpoint_dir(
|
||||
self.logdir, index="default")
|
||||
return self.default_checkpoint_dir
|
||||
|
||||
def save(self, checkpoint_path=None):
|
||||
if checkpoint_path:
|
||||
raise ValueError(
|
||||
"Checkpoint path should not be used with function API.")
|
||||
|
||||
checkpoint = self._status_reporter.get_checkpoint()
|
||||
state = self.get_state()
|
||||
|
||||
if not checkpoint:
|
||||
state.update(iteration=0, timesteps_total=0, episodes_total=0)
|
||||
parent_dir = self.create_default_checkpoint_dir()
|
||||
elif isinstance(checkpoint, dict):
|
||||
parent_dir = TrainableUtil.make_checkpoint_dir(
|
||||
self.logdir, index=self.training_iteration)
|
||||
else:
|
||||
parent_dir = TrainableUtil.find_checkpoint_dir(checkpoint)
|
||||
checkpoint_path = TrainableUtil.process_checkpoint(
|
||||
checkpoint, parent_dir, state)
|
||||
return checkpoint_path
|
||||
|
||||
def save_to_object(self):
|
||||
checkpoint_path = self.save()
|
||||
data_dict = TrainableUtil.pickle_checkpoint(checkpoint_path)
|
||||
out = io.BytesIO()
|
||||
if len(data_dict) > 10e6: # getting pretty large
|
||||
logger.info("Checkpoint size is {} bytes".format(len(data_dict)))
|
||||
out.write(data_dict)
|
||||
return out.getvalue()
|
||||
|
||||
def _restore(self, checkpoint):
|
||||
# This should be removed once Trainables are refactored.
|
||||
if "tune_checkpoint_path" in checkpoint:
|
||||
del checkpoint["tune_checkpoint_path"]
|
||||
self._status_reporter.save_checkpoint(checkpoint)
|
||||
|
||||
def restore_from_object(self, obj):
|
||||
if self.default_checkpoint_dir is not None and os.exists(
|
||||
self.default_checkpoint_dir):
|
||||
shutil.rmtree(self.default_checkpoint_dir)
|
||||
logger.debug("Clearing default checkpoint: %s",
|
||||
self.default_checkpoint_dir)
|
||||
|
||||
checkpoint_dir = self.create_default_checkpoint_dir()
|
||||
checkpoint_path = TrainableUtil.create_from_pickle(obj, checkpoint_dir)
|
||||
self.restore(checkpoint_path)
|
||||
|
||||
def _stop(self):
|
||||
# If everything stayed in synch properly, this should never happen.
|
||||
if not self._results_queue.empty():
|
||||
@@ -251,7 +339,6 @@ class FunctionRunner(Trainable):
|
||||
|
||||
# Check for any errors that might have been missed.
|
||||
self._report_thread_runner_error()
|
||||
|
||||
session.shutdown()
|
||||
|
||||
def _report_thread_runner_error(self, block=False):
|
||||
@@ -264,13 +351,35 @@ class FunctionRunner(Trainable):
|
||||
pass
|
||||
|
||||
|
||||
def detect_checkpoint_function(train_func):
|
||||
func_args = inspect.getfullargspec(train_func).args
|
||||
use_checkpoint = "checkpoint" in func_args
|
||||
return use_checkpoint
|
||||
|
||||
|
||||
def wrap_function(train_func):
|
||||
class ImplicitFunc(FunctionRunner):
|
||||
def _trainable_func(self, config, reporter):
|
||||
def _trainable_func(self, config, reporter, checkpoint):
|
||||
func_args = inspect.getfullargspec(train_func).args
|
||||
use_track = ("reporter" not in func_args and len(func_args) == 1)
|
||||
if use_track:
|
||||
if len(func_args) > 1: # more arguments than just the config
|
||||
if "reporter" not in func_args and (
|
||||
"checkpoint" not in func_args):
|
||||
raise ValueError(
|
||||
"Unknown argument found in the Trainable function. "
|
||||
"Arguments other than the 'config' arg must be one "
|
||||
"of ['reporter', 'checkpoint']. Found: {}".format(
|
||||
func_args))
|
||||
use_reporter = "reporter" in func_args
|
||||
use_checkpoint = "checkpoint" in func_args
|
||||
if not use_checkpoint and not use_reporter:
|
||||
logger.warning(
|
||||
"Function checkpointing is disabled. This may result in "
|
||||
"unexpected behavior when using checkpointing features or "
|
||||
"certain schedulers. To enable, set the train function "
|
||||
"arguments to be `func(config, checkpoint)`.")
|
||||
output = train_func(config)
|
||||
elif use_checkpoint:
|
||||
output = train_func(config, checkpoint=checkpoint)
|
||||
else:
|
||||
output = train_func(config, reporter)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user