From 287aba6dc39ad4e88d56884b27169d44bed5e5cf Mon Sep 17 00:00:00 2001 From: Kai Fricke Date: Mon, 9 Nov 2020 20:37:05 +0100 Subject: [PATCH] [tune] schedulers: Add test for context finalization (#11889) --- python/ray/tune/tests/test_trial_scheduler.py | 57 ++++++++++++++++++- 1 file changed, 56 insertions(+), 1 deletion(-) diff --git a/python/ray/tune/tests/test_trial_scheduler.py b/python/ray/tune/tests/test_trial_scheduler.py index d5708f346..a2cd02fb2 100644 --- a/python/ray/tune/tests/test_trial_scheduler.py +++ b/python/ray/tune/tests/test_trial_scheduler.py @@ -2,6 +2,7 @@ import os import json import random import unittest + import numpy as np import sys import tempfile @@ -12,7 +13,8 @@ import ray from ray import tune from ray.tune import Trainable from ray.tune.result import TRAINING_ITERATION -from ray.tune.schedulers import (HyperBandScheduler, AsyncHyperBandScheduler, +from ray.tune.schedulers import (FIFOScheduler, HyperBandScheduler, + AsyncHyperBandScheduler, PopulationBasedTraining, MedianStoppingRule, TrialScheduler, HyperBandForBOHB) @@ -1696,6 +1698,59 @@ class PopulationBasedTestingSuite(unittest.TestCase): pbt._exploit(runner.trial_executor, trials[1], trials[2]) shutil.rmtree(tmpdir) + def testContextExit(self): + vals = [5, 1] + + class MockContext: + def __init__(self, config): + self.config = config + self.active = False + + def __enter__(self): + print("Set up resource.", self.config) + with open("status.txt", "wt") as fp: + fp.write("Activate\n") + self.active = True + return self + + def __exit__(self, type, value, traceback): + print("Clean up resource.", self.config) + with open("status.txt", "at") as fp: + fp.write("Cleanup\n") + self.active = False + + def train(config): + with MockContext(config): + for i in range(10): + tune.report(metric=i + config["x"]) + + class MockScheduler(FIFOScheduler): + def on_trial_result(self, trial_runner, trial, result): + return TrialScheduler.STOP + + scheduler = MockScheduler() + + out = tune.run( + train, config={"x": tune.grid_search(vals)}, scheduler=scheduler) + + ever_active = set() + active = set() + for trial in out.trials: + with open(os.path.join(trial.logdir, "status.txt"), "rt") as fp: + status = fp.read() + print(f"Status for trial {trial}: {status}") + if "Activate" in status: + ever_active.add(trial) + active.add(trial) + if "Cleanup" in status: + active.remove(trial) + + print(f"Ever active: {ever_active}") + print(f"Still active: {active}") + + self.assertEqual(len(ever_active), len(vals)) + self.assertEqual(len(active), 0) + class E2EPopulationBasedTestingSuite(unittest.TestCase): def setUp(self):