diff --git a/python/ray/tune/logger.py b/python/ray/tune/logger.py index 16765d09e..b4ff76bae 100644 --- a/python/ray/tune/logger.py +++ b/python/ray/tune/logger.py @@ -630,6 +630,8 @@ class TBXLoggerCallback(LoggerCallback): self._trial_result: Dict["Trial", Dict] = {} def log_trial_start(self, trial: "Trial"): + if trial in self._trial_writer: + self._trial_writer[trial].close() trial.init_logdir() self._trial_writer[trial] = self._summary_writer_cls( trial.logdir, flush_secs=30) diff --git a/python/ray/tune/tests/test_trial_scheduler_pbt.py b/python/ray/tune/tests/test_trial_scheduler_pbt.py index f94b8251b..300ea0bfb 100644 --- a/python/ray/tune/tests/test_trial_scheduler_pbt.py +++ b/python/ray/tune/tests/test_trial_scheduler_pbt.py @@ -1,6 +1,7 @@ import numpy as np import os import pickle +import psutil import random import unittest import sys @@ -87,6 +88,84 @@ class PopulationBasedTrainingMemoryTest(unittest.TestCase): ) +class PopulationBasedTrainingFileDescriptorTest(unittest.TestCase): + def setUp(self): + ray.init(num_cpus=2) + os.environ["TUNE_GLOBAL_CHECKPOINT_S"] = "0" + + def tearDown(self): + ray.shutdown() + + def testFileFree(self): + class MyTrainable(Trainable): + def setup(self, config): + self.iter = 0 + self.a = config["a"] + + def step(self): + self.iter += 1 + return {"metric": self.iter + self.a} + + def save_checkpoint(self, checkpoint_dir): + file_path = os.path.join(checkpoint_dir, "model.mock") + + with open(file_path, "wb") as fp: + pickle.dump((self.iter, self.a), fp) + return file_path + + def load_checkpoint(self, path): + with open(path, "rb") as fp: + self.iter, self.a = pickle.load(fp) + + from ray.tune.callback import Callback + + class FileCheck(Callback): + def __init__(self, verbose=False): + self.iter_ = 0 + self.process = psutil.Process() + self.verbose = verbose + + def on_trial_result(self, *args, **kwargs): + self.iter_ += 1 + all_files = self.process.open_files() + if self.verbose: + print("Iteration", self.iter_) + print("=" * 10) + print("Number of objects: ", len(ray.objects())) + print("Virtual Mem:", self.get_virt_mem() >> 30, "gb") + print("File Descriptors:", len(all_files)) + assert len(all_files) < 20 + + @classmethod + def get_virt_mem(cls): + return psutil.virtual_memory().used + + param_a = MockParam([1, -1]) + + pbt = PopulationBasedTraining( + time_attr="training_iteration", + metric="metric", + mode="max", + perturbation_interval=1, + quantile_fraction=0.5, + hyperparam_mutations={"b": [-1]}, + ) + + tune.run( + MyTrainable, + name="ray_demo", + scheduler=pbt, + stop={"training_iteration": 10}, + num_samples=4, + checkpoint_freq=2, + keep_checkpoints_num=1, + verbose=False, + fail_fast=True, + config={"a": tune.sample_from(lambda _: param_a())}, + callbacks=[FileCheck()], + ) + + class PopulationBasedTrainingSynchTest(unittest.TestCase): def setUp(self): ray.init(num_cpus=2)