[Tune] Fix Memory Leak (#10989)

Co-authored-by: Richard Liaw <rliaw@berkeley.edu>
This commit is contained in:
Amog Kamsetty
2020-09-24 20:26:55 -07:00
committed by GitHub
parent a26394d184
commit ee85cb31a5
2 changed files with 80 additions and 2 deletions
+17 -2
View File
@@ -44,6 +44,9 @@ class Checkpoint:
return isinstance(self.value, str)
return self.storage == Checkpoint.MEMORY
def __repr__(self):
return f"Checkpoint({self.storage}, {self.value})"
class QueueItem:
def __init__(self, priority, value):
@@ -53,6 +56,9 @@ class QueueItem:
def __lt__(self, other):
return self.priority < other.priority
def __repr__(self):
return f"QueueItem({repr(self.value)})"
class CheckpointManager:
"""Manages checkpoints on the driver for a trial."""
@@ -82,7 +88,7 @@ class CheckpointManager:
self.delete = delete_fn
self.newest_persistent_checkpoint = Checkpoint(Checkpoint.PERSISTENT,
None)
self.newest_memory_checkpoint = Checkpoint(Checkpoint.MEMORY, None)
self._newest_memory_checkpoint = Checkpoint(Checkpoint.MEMORY, None)
self._best_checkpoints = []
self._membership = set()
@@ -94,6 +100,10 @@ class CheckpointManager:
key=lambda c: c.result.get(TRAINING_ITERATION, -1))
return newest_checkpoint
@property
def newest_memory_checkpoint(self):
return self._newest_memory_checkpoint
def on_checkpoint(self, checkpoint):
"""Starts tracking checkpoint metadata on checkpoint.
@@ -105,7 +115,9 @@ class CheckpointManager:
checkpoint (Checkpoint): Trial state checkpoint.
"""
if checkpoint.storage == Checkpoint.MEMORY:
self.newest_memory_checkpoint = checkpoint
# Forcibly remove the memory checkpoint
del self._newest_memory_checkpoint
self._newest_memory_checkpoint = checkpoint
return
old_checkpoint = self.newest_persistent_checkpoint
@@ -151,6 +163,9 @@ class CheckpointManager:
def __getstate__(self):
state = self.__dict__.copy()
# Avoid serializing the memory checkpoint.
state["_newest_memory_checkpoint"] = Checkpoint(
Checkpoint.MEMORY, None)
# Avoid serializing lambda since it may capture cyclical dependencies.
state.pop("delete")
return state
@@ -8,6 +8,8 @@ import time
import ray
from ray import tune
from ray.tune import Trainable
from ray.tune.ray_trial_executor import RayTrialExecutor
from ray.tune.schedulers import PopulationBasedTraining
@@ -22,6 +24,67 @@ class MockParam(object):
return val
class PopulationBasedTrainingMemoryTest(unittest.TestCase):
def setUp(self):
ray.init(num_cpus=1)
def tearDown(self):
ray.shutdown()
def testMemoryCheckpointFree(self):
class MyTrainable(Trainable):
def setup(self, config):
# Make sure this is large enough so ray uses object store
# instead of in-process store.
self.large_object = random.getrandbits(int(10e7))
self.iter = 0
self.a = config["a"]
def step(self):
self.iter += 1
return {"metric": self.iter + self.a}
def save_checkpoint(self, checkpoint_dir):
file_path = os.path.join(checkpoint_dir, "model.mock")
with open(file_path, "wb") as fp:
pickle.dump((self.large_object, self.iter, self.a), fp)
return file_path
def load_checkpoint(self, path):
with open(path, "rb") as fp:
self.large_object, self.iter, self.a = pickle.load(fp)
class CustomExecutor(RayTrialExecutor):
def save(self, *args, **kwargs):
checkpoint = super(CustomExecutor, self).save(*args, **kwargs)
assert len(ray.objects()) <= 10
return checkpoint
param_a = MockParam([1, -1])
pbt = PopulationBasedTraining(
time_attr="training_iteration",
metric="metric",
mode="max",
perturbation_interval=1,
hyperparam_mutations={"b": [-1]},
)
tune.run(
MyTrainable,
name="ray_demo",
scheduler=pbt,
stop={"training_iteration": 10},
num_samples=3,
checkpoint_freq=1,
fail_fast=True,
config={"a": tune.sample_from(lambda _: param_a())},
trial_executor=CustomExecutor(
queue_trials=False, reuse_actors=False),
)
class PopulationBasedTrainingSynchTest(unittest.TestCase):
def setUp(self):
ray.init(num_cpus=2)