mirror of
https://github.com/wassname/ray.git
synced 2026-06-28 20:56:34 +08:00
[Tune] Fix Memory Leak (#10989)
Co-authored-by: Richard Liaw <rliaw@berkeley.edu>
This commit is contained in:
@@ -44,6 +44,9 @@ class Checkpoint:
|
||||
return isinstance(self.value, str)
|
||||
return self.storage == Checkpoint.MEMORY
|
||||
|
||||
def __repr__(self):
|
||||
return f"Checkpoint({self.storage}, {self.value})"
|
||||
|
||||
|
||||
class QueueItem:
|
||||
def __init__(self, priority, value):
|
||||
@@ -53,6 +56,9 @@ class QueueItem:
|
||||
def __lt__(self, other):
|
||||
return self.priority < other.priority
|
||||
|
||||
def __repr__(self):
|
||||
return f"QueueItem({repr(self.value)})"
|
||||
|
||||
|
||||
class CheckpointManager:
|
||||
"""Manages checkpoints on the driver for a trial."""
|
||||
@@ -82,7 +88,7 @@ class CheckpointManager:
|
||||
self.delete = delete_fn
|
||||
self.newest_persistent_checkpoint = Checkpoint(Checkpoint.PERSISTENT,
|
||||
None)
|
||||
self.newest_memory_checkpoint = Checkpoint(Checkpoint.MEMORY, None)
|
||||
self._newest_memory_checkpoint = Checkpoint(Checkpoint.MEMORY, None)
|
||||
self._best_checkpoints = []
|
||||
self._membership = set()
|
||||
|
||||
@@ -94,6 +100,10 @@ class CheckpointManager:
|
||||
key=lambda c: c.result.get(TRAINING_ITERATION, -1))
|
||||
return newest_checkpoint
|
||||
|
||||
@property
|
||||
def newest_memory_checkpoint(self):
|
||||
return self._newest_memory_checkpoint
|
||||
|
||||
def on_checkpoint(self, checkpoint):
|
||||
"""Starts tracking checkpoint metadata on checkpoint.
|
||||
|
||||
@@ -105,7 +115,9 @@ class CheckpointManager:
|
||||
checkpoint (Checkpoint): Trial state checkpoint.
|
||||
"""
|
||||
if checkpoint.storage == Checkpoint.MEMORY:
|
||||
self.newest_memory_checkpoint = checkpoint
|
||||
# Forcibly remove the memory checkpoint
|
||||
del self._newest_memory_checkpoint
|
||||
self._newest_memory_checkpoint = checkpoint
|
||||
return
|
||||
|
||||
old_checkpoint = self.newest_persistent_checkpoint
|
||||
@@ -151,6 +163,9 @@ class CheckpointManager:
|
||||
|
||||
def __getstate__(self):
|
||||
state = self.__dict__.copy()
|
||||
# Avoid serializing the memory checkpoint.
|
||||
state["_newest_memory_checkpoint"] = Checkpoint(
|
||||
Checkpoint.MEMORY, None)
|
||||
# Avoid serializing lambda since it may capture cyclical dependencies.
|
||||
state.pop("delete")
|
||||
return state
|
||||
|
||||
@@ -8,6 +8,8 @@ import time
|
||||
|
||||
import ray
|
||||
from ray import tune
|
||||
from ray.tune import Trainable
|
||||
from ray.tune.ray_trial_executor import RayTrialExecutor
|
||||
from ray.tune.schedulers import PopulationBasedTraining
|
||||
|
||||
|
||||
@@ -22,6 +24,67 @@ class MockParam(object):
|
||||
return val
|
||||
|
||||
|
||||
class PopulationBasedTrainingMemoryTest(unittest.TestCase):
|
||||
def setUp(self):
|
||||
ray.init(num_cpus=1)
|
||||
|
||||
def tearDown(self):
|
||||
ray.shutdown()
|
||||
|
||||
def testMemoryCheckpointFree(self):
|
||||
class MyTrainable(Trainable):
|
||||
def setup(self, config):
|
||||
# Make sure this is large enough so ray uses object store
|
||||
# instead of in-process store.
|
||||
self.large_object = random.getrandbits(int(10e7))
|
||||
self.iter = 0
|
||||
self.a = config["a"]
|
||||
|
||||
def step(self):
|
||||
self.iter += 1
|
||||
return {"metric": self.iter + self.a}
|
||||
|
||||
def save_checkpoint(self, checkpoint_dir):
|
||||
file_path = os.path.join(checkpoint_dir, "model.mock")
|
||||
|
||||
with open(file_path, "wb") as fp:
|
||||
pickle.dump((self.large_object, self.iter, self.a), fp)
|
||||
return file_path
|
||||
|
||||
def load_checkpoint(self, path):
|
||||
with open(path, "rb") as fp:
|
||||
self.large_object, self.iter, self.a = pickle.load(fp)
|
||||
|
||||
class CustomExecutor(RayTrialExecutor):
|
||||
def save(self, *args, **kwargs):
|
||||
checkpoint = super(CustomExecutor, self).save(*args, **kwargs)
|
||||
assert len(ray.objects()) <= 10
|
||||
return checkpoint
|
||||
|
||||
param_a = MockParam([1, -1])
|
||||
|
||||
pbt = PopulationBasedTraining(
|
||||
time_attr="training_iteration",
|
||||
metric="metric",
|
||||
mode="max",
|
||||
perturbation_interval=1,
|
||||
hyperparam_mutations={"b": [-1]},
|
||||
)
|
||||
|
||||
tune.run(
|
||||
MyTrainable,
|
||||
name="ray_demo",
|
||||
scheduler=pbt,
|
||||
stop={"training_iteration": 10},
|
||||
num_samples=3,
|
||||
checkpoint_freq=1,
|
||||
fail_fast=True,
|
||||
config={"a": tune.sample_from(lambda _: param_a())},
|
||||
trial_executor=CustomExecutor(
|
||||
queue_trials=False, reuse_actors=False),
|
||||
)
|
||||
|
||||
|
||||
class PopulationBasedTrainingSynchTest(unittest.TestCase):
|
||||
def setUp(self):
|
||||
ray.init(num_cpus=2)
|
||||
|
||||
Reference in New Issue
Block a user