From 4edfaf2f38d9a9e6a6fbefc30fe87a5f9d1a20d2 Mon Sep 17 00:00:00 2001 From: Richard Liaw Date: Tue, 21 Jan 2020 10:24:25 -0800 Subject: [PATCH] [tune] Support callable objects in variant generation (#6849) * minorcallable * format --- python/ray/tune/suggest/variant_generator.py | 3 +-- python/ray/tune/tests/test_var.py | 24 ++++++++++++++++++++ 2 files changed, 25 insertions(+), 2 deletions(-) diff --git a/python/ray/tune/suggest/variant_generator.py b/python/ray/tune/suggest/variant_generator.py index 6e7092536..e772ffb49 100644 --- a/python/ray/tune/suggest/variant_generator.py +++ b/python/ray/tune/suggest/variant_generator.py @@ -2,7 +2,6 @@ import copy import logging import numpy import random -import types from ray.tune import TuneError from ray.tune.sample import sample_from @@ -126,7 +125,7 @@ def _generate_variants(spec): grid_vars = [] lambda_vars = [] for path, value in unresolved.items(): - if isinstance(value, types.FunctionType): + if callable(value): lambda_vars.append((path, value)) else: grid_vars.append((path, value)) diff --git a/python/ray/tune/tests/test_var.py b/python/ray/tune/tests/test_var.py index 3b9695a00..0ed98d860 100644 --- a/python/ray/tune/tests/test_var.py +++ b/python/ray/tune/tests/test_var.py @@ -1,5 +1,6 @@ import os import numpy as np +import random import unittest import ray @@ -210,6 +211,29 @@ class VariantGeneratorTest(unittest.TestCase): self.assertEqual(trials[0].config, {"x": 100, "y": 1}) self.assertEqual(trials[1].config, {"x": 200, "y": 1}) + def testDependentGridSearchCallable(self): + class Normal: + def __call__(self, _config): + return random.normalvariate(mu=0, sigma=1) + + class Single: + def __call__(self, _config): + return 20 + + trials = self.generate_trials({ + "run": "PPO", + "config": { + "x": grid_search( + [tune.sample_from(Normal()), + tune.sample_from(Normal())]), + "y": tune.sample_from(Single()), + }, + }, "dependent_grid_search") + trials = list(trials) + self.assertEqual(len(trials), 2) + self.assertEqual(trials[0].config["y"], 20) + self.assertEqual(trials[1].config["y"], 20) + def testNestedValues(self): trials = self.generate_trials({ "run": "PPO",