[tune] Deprecate tune.function (#5601)

* remove tune function

* remove examples

* Update tune-usage.rst
This commit is contained in:
Eric Liang
2019-08-31 16:00:10 -07:00
committed by GitHub
parent 747daff2cb
commit daf38c8723
18 changed files with 52 additions and 100 deletions
+1 -1
View File
@@ -64,7 +64,7 @@ if __name__ == "__main__":
MyTrainableClass,
name="hyperband_test",
num_samples=5,
trial_name_creator=tune.function(trial_str_creator),
trial_name_creator=trial_str_creator,
loggers=[TestLogger],
stop={"training_iteration": 1 if args.smoke_test else 99999},
config={
+4 -24
View File
@@ -11,9 +11,6 @@ logger = logging.getLogger(__name__)
class sample_from(object):
"""Specify that tune should sample configuration values from this function.
The use of function arguments in tune configs must be disambiguated by
either wrapped the function in tune.sample_from() or tune.function().
Arguments:
func: An callable function to draw a sample from.
"""
@@ -28,27 +25,10 @@ class sample_from(object):
return "tune.sample_from({})".format(repr(self.func))
class function(object):
"""Wraps `func` to make sure it is not expanded during resolution.
The use of function arguments in tune configs must be disambiguated by
either wrapped the function in tune.sample_from() or tune.function().
Arguments:
func: A function literal.
"""
def __init__(self, func):
self.func = func
def __call__(self, *args, **kwargs):
return self.func(*args, **kwargs)
def __str__(self):
return "tune.function({})".format(str(self.func))
def __repr__(self):
return "tune.function({})".format(repr(self.func))
def function(func):
logger.warn("DeprecationWarning: wrapping {} with tune.function() is no "
"longer needed".format(func))
return func
def uniform(*args, **kwargs):
+2 -13
View File
@@ -162,9 +162,7 @@ def _resolve_lambda_vars(spec, lambda_vars):
error = e
except Exception:
raise ValueError(
"Failed to evaluate expression: {}: {}".format(path, fn) +
". If you meant to pass this as a function literal, use "
"tune.function() to escape it.")
"Failed to evaluate expression: {}: {}".format(path, fn))
else:
_assign_value(spec, path, value)
resolved[path] = value
@@ -207,16 +205,7 @@ def _is_resolved(v):
def _try_resolve(v):
if isinstance(v, types.FunctionType):
raise DeprecationWarning(
"Function values are ambiguous in Tune "
"configuations. Either wrap the function with "
"`tune.function(func)` to specify a function literal, or "
"`tune.sample_from(func)` to tell Tune to "
"sample values from the function during variant generation: "
"{}".format(v))
return False, v
elif isinstance(v, sample_from):
if isinstance(v, sample_from):
# Function to sample from
return False, v.func
elif isinstance(v, dict) and len(v) == 1 and "eval" in v:
+1 -3
View File
@@ -15,7 +15,6 @@ try: # py3
except ImportError: # py2
from pipes import quote
from ray.tune.sample import function as tune_function
from ray.tune.error import TuneError
from ray.tune.log_sync import log_sync_template, NodeSyncMixin
@@ -169,8 +168,7 @@ class CommandSyncer(BaseSyncer):
def _get_sync_cls(sync_function):
if not sync_function:
return
if isinstance(sync_function, types.FunctionType) or isinstance(
sync_function, tune_function):
if isinstance(sync_function, types.FunctionType):
return BaseSyncer
elif isinstance(sync_function, str):
return CommandSyncer
+10 -12
View File
@@ -972,8 +972,8 @@ class RunExperimentTest(unittest.TestCase):
"stop": {
"training_iteration": 1
},
"trial_name_creator": tune.function(
lambda t: "{}_{}_321".format(t.trainable_name, t.trial_id))
"trial_name_creator":
lambda t: "{}_{}_321".format(t.trainable_name, t.trial_id)
}
})
self.assertEquals(
@@ -1113,7 +1113,7 @@ class TestSyncFunctionality(unittest.TestCase):
"training_iteration": 1
},
upload_dir=tmpdir2,
sync_to_cloud=tune.function(sync_func)).trials
sync_to_cloud=sync_func).trials
test_file_path = glob.glob(os.path.join(tmpdir2, "foo", "*.json"))
self.assertTrue(test_file_path)
shutil.rmtree(tmpdir)
@@ -1134,7 +1134,7 @@ class TestSyncFunctionality(unittest.TestCase):
stop={
"training_iteration": 1
},
sync_to_driver=tune.function(sync_func_driver)).trials
sync_to_driver=sync_func_driver).trials
test_file_path = os.path.join(trial.logdir, "test.log2")
self.assertFalse(os.path.exists(test_file_path))
@@ -1147,7 +1147,7 @@ class TestSyncFunctionality(unittest.TestCase):
stop={
"training_iteration": 1
},
sync_to_driver=tune.function(sync_func_driver)).trials
sync_to_driver=sync_func_driver).trials
test_file_path = os.path.join(trial.logdir, "test.log2")
self.assertTrue(os.path.exists(test_file_path))
os.remove(test_file_path)
@@ -1166,8 +1166,8 @@ class TestSyncFunctionality(unittest.TestCase):
"training_iteration": 1
},
"upload_dir": "test",
"sync_to_driver": tune.function(sync_func),
"sync_to_cloud": tune.function(sync_func)
"sync_to_driver": sync_func,
"sync_to_cloud": sync_func
}).trials
self.assertEqual(mock_sync.call_count, 0)
@@ -2273,11 +2273,9 @@ class TrialRunnerTest(unittest.TestCase):
ray.init()
trial = Trial(
"__fake",
config={
"callbacks": {
"on_episode_start": tune.function(lambda i: i),
}
},
config={"callbacks": {
"on_episode_start": lambda i: i,
}},
checkpoint_freq=1)
tmpdir = tempfile.mkdtemp()
runner = TrialRunner(local_checkpoint_dir=tmpdir, checkpoint_period=0)
+3 -6
View File
@@ -8,7 +8,6 @@ import subprocess
import json
import ray
from ray import tune
from ray.rllib import _register_all
from ray.tune.trial import Trial, Resources
from ray.tune.web_server import TuneClient
@@ -94,11 +93,9 @@ class TuneServerSuite(unittest.TestCase):
"__fake",
trial_id="function_trial",
stopping_criterion={"training_iteration": 3},
config={
"callbacks": {
"on_episode_start": tune.function(lambda x: None)
}
})
config={"callbacks": {
"on_episode_start": lambda x: None
}})
runner.add_trial(test_trial)
for i in range(3):
+2 -2
View File
@@ -11,6 +11,7 @@ import os
import re
import time
import traceback
import types
import ray.cloudpickle as cloudpickle
from ray.tune import TuneError
@@ -19,7 +20,6 @@ from ray.tune.result import (TIME_THIS_ITER_S, RESULT_DUPLICATE,
SHOULD_CHECKPOINT)
from ray.tune.syncer import get_syncer
from ray.tune.trial import Trial, Checkpoint
from ray.tune.sample import function
from ray.tune.schedulers import FIFOScheduler, TrialScheduler
from ray.tune.suggest import BasicVariantGenerator
from ray.tune.util import warn_if_slow, flatten_dict
@@ -48,7 +48,7 @@ def _find_newest_ckpt(ckpt_dir):
class _TuneFunctionEncoder(json.JSONEncoder):
def default(self, obj):
if isinstance(obj, function):
if isinstance(obj, types.FunctionType):
return self._to_cloudpickle(obj)
try:
return super(_TuneFunctionEncoder, self).default(obj)