mirror of
https://github.com/wassname/ray.git
synced 2026-06-28 03:34:48 +08:00
[tune] Deprecate tune.function (#5601)
* remove tune function * remove examples * Update tune-usage.rst
This commit is contained in:
@@ -64,7 +64,7 @@ if __name__ == "__main__":
|
||||
MyTrainableClass,
|
||||
name="hyperband_test",
|
||||
num_samples=5,
|
||||
trial_name_creator=tune.function(trial_str_creator),
|
||||
trial_name_creator=trial_str_creator,
|
||||
loggers=[TestLogger],
|
||||
stop={"training_iteration": 1 if args.smoke_test else 99999},
|
||||
config={
|
||||
|
||||
@@ -11,9 +11,6 @@ logger = logging.getLogger(__name__)
|
||||
class sample_from(object):
|
||||
"""Specify that tune should sample configuration values from this function.
|
||||
|
||||
The use of function arguments in tune configs must be disambiguated by
|
||||
either wrapped the function in tune.sample_from() or tune.function().
|
||||
|
||||
Arguments:
|
||||
func: An callable function to draw a sample from.
|
||||
"""
|
||||
@@ -28,27 +25,10 @@ class sample_from(object):
|
||||
return "tune.sample_from({})".format(repr(self.func))
|
||||
|
||||
|
||||
class function(object):
|
||||
"""Wraps `func` to make sure it is not expanded during resolution.
|
||||
|
||||
The use of function arguments in tune configs must be disambiguated by
|
||||
either wrapped the function in tune.sample_from() or tune.function().
|
||||
|
||||
Arguments:
|
||||
func: A function literal.
|
||||
"""
|
||||
|
||||
def __init__(self, func):
|
||||
self.func = func
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
return self.func(*args, **kwargs)
|
||||
|
||||
def __str__(self):
|
||||
return "tune.function({})".format(str(self.func))
|
||||
|
||||
def __repr__(self):
|
||||
return "tune.function({})".format(repr(self.func))
|
||||
def function(func):
|
||||
logger.warn("DeprecationWarning: wrapping {} with tune.function() is no "
|
||||
"longer needed".format(func))
|
||||
return func
|
||||
|
||||
|
||||
def uniform(*args, **kwargs):
|
||||
|
||||
@@ -162,9 +162,7 @@ def _resolve_lambda_vars(spec, lambda_vars):
|
||||
error = e
|
||||
except Exception:
|
||||
raise ValueError(
|
||||
"Failed to evaluate expression: {}: {}".format(path, fn) +
|
||||
". If you meant to pass this as a function literal, use "
|
||||
"tune.function() to escape it.")
|
||||
"Failed to evaluate expression: {}: {}".format(path, fn))
|
||||
else:
|
||||
_assign_value(spec, path, value)
|
||||
resolved[path] = value
|
||||
@@ -207,16 +205,7 @@ def _is_resolved(v):
|
||||
|
||||
|
||||
def _try_resolve(v):
|
||||
if isinstance(v, types.FunctionType):
|
||||
raise DeprecationWarning(
|
||||
"Function values are ambiguous in Tune "
|
||||
"configuations. Either wrap the function with "
|
||||
"`tune.function(func)` to specify a function literal, or "
|
||||
"`tune.sample_from(func)` to tell Tune to "
|
||||
"sample values from the function during variant generation: "
|
||||
"{}".format(v))
|
||||
return False, v
|
||||
elif isinstance(v, sample_from):
|
||||
if isinstance(v, sample_from):
|
||||
# Function to sample from
|
||||
return False, v.func
|
||||
elif isinstance(v, dict) and len(v) == 1 and "eval" in v:
|
||||
|
||||
@@ -15,7 +15,6 @@ try: # py3
|
||||
except ImportError: # py2
|
||||
from pipes import quote
|
||||
|
||||
from ray.tune.sample import function as tune_function
|
||||
from ray.tune.error import TuneError
|
||||
from ray.tune.log_sync import log_sync_template, NodeSyncMixin
|
||||
|
||||
@@ -169,8 +168,7 @@ class CommandSyncer(BaseSyncer):
|
||||
def _get_sync_cls(sync_function):
|
||||
if not sync_function:
|
||||
return
|
||||
if isinstance(sync_function, types.FunctionType) or isinstance(
|
||||
sync_function, tune_function):
|
||||
if isinstance(sync_function, types.FunctionType):
|
||||
return BaseSyncer
|
||||
elif isinstance(sync_function, str):
|
||||
return CommandSyncer
|
||||
|
||||
@@ -972,8 +972,8 @@ class RunExperimentTest(unittest.TestCase):
|
||||
"stop": {
|
||||
"training_iteration": 1
|
||||
},
|
||||
"trial_name_creator": tune.function(
|
||||
lambda t: "{}_{}_321".format(t.trainable_name, t.trial_id))
|
||||
"trial_name_creator":
|
||||
lambda t: "{}_{}_321".format(t.trainable_name, t.trial_id)
|
||||
}
|
||||
})
|
||||
self.assertEquals(
|
||||
@@ -1113,7 +1113,7 @@ class TestSyncFunctionality(unittest.TestCase):
|
||||
"training_iteration": 1
|
||||
},
|
||||
upload_dir=tmpdir2,
|
||||
sync_to_cloud=tune.function(sync_func)).trials
|
||||
sync_to_cloud=sync_func).trials
|
||||
test_file_path = glob.glob(os.path.join(tmpdir2, "foo", "*.json"))
|
||||
self.assertTrue(test_file_path)
|
||||
shutil.rmtree(tmpdir)
|
||||
@@ -1134,7 +1134,7 @@ class TestSyncFunctionality(unittest.TestCase):
|
||||
stop={
|
||||
"training_iteration": 1
|
||||
},
|
||||
sync_to_driver=tune.function(sync_func_driver)).trials
|
||||
sync_to_driver=sync_func_driver).trials
|
||||
test_file_path = os.path.join(trial.logdir, "test.log2")
|
||||
self.assertFalse(os.path.exists(test_file_path))
|
||||
|
||||
@@ -1147,7 +1147,7 @@ class TestSyncFunctionality(unittest.TestCase):
|
||||
stop={
|
||||
"training_iteration": 1
|
||||
},
|
||||
sync_to_driver=tune.function(sync_func_driver)).trials
|
||||
sync_to_driver=sync_func_driver).trials
|
||||
test_file_path = os.path.join(trial.logdir, "test.log2")
|
||||
self.assertTrue(os.path.exists(test_file_path))
|
||||
os.remove(test_file_path)
|
||||
@@ -1166,8 +1166,8 @@ class TestSyncFunctionality(unittest.TestCase):
|
||||
"training_iteration": 1
|
||||
},
|
||||
"upload_dir": "test",
|
||||
"sync_to_driver": tune.function(sync_func),
|
||||
"sync_to_cloud": tune.function(sync_func)
|
||||
"sync_to_driver": sync_func,
|
||||
"sync_to_cloud": sync_func
|
||||
}).trials
|
||||
self.assertEqual(mock_sync.call_count, 0)
|
||||
|
||||
@@ -2273,11 +2273,9 @@ class TrialRunnerTest(unittest.TestCase):
|
||||
ray.init()
|
||||
trial = Trial(
|
||||
"__fake",
|
||||
config={
|
||||
"callbacks": {
|
||||
"on_episode_start": tune.function(lambda i: i),
|
||||
}
|
||||
},
|
||||
config={"callbacks": {
|
||||
"on_episode_start": lambda i: i,
|
||||
}},
|
||||
checkpoint_freq=1)
|
||||
tmpdir = tempfile.mkdtemp()
|
||||
runner = TrialRunner(local_checkpoint_dir=tmpdir, checkpoint_period=0)
|
||||
|
||||
@@ -8,7 +8,6 @@ import subprocess
|
||||
import json
|
||||
|
||||
import ray
|
||||
from ray import tune
|
||||
from ray.rllib import _register_all
|
||||
from ray.tune.trial import Trial, Resources
|
||||
from ray.tune.web_server import TuneClient
|
||||
@@ -94,11 +93,9 @@ class TuneServerSuite(unittest.TestCase):
|
||||
"__fake",
|
||||
trial_id="function_trial",
|
||||
stopping_criterion={"training_iteration": 3},
|
||||
config={
|
||||
"callbacks": {
|
||||
"on_episode_start": tune.function(lambda x: None)
|
||||
}
|
||||
})
|
||||
config={"callbacks": {
|
||||
"on_episode_start": lambda x: None
|
||||
}})
|
||||
runner.add_trial(test_trial)
|
||||
|
||||
for i in range(3):
|
||||
|
||||
@@ -11,6 +11,7 @@ import os
|
||||
import re
|
||||
import time
|
||||
import traceback
|
||||
import types
|
||||
|
||||
import ray.cloudpickle as cloudpickle
|
||||
from ray.tune import TuneError
|
||||
@@ -19,7 +20,6 @@ from ray.tune.result import (TIME_THIS_ITER_S, RESULT_DUPLICATE,
|
||||
SHOULD_CHECKPOINT)
|
||||
from ray.tune.syncer import get_syncer
|
||||
from ray.tune.trial import Trial, Checkpoint
|
||||
from ray.tune.sample import function
|
||||
from ray.tune.schedulers import FIFOScheduler, TrialScheduler
|
||||
from ray.tune.suggest import BasicVariantGenerator
|
||||
from ray.tune.util import warn_if_slow, flatten_dict
|
||||
@@ -48,7 +48,7 @@ def _find_newest_ckpt(ckpt_dir):
|
||||
|
||||
class _TuneFunctionEncoder(json.JSONEncoder):
|
||||
def default(self, obj):
|
||||
if isinstance(obj, function):
|
||||
if isinstance(obj, types.FunctionType):
|
||||
return self._to_cloudpickle(obj)
|
||||
try:
|
||||
return super(_TuneFunctionEncoder, self).default(obj)
|
||||
|
||||
Reference in New Issue
Block a user