From daf38c8723581f2cfa183c940b7fbf3928dd7655 Mon Sep 17 00:00:00 2001 From: Eric Liang Date: Sat, 31 Aug 2019 16:00:10 -0700 Subject: [PATCH] [tune] Deprecate tune.function (#5601) * remove tune function * remove examples * Update tune-usage.rst --- doc/source/rllib-training.rst | 12 ++++---- doc/source/tune-usage.rst | 16 +++++------ python/ray/tune/examples/logging_example.py | 2 +- python/ray/tune/sample.py | 28 +++---------------- python/ray/tune/suggest/variant_generator.py | 15 ++-------- python/ray/tune/syncer.py | 4 +-- python/ray/tune/tests/test_trial_runner.py | 22 +++++++-------- python/ray/tune/tests/test_tune_server.py | 9 ++---- python/ray/tune/trial_runner.py | 4 +-- rllib/agents/dqn/dqn.py | 6 ++-- rllib/evaluation/rollout_worker.py | 5 +--- rllib/examples/centralized_critic.py | 3 +- rllib/examples/centralized_critic_2.py | 5 ++-- .../examples/custom_metrics_and_callbacks.py | 12 ++++---- rllib/examples/multiagent_cartpole.py | 2 +- rllib/examples/multiagent_custom_policy.py | 2 +- .../rock_paper_scissors_multiagent.py | 2 +- rllib/examples/twostep_game.py | 3 +- 18 files changed, 52 insertions(+), 100 deletions(-) diff --git a/doc/source/rllib-training.rst b/doc/source/rllib-training.rst index 76081e946..2d58ad89f 100644 --- a/doc/source/rllib-training.rst +++ b/doc/source/rllib-training.rst @@ -286,11 +286,11 @@ You can provide callback functions to be called at points during policy evaluati config={ "env": "CartPole-v0", "callbacks": { - "on_episode_start": tune.function(on_episode_start), - "on_episode_step": tune.function(on_episode_step), - "on_episode_end": tune.function(on_episode_end), - "on_train_result": tune.function(on_train_result), - "on_postprocess_traj": tune.function(on_postprocess_traj), + "on_episode_start": on_episode_start, + "on_episode_step": on_episode_step, + "on_episode_end": on_episode_end, + "on_train_result": on_train_result, + "on_postprocess_traj": on_postprocess_traj, }, }, ) @@ -377,7 +377,7 @@ Approach 2: Use the callbacks API to update the environment on new training resu config={ "env": YourEnv, "callbacks": { - "on_train_result": tune.function(on_train_result), + "on_train_result": on_train_result, }, }, ) diff --git a/doc/source/tune-usage.rst b/doc/source/tune-usage.rst index 54fc37a70..809972d99 100644 --- a/doc/source/tune-usage.rst +++ b/doc/source/tune-usage.rst @@ -167,9 +167,8 @@ The following shows grid search over two nested parameters combined with random } ) - -.. note:: - Use ``tune.sample_from(...)`` to sample from a function during trial variant generation. If you need to pass a literal function in your config, use ``tune.function(...)`` to escape it. +.. note:: + Use ``tune.sample_from(...)`` to sample from a function during trial variant generation. For more information on variant generation, see `basic_variant.py `__. @@ -177,8 +176,7 @@ Custom Trial Names ------------------ To specify custom trial names, you can pass use the ``trial_name_creator`` argument -to `tune.run`. This takes a function with the following signature, and -be sure to wrap it with `tune.function`: +to `tune.run`. This takes a function with the following signature: .. code-block:: python @@ -196,7 +194,7 @@ be sure to wrap it with `tune.function`: MyTrainableClass, name="example-experiment", num_samples=1, - trial_name_creator=tune.function(trial_name_string) + trial_name_creator=trial_name_string ) An example can be found in `logging_example.py `__. @@ -496,7 +494,7 @@ Uploading/Syncing Tune automatically syncs the trial folder on remote nodes back to the head node. This requires the ray cluster to be started with the `autoscaler `__. By default, local syncing requires rsync to be installed. You can customize the sync command with the ``sync_to_driver`` argument in ``tune.run`` by providing either a function or a string. -If a string is provided, then it must include replacement fields ``{source}`` and ``{target}``, like ``rsync -savz -e "ssh -i ssh_key.pem" {source} {target}``. Alternatively, a function can be provided with the following signature (and must be wrapped with ``tune.function``): +If a string is provided, then it must include replacement fields ``{source}`` and ``{target}``, like ``rsync -savz -e "ssh -i ssh_key.pem" {source} {target}``. Alternatively, a function can be provided with the following signature: .. code-block:: python @@ -510,7 +508,7 @@ If a string is provided, then it must include replacement fields ``{source}`` an tune.run( MyTrainableClass, name="experiment_name", - sync_to_driver=tune.function(custom_sync_func), + sync_to_driver=custom_sync_func, ) When syncing results back to the driver, the source would be a path similar to ``ubuntu@192.0.0.1:/home/ubuntu/ray_results/trial1``, and the target would be a local path. @@ -524,7 +522,7 @@ You can customize this to specify arbitrary storages with the ``sync_to_cloud`` tune.run( MyTrainableClass, name="experiment_name", - sync_to_cloud=tune.function(custom_sync_func), + sync_to_cloud=custom_sync_func, ) Tune Client API diff --git a/python/ray/tune/examples/logging_example.py b/python/ray/tune/examples/logging_example.py index eafc23e69..3bacd0975 100755 --- a/python/ray/tune/examples/logging_example.py +++ b/python/ray/tune/examples/logging_example.py @@ -64,7 +64,7 @@ if __name__ == "__main__": MyTrainableClass, name="hyperband_test", num_samples=5, - trial_name_creator=tune.function(trial_str_creator), + trial_name_creator=trial_str_creator, loggers=[TestLogger], stop={"training_iteration": 1 if args.smoke_test else 99999}, config={ diff --git a/python/ray/tune/sample.py b/python/ray/tune/sample.py index f06470f62..ec632e8c1 100644 --- a/python/ray/tune/sample.py +++ b/python/ray/tune/sample.py @@ -11,9 +11,6 @@ logger = logging.getLogger(__name__) class sample_from(object): """Specify that tune should sample configuration values from this function. - The use of function arguments in tune configs must be disambiguated by - either wrapped the function in tune.sample_from() or tune.function(). - Arguments: func: An callable function to draw a sample from. """ @@ -28,27 +25,10 @@ class sample_from(object): return "tune.sample_from({})".format(repr(self.func)) -class function(object): - """Wraps `func` to make sure it is not expanded during resolution. - - The use of function arguments in tune configs must be disambiguated by - either wrapped the function in tune.sample_from() or tune.function(). - - Arguments: - func: A function literal. - """ - - def __init__(self, func): - self.func = func - - def __call__(self, *args, **kwargs): - return self.func(*args, **kwargs) - - def __str__(self): - return "tune.function({})".format(str(self.func)) - - def __repr__(self): - return "tune.function({})".format(repr(self.func)) +def function(func): + logger.warn("DeprecationWarning: wrapping {} with tune.function() is no " + "longer needed".format(func)) + return func def uniform(*args, **kwargs): diff --git a/python/ray/tune/suggest/variant_generator.py b/python/ray/tune/suggest/variant_generator.py index cf2b01eea..2488b14ae 100644 --- a/python/ray/tune/suggest/variant_generator.py +++ b/python/ray/tune/suggest/variant_generator.py @@ -162,9 +162,7 @@ def _resolve_lambda_vars(spec, lambda_vars): error = e except Exception: raise ValueError( - "Failed to evaluate expression: {}: {}".format(path, fn) + - ". If you meant to pass this as a function literal, use " - "tune.function() to escape it.") + "Failed to evaluate expression: {}: {}".format(path, fn)) else: _assign_value(spec, path, value) resolved[path] = value @@ -207,16 +205,7 @@ def _is_resolved(v): def _try_resolve(v): - if isinstance(v, types.FunctionType): - raise DeprecationWarning( - "Function values are ambiguous in Tune " - "configuations. Either wrap the function with " - "`tune.function(func)` to specify a function literal, or " - "`tune.sample_from(func)` to tell Tune to " - "sample values from the function during variant generation: " - "{}".format(v)) - return False, v - elif isinstance(v, sample_from): + if isinstance(v, sample_from): # Function to sample from return False, v.func elif isinstance(v, dict) and len(v) == 1 and "eval" in v: diff --git a/python/ray/tune/syncer.py b/python/ray/tune/syncer.py index e83c966a6..ba54163e8 100644 --- a/python/ray/tune/syncer.py +++ b/python/ray/tune/syncer.py @@ -15,7 +15,6 @@ try: # py3 except ImportError: # py2 from pipes import quote -from ray.tune.sample import function as tune_function from ray.tune.error import TuneError from ray.tune.log_sync import log_sync_template, NodeSyncMixin @@ -169,8 +168,7 @@ class CommandSyncer(BaseSyncer): def _get_sync_cls(sync_function): if not sync_function: return - if isinstance(sync_function, types.FunctionType) or isinstance( - sync_function, tune_function): + if isinstance(sync_function, types.FunctionType): return BaseSyncer elif isinstance(sync_function, str): return CommandSyncer diff --git a/python/ray/tune/tests/test_trial_runner.py b/python/ray/tune/tests/test_trial_runner.py index c53bd84de..0e75f6f73 100644 --- a/python/ray/tune/tests/test_trial_runner.py +++ b/python/ray/tune/tests/test_trial_runner.py @@ -972,8 +972,8 @@ class RunExperimentTest(unittest.TestCase): "stop": { "training_iteration": 1 }, - "trial_name_creator": tune.function( - lambda t: "{}_{}_321".format(t.trainable_name, t.trial_id)) + "trial_name_creator": + lambda t: "{}_{}_321".format(t.trainable_name, t.trial_id) } }) self.assertEquals( @@ -1113,7 +1113,7 @@ class TestSyncFunctionality(unittest.TestCase): "training_iteration": 1 }, upload_dir=tmpdir2, - sync_to_cloud=tune.function(sync_func)).trials + sync_to_cloud=sync_func).trials test_file_path = glob.glob(os.path.join(tmpdir2, "foo", "*.json")) self.assertTrue(test_file_path) shutil.rmtree(tmpdir) @@ -1134,7 +1134,7 @@ class TestSyncFunctionality(unittest.TestCase): stop={ "training_iteration": 1 }, - sync_to_driver=tune.function(sync_func_driver)).trials + sync_to_driver=sync_func_driver).trials test_file_path = os.path.join(trial.logdir, "test.log2") self.assertFalse(os.path.exists(test_file_path)) @@ -1147,7 +1147,7 @@ class TestSyncFunctionality(unittest.TestCase): stop={ "training_iteration": 1 }, - sync_to_driver=tune.function(sync_func_driver)).trials + sync_to_driver=sync_func_driver).trials test_file_path = os.path.join(trial.logdir, "test.log2") self.assertTrue(os.path.exists(test_file_path)) os.remove(test_file_path) @@ -1166,8 +1166,8 @@ class TestSyncFunctionality(unittest.TestCase): "training_iteration": 1 }, "upload_dir": "test", - "sync_to_driver": tune.function(sync_func), - "sync_to_cloud": tune.function(sync_func) + "sync_to_driver": sync_func, + "sync_to_cloud": sync_func }).trials self.assertEqual(mock_sync.call_count, 0) @@ -2273,11 +2273,9 @@ class TrialRunnerTest(unittest.TestCase): ray.init() trial = Trial( "__fake", - config={ - "callbacks": { - "on_episode_start": tune.function(lambda i: i), - } - }, + config={"callbacks": { + "on_episode_start": lambda i: i, + }}, checkpoint_freq=1) tmpdir = tempfile.mkdtemp() runner = TrialRunner(local_checkpoint_dir=tmpdir, checkpoint_period=0) diff --git a/python/ray/tune/tests/test_tune_server.py b/python/ray/tune/tests/test_tune_server.py index dd9a9134c..24bbef5a9 100644 --- a/python/ray/tune/tests/test_tune_server.py +++ b/python/ray/tune/tests/test_tune_server.py @@ -8,7 +8,6 @@ import subprocess import json import ray -from ray import tune from ray.rllib import _register_all from ray.tune.trial import Trial, Resources from ray.tune.web_server import TuneClient @@ -94,11 +93,9 @@ class TuneServerSuite(unittest.TestCase): "__fake", trial_id="function_trial", stopping_criterion={"training_iteration": 3}, - config={ - "callbacks": { - "on_episode_start": tune.function(lambda x: None) - } - }) + config={"callbacks": { + "on_episode_start": lambda x: None + }}) runner.add_trial(test_trial) for i in range(3): diff --git a/python/ray/tune/trial_runner.py b/python/ray/tune/trial_runner.py index 1910c1b97..94666985c 100644 --- a/python/ray/tune/trial_runner.py +++ b/python/ray/tune/trial_runner.py @@ -11,6 +11,7 @@ import os import re import time import traceback +import types import ray.cloudpickle as cloudpickle from ray.tune import TuneError @@ -19,7 +20,6 @@ from ray.tune.result import (TIME_THIS_ITER_S, RESULT_DUPLICATE, SHOULD_CHECKPOINT) from ray.tune.syncer import get_syncer from ray.tune.trial import Trial, Checkpoint -from ray.tune.sample import function from ray.tune.schedulers import FIFOScheduler, TrialScheduler from ray.tune.suggest import BasicVariantGenerator from ray.tune.util import warn_if_slow, flatten_dict @@ -48,7 +48,7 @@ def _find_newest_ckpt(ckpt_dir): class _TuneFunctionEncoder(json.JSONEncoder): def default(self, obj): - if isinstance(obj, function): + if isinstance(obj, types.FunctionType): return self._to_cloudpickle(obj) try: return super(_TuneFunctionEncoder, self).default(obj) diff --git a/rllib/agents/dqn/dqn.py b/rllib/agents/dqn/dqn.py index cb862c05c..f339b512d 100644 --- a/rllib/agents/dqn/dqn.py +++ b/rllib/agents/dqn/dqn.py @@ -4,7 +4,6 @@ from __future__ import print_function import logging -from ray import tune from ray.rllib.agents.trainer import with_common_config from ray.rllib.agents.trainer_template import build_trainer from ray.rllib.agents.dqn.dqn_policy import DQNTFPolicy @@ -174,8 +173,7 @@ def check_config_and_setup_param_noise(config): if start_callback: start_callback(info) - config["callbacks"]["on_episode_start"] = tune.function( - on_episode_start) + config["callbacks"]["on_episode_start"] = on_episode_start if config["callbacks"]["on_episode_end"]: end_callback = config["callbacks"]["on_episode_end"] else: @@ -191,7 +189,7 @@ def check_config_and_setup_param_noise(config): if end_callback: end_callback(info) - config["callbacks"]["on_episode_end"] = tune.function(on_episode_end) + config["callbacks"]["on_episode_end"] = on_episode_end def get_initial_state(config): diff --git a/rllib/evaluation/rollout_worker.py b/rllib/evaluation/rollout_worker.py index e66e20700..d42fb0d3b 100644 --- a/rllib/evaluation/rollout_worker.py +++ b/rllib/evaluation/rollout_worker.py @@ -259,10 +259,7 @@ class RolloutWorker(EvaluatorInterface): policy_mapping_fn = (policy_mapping_fn or (lambda agent_id: DEFAULT_POLICY_ID)) if not callable(policy_mapping_fn): - raise ValueError( - "Policy mapping function not callable. If you're using Tune, " - "make sure to escape the function with tune.function() " - "to prevent it from being evaluated as an expression.") + raise ValueError("Policy mapping function not callable?") self.env_creator = env_creator self.sample_batch_size = batch_steps * num_envs self.batch_mode = batch_mode diff --git a/rllib/examples/centralized_critic.py b/rllib/examples/centralized_critic.py index b26e11762..d2bebe301 100644 --- a/rllib/examples/centralized_critic.py +++ b/rllib/examples/centralized_critic.py @@ -222,8 +222,7 @@ if __name__ == "__main__": "pol2": (None, TwoStepGame.observation_space, TwoStepGame.action_space, {}), }, - "policy_mapping_fn": tune.function( - lambda x: "pol1" if x == 0 else "pol2"), + "policy_mapping_fn": lambda x: "pol1" if x == 0 else "pol2", }, "model": { "custom_model": "cc_model", diff --git a/rllib/examples/centralized_critic_2.py b/rllib/examples/centralized_critic_2.py index 419b5fd74..1174052e6 100644 --- a/rllib/examples/centralized_critic_2.py +++ b/rllib/examples/centralized_critic_2.py @@ -139,7 +139,7 @@ if __name__ == "__main__": "env": GlobalObsTwoStepGame, "batch_mode": "complete_episodes", "callbacks": { - "on_postprocess_traj": tune.function(fill_in_actions), + "on_postprocess_traj": fill_in_actions, }, "num_workers": 0, "multiagent": { @@ -149,8 +149,7 @@ if __name__ == "__main__": "pol2": (None, GlobalObsTwoStepGame.observation_space, GlobalObsTwoStepGame.action_space, {}), }, - "policy_mapping_fn": tune.function( - lambda x: "pol1" if x == 0 else "pol2"), + "policy_mapping_fn": lambda x: "pol1" if x == 0 else "pol2", }, "model": { "custom_model": "cc_model", diff --git a/rllib/examples/custom_metrics_and_callbacks.py b/rllib/examples/custom_metrics_and_callbacks.py index b8772ad53..99512cab9 100644 --- a/rllib/examples/custom_metrics_and_callbacks.py +++ b/rllib/examples/custom_metrics_and_callbacks.py @@ -67,12 +67,12 @@ if __name__ == "__main__": config={ "env": "CartPole-v0", "callbacks": { - "on_episode_start": tune.function(on_episode_start), - "on_episode_step": tune.function(on_episode_step), - "on_episode_end": tune.function(on_episode_end), - "on_sample_end": tune.function(on_sample_end), - "on_train_result": tune.function(on_train_result), - "on_postprocess_traj": tune.function(on_postprocess_traj), + "on_episode_start": on_episode_start, + "on_episode_step": on_episode_step, + "on_episode_end": on_episode_end, + "on_sample_end": on_sample_end, + "on_train_result": on_train_result, + "on_postprocess_traj": on_postprocess_traj, }, }, return_trials=True) diff --git a/rllib/examples/multiagent_cartpole.py b/rllib/examples/multiagent_cartpole.py index 275c54390..6cc00bdeb 100644 --- a/rllib/examples/multiagent_cartpole.py +++ b/rllib/examples/multiagent_cartpole.py @@ -108,7 +108,7 @@ if __name__ == "__main__": "num_sgd_iter": 10, "multiagent": { "policies": policies, - "policy_mapping_fn": tune.function( + "policy_mapping_fn": ( lambda agent_id: random.choice(policy_ids)), }, }, diff --git a/rllib/examples/multiagent_custom_policy.py b/rllib/examples/multiagent_custom_policy.py index d34d67809..b19dbfb3d 100644 --- a/rllib/examples/multiagent_custom_policy.py +++ b/rllib/examples/multiagent_custom_policy.py @@ -69,7 +69,7 @@ if __name__ == "__main__": "pg_policy": (None, obs_space, act_space, {}), "random": (RandomPolicy, obs_space, act_space, {}), }, - "policy_mapping_fn": tune.function( + "policy_mapping_fn": ( lambda agent_id: ["pg_policy", "random"][agent_id % 2]), }, }, diff --git a/rllib/examples/rock_paper_scissors_multiagent.py b/rllib/examples/rock_paper_scissors_multiagent.py index 4a9feb1b7..e340e3612 100644 --- a/rllib/examples/rock_paper_scissors_multiagent.py +++ b/rllib/examples/rock_paper_scissors_multiagent.py @@ -185,7 +185,7 @@ def run_heuristic_vs_learned(use_lstm=False, trainer="PG"): } }), }, - "policy_mapping_fn": tune.function(select_policy), + "policy_mapping_fn": select_policy, }, }) diff --git a/rllib/examples/twostep_game.py b/rllib/examples/twostep_game.py index f47f64730..3ddc0fd04 100644 --- a/rllib/examples/twostep_game.py +++ b/rllib/examples/twostep_game.py @@ -129,8 +129,7 @@ if __name__ == "__main__": "agent_id": 1, }), }, - "policy_mapping_fn": tune.function( - lambda x: "pol1" if x == 0 else "pol2"), + "policy_mapping_fn": lambda x: "pol1" if x == 0 else "pol2", }, } group = False