mirror of
https://github.com/wassname/ray.git
synced 2026-06-30 09:41:11 +08:00
[tune] Support lambda functions in hyperparameters / tune rllib multiagent support (#2568)
* update * func * Update registry.py * revert
This commit is contained in:
@@ -17,10 +17,10 @@ import gym
|
||||
import random
|
||||
|
||||
import ray
|
||||
from ray.rllib.agents.pg.pg import PGAgent
|
||||
from ray import tune
|
||||
from ray.rllib.agents.pg.pg_policy_graph import PGPolicyGraph
|
||||
from ray.rllib.test.test_multi_agent_env import MultiCartpole
|
||||
from ray.tune.logger import pretty_print
|
||||
from ray.tune import run_experiments
|
||||
from ray.tune.registry import register_env
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
@@ -53,16 +53,19 @@ if __name__ == "__main__":
|
||||
}
|
||||
policy_ids = list(policy_graphs.keys())
|
||||
|
||||
agent = PGAgent(
|
||||
env="multi_cartpole",
|
||||
config={
|
||||
"multiagent": {
|
||||
"policy_graphs": policy_graphs,
|
||||
"policy_mapping_fn": (
|
||||
lambda agent_id: random.choice(policy_ids)),
|
||||
run_experiments({
|
||||
"test": {
|
||||
"run": "PG",
|
||||
"env": "multi_cartpole",
|
||||
"stop": {
|
||||
"training_iteration": args.num_iters
|
||||
},
|
||||
})
|
||||
|
||||
for i in range(args.num_iters):
|
||||
print("== Iteration", i, "==")
|
||||
print(pretty_print(agent.train()))
|
||||
"config": {
|
||||
"multiagent": {
|
||||
"policy_graphs": policy_graphs,
|
||||
"policy_mapping_fn": tune.function(
|
||||
lambda agent_id: random.choice(policy_ids)),
|
||||
},
|
||||
},
|
||||
}
|
||||
})
|
||||
|
||||
@@ -7,9 +7,9 @@ from ray.tune.tune import run_experiments
|
||||
from ray.tune.experiment import Experiment
|
||||
from ray.tune.registry import register_env, register_trainable
|
||||
from ray.tune.trainable import Trainable
|
||||
from ray.tune.suggest import grid_search
|
||||
from ray.tune.suggest import grid_search, function
|
||||
|
||||
__all__ = [
|
||||
"Trainable", "TuneError", "grid_search", "register_env",
|
||||
"register_trainable", "run_experiments", "Experiment"
|
||||
"register_trainable", "run_experiments", "Experiment", "function"
|
||||
]
|
||||
|
||||
@@ -2,9 +2,9 @@ from ray.tune.suggest.search import SearchAlgorithm
|
||||
from ray.tune.suggest.basic_variant import BasicVariantGenerator
|
||||
from ray.tune.suggest.suggestion import SuggestionAlgorithm
|
||||
from ray.tune.suggest.hyperopt import HyperOptSearch
|
||||
from ray.tune.suggest.variant_generator import grid_search
|
||||
from ray.tune.suggest.variant_generator import grid_search, function
|
||||
|
||||
__all__ = [
|
||||
"SearchAlgorithm", "BasicVariantGenerator", "HyperOptSearch",
|
||||
"SuggestionAlgorithm", "grid_search"
|
||||
"SuggestionAlgorithm", "grid_search", "function"
|
||||
]
|
||||
|
||||
@@ -51,6 +51,16 @@ def grid_search(values):
|
||||
return {"grid_search": values}
|
||||
|
||||
|
||||
class function(object):
|
||||
"""Wraps `func` to make sure it is not expanded during resolution."""
|
||||
|
||||
def __init__(self, func):
|
||||
self.func = func
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
return self.func(*args, **kwargs)
|
||||
|
||||
|
||||
_STANDARD_IMPORTS = {
|
||||
"random": random,
|
||||
"np": numpy,
|
||||
|
||||
Reference in New Issue
Block a user