diff --git a/doc/source/tune-config.rst b/doc/source/tune-config.rst index 3dd7da4b0..2dba809a9 100644 --- a/doc/source/tune-config.rst +++ b/doc/source/tune-config.rst @@ -34,8 +34,8 @@ dictionary. "trial_resources": { "cpu": 1, "gpu": 0 }, "stop": { "mean_accuracy": 100 }, "config": { - "alpha": grid_search([0.2, 0.4, 0.6]), - "beta": grid_search([1, 2]), + "alpha": tune.grid_search([0.2, 0.4, 0.6]), + "beta": tune.grid_search([1, 2]), }, "upload_dir": "s3://your_bucket/path", "local_dir": "~/ray_results", @@ -49,7 +49,7 @@ An example of this can be found in `async_hyperband_example.py `__. Resource Allocation diff --git a/python/ray/rllib/examples/multiagent_cartpole.py b/python/ray/rllib/examples/multiagent_cartpole.py index 767bf84aa..8faeb184b 100644 --- a/python/ray/rllib/examples/multiagent_cartpole.py +++ b/python/ray/rllib/examples/multiagent_cartpole.py @@ -17,10 +17,10 @@ import gym import random import ray -from ray.rllib.agents.pg.pg import PGAgent +from ray import tune from ray.rllib.agents.pg.pg_policy_graph import PGPolicyGraph from ray.rllib.test.test_multi_agent_env import MultiCartpole -from ray.tune.logger import pretty_print +from ray.tune import run_experiments from ray.tune.registry import register_env parser = argparse.ArgumentParser() @@ -53,16 +53,19 @@ if __name__ == "__main__": } policy_ids = list(policy_graphs.keys()) - agent = PGAgent( - env="multi_cartpole", - config={ - "multiagent": { - "policy_graphs": policy_graphs, - "policy_mapping_fn": ( - lambda agent_id: random.choice(policy_ids)), + run_experiments({ + "test": { + "run": "PG", + "env": "multi_cartpole", + "stop": { + "training_iteration": args.num_iters }, - }) - - for i in range(args.num_iters): - print("== Iteration", i, "==") - print(pretty_print(agent.train())) + "config": { + "multiagent": { + "policy_graphs": policy_graphs, + "policy_mapping_fn": tune.function( + lambda agent_id: random.choice(policy_ids)), + }, + }, + } + }) diff --git a/python/ray/tune/__init__.py b/python/ray/tune/__init__.py index 535e728bc..83d4f4fde 100644 --- a/python/ray/tune/__init__.py +++ b/python/ray/tune/__init__.py @@ -7,9 +7,9 @@ from ray.tune.tune import run_experiments from ray.tune.experiment import Experiment from ray.tune.registry import register_env, register_trainable from ray.tune.trainable import Trainable -from ray.tune.suggest import grid_search +from ray.tune.suggest import grid_search, function __all__ = [ "Trainable", "TuneError", "grid_search", "register_env", - "register_trainable", "run_experiments", "Experiment" + "register_trainable", "run_experiments", "Experiment", "function" ] diff --git a/python/ray/tune/suggest/__init__.py b/python/ray/tune/suggest/__init__.py index f237a7f33..f0146ca5e 100644 --- a/python/ray/tune/suggest/__init__.py +++ b/python/ray/tune/suggest/__init__.py @@ -2,9 +2,9 @@ from ray.tune.suggest.search import SearchAlgorithm from ray.tune.suggest.basic_variant import BasicVariantGenerator from ray.tune.suggest.suggestion import SuggestionAlgorithm from ray.tune.suggest.hyperopt import HyperOptSearch -from ray.tune.suggest.variant_generator import grid_search +from ray.tune.suggest.variant_generator import grid_search, function __all__ = [ "SearchAlgorithm", "BasicVariantGenerator", "HyperOptSearch", - "SuggestionAlgorithm", "grid_search" + "SuggestionAlgorithm", "grid_search", "function" ] diff --git a/python/ray/tune/suggest/variant_generator.py b/python/ray/tune/suggest/variant_generator.py index bbd7f9f36..866f7d262 100644 --- a/python/ray/tune/suggest/variant_generator.py +++ b/python/ray/tune/suggest/variant_generator.py @@ -51,6 +51,16 @@ def grid_search(values): return {"grid_search": values} +class function(object): + """Wraps `func` to make sure it is not expanded during resolution.""" + + def __init__(self, func): + self.func = func + + def __call__(self, *args, **kwargs): + return self.func(*args, **kwargs) + + _STANDARD_IMPORTS = { "random": random, "np": numpy,