mirror of
https://github.com/wassname/ray.git
synced 2026-06-28 14:48:54 +08:00
[tune] Improve PBT example (#4575)
This commit is contained in:
@@ -4,57 +4,83 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import numpy as np
|
||||
import argparse
|
||||
import json
|
||||
import os
|
||||
import random
|
||||
import time
|
||||
|
||||
import ray
|
||||
from ray.tune import Trainable, run
|
||||
from ray.tune.schedulers import PopulationBasedTraining
|
||||
|
||||
|
||||
class MyTrainableClass(Trainable):
|
||||
"""Fake agent whose learning rate is determined by dummy factors."""
|
||||
class PBTBenchmarkExample(Trainable):
|
||||
"""Toy PBT problem for benchmarking adaptive learning rate.
|
||||
|
||||
The goal is to optimize this trainable's accuracy. The accuracy increases
|
||||
fastest at the optimal lr, which is a function of the current accuracy.
|
||||
|
||||
The optimal lr schedule for this problem is the triangle wave as follows.
|
||||
Note that many lr schedules for real models also follow this shape:
|
||||
|
||||
best lr
|
||||
^
|
||||
| /\
|
||||
| / \
|
||||
| / \
|
||||
| / \
|
||||
------------> accuracy
|
||||
|
||||
In this problem, using PBT with a population of 2-4 is sufficient to
|
||||
roughly approximate this lr schedule. Higher population sizes will yield
|
||||
faster convergence. Training will not converge without PBT.
|
||||
"""
|
||||
|
||||
def _setup(self, config):
|
||||
self.timestep = 0
|
||||
self.current_value = 0.0
|
||||
self.lr = config["lr"]
|
||||
self.accuracy = 0.0 # end = 1000
|
||||
|
||||
def _train(self):
|
||||
time.sleep(0.1)
|
||||
midpoint = 100 # lr starts decreasing after acc > midpoint
|
||||
q_tolerance = 3 # penalize exceeding lr by more than this multiple
|
||||
noise_level = 2 # add gaussian noise to the acc increase
|
||||
# triangle wave:
|
||||
# - start at 0.001 @ t=0,
|
||||
# - peak at 0.01 @ t=midpoint,
|
||||
# - end at 0.001 @ t=midpoint * 2,
|
||||
if self.accuracy < midpoint:
|
||||
optimal_lr = 0.01 * self.accuracy / midpoint
|
||||
else:
|
||||
optimal_lr = 0.01 - 0.01 * (self.accuracy - midpoint) / midpoint
|
||||
optimal_lr = min(0.01, max(0.001, optimal_lr))
|
||||
|
||||
# Reward increase is parabolic as a function of factor_2, with a
|
||||
# maxima around factor_1=10.0.
|
||||
self.current_value += max(
|
||||
0.0, random.gauss(5.0 - (self.config["factor_1"] - 10.0)**2, 2.0))
|
||||
# compute accuracy increase
|
||||
q_err = max(self.lr, optimal_lr) / min(self.lr, optimal_lr)
|
||||
if q_err < q_tolerance:
|
||||
self.accuracy += (1.0 / q_err) * random.random()
|
||||
elif self.lr > optimal_lr:
|
||||
self.accuracy -= (q_err - q_tolerance) * random.random()
|
||||
self.accuracy += noise_level * np.random.normal()
|
||||
self.accuracy = max(0, self.accuracy)
|
||||
|
||||
# Flat increase by factor_2
|
||||
self.current_value += random.gauss(self.config["factor_2"], 1.0)
|
||||
|
||||
# Here we use `episode_reward_mean`, but you can also report other
|
||||
# objectives such as loss or accuracy.
|
||||
return {"episode_reward_mean": self.current_value}
|
||||
return {
|
||||
"mean_accuracy": self.accuracy,
|
||||
"cur_lr": self.lr,
|
||||
"optimal_lr": optimal_lr, # for debugging
|
||||
"q_err": q_err, # for debugging
|
||||
"done": self.accuracy > midpoint * 2,
|
||||
}
|
||||
|
||||
def _save(self, checkpoint_dir):
|
||||
path = os.path.join(checkpoint_dir, "checkpoint")
|
||||
with open(path, "w") as f:
|
||||
f.write(
|
||||
json.dumps({
|
||||
"timestep": self.timestep,
|
||||
"value": self.current_value
|
||||
}))
|
||||
return path
|
||||
return {
|
||||
"accuracy": self.accuracy,
|
||||
"lr": self.lr,
|
||||
}
|
||||
|
||||
def _restore(self, checkpoint_path):
|
||||
with open(checkpoint_path) as f:
|
||||
data = json.loads(f.read())
|
||||
self.timestep = data["timestep"]
|
||||
self.current_value = data["value"]
|
||||
def _restore(self, checkpoint):
|
||||
self.accuracy = checkpoint["accuracy"]
|
||||
|
||||
def reset_config(self, new_config):
|
||||
self.config = new_config
|
||||
self.lr = new_config["lr"]
|
||||
return True
|
||||
|
||||
|
||||
@@ -64,35 +90,36 @@ if __name__ == "__main__":
|
||||
"--smoke-test", action="store_true", help="Finish quickly for testing")
|
||||
args, _ = parser.parse_known_args()
|
||||
if args.smoke_test:
|
||||
ray.init(num_cpus=4) # force pausing to happen for test
|
||||
ray.init(num_cpus=2) # force pausing to happen for test
|
||||
else:
|
||||
ray.init()
|
||||
|
||||
pbt = PopulationBasedTraining(
|
||||
time_attr="training_iteration",
|
||||
reward_attr="episode_reward_mean",
|
||||
perturbation_interval=10,
|
||||
reward_attr="mean_accuracy",
|
||||
perturbation_interval=20,
|
||||
hyperparam_mutations={
|
||||
# Allow for scaling-based perturbations, with a uniform backing
|
||||
# distribution for resampling.
|
||||
"factor_1": lambda: random.uniform(0.0, 20.0),
|
||||
# Allow perturbations within this set of categorical values.
|
||||
"factor_2": [1, 2],
|
||||
# distribution for resampling
|
||||
"lr": lambda: random.uniform(0.0001, 0.02),
|
||||
# allow perturbations within this set of categorical values
|
||||
"some_other_factor": [1, 2],
|
||||
})
|
||||
|
||||
# Try to find the best factor 1 and factor 2
|
||||
run(MyTrainableClass,
|
||||
run(
|
||||
PBTBenchmarkExample,
|
||||
name="pbt_test",
|
||||
scheduler=pbt,
|
||||
reuse_actors=True,
|
||||
verbose=False,
|
||||
**{
|
||||
"stop": {
|
||||
"training_iteration": 20 if args.smoke_test else 99999
|
||||
"training_iteration": 2000,
|
||||
},
|
||||
"num_samples": 10,
|
||||
"num_samples": 4,
|
||||
"config": {
|
||||
"factor_1": 4.0,
|
||||
"factor_2": 1.0,
|
||||
"lr": 0.0001,
|
||||
# note: this parameter is perturbed but has no effect on
|
||||
# the model training in this example
|
||||
"some_other_factor": 1,
|
||||
},
|
||||
})
|
||||
|
||||
Reference in New Issue
Block a user