[tune] Support lambda functions in hyperparameters / tune rllib multiagent support (#2568)

* update

* func

* Update registry.py

* revert
This commit is contained in:
Eric Liang
2018-08-07 16:29:21 -07:00
committed by GitHub
parent e7f76d7914
commit 64053278aa
5 changed files with 40 additions and 23 deletions
@@ -17,10 +17,10 @@ import gym
import random
import ray
from ray.rllib.agents.pg.pg import PGAgent
from ray import tune
from ray.rllib.agents.pg.pg_policy_graph import PGPolicyGraph
from ray.rllib.test.test_multi_agent_env import MultiCartpole
from ray.tune.logger import pretty_print
from ray.tune import run_experiments
from ray.tune.registry import register_env
parser = argparse.ArgumentParser()
@@ -53,16 +53,19 @@ if __name__ == "__main__":
}
policy_ids = list(policy_graphs.keys())
agent = PGAgent(
env="multi_cartpole",
config={
"multiagent": {
"policy_graphs": policy_graphs,
"policy_mapping_fn": (
lambda agent_id: random.choice(policy_ids)),
run_experiments({
"test": {
"run": "PG",
"env": "multi_cartpole",
"stop": {
"training_iteration": args.num_iters
},
})
for i in range(args.num_iters):
print("== Iteration", i, "==")
print(pretty_print(agent.train()))
"config": {
"multiagent": {
"policy_graphs": policy_graphs,
"policy_mapping_fn": tune.function(
lambda agent_id: random.choice(policy_ids)),
},
},
}
})
+2 -2
View File
@@ -7,9 +7,9 @@ from ray.tune.tune import run_experiments
from ray.tune.experiment import Experiment
from ray.tune.registry import register_env, register_trainable
from ray.tune.trainable import Trainable
from ray.tune.suggest import grid_search
from ray.tune.suggest import grid_search, function
__all__ = [
"Trainable", "TuneError", "grid_search", "register_env",
"register_trainable", "run_experiments", "Experiment"
"register_trainable", "run_experiments", "Experiment", "function"
]
+2 -2
View File
@@ -2,9 +2,9 @@ from ray.tune.suggest.search import SearchAlgorithm
from ray.tune.suggest.basic_variant import BasicVariantGenerator
from ray.tune.suggest.suggestion import SuggestionAlgorithm
from ray.tune.suggest.hyperopt import HyperOptSearch
from ray.tune.suggest.variant_generator import grid_search
from ray.tune.suggest.variant_generator import grid_search, function
__all__ = [
"SearchAlgorithm", "BasicVariantGenerator", "HyperOptSearch",
"SuggestionAlgorithm", "grid_search"
"SuggestionAlgorithm", "grid_search", "function"
]
@@ -51,6 +51,16 @@ def grid_search(values):
return {"grid_search": values}
class function(object):
"""Wraps `func` to make sure it is not expanded during resolution."""
def __init__(self, func):
self.func = func
def __call__(self, *args, **kwargs):
return self.func(*args, **kwargs)
_STANDARD_IMPORTS = {
"random": random,
"np": numpy,