[tune] Fix flaky test in test_sample (#10602)

This commit is contained in:
Kai Fricke
2020-09-06 18:29:48 +01:00
committed by GitHub
parent 28ab797cf5
commit c5e9bafe15
2 changed files with 42 additions and 13 deletions
+13 -1
View File
@@ -2,6 +2,7 @@ import logging
import random
from copy import copy
from inspect import signature
from math import isclose
from numbers import Number
from typing import Any, Callable, Dict, List, Optional, Sequence, Union
@@ -187,7 +188,18 @@ class Float(Domain):
new.set_sampler(self._Normal(mean, sd))
return new
def quantized(self, q: Number):
def quantized(self, q: float):
if self.lower > float("-inf") and not isclose(self.lower / q,
round(self.lower / q)):
raise ValueError(
f"Your lower variable bound {self.lower} is not divisible by "
f"quantization factor {q}.")
if self.upper < float("inf") and not isclose(self.upper / q,
round(self.upper / q)):
raise ValueError(
f"Your upper variable bound {self.upper} is not divisible by "
f"quantization factor {q}.")
new = copy(self)
new.set_sampler(Quantized(new.get_sampler(), q), allow_override=True)
return new
+29 -12
View File
@@ -22,7 +22,7 @@ class SearchSpaceTest(unittest.TestCase):
"uniform": tune.uniform(-5, -1),
"quniform": tune.quniform(3.2, 5.4, 0.2),
"loguniform": tune.loguniform(1e-4, 1e-2),
"qloguniform": tune.qloguniform(1e-4, 1e-1, 5e-4),
"qloguniform": tune.qloguniform(1e-4, 1e-1, 5e-5),
"choice": tune.choice([2, 3, 4]),
"randint": tune.randint(-9, 15),
"qrandint": tune.qrandint(-21, 12, 3),
@@ -30,35 +30,35 @@ class SearchSpaceTest(unittest.TestCase):
"qrandn": tune.qrandn(10, 2, 0.2),
}
for _, (_, generated) in zip(
range(10), generate_variants({
range(1000), generate_variants({
"config": config
})):
out = generated["config"]
self.assertAlmostEqual(out["func"], out["uniform"] * 0.01)
self.assertGreater(out["uniform"], -5)
self.assertGreaterEqual(out["uniform"], -5)
self.assertLess(out["uniform"], -1)
self.assertGreater(out["quniform"], 3.2)
self.assertGreaterEqual(out["quniform"], 3.2)
self.assertLessEqual(out["quniform"], 5.4)
self.assertAlmostEqual(out["quniform"] / 0.2,
round(out["quniform"] / 0.2))
self.assertGreater(out["loguniform"], 1e-4)
self.assertGreaterEqual(out["loguniform"], 1e-4)
self.assertLess(out["loguniform"], 1e-2)
self.assertGreater(out["qloguniform"], 1e-4)
self.assertGreaterEqual(out["qloguniform"], 1e-4)
self.assertLessEqual(out["qloguniform"], 1e-1)
self.assertAlmostEqual(out["qloguniform"] / 5e-4,
round(out["qloguniform"] / 5e-4))
self.assertAlmostEqual(out["qloguniform"] / 5e-5,
round(out["qloguniform"] / 5e-5))
self.assertIn(out["choice"], [2, 3, 4])
self.assertGreater(out["randint"], -9)
self.assertGreaterEqual(out["randint"], -9)
self.assertLess(out["randint"], 15)
self.assertGreater(out["qrandint"], -21)
self.assertGreaterEqual(out["qrandint"], -21)
self.assertLessEqual(out["qrandint"], 12)
self.assertEqual(out["qrandint"] % 3, 0)
@@ -130,12 +130,29 @@ class SearchSpaceTest(unittest.TestCase):
def testQuantized(self):
bounded_positive = tune.sample.Float(1e-4, 1e-1)
samples = bounded_positive.loguniform().quantized(5e-4).sample(size=10)
bounded = tune.sample.Float(1e-4, 1e-1)
with self.assertRaises(ValueError):
# Granularity too high
bounded.quantized(5e-4)
with self.assertRaises(ValueError):
tune.sample.Float(-1e-1, -1e-4).quantized(5e-4)
samples = bounded_positive.loguniform().quantized(5e-5).sample(
size=1000)
for sample in samples:
factor = sample / 5e-4
factor = sample / 5e-5
assert 1e-4 <= sample <= 1e-1
self.assertAlmostEqual(factor, round(factor), places=10)
with self.assertRaises(ValueError):
tune.sample.Float(0, 32).quantized(3)
samples = tune.sample.Float(0, 33).quantized(3).sample(size=1000)
self.assertTrue(all(0 <= s <= 33 for s in samples))
def testConvertAx(self):
from ray.tune.suggest.ax import AxSearch
from ax.service.ax_client import AxClient