mirror of
https://github.com/wassname/ray.git
synced 2026-06-27 23:39:37 +08:00
[tune] Fix restoration for function API PBT (#9853)
This commit is contained in:
committed by
SangBin Cho
parent
ea1ac15da0
commit
a96ddec358
@@ -5,6 +5,7 @@ import inspect
|
||||
import shutil
|
||||
import threading
|
||||
import traceback
|
||||
import uuid
|
||||
|
||||
from six.moves import queue
|
||||
|
||||
@@ -22,6 +23,84 @@ RESULT_FETCH_TIMEOUT = 0.2
|
||||
ERROR_REPORT_TIMEOUT = 10
|
||||
ERROR_FETCH_TIMEOUT = 1
|
||||
|
||||
NULL_MARKER = ".null_marker"
|
||||
TEMP_MARKER = ".temp_marker"
|
||||
|
||||
|
||||
class FuncCheckpointUtil:
|
||||
"""Utility class holding various function-checkpointing mechanisms.
|
||||
|
||||
The two special modes are "null" and "temporary" checkpoints.
|
||||
|
||||
*Null Checkpoints*
|
||||
-------------------
|
||||
|
||||
Null checkpoints are generated when a trial is being saved
|
||||
but a checkpoint has not been created. In this case,
|
||||
a marker is set, indicating that the checkpoint is null.
|
||||
|
||||
When restoring from an null checkpoint, the FunctionRunner
|
||||
will detect this and *not* restore from any checkpoint at all.
|
||||
|
||||
*Temporary Checkpoints*
|
||||
-----------------------
|
||||
|
||||
Temporary checkpoints are generated when a trial is being
|
||||
restored from a prior in-memory checkpoint. In this case, a marker
|
||||
will be set indicating that a checkpoint is temporary.
|
||||
|
||||
Upon termination of the trial, temporary checkpoints
|
||||
will be removed. We cannot remove them any earlier because
|
||||
the loading of checkpoints is non-deterministic.
|
||||
|
||||
If "save" is called on a trial whose most recent checkpoint
|
||||
is temporary, "create_perm_checkpoint" will be called. This
|
||||
copies the temporary checkpoint to a permanent checkpoint directory.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def mk_null_checkpoint_dir(logdir):
|
||||
"""Indicate that the given checkpoint doesn't have state."""
|
||||
checkpoint_dir = TrainableUtil.make_checkpoint_dir(
|
||||
logdir, index=-1, override=True)
|
||||
open(os.path.join(checkpoint_dir, NULL_MARKER), "a").close()
|
||||
return checkpoint_dir
|
||||
|
||||
@staticmethod
|
||||
def mk_temp_checkpoint_dir(logdir):
|
||||
"""Indicate that the checkpoint is only for restoration."""
|
||||
temporary_checkpoint_dir = TrainableUtil.make_checkpoint_dir(
|
||||
logdir, index="tmp" + uuid.uuid4().hex[:6], override=True)
|
||||
open(os.path.join(temporary_checkpoint_dir, TEMP_MARKER), "a").close()
|
||||
return temporary_checkpoint_dir
|
||||
|
||||
@staticmethod
|
||||
def is_temp_checkpoint_dir(checkpoint_dir):
|
||||
"""Checks for the temp checkpoint marker."""
|
||||
return os.path.exists(os.path.join(checkpoint_dir, TEMP_MARKER))
|
||||
|
||||
@staticmethod
|
||||
def is_null_checkpoint(checkpoint_dir):
|
||||
"""Checks for the empty checkpoint marker."""
|
||||
return os.path.exists(os.path.join(checkpoint_dir, NULL_MARKER))
|
||||
|
||||
@staticmethod
|
||||
def create_perm_checkpoint(checkpoint_dir, logdir, step):
|
||||
"""Copies temporary checkpoint to a permanent checkpoint directory."""
|
||||
checkpoint_dir = os.path.abspath(checkpoint_dir)
|
||||
temporary_marker = os.path.join(checkpoint_dir, TEMP_MARKER)
|
||||
assert os.path.exists(temporary_marker), (
|
||||
"Should not be calling this method on a permanent checkpoint.")
|
||||
os.remove(temporary_marker)
|
||||
perm_checkpoint_dir = TrainableUtil.make_checkpoint_dir(
|
||||
logdir, index=step, override=True)
|
||||
shutil.rmtree(perm_checkpoint_dir)
|
||||
|
||||
shutil.copytree(checkpoint_dir, perm_checkpoint_dir)
|
||||
assert not os.path.exists(
|
||||
os.path.join(perm_checkpoint_dir, TEMP_MARKER))
|
||||
return perm_checkpoint_dir
|
||||
|
||||
|
||||
class StatusReporter:
|
||||
"""Object passed into your function that you can report status through.
|
||||
@@ -44,7 +123,7 @@ class StatusReporter:
|
||||
self._trial_name = trial_name
|
||||
self._trial_id = trial_id
|
||||
self._logdir = logdir
|
||||
self._last_checkpoint = {}
|
||||
self._last_checkpoint = None
|
||||
self._fresh_checkpoint = False
|
||||
|
||||
def __call__(self, **kwargs):
|
||||
@@ -83,13 +162,18 @@ class StatusReporter:
|
||||
# resume training.
|
||||
self._continue_semaphore.acquire()
|
||||
|
||||
def make_checkpoint_dir(self, step=None):
|
||||
def make_checkpoint_dir(self, step):
|
||||
checkpoint_dir = TrainableUtil.make_checkpoint_dir(
|
||||
self.logdir, index=step)
|
||||
logger.debug("Making checkpoint dir at %s", checkpoint_dir)
|
||||
return checkpoint_dir
|
||||
|
||||
def save_checkpoint(self, checkpoint):
|
||||
def set_checkpoint(self, checkpoint, is_new=True):
|
||||
"""Sets the checkpoint to be returned upon get_checkpoint.
|
||||
|
||||
If this is a "new" checkpoint, it will notify Tune
|
||||
(via has_new_checkpoint). Otherwise, it will NOT notify Tune.
|
||||
"""
|
||||
if isinstance(checkpoint, str):
|
||||
try:
|
||||
TrainableUtil.find_checkpoint_dir(checkpoint)
|
||||
@@ -98,7 +182,8 @@ class StatusReporter:
|
||||
"make_checkpoint_dir.")
|
||||
raise
|
||||
self._last_checkpoint = checkpoint
|
||||
self._fresh_checkpoint = True
|
||||
if is_new:
|
||||
self._fresh_checkpoint = True
|
||||
|
||||
def has_new_checkpoint(self):
|
||||
return self._fresh_checkpoint
|
||||
@@ -189,7 +274,7 @@ class FunctionRunner(Trainable):
|
||||
session.init(self._status_reporter)
|
||||
self._runner = None
|
||||
self._restore_tmpdir = None
|
||||
self.default_checkpoint_dir = None
|
||||
self.temp_checkpoint_dir = None
|
||||
|
||||
def _trainable_func(self):
|
||||
"""Subclasses can override this to set the trainable func."""
|
||||
@@ -282,11 +367,6 @@ class FunctionRunner(Trainable):
|
||||
def execute(self, fn):
|
||||
return fn(self)
|
||||
|
||||
def create_default_checkpoint_dir(self):
|
||||
self.default_checkpoint_dir = TrainableUtil.make_checkpoint_dir(
|
||||
self.logdir, index="default")
|
||||
return self.default_checkpoint_dir
|
||||
|
||||
def save(self, checkpoint_path=None):
|
||||
if checkpoint_path:
|
||||
raise ValueError(
|
||||
@@ -297,12 +377,31 @@ class FunctionRunner(Trainable):
|
||||
|
||||
if not checkpoint:
|
||||
state.update(iteration=0, timesteps_total=0, episodes_total=0)
|
||||
parent_dir = self.create_default_checkpoint_dir()
|
||||
# We drop a marker here to indicate that the checkpoint is empty
|
||||
checkpoint = FuncCheckpointUtil.mk_null_checkpoint_dir(self.logdir)
|
||||
parent_dir = checkpoint
|
||||
elif isinstance(checkpoint, dict):
|
||||
parent_dir = TrainableUtil.make_checkpoint_dir(
|
||||
self.logdir, index=self.training_iteration)
|
||||
else:
|
||||
elif isinstance(checkpoint, str):
|
||||
parent_dir = TrainableUtil.find_checkpoint_dir(checkpoint)
|
||||
# When the trainable is restored, a temporary checkpoint
|
||||
# is created. However, when saved, it should become permanent.
|
||||
# Ideally, there are no save calls upon a temporary
|
||||
# checkpoint, but certain schedulers might.
|
||||
if FuncCheckpointUtil.is_temp_checkpoint_dir(parent_dir):
|
||||
relative_path = os.path.relpath(checkpoint, parent_dir)
|
||||
parent_dir = FuncCheckpointUtil.create_perm_checkpoint(
|
||||
checkpoint_dir=parent_dir,
|
||||
logdir=self.logdir,
|
||||
step=self.training_iteration)
|
||||
checkpoint = os.path.abspath(
|
||||
os.path.join(parent_dir, relative_path))
|
||||
else:
|
||||
raise ValueError("Provided checkpoint was expected to have "
|
||||
"type (str, dict). Got {}.".format(
|
||||
type(checkpoint)))
|
||||
|
||||
checkpoint_path = TrainableUtil.process_checkpoint(
|
||||
checkpoint, parent_dir, state)
|
||||
return checkpoint_path
|
||||
@@ -316,17 +415,20 @@ class FunctionRunner(Trainable):
|
||||
# This should be removed once Trainables are refactored.
|
||||
if "tune_checkpoint_path" in checkpoint:
|
||||
del checkpoint["tune_checkpoint_path"]
|
||||
self._status_reporter.save_checkpoint(checkpoint)
|
||||
# If there does not exist a checkpoint, we will not restore
|
||||
# from it and will remove the marker.
|
||||
if FuncCheckpointUtil.is_null_checkpoint(checkpoint):
|
||||
return
|
||||
# By informing that this checkpoint is not new,
|
||||
# we will not return the checkpoint path
|
||||
# as a new checkpoint.
|
||||
self._status_reporter.set_checkpoint(checkpoint, is_new=False)
|
||||
|
||||
def restore_from_object(self, obj):
|
||||
if self.default_checkpoint_dir is not None and os.exists(
|
||||
self.default_checkpoint_dir):
|
||||
shutil.rmtree(self.default_checkpoint_dir)
|
||||
logger.debug("Clearing default checkpoint: %s",
|
||||
self.default_checkpoint_dir)
|
||||
|
||||
checkpoint_dir = self.create_default_checkpoint_dir()
|
||||
checkpoint_path = TrainableUtil.create_from_pickle(obj, checkpoint_dir)
|
||||
self.temp_checkpoint_dir = (FuncCheckpointUtil.mk_temp_checkpoint_dir(
|
||||
self.logdir))
|
||||
checkpoint_path = TrainableUtil.create_from_pickle(
|
||||
obj, self.temp_checkpoint_dir)
|
||||
self.restore(checkpoint_path)
|
||||
|
||||
def cleanup(self):
|
||||
@@ -340,6 +442,12 @@ class FunctionRunner(Trainable):
|
||||
self._report_thread_runner_error()
|
||||
session.shutdown()
|
||||
|
||||
if self.temp_checkpoint_dir is not None and os.path.exists(
|
||||
self.temp_checkpoint_dir):
|
||||
shutil.rmtree(self.temp_checkpoint_dir)
|
||||
logger.debug("Clearing temporary checkpoint: %s",
|
||||
self.temp_checkpoint_dir)
|
||||
|
||||
def _report_thread_runner_error(self, block=False):
|
||||
try:
|
||||
err_tb_str = self._error_queue.get(
|
||||
|
||||
@@ -98,12 +98,16 @@ def save_checkpoint(checkpoint):
|
||||
|
||||
|
||||
@contextmanager
|
||||
def checkpoint_dir(step=None):
|
||||
def checkpoint_dir(step):
|
||||
"""Returns a checkpoint dir inside a context.
|
||||
|
||||
Store any files related to restoring state within the
|
||||
provided checkpoint dir.
|
||||
|
||||
Args:
|
||||
step (int): Index for the checkpoint. Expected to be a
|
||||
monotonically increasing quantity.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
import os
|
||||
@@ -136,6 +140,9 @@ def checkpoint_dir(step=None):
|
||||
"""
|
||||
_session = get_session()
|
||||
|
||||
if step is None:
|
||||
raise ValueError("checkpoint_dir(step) must be provided - got None.")
|
||||
|
||||
if _session:
|
||||
_checkpoint_dir = _session.make_checkpoint_dir(step=step)
|
||||
else:
|
||||
@@ -144,7 +151,7 @@ def checkpoint_dir(step=None):
|
||||
yield _checkpoint_dir
|
||||
|
||||
if _session:
|
||||
_session.save_checkpoint(_checkpoint_dir)
|
||||
_session.set_checkpoint(_checkpoint_dir)
|
||||
|
||||
|
||||
def get_trial_dir():
|
||||
|
||||
@@ -1,36 +1,230 @@
|
||||
import json
|
||||
import os
|
||||
import shutil
|
||||
import tempfile
|
||||
import unittest
|
||||
|
||||
import ray
|
||||
from ray.rllib import _register_all
|
||||
|
||||
from ray import tune
|
||||
from ray.tune.function_runner import wrap_function
|
||||
from ray.tune.logger import NoopLogger
|
||||
from ray.tune.trainable import TrainableUtil
|
||||
from ray.tune.function_runner import wrap_function, FuncCheckpointUtil
|
||||
from ray.tune.result import TRAINING_ITERATION
|
||||
|
||||
|
||||
class FunctionApiTest(unittest.TestCase):
|
||||
def creator_generator(logdir):
|
||||
def logger_creator(config):
|
||||
return NoopLogger(config, logdir)
|
||||
|
||||
return logger_creator
|
||||
|
||||
|
||||
class FuncCheckpointUtilTest(unittest.TestCase):
|
||||
def setUp(self):
|
||||
ray.init(num_cpus=4, num_gpus=0, object_store_memory=150 * 1024 * 1024)
|
||||
self.logdir = tempfile.mkdtemp()
|
||||
|
||||
def tearDown(self):
|
||||
ray.shutdown()
|
||||
_register_all() # re-register the evicted objects
|
||||
shutil.rmtree(self.logdir)
|
||||
|
||||
def testEmptyCheckpoint(self):
|
||||
checkpoint_dir = FuncCheckpointUtil.mk_null_checkpoint_dir(self.logdir)
|
||||
assert FuncCheckpointUtil.is_null_checkpoint(checkpoint_dir)
|
||||
|
||||
def testTempCheckpointDir(self):
|
||||
checkpoint_dir = FuncCheckpointUtil.mk_temp_checkpoint_dir(self.logdir)
|
||||
assert FuncCheckpointUtil.is_temp_checkpoint_dir(checkpoint_dir)
|
||||
|
||||
def testConvertTempToPermanent(self):
|
||||
checkpoint_dir = FuncCheckpointUtil.mk_temp_checkpoint_dir(self.logdir)
|
||||
new_checkpoint_dir = FuncCheckpointUtil.create_perm_checkpoint(
|
||||
checkpoint_dir, self.logdir, step=4)
|
||||
assert new_checkpoint_dir == TrainableUtil.find_checkpoint_dir(
|
||||
new_checkpoint_dir)
|
||||
assert os.path.exists(new_checkpoint_dir)
|
||||
assert not FuncCheckpointUtil.is_temp_checkpoint_dir(
|
||||
new_checkpoint_dir)
|
||||
|
||||
tmp_checkpoint_dir = FuncCheckpointUtil.mk_temp_checkpoint_dir(
|
||||
self.logdir)
|
||||
assert tmp_checkpoint_dir != new_checkpoint_dir
|
||||
|
||||
|
||||
class FunctionCheckpointingTest(unittest.TestCase):
|
||||
def setUp(self):
|
||||
self.logdir = tempfile.mkdtemp()
|
||||
self.logger_creator = creator_generator(self.logdir)
|
||||
|
||||
def tearDown(self):
|
||||
shutil.rmtree(self.logdir)
|
||||
|
||||
def testCheckpointReuse(self):
|
||||
"""Test that repeated save/restore never reuses same checkpoint dir."""
|
||||
|
||||
def train(config, checkpoint_dir=None):
|
||||
if checkpoint_dir:
|
||||
count = sum("checkpoint-" in path
|
||||
for path in os.listdir(checkpoint_dir))
|
||||
assert count == 1, os.listdir(checkpoint_dir)
|
||||
|
||||
for step in range(20):
|
||||
with tune.checkpoint_dir(step=step) as checkpoint_dir:
|
||||
path = os.path.join(checkpoint_dir,
|
||||
"checkpoint-{}".format(step))
|
||||
open(path, "a").close()
|
||||
tune.report(test=step)
|
||||
|
||||
wrapped = wrap_function(train)
|
||||
checkpoint = None
|
||||
for i in range(5):
|
||||
new_trainable = wrapped(logger_creator=self.logger_creator)
|
||||
if checkpoint:
|
||||
new_trainable.restore(checkpoint)
|
||||
for i in range(2):
|
||||
result = new_trainable.train()
|
||||
checkpoint = new_trainable.save()
|
||||
new_trainable.stop()
|
||||
assert result[TRAINING_ITERATION] == 10
|
||||
|
||||
def testCheckpointReuseObject(self):
|
||||
"""Test that repeated save/restore never reuses same checkpoint dir."""
|
||||
|
||||
def train(config, checkpoint_dir=None):
|
||||
if checkpoint_dir:
|
||||
count = sum("checkpoint-" in path
|
||||
for path in os.listdir(checkpoint_dir))
|
||||
assert count == 1, os.listdir(checkpoint_dir)
|
||||
|
||||
for step in range(20):
|
||||
with tune.checkpoint_dir(step=step) as checkpoint_dir:
|
||||
path = os.path.join(checkpoint_dir,
|
||||
"checkpoint-{}".format(step))
|
||||
open(path, "a").close()
|
||||
tune.report(test=step)
|
||||
|
||||
wrapped = wrap_function(train)
|
||||
checkpoint = None
|
||||
for i in range(5):
|
||||
new_trainable = wrapped(logger_creator=self.logger_creator)
|
||||
if checkpoint:
|
||||
new_trainable.restore_from_object(checkpoint)
|
||||
for i in range(2):
|
||||
result = new_trainable.train()
|
||||
checkpoint = new_trainable.save_to_object()
|
||||
new_trainable.stop()
|
||||
self.assertTrue(result[TRAINING_ITERATION] == 10)
|
||||
|
||||
def testCheckpointReuseObjectWithoutTraining(self):
|
||||
"""Test that repeated save/restore never reuses same checkpoint dir."""
|
||||
|
||||
def train(config, checkpoint_dir=None):
|
||||
if checkpoint_dir:
|
||||
count = sum("checkpoint-" in path
|
||||
for path in os.listdir(checkpoint_dir))
|
||||
assert count == 1, os.listdir(checkpoint_dir)
|
||||
|
||||
for step in range(20):
|
||||
with tune.checkpoint_dir(step=step) as checkpoint_dir:
|
||||
path = os.path.join(checkpoint_dir,
|
||||
"checkpoint-{}".format(step))
|
||||
open(path, "a").close()
|
||||
tune.report(test=step)
|
||||
|
||||
wrapped = wrap_function(train)
|
||||
new_trainable = wrapped(logger_creator=self.logger_creator)
|
||||
for i in range(2):
|
||||
result = new_trainable.train()
|
||||
checkpoint = new_trainable.save_to_object()
|
||||
new_trainable.stop()
|
||||
|
||||
new_trainable2 = wrapped(logger_creator=self.logger_creator)
|
||||
new_trainable2.restore_from_object(checkpoint)
|
||||
new_trainable2.stop()
|
||||
|
||||
new_trainable2 = wrapped(logger_creator=self.logger_creator)
|
||||
new_trainable2.restore_from_object(checkpoint)
|
||||
result = new_trainable2.train()
|
||||
new_trainable2.stop()
|
||||
self.assertTrue(result[TRAINING_ITERATION] == 3)
|
||||
|
||||
def testReuseNullCheckpoint(self):
|
||||
def train(config, checkpoint_dir=None):
|
||||
assert not checkpoint_dir
|
||||
for step in range(10):
|
||||
tune.report(test=step)
|
||||
|
||||
# Create checkpoint
|
||||
wrapped = wrap_function(train)
|
||||
checkpoint = None
|
||||
new_trainable = wrapped(logger_creator=self.logger_creator)
|
||||
new_trainable.train()
|
||||
checkpoint = new_trainable.save()
|
||||
new_trainable.stop()
|
||||
|
||||
# Use the checkpoint a couple of times
|
||||
for i in range(3):
|
||||
new_trainable = wrapped(logger_creator=self.logger_creator)
|
||||
new_trainable.restore(checkpoint)
|
||||
new_trainable.stop()
|
||||
|
||||
# Make sure the result is still good
|
||||
new_trainable = wrapped(logger_creator=self.logger_creator)
|
||||
new_trainable.restore(checkpoint)
|
||||
result = new_trainable.train()
|
||||
checkpoint = new_trainable.save()
|
||||
new_trainable.stop()
|
||||
self.assertTrue(result[TRAINING_ITERATION] == 1)
|
||||
|
||||
def testMultipleNullCheckpoints(self):
|
||||
def train(config, checkpoint_dir=None):
|
||||
assert not checkpoint_dir
|
||||
for step in range(10):
|
||||
tune.report(test=step)
|
||||
|
||||
wrapped = wrap_function(train)
|
||||
checkpoint = None
|
||||
for i in range(5):
|
||||
new_trainable = wrapped(logger_creator=self.logger_creator)
|
||||
if checkpoint:
|
||||
new_trainable.restore(checkpoint)
|
||||
result = new_trainable.train()
|
||||
checkpoint = new_trainable.save()
|
||||
new_trainable.stop()
|
||||
self.assertTrue(result[TRAINING_ITERATION] == 1)
|
||||
|
||||
def testMultipleNullMemoryCheckpoints(self):
|
||||
def train(config, checkpoint_dir=None):
|
||||
assert not checkpoint_dir
|
||||
for step in range(10):
|
||||
tune.report(test=step)
|
||||
|
||||
wrapped = wrap_function(train)
|
||||
checkpoint = None
|
||||
for i in range(5):
|
||||
new_trainable = wrapped(logger_creator=self.logger_creator)
|
||||
if checkpoint:
|
||||
new_trainable.restore_from_object(checkpoint)
|
||||
result = new_trainable.train()
|
||||
checkpoint = new_trainable.save_to_object()
|
||||
new_trainable.stop()
|
||||
assert result[TRAINING_ITERATION] == 1
|
||||
|
||||
def testFunctionNoCheckpointing(self):
|
||||
def train(config, checkpoint_dir=None):
|
||||
for i in range(10):
|
||||
tune.report(test=i)
|
||||
if checkpoint_dir:
|
||||
assert os.path.exists(checkpoint_dir)
|
||||
for step in range(10):
|
||||
tune.report(test=step)
|
||||
|
||||
wrapped = wrap_function(train)
|
||||
|
||||
new_trainable = wrapped()
|
||||
new_trainable = wrapped(logger_creator=self.logger_creator)
|
||||
result = new_trainable.train()
|
||||
checkpoint = new_trainable.save()
|
||||
new_trainable.stop()
|
||||
|
||||
new_trainable2 = wrapped()
|
||||
new_trainable2 = wrapped(logger_creator=self.logger_creator)
|
||||
new_trainable2.restore(checkpoint)
|
||||
result = new_trainable2.train()
|
||||
self.assertEquals(result[TRAINING_ITERATION], 1)
|
||||
@@ -41,6 +235,8 @@ class FunctionApiTest(unittest.TestCase):
|
||||
"""This tests that save and restore are commutative."""
|
||||
|
||||
def train(config, checkpoint_dir=None):
|
||||
if checkpoint_dir:
|
||||
assert os.path.exists(checkpoint_dir)
|
||||
for step in range(10):
|
||||
if step % 3 == 0:
|
||||
with tune.checkpoint_dir(step=step) as checkpoint_dir:
|
||||
@@ -51,18 +247,59 @@ class FunctionApiTest(unittest.TestCase):
|
||||
|
||||
wrapped = wrap_function(train)
|
||||
|
||||
new_trainable = wrapped()
|
||||
new_trainable = wrapped(logger_creator=self.logger_creator)
|
||||
new_trainable.train()
|
||||
checkpoint_obj = new_trainable.save_to_object()
|
||||
new_trainable.restore_from_object(checkpoint_obj)
|
||||
checkpoint = new_trainable.save()
|
||||
|
||||
new_trainable.stop()
|
||||
|
||||
new_trainable2 = wrapped()
|
||||
new_trainable2 = wrapped(logger_creator=self.logger_creator)
|
||||
new_trainable2.restore(checkpoint)
|
||||
new_trainable2.train()
|
||||
new_trainable2.stop()
|
||||
|
||||
def testFunctionImmediateSave(self):
|
||||
"""This tests that save and restore are commutative."""
|
||||
|
||||
def train(config, checkpoint_dir=None):
|
||||
if checkpoint_dir:
|
||||
assert os.path.exists(checkpoint_dir)
|
||||
for step in range(10):
|
||||
with tune.checkpoint_dir(step=step) as checkpoint_dir:
|
||||
print(checkpoint_dir)
|
||||
path = os.path.join(checkpoint_dir,
|
||||
"checkpoint-{}".format(step))
|
||||
open(path, "w").close()
|
||||
tune.report(test=step)
|
||||
|
||||
wrapped = wrap_function(train)
|
||||
new_trainable = wrapped(logger_creator=self.logger_creator)
|
||||
new_trainable.train()
|
||||
new_trainable.train()
|
||||
checkpoint_obj = new_trainable.save_to_object()
|
||||
new_trainable.stop()
|
||||
|
||||
new_trainable2 = wrapped(logger_creator=self.logger_creator)
|
||||
new_trainable2.restore_from_object(checkpoint_obj)
|
||||
checkpoint_obj = new_trainable2.save_to_object()
|
||||
new_trainable2.train()
|
||||
result = new_trainable2.train()
|
||||
assert sum("tmp" in path for path in os.listdir(self.logdir)) == 1
|
||||
new_trainable2.stop()
|
||||
assert sum("tmp" in path for path in os.listdir(self.logdir)) == 0
|
||||
assert result[TRAINING_ITERATION] == 4
|
||||
|
||||
|
||||
class FunctionApiTest(unittest.TestCase):
|
||||
def setUp(self):
|
||||
ray.init(num_cpus=4, num_gpus=0, object_store_memory=150 * 1024 * 1024)
|
||||
|
||||
def tearDown(self):
|
||||
ray.shutdown()
|
||||
_register_all() # re-register the evicted objects
|
||||
|
||||
def testCheckpointFunctionAtEnd(self):
|
||||
def train(config, checkpoint_dir=False):
|
||||
for i in range(10):
|
||||
@@ -90,12 +327,12 @@ class FunctionApiTest(unittest.TestCase):
|
||||
def testVariousCheckpointFunctionAtEnd(self):
|
||||
def train(config, checkpoint_dir=False):
|
||||
for i in range(10):
|
||||
with tune.checkpoint_dir() as checkpoint_dir:
|
||||
with tune.checkpoint_dir(step=i) as checkpoint_dir:
|
||||
checkpoint_path = os.path.join(checkpoint_dir, "ckpt.log")
|
||||
with open(checkpoint_path, "w") as f:
|
||||
f.write("hello")
|
||||
tune.report(test=i)
|
||||
with tune.checkpoint_dir() as checkpoint_dir:
|
||||
with tune.checkpoint_dir(step=i) as checkpoint_dir:
|
||||
checkpoint_path = os.path.join(checkpoint_dir, "ckpt.log2")
|
||||
with open(checkpoint_path, "w") as f:
|
||||
f.write("goodbye")
|
||||
@@ -164,7 +401,7 @@ class FunctionApiTest(unittest.TestCase):
|
||||
for i in range(itr, 10):
|
||||
if i == 5 and not restored:
|
||||
raise Exception("try to fail me")
|
||||
with tune.checkpoint_dir() as checkpoint_dir:
|
||||
with tune.checkpoint_dir(step=itr) as checkpoint_dir:
|
||||
checkpoint_path = os.path.join(checkpoint_dir, "ckpt.log")
|
||||
with open(checkpoint_path, "w") as f:
|
||||
f.write(str(i))
|
||||
|
||||
@@ -109,19 +109,23 @@ class TrainableUtil:
|
||||
return checkpoint_dir
|
||||
|
||||
@staticmethod
|
||||
def make_checkpoint_dir(checkpoint_dir, index):
|
||||
def make_checkpoint_dir(checkpoint_dir, index, override=False):
|
||||
"""Creates a checkpoint directory within the provided path.
|
||||
|
||||
Args:
|
||||
checkpoint_dir (str): Path to checkpoint directory.
|
||||
index (str): A subdirectory will be created
|
||||
at the checkpoint directory named 'checkpoint_{index}'.
|
||||
override (bool): Deletes checkpoint_dir before creating
|
||||
a new one.
|
||||
"""
|
||||
suffix = "checkpoint"
|
||||
if index is not None:
|
||||
suffix += "_{}".format(index)
|
||||
checkpoint_dir = os.path.join(checkpoint_dir, suffix)
|
||||
|
||||
if override and os.path.exists(checkpoint_dir):
|
||||
shutil.rmtree(checkpoint_dir)
|
||||
os.makedirs(checkpoint_dir, exist_ok=True)
|
||||
# Drop marker in directory to identify it as a checkpoint dir.
|
||||
open(os.path.join(checkpoint_dir, ".is_checkpoint"), "a").close()
|
||||
|
||||
Reference in New Issue
Block a user