From c5e9bafe159af557939f915a9d4238446bbd123e Mon Sep 17 00:00:00 2001 From: Kai Fricke Date: Sun, 6 Sep 2020 18:29:48 +0100 Subject: [PATCH] [tune] Fix flaky test in test_sample (#10602) --- python/ray/tune/sample.py | 14 +++++++++- python/ray/tune/tests/test_sample.py | 41 ++++++++++++++++++++-------- 2 files changed, 42 insertions(+), 13 deletions(-) diff --git a/python/ray/tune/sample.py b/python/ray/tune/sample.py index c0a85a62a..98fc84294 100644 --- a/python/ray/tune/sample.py +++ b/python/ray/tune/sample.py @@ -2,6 +2,7 @@ import logging import random from copy import copy from inspect import signature +from math import isclose from numbers import Number from typing import Any, Callable, Dict, List, Optional, Sequence, Union @@ -187,7 +188,18 @@ class Float(Domain): new.set_sampler(self._Normal(mean, sd)) return new - def quantized(self, q: Number): + def quantized(self, q: float): + if self.lower > float("-inf") and not isclose(self.lower / q, + round(self.lower / q)): + raise ValueError( + f"Your lower variable bound {self.lower} is not divisible by " + f"quantization factor {q}.") + if self.upper < float("inf") and not isclose(self.upper / q, + round(self.upper / q)): + raise ValueError( + f"Your upper variable bound {self.upper} is not divisible by " + f"quantization factor {q}.") + new = copy(self) new.set_sampler(Quantized(new.get_sampler(), q), allow_override=True) return new diff --git a/python/ray/tune/tests/test_sample.py b/python/ray/tune/tests/test_sample.py index dee68fe83..7cc34ee01 100644 --- a/python/ray/tune/tests/test_sample.py +++ b/python/ray/tune/tests/test_sample.py @@ -22,7 +22,7 @@ class SearchSpaceTest(unittest.TestCase): "uniform": tune.uniform(-5, -1), "quniform": tune.quniform(3.2, 5.4, 0.2), "loguniform": tune.loguniform(1e-4, 1e-2), - "qloguniform": tune.qloguniform(1e-4, 1e-1, 5e-4), + "qloguniform": tune.qloguniform(1e-4, 1e-1, 5e-5), "choice": tune.choice([2, 3, 4]), "randint": tune.randint(-9, 15), "qrandint": tune.qrandint(-21, 12, 3), @@ -30,35 +30,35 @@ class SearchSpaceTest(unittest.TestCase): "qrandn": tune.qrandn(10, 2, 0.2), } for _, (_, generated) in zip( - range(10), generate_variants({ + range(1000), generate_variants({ "config": config })): out = generated["config"] self.assertAlmostEqual(out["func"], out["uniform"] * 0.01) - self.assertGreater(out["uniform"], -5) + self.assertGreaterEqual(out["uniform"], -5) self.assertLess(out["uniform"], -1) - self.assertGreater(out["quniform"], 3.2) + self.assertGreaterEqual(out["quniform"], 3.2) self.assertLessEqual(out["quniform"], 5.4) self.assertAlmostEqual(out["quniform"] / 0.2, round(out["quniform"] / 0.2)) - self.assertGreater(out["loguniform"], 1e-4) + self.assertGreaterEqual(out["loguniform"], 1e-4) self.assertLess(out["loguniform"], 1e-2) - self.assertGreater(out["qloguniform"], 1e-4) + self.assertGreaterEqual(out["qloguniform"], 1e-4) self.assertLessEqual(out["qloguniform"], 1e-1) - self.assertAlmostEqual(out["qloguniform"] / 5e-4, - round(out["qloguniform"] / 5e-4)) + self.assertAlmostEqual(out["qloguniform"] / 5e-5, + round(out["qloguniform"] / 5e-5)) self.assertIn(out["choice"], [2, 3, 4]) - self.assertGreater(out["randint"], -9) + self.assertGreaterEqual(out["randint"], -9) self.assertLess(out["randint"], 15) - self.assertGreater(out["qrandint"], -21) + self.assertGreaterEqual(out["qrandint"], -21) self.assertLessEqual(out["qrandint"], 12) self.assertEqual(out["qrandint"] % 3, 0) @@ -130,12 +130,29 @@ class SearchSpaceTest(unittest.TestCase): def testQuantized(self): bounded_positive = tune.sample.Float(1e-4, 1e-1) - samples = bounded_positive.loguniform().quantized(5e-4).sample(size=10) + + bounded = tune.sample.Float(1e-4, 1e-1) + with self.assertRaises(ValueError): + # Granularity too high + bounded.quantized(5e-4) + + with self.assertRaises(ValueError): + tune.sample.Float(-1e-1, -1e-4).quantized(5e-4) + + samples = bounded_positive.loguniform().quantized(5e-5).sample( + size=1000) for sample in samples: - factor = sample / 5e-4 + factor = sample / 5e-5 + assert 1e-4 <= sample <= 1e-1 self.assertAlmostEqual(factor, round(factor), places=10) + with self.assertRaises(ValueError): + tune.sample.Float(0, 32).quantized(3) + + samples = tune.sample.Float(0, 33).quantized(3).sample(size=1000) + self.assertTrue(all(0 <= s <= 33 for s in samples)) + def testConvertAx(self): from ray.tune.suggest.ax import AxSearch from ax.service.ax_client import AxClient