[tune] Fully deprecate raw function literals in Tune (#3788)

Related: https://github.com/ray-project/ray/issues/3785
This commit is contained in:
Eric Liang
2019-01-19 17:09:36 -08:00
committed by Richard Liaw
parent 16f7ca45e4
commit aad48ee5a5
4 changed files with 17 additions and 18 deletions
-2
View File
@@ -23,8 +23,6 @@ DEFAULT_CONFIG = with_base_config(impala.DEFAULT_CONFIG, {
# == PPO surrogate loss options ==
"clip_param": 0.4,
"kl_coeff": 0.2,
"kl_target": 0.01,
# == IMPALA optimizer params (see documentation in impala.py) ==
"sample_batch_size": 50,
+3 -2
View File
@@ -69,8 +69,9 @@ if __name__ == "__main__":
custom_loggers=[TestLogger],
stop={"training_iteration": 1 if args.smoke_test else 99999},
config={
"width": lambda spec: 10 + int(90 * random.random()),
"height": lambda spec: int(100 * random.random())
"width": tune.sample_from(
lambda spec: 10 + int(90 * random.random())),
"height": tune.sample_from(lambda spec: int(100 * random.random()))
})
trials = run_experiments(exp)
+2 -2
View File
@@ -238,8 +238,8 @@ def _is_resolved(v):
def _try_resolve(v):
if isinstance(v, types.FunctionType):
logger.warning(
"Deprecation warning: Function values are ambiguous in Tune "
raise DeprecationWarning(
"Function values are ambiguous in Tune "
"configuations. Either wrap the function with "
"`tune.function(func)` to specify a function literal, or "
"`tune.sample_from(func)` to tell Tune to "
+12 -12
View File
@@ -261,8 +261,8 @@ class TrainableFunctionApiTest(unittest.TestCase):
"run": "f1",
"local_dir": "/tmp/logdir",
"config": {
"a" * 50: lambda spec: 5.0 / 7,
"b" * 50: lambda spec: "long" * 40
"a" * 50: tune.sample_from(lambda spec: 5.0 / 7),
"b" * 50: tune.sample_from(lambda spec: "long" * 40),
},
}
})
@@ -848,7 +848,7 @@ class VariantGeneratorTest(unittest.TestCase):
trials = self.generate_trials({
"run": "PPO",
"config": {
"qux": lambda spec: 2 + 2,
"qux": tune.sample_from(lambda spec: 2 + 2),
"bar": grid_search([True, False]),
"foo": grid_search([1, 2, 3]),
},
@@ -863,8 +863,8 @@ class VariantGeneratorTest(unittest.TestCase):
"run": "PPO",
"config": {
"x": 1,
"y": lambda spec: spec.config.x + 1,
"z": lambda spec: spec.config.y + 1,
"y": tune.sample_from(lambda spec: spec.config.x + 1),
"z": tune.sample_from(lambda spec: spec.config.y + 1),
},
}, "condition_resolution")
trials = list(trials)
@@ -876,7 +876,7 @@ class VariantGeneratorTest(unittest.TestCase):
"run": "PPO",
"config": {
"x": grid_search([1, 2]),
"y": lambda spec: spec.config.x * 100,
"y": tune.sample_from(lambda spec: spec.config.x * 100),
},
}, "dependent_lambda")
trials = list(trials)
@@ -889,10 +889,10 @@ class VariantGeneratorTest(unittest.TestCase):
"run": "PPO",
"config": {
"x": grid_search([
lambda spec: spec.config.y * 100,
lambda spec: spec.config.y * 200
tune.sample_from(lambda spec: spec.config.y * 100),
tune.sample_from(lambda spec: spec.config.y * 200)
]),
"y": lambda spec: 1,
"y": tune.sample_from(lambda spec: 1),
},
}, "dependent_grid_search")
trials = list(trials)
@@ -920,7 +920,7 @@ class VariantGeneratorTest(unittest.TestCase):
self.generate_trials({
"run": "PPO",
"config": {
"foo": lambda spec: spec.config.foo,
"foo": tune.sample_from(lambda spec: spec.config.foo),
},
}, "recursive_dep"))
except RecursiveDependencyError as e:
@@ -1007,8 +1007,8 @@ class TrialRunnerTest(unittest.TestCase):
"foo": {
"run": "f1",
"config": {
"a" * 50: lambda spec: 5.0 / 7,
"b" * 50: lambda spec: "long" * 40
"a" * 50: tune.sample_from(lambda spec: 5.0 / 7),
"b" * 50: tune.sample_from(lambda spec: "long" * 40)
},
}
}