mirror of
https://github.com/wassname/ray.git
synced 2026-06-27 21:38:18 +08:00
[tune] fix Tensorboard file descriptor leak (#12425)
This commit is contained in:
@@ -630,6 +630,8 @@ class TBXLoggerCallback(LoggerCallback):
|
||||
self._trial_result: Dict["Trial", Dict] = {}
|
||||
|
||||
def log_trial_start(self, trial: "Trial"):
|
||||
if trial in self._trial_writer:
|
||||
self._trial_writer[trial].close()
|
||||
trial.init_logdir()
|
||||
self._trial_writer[trial] = self._summary_writer_cls(
|
||||
trial.logdir, flush_secs=30)
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
import numpy as np
|
||||
import os
|
||||
import pickle
|
||||
import psutil
|
||||
import random
|
||||
import unittest
|
||||
import sys
|
||||
@@ -87,6 +88,84 @@ class PopulationBasedTrainingMemoryTest(unittest.TestCase):
|
||||
)
|
||||
|
||||
|
||||
class PopulationBasedTrainingFileDescriptorTest(unittest.TestCase):
|
||||
def setUp(self):
|
||||
ray.init(num_cpus=2)
|
||||
os.environ["TUNE_GLOBAL_CHECKPOINT_S"] = "0"
|
||||
|
||||
def tearDown(self):
|
||||
ray.shutdown()
|
||||
|
||||
def testFileFree(self):
|
||||
class MyTrainable(Trainable):
|
||||
def setup(self, config):
|
||||
self.iter = 0
|
||||
self.a = config["a"]
|
||||
|
||||
def step(self):
|
||||
self.iter += 1
|
||||
return {"metric": self.iter + self.a}
|
||||
|
||||
def save_checkpoint(self, checkpoint_dir):
|
||||
file_path = os.path.join(checkpoint_dir, "model.mock")
|
||||
|
||||
with open(file_path, "wb") as fp:
|
||||
pickle.dump((self.iter, self.a), fp)
|
||||
return file_path
|
||||
|
||||
def load_checkpoint(self, path):
|
||||
with open(path, "rb") as fp:
|
||||
self.iter, self.a = pickle.load(fp)
|
||||
|
||||
from ray.tune.callback import Callback
|
||||
|
||||
class FileCheck(Callback):
|
||||
def __init__(self, verbose=False):
|
||||
self.iter_ = 0
|
||||
self.process = psutil.Process()
|
||||
self.verbose = verbose
|
||||
|
||||
def on_trial_result(self, *args, **kwargs):
|
||||
self.iter_ += 1
|
||||
all_files = self.process.open_files()
|
||||
if self.verbose:
|
||||
print("Iteration", self.iter_)
|
||||
print("=" * 10)
|
||||
print("Number of objects: ", len(ray.objects()))
|
||||
print("Virtual Mem:", self.get_virt_mem() >> 30, "gb")
|
||||
print("File Descriptors:", len(all_files))
|
||||
assert len(all_files) < 20
|
||||
|
||||
@classmethod
|
||||
def get_virt_mem(cls):
|
||||
return psutil.virtual_memory().used
|
||||
|
||||
param_a = MockParam([1, -1])
|
||||
|
||||
pbt = PopulationBasedTraining(
|
||||
time_attr="training_iteration",
|
||||
metric="metric",
|
||||
mode="max",
|
||||
perturbation_interval=1,
|
||||
quantile_fraction=0.5,
|
||||
hyperparam_mutations={"b": [-1]},
|
||||
)
|
||||
|
||||
tune.run(
|
||||
MyTrainable,
|
||||
name="ray_demo",
|
||||
scheduler=pbt,
|
||||
stop={"training_iteration": 10},
|
||||
num_samples=4,
|
||||
checkpoint_freq=2,
|
||||
keep_checkpoints_num=1,
|
||||
verbose=False,
|
||||
fail_fast=True,
|
||||
config={"a": tune.sample_from(lambda _: param_a())},
|
||||
callbacks=[FileCheck()],
|
||||
)
|
||||
|
||||
|
||||
class PopulationBasedTrainingSynchTest(unittest.TestCase):
|
||||
def setUp(self):
|
||||
ray.init(num_cpus=2)
|
||||
|
||||
Reference in New Issue
Block a user