[tune] Support callable objects in variant generation (#6849)

* minorcallable

* format
This commit is contained in:
Richard Liaw
2020-01-21 10:24:25 -08:00
committed by GitHub
parent dac6268c5b
commit 4edfaf2f38
2 changed files with 25 additions and 2 deletions
+1 -2
View File
@@ -2,7 +2,6 @@ import copy
import logging
import numpy
import random
import types
from ray.tune import TuneError
from ray.tune.sample import sample_from
@@ -126,7 +125,7 @@ def _generate_variants(spec):
grid_vars = []
lambda_vars = []
for path, value in unresolved.items():
if isinstance(value, types.FunctionType):
if callable(value):
lambda_vars.append((path, value))
else:
grid_vars.append((path, value))
+24
View File
@@ -1,5 +1,6 @@
import os
import numpy as np
import random
import unittest
import ray
@@ -210,6 +211,29 @@ class VariantGeneratorTest(unittest.TestCase):
self.assertEqual(trials[0].config, {"x": 100, "y": 1})
self.assertEqual(trials[1].config, {"x": 200, "y": 1})
def testDependentGridSearchCallable(self):
class Normal:
def __call__(self, _config):
return random.normalvariate(mu=0, sigma=1)
class Single:
def __call__(self, _config):
return 20
trials = self.generate_trials({
"run": "PPO",
"config": {
"x": grid_search(
[tune.sample_from(Normal()),
tune.sample_from(Normal())]),
"y": tune.sample_from(Single()),
},
}, "dependent_grid_search")
trials = list(trials)
self.assertEqual(len(trials), 2)
self.assertEqual(trials[0].config["y"], 20)
self.assertEqual(trials[1].config["y"], 20)
def testNestedValues(self):
trials = self.generate_trials({
"run": "PPO",