mirror of
https://github.com/wassname/ray.git
synced 2026-06-30 16:49:48 +08:00
[tune] Fix flaky test in test_sample (#10602)
This commit is contained in:
@@ -2,6 +2,7 @@ import logging
|
||||
import random
|
||||
from copy import copy
|
||||
from inspect import signature
|
||||
from math import isclose
|
||||
from numbers import Number
|
||||
from typing import Any, Callable, Dict, List, Optional, Sequence, Union
|
||||
|
||||
@@ -187,7 +188,18 @@ class Float(Domain):
|
||||
new.set_sampler(self._Normal(mean, sd))
|
||||
return new
|
||||
|
||||
def quantized(self, q: Number):
|
||||
def quantized(self, q: float):
|
||||
if self.lower > float("-inf") and not isclose(self.lower / q,
|
||||
round(self.lower / q)):
|
||||
raise ValueError(
|
||||
f"Your lower variable bound {self.lower} is not divisible by "
|
||||
f"quantization factor {q}.")
|
||||
if self.upper < float("inf") and not isclose(self.upper / q,
|
||||
round(self.upper / q)):
|
||||
raise ValueError(
|
||||
f"Your upper variable bound {self.upper} is not divisible by "
|
||||
f"quantization factor {q}.")
|
||||
|
||||
new = copy(self)
|
||||
new.set_sampler(Quantized(new.get_sampler(), q), allow_override=True)
|
||||
return new
|
||||
|
||||
@@ -22,7 +22,7 @@ class SearchSpaceTest(unittest.TestCase):
|
||||
"uniform": tune.uniform(-5, -1),
|
||||
"quniform": tune.quniform(3.2, 5.4, 0.2),
|
||||
"loguniform": tune.loguniform(1e-4, 1e-2),
|
||||
"qloguniform": tune.qloguniform(1e-4, 1e-1, 5e-4),
|
||||
"qloguniform": tune.qloguniform(1e-4, 1e-1, 5e-5),
|
||||
"choice": tune.choice([2, 3, 4]),
|
||||
"randint": tune.randint(-9, 15),
|
||||
"qrandint": tune.qrandint(-21, 12, 3),
|
||||
@@ -30,35 +30,35 @@ class SearchSpaceTest(unittest.TestCase):
|
||||
"qrandn": tune.qrandn(10, 2, 0.2),
|
||||
}
|
||||
for _, (_, generated) in zip(
|
||||
range(10), generate_variants({
|
||||
range(1000), generate_variants({
|
||||
"config": config
|
||||
})):
|
||||
out = generated["config"]
|
||||
|
||||
self.assertAlmostEqual(out["func"], out["uniform"] * 0.01)
|
||||
|
||||
self.assertGreater(out["uniform"], -5)
|
||||
self.assertGreaterEqual(out["uniform"], -5)
|
||||
self.assertLess(out["uniform"], -1)
|
||||
|
||||
self.assertGreater(out["quniform"], 3.2)
|
||||
self.assertGreaterEqual(out["quniform"], 3.2)
|
||||
self.assertLessEqual(out["quniform"], 5.4)
|
||||
self.assertAlmostEqual(out["quniform"] / 0.2,
|
||||
round(out["quniform"] / 0.2))
|
||||
|
||||
self.assertGreater(out["loguniform"], 1e-4)
|
||||
self.assertGreaterEqual(out["loguniform"], 1e-4)
|
||||
self.assertLess(out["loguniform"], 1e-2)
|
||||
|
||||
self.assertGreater(out["qloguniform"], 1e-4)
|
||||
self.assertGreaterEqual(out["qloguniform"], 1e-4)
|
||||
self.assertLessEqual(out["qloguniform"], 1e-1)
|
||||
self.assertAlmostEqual(out["qloguniform"] / 5e-4,
|
||||
round(out["qloguniform"] / 5e-4))
|
||||
self.assertAlmostEqual(out["qloguniform"] / 5e-5,
|
||||
round(out["qloguniform"] / 5e-5))
|
||||
|
||||
self.assertIn(out["choice"], [2, 3, 4])
|
||||
|
||||
self.assertGreater(out["randint"], -9)
|
||||
self.assertGreaterEqual(out["randint"], -9)
|
||||
self.assertLess(out["randint"], 15)
|
||||
|
||||
self.assertGreater(out["qrandint"], -21)
|
||||
self.assertGreaterEqual(out["qrandint"], -21)
|
||||
self.assertLessEqual(out["qrandint"], 12)
|
||||
self.assertEqual(out["qrandint"] % 3, 0)
|
||||
|
||||
@@ -130,12 +130,29 @@ class SearchSpaceTest(unittest.TestCase):
|
||||
|
||||
def testQuantized(self):
|
||||
bounded_positive = tune.sample.Float(1e-4, 1e-1)
|
||||
samples = bounded_positive.loguniform().quantized(5e-4).sample(size=10)
|
||||
|
||||
bounded = tune.sample.Float(1e-4, 1e-1)
|
||||
with self.assertRaises(ValueError):
|
||||
# Granularity too high
|
||||
bounded.quantized(5e-4)
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
tune.sample.Float(-1e-1, -1e-4).quantized(5e-4)
|
||||
|
||||
samples = bounded_positive.loguniform().quantized(5e-5).sample(
|
||||
size=1000)
|
||||
|
||||
for sample in samples:
|
||||
factor = sample / 5e-4
|
||||
factor = sample / 5e-5
|
||||
assert 1e-4 <= sample <= 1e-1
|
||||
self.assertAlmostEqual(factor, round(factor), places=10)
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
tune.sample.Float(0, 32).quantized(3)
|
||||
|
||||
samples = tune.sample.Float(0, 33).quantized(3).sample(size=1000)
|
||||
self.assertTrue(all(0 <= s <= 33 for s in samples))
|
||||
|
||||
def testConvertAx(self):
|
||||
from ray.tune.suggest.ax import AxSearch
|
||||
from ax.service.ax_client import AxClient
|
||||
|
||||
Reference in New Issue
Block a user