diff --git a/python/ray/tune/checkpoint_manager.py b/python/ray/tune/checkpoint_manager.py index 9b0aa619a..a483adf8c 100644 --- a/python/ray/tune/checkpoint_manager.py +++ b/python/ray/tune/checkpoint_manager.py @@ -1,5 +1,6 @@ # coding: utf-8 import heapq +import gc import logging from ray.tune.result import TRAINING_ITERATION @@ -104,6 +105,13 @@ class CheckpointManager: def newest_memory_checkpoint(self): return self._newest_memory_checkpoint + def replace_newest_memory_checkpoint(self, new_checkpoint): + # Forcibly remove the memory checkpoint + del self._newest_memory_checkpoint + # Apparently avoids memory leaks on k8s/k3s/pods + gc.collect() + self._newest_memory_checkpoint = new_checkpoint + def on_checkpoint(self, checkpoint): """Starts tracking checkpoint metadata on checkpoint. @@ -115,9 +123,7 @@ class CheckpointManager: checkpoint (Checkpoint): Trial state checkpoint. """ if checkpoint.storage == Checkpoint.MEMORY: - # Forcibly remove the memory checkpoint - del self._newest_memory_checkpoint - self._newest_memory_checkpoint = checkpoint + self.replace_newest_memory_checkpoint(checkpoint) return old_checkpoint = self.newest_persistent_checkpoint diff --git a/python/ray/tune/tests/test_trial_scheduler_pbt.py b/python/ray/tune/tests/test_trial_scheduler_pbt.py index bfa69a393..f94b8251b 100644 --- a/python/ray/tune/tests/test_trial_scheduler_pbt.py +++ b/python/ray/tune/tests/test_trial_scheduler_pbt.py @@ -12,6 +12,8 @@ from ray.tune import Trainable from ray.tune.ray_trial_executor import RayTrialExecutor from ray.tune.schedulers import PopulationBasedTraining +MB = 1024**2 + class MockParam(object): def __init__(self, params): @@ -26,7 +28,7 @@ class MockParam(object): class PopulationBasedTrainingMemoryTest(unittest.TestCase): def setUp(self): - ray.init(num_cpus=1) + ray.init(num_cpus=1, object_store_memory=100 * MB) def tearDown(self): ray.shutdown() @@ -36,7 +38,7 @@ class PopulationBasedTrainingMemoryTest(unittest.TestCase): def setup(self, config): # Make sure this is large enough so ray uses object store # instead of in-process store. - self.large_object = random.getrandbits(int(10e7)) + self.large_object = random.getrandbits(int(10e6)) self.iter = 0 self.a = config["a"] @@ -58,7 +60,7 @@ class PopulationBasedTrainingMemoryTest(unittest.TestCase): class CustomExecutor(RayTrialExecutor): def save(self, *args, **kwargs): checkpoint = super(CustomExecutor, self).save(*args, **kwargs) - assert len(ray.objects()) <= 10 + assert len(ray.objects()) <= 12 return checkpoint param_a = MockParam([1, -1])