diff --git a/python/ray/tune/checkpoint_manager.py b/python/ray/tune/checkpoint_manager.py index df50eb185..9b0aa619a 100644 --- a/python/ray/tune/checkpoint_manager.py +++ b/python/ray/tune/checkpoint_manager.py @@ -44,6 +44,9 @@ class Checkpoint: return isinstance(self.value, str) return self.storage == Checkpoint.MEMORY + def __repr__(self): + return f"Checkpoint({self.storage}, {self.value})" + class QueueItem: def __init__(self, priority, value): @@ -53,6 +56,9 @@ class QueueItem: def __lt__(self, other): return self.priority < other.priority + def __repr__(self): + return f"QueueItem({repr(self.value)})" + class CheckpointManager: """Manages checkpoints on the driver for a trial.""" @@ -82,7 +88,7 @@ class CheckpointManager: self.delete = delete_fn self.newest_persistent_checkpoint = Checkpoint(Checkpoint.PERSISTENT, None) - self.newest_memory_checkpoint = Checkpoint(Checkpoint.MEMORY, None) + self._newest_memory_checkpoint = Checkpoint(Checkpoint.MEMORY, None) self._best_checkpoints = [] self._membership = set() @@ -94,6 +100,10 @@ class CheckpointManager: key=lambda c: c.result.get(TRAINING_ITERATION, -1)) return newest_checkpoint + @property + def newest_memory_checkpoint(self): + return self._newest_memory_checkpoint + def on_checkpoint(self, checkpoint): """Starts tracking checkpoint metadata on checkpoint. @@ -105,7 +115,9 @@ class CheckpointManager: checkpoint (Checkpoint): Trial state checkpoint. """ if checkpoint.storage == Checkpoint.MEMORY: - self.newest_memory_checkpoint = checkpoint + # Forcibly remove the memory checkpoint + del self._newest_memory_checkpoint + self._newest_memory_checkpoint = checkpoint return old_checkpoint = self.newest_persistent_checkpoint @@ -151,6 +163,9 @@ class CheckpointManager: def __getstate__(self): state = self.__dict__.copy() + # Avoid serializing the memory checkpoint. + state["_newest_memory_checkpoint"] = Checkpoint( + Checkpoint.MEMORY, None) # Avoid serializing lambda since it may capture cyclical dependencies. state.pop("delete") return state diff --git a/python/ray/tune/tests/test_trial_scheduler_pbt.py b/python/ray/tune/tests/test_trial_scheduler_pbt.py index 09831fa01..bfa69a393 100644 --- a/python/ray/tune/tests/test_trial_scheduler_pbt.py +++ b/python/ray/tune/tests/test_trial_scheduler_pbt.py @@ -8,6 +8,8 @@ import time import ray from ray import tune +from ray.tune import Trainable +from ray.tune.ray_trial_executor import RayTrialExecutor from ray.tune.schedulers import PopulationBasedTraining @@ -22,6 +24,67 @@ class MockParam(object): return val +class PopulationBasedTrainingMemoryTest(unittest.TestCase): + def setUp(self): + ray.init(num_cpus=1) + + def tearDown(self): + ray.shutdown() + + def testMemoryCheckpointFree(self): + class MyTrainable(Trainable): + def setup(self, config): + # Make sure this is large enough so ray uses object store + # instead of in-process store. + self.large_object = random.getrandbits(int(10e7)) + self.iter = 0 + self.a = config["a"] + + def step(self): + self.iter += 1 + return {"metric": self.iter + self.a} + + def save_checkpoint(self, checkpoint_dir): + file_path = os.path.join(checkpoint_dir, "model.mock") + + with open(file_path, "wb") as fp: + pickle.dump((self.large_object, self.iter, self.a), fp) + return file_path + + def load_checkpoint(self, path): + with open(path, "rb") as fp: + self.large_object, self.iter, self.a = pickle.load(fp) + + class CustomExecutor(RayTrialExecutor): + def save(self, *args, **kwargs): + checkpoint = super(CustomExecutor, self).save(*args, **kwargs) + assert len(ray.objects()) <= 10 + return checkpoint + + param_a = MockParam([1, -1]) + + pbt = PopulationBasedTraining( + time_attr="training_iteration", + metric="metric", + mode="max", + perturbation_interval=1, + hyperparam_mutations={"b": [-1]}, + ) + + tune.run( + MyTrainable, + name="ray_demo", + scheduler=pbt, + stop={"training_iteration": 10}, + num_samples=3, + checkpoint_freq=1, + fail_fast=True, + config={"a": tune.sample_from(lambda _: param_a())}, + trial_executor=CustomExecutor( + queue_trials=False, reuse_actors=False), + ) + + class PopulationBasedTrainingSynchTest(unittest.TestCase): def setUp(self): ray.init(num_cpus=2)