mirror of
https://github.com/wassname/ray.git
synced 2026-06-27 23:08:32 +08:00
@@ -1,6 +1,5 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import os
|
||||
|
||||
@@ -11,14 +10,6 @@ from ray.tune import Trainable, run
|
||||
from ray.tune.schedulers.hb_bohb import HyperBandForBOHB
|
||||
from ray.tune.suggest.bohb import TuneBOHB
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--smoke-test", action="store_true", help="Finish quickly for testing")
|
||||
parser.add_argument(
|
||||
"--ray-address",
|
||||
help="Address of Ray cluster for seamless distributed execution.")
|
||||
args, _ = parser.parse_known_args()
|
||||
|
||||
|
||||
class MyTrainableClass(Trainable):
|
||||
"""Example agent whose learning curve is a random sigmoid.
|
||||
@@ -52,7 +43,7 @@ class MyTrainableClass(Trainable):
|
||||
|
||||
if __name__ == "__main__":
|
||||
import ConfigSpace as CS
|
||||
ray.init(address=args.ray_address)
|
||||
ray.init(num_cpus=8)
|
||||
|
||||
# BOHB uses ConfigSpace for their hyperparameter search space
|
||||
config_space = CS.ConfigurationSpace()
|
||||
@@ -75,4 +66,4 @@ if __name__ == "__main__":
|
||||
scheduler=bohb_hyperband,
|
||||
search_alg=bohb_search,
|
||||
num_samples=10,
|
||||
stop={"training_iteration": 10 if args.smoke_test else 100})
|
||||
stop={"training_iteration": 100})
|
||||
|
||||
@@ -47,7 +47,7 @@ if __name__ == "__main__":
|
||||
parser.add_argument(
|
||||
"--smoke-test", action="store_true", help="Finish quickly for testing")
|
||||
args, _ = parser.parse_known_args()
|
||||
ray.init()
|
||||
ray.init(num_cpus=4 if args.smoke_test else None)
|
||||
|
||||
# Hyperband early stopping, configured with `episode_reward_mean` as the
|
||||
# objective and `training_iteration` as the time unit,
|
||||
@@ -56,7 +56,7 @@ if __name__ == "__main__":
|
||||
time_attr="training_iteration",
|
||||
metric="episode_reward_mean",
|
||||
mode="max",
|
||||
max_t=100)
|
||||
max_t=200)
|
||||
|
||||
run(MyTrainableClass,
|
||||
name="hyperband_test",
|
||||
|
||||
@@ -110,9 +110,9 @@ if __name__ == "__main__":
|
||||
reuse_actors=True,
|
||||
verbose=False,
|
||||
stop={
|
||||
"training_iteration": 2000,
|
||||
"training_iteration": 200,
|
||||
},
|
||||
num_samples=4,
|
||||
num_samples=8,
|
||||
config={
|
||||
"lr": 0.0001,
|
||||
# note: this parameter is perturbed but has no effect on
|
||||
|
||||
Reference in New Issue
Block a user