mirror of
https://github.com/wassname/ray.git
synced 2026-06-28 21:12:15 +08:00
[tune] Support callable objects in variant generation (#6849)
* minorcallable * format
This commit is contained in:
@@ -2,7 +2,6 @@ import copy
|
||||
import logging
|
||||
import numpy
|
||||
import random
|
||||
import types
|
||||
|
||||
from ray.tune import TuneError
|
||||
from ray.tune.sample import sample_from
|
||||
@@ -126,7 +125,7 @@ def _generate_variants(spec):
|
||||
grid_vars = []
|
||||
lambda_vars = []
|
||||
for path, value in unresolved.items():
|
||||
if isinstance(value, types.FunctionType):
|
||||
if callable(value):
|
||||
lambda_vars.append((path, value))
|
||||
else:
|
||||
grid_vars.append((path, value))
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
import os
|
||||
import numpy as np
|
||||
import random
|
||||
import unittest
|
||||
|
||||
import ray
|
||||
@@ -210,6 +211,29 @@ class VariantGeneratorTest(unittest.TestCase):
|
||||
self.assertEqual(trials[0].config, {"x": 100, "y": 1})
|
||||
self.assertEqual(trials[1].config, {"x": 200, "y": 1})
|
||||
|
||||
def testDependentGridSearchCallable(self):
|
||||
class Normal:
|
||||
def __call__(self, _config):
|
||||
return random.normalvariate(mu=0, sigma=1)
|
||||
|
||||
class Single:
|
||||
def __call__(self, _config):
|
||||
return 20
|
||||
|
||||
trials = self.generate_trials({
|
||||
"run": "PPO",
|
||||
"config": {
|
||||
"x": grid_search(
|
||||
[tune.sample_from(Normal()),
|
||||
tune.sample_from(Normal())]),
|
||||
"y": tune.sample_from(Single()),
|
||||
},
|
||||
}, "dependent_grid_search")
|
||||
trials = list(trials)
|
||||
self.assertEqual(len(trials), 2)
|
||||
self.assertEqual(trials[0].config["y"], 20)
|
||||
self.assertEqual(trials[1].config["y"], 20)
|
||||
|
||||
def testNestedValues(self):
|
||||
trials = self.generate_trials({
|
||||
"run": "PPO",
|
||||
|
||||
Reference in New Issue
Block a user