mirror of
https://github.com/wassname/ray.git
synced 2026-06-28 21:46:57 +08:00
[Tune] Pbt Function API (#9958)
* adding function convnet example * add unit test * update test * update example * wip * move error from experiment to tune * wip * Fix checkpoint deletion * updating code * adding smoke test * updating pbt guide * formatting * fix build * add best checkpoint analysis util * update test * add comments * remove class api * fix example * add setup and teardown to tests * formatting * Update python/ray/tune/tests/test_trial_scheduler_pbt.py Co-authored-by: Kai Fricke <kai@anyscale.com> Co-authored-by: Richard Liaw <rliaw@berkeley.edu>
This commit is contained in:
@@ -446,6 +446,15 @@ py_test(
|
||||
args = ["--smoke-test"]
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "pbt_convnet_function_example",
|
||||
size = "medium",
|
||||
srcs = ["examples/pbt_convnet_function_example.py"],
|
||||
deps = [":tune_lib"],
|
||||
tags = ["exclusive", "example"],
|
||||
args = ["--smoke-test"]
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "pbt_example",
|
||||
size = "medium",
|
||||
|
||||
@@ -167,6 +167,22 @@ class Analysis:
|
||||
else:
|
||||
raise ValueError("trial should be a string or a Trial instance.")
|
||||
|
||||
def get_best_checkpoint(self, trial, metric=TRAINING_ITERATION):
|
||||
"""Gets best persistent checkpoint path of provided trial.
|
||||
|
||||
Args:
|
||||
trial (Trial): The log directory of a trial, or a trial instance.
|
||||
metric (str): key of trial info to return, e.g. "mean_accuracy".
|
||||
"training_iteration" is used by default.
|
||||
|
||||
Returns:
|
||||
Path for best checkpoint of trial determined by metric
|
||||
"""
|
||||
|
||||
return max(
|
||||
self.get_trial_checkpoints_paths(trial, metric),
|
||||
key=lambda x: x[1])[0]
|
||||
|
||||
def _retrieve_rows(self, metric=None, mode=None):
|
||||
assert mode is None or mode in ["max", "min"]
|
||||
rows = {}
|
||||
|
||||
@@ -137,7 +137,7 @@ class CheckpointManager:
|
||||
self._membership.remove(worst)
|
||||
# Don't delete the newest checkpoint. It will be deleted on the
|
||||
# next on_checkpoint() call since it isn't in self._membership.
|
||||
if worst != checkpoint:
|
||||
if worst.value != checkpoint.value:
|
||||
self.delete(worst)
|
||||
|
||||
def best_checkpoints(self):
|
||||
|
||||
@@ -132,10 +132,9 @@ if __name__ == "__main__":
|
||||
# __tune_end__
|
||||
|
||||
best_trial = analysis.get_best_trial("mean_accuracy")
|
||||
best_checkpoint = max(
|
||||
analysis.get_trial_checkpoints_paths(best_trial, "mean_accuracy"))
|
||||
best_checkpoint = analysis.get_best_checkpoint(best_trial, metric="mean_accuracy")
|
||||
restored_trainable = PytorchTrainable()
|
||||
restored_trainable.restore(best_checkpoint[0])
|
||||
restored_trainable.restore(best_checkpoint)
|
||||
best_model = restored_trainable.model
|
||||
# Note that test only runs on a small random set of the test data, thus the
|
||||
# accuracy may be different from metrics shown in tuning process.
|
||||
|
||||
@@ -0,0 +1,129 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# __tutorial_imports_begin__
|
||||
import argparse
|
||||
import os
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.optim as optim
|
||||
from torchvision import datasets
|
||||
from ray.tune.examples.mnist_pytorch import train, test, ConvNet,\
|
||||
get_data_loaders
|
||||
|
||||
import ray
|
||||
from ray import tune
|
||||
from ray.tune.schedulers import PopulationBasedTraining
|
||||
from ray.tune.trial import ExportFormat
|
||||
|
||||
# __tutorial_imports_end__
|
||||
|
||||
|
||||
# __train_begin__
|
||||
def train_convnet(config, checkpoint_dir=None):
|
||||
# Create our data loaders, model, and optmizer.
|
||||
step = 0
|
||||
train_loader, test_loader = get_data_loaders()
|
||||
model = ConvNet()
|
||||
optimizer = optim.SGD(
|
||||
model.parameters(),
|
||||
lr=config.get("lr", 0.01),
|
||||
momentum=config.get("momentum", 0.9))
|
||||
|
||||
# If checkpoint_dir is not None, then we are resuming from a checkpoint.
|
||||
# Load model state and iteration step from checkpoint.
|
||||
if checkpoint_dir:
|
||||
print("Loading from checkpoint.")
|
||||
path = os.path.join(checkpoint_dir, "checkpoint")
|
||||
checkpoint = torch.load(path)
|
||||
model.load_state_dict(checkpoint["model_state_dict"])
|
||||
step = checkpoint["step"]
|
||||
|
||||
while True:
|
||||
train(model, optimizer, train_loader)
|
||||
acc = test(model, test_loader)
|
||||
if step % 5 == 0:
|
||||
# Every 5 steps, checkpoint our current state.
|
||||
# First get the checkpoint directory from tune.
|
||||
with tune.checkpoint_dir(step=step) as checkpoint_dir:
|
||||
# Then create a checkpoint file in this directory.
|
||||
path = os.path.join(checkpoint_dir, "checkpoint")
|
||||
# Save state to checkpoint file.
|
||||
# No need to save optimizer for SGD.
|
||||
torch.save({
|
||||
"step": step,
|
||||
"model_state_dict": model.state_dict(),
|
||||
"mean_accuracy": acc
|
||||
}, path)
|
||||
step += 1
|
||||
tune.report(mean_accuracy=acc)
|
||||
|
||||
|
||||
# __train_end__
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--smoke-test", action="store_true", help="Finish quickly for testing")
|
||||
args, _ = parser.parse_known_args()
|
||||
|
||||
ray.init()
|
||||
datasets.MNIST("~/data", train=True, download=True)
|
||||
|
||||
# __pbt_begin__
|
||||
scheduler = PopulationBasedTraining(
|
||||
time_attr="training_iteration",
|
||||
metric="mean_accuracy",
|
||||
mode="max",
|
||||
perturbation_interval=5,
|
||||
hyperparam_mutations={
|
||||
# distribution for resampling
|
||||
"lr": lambda: np.random.uniform(0.0001, 1),
|
||||
# allow perturbations within this set of categorical values
|
||||
"momentum": [0.8, 0.9, 0.99],
|
||||
})
|
||||
|
||||
# __pbt_end__
|
||||
|
||||
# __tune_begin__
|
||||
class CustomStopper(tune.Stopper):
|
||||
def __init__(self):
|
||||
self.should_stop = False
|
||||
|
||||
def __call__(self, trial_id, result):
|
||||
max_iter = 5 if args.smoke_test else 100
|
||||
if not self.should_stop and result["mean_accuracy"] > 0.96:
|
||||
self.should_stop = True
|
||||
return self.should_stop or result["training_iteration"] >= max_iter
|
||||
|
||||
def stop_all(self):
|
||||
return self.should_stop
|
||||
|
||||
stopper = CustomStopper()
|
||||
|
||||
analysis = tune.run(
|
||||
train_convnet,
|
||||
name="pbt_test",
|
||||
scheduler=scheduler,
|
||||
verbose=1,
|
||||
stop=stopper,
|
||||
export_formats=[ExportFormat.MODEL],
|
||||
checkpoint_score_attr="mean_accuracy",
|
||||
keep_checkpoints_num=4,
|
||||
num_samples=4,
|
||||
config={
|
||||
"lr": tune.uniform(0.001, 1),
|
||||
"momentum": tune.uniform(0.001, 1),
|
||||
})
|
||||
# __tune_end__
|
||||
|
||||
best_trial = analysis.get_best_trial("mean_accuracy")
|
||||
best_checkpoint_path = analysis.get_best_checkpoint(
|
||||
best_trial, metric="mean_accuracy")
|
||||
best_model = ConvNet()
|
||||
best_checkpoint = torch.load(
|
||||
os.path.join(best_checkpoint_path, "checkpoint"))
|
||||
best_model.load_state_dict(best_checkpoint["model_state_dict"])
|
||||
# Note that test only runs on a small random set of the test data, thus the
|
||||
# accuracy may be different from metrics shown in tuning process.
|
||||
test_acc = test(best_model, get_data_loaders()[1])
|
||||
print("best model accuracy: ", test_acc)
|
||||
@@ -1,6 +1,8 @@
|
||||
# coding: utf-8
|
||||
import os
|
||||
import random
|
||||
import sys
|
||||
import tempfile
|
||||
import unittest
|
||||
from unittest.mock import patch
|
||||
|
||||
@@ -128,6 +130,35 @@ class CheckpointManagerTest(unittest.TestCase):
|
||||
self.assertEqual(newest, checkpoints[1])
|
||||
self.assertEqual(checkpoint_manager.best_checkpoints(), [])
|
||||
|
||||
def testSameCheckpoint(self):
|
||||
checkpoint_manager = CheckpointManager(
|
||||
1, "i", delete_fn=lambda c: os.remove(c.value))
|
||||
|
||||
tmpfiles = []
|
||||
for i in range(3):
|
||||
tmpfile = tempfile.mktemp()
|
||||
with open(tmpfile, "wt") as fp:
|
||||
fp.write("")
|
||||
tmpfiles.append(tmpfile)
|
||||
|
||||
checkpoints = [
|
||||
Checkpoint(Checkpoint.PERSISTENT, tmpfiles[0],
|
||||
self.mock_result(5)),
|
||||
Checkpoint(Checkpoint.PERSISTENT, tmpfiles[1],
|
||||
self.mock_result(10)),
|
||||
Checkpoint(Checkpoint.PERSISTENT, tmpfiles[2],
|
||||
self.mock_result(0)),
|
||||
Checkpoint(Checkpoint.PERSISTENT, tmpfiles[1],
|
||||
self.mock_result(20))
|
||||
]
|
||||
for checkpoint in checkpoints:
|
||||
checkpoint_manager.on_checkpoint(checkpoint)
|
||||
self.assertTrue(os.path.exists(checkpoint.value))
|
||||
|
||||
for tmpfile in tmpfiles:
|
||||
if os.path.exists(tmpfile):
|
||||
os.remove(tmpfile)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import pytest
|
||||
|
||||
@@ -126,6 +126,13 @@ class ExperimentAnalysisSuite(unittest.TestCase):
|
||||
assert paths[0][0] == expected_path
|
||||
assert paths[0][1] == best_trial.metric_analysis[self.metric]["last"]
|
||||
|
||||
def testGetBestCheckpoint(self):
|
||||
best_trial = self.ea.get_best_trial(self.metric)
|
||||
checkpoints_metrics = self.ea.get_trial_checkpoints_paths(best_trial)
|
||||
expected_path = max(checkpoints_metrics, key=lambda x: x[1])[0]
|
||||
best_checkpoint = self.ea.get_best_checkpoint(best_trial, self.metric)
|
||||
assert expected_path == best_checkpoint
|
||||
|
||||
def testAllDataframes(self):
|
||||
dataframes = self.ea.trial_dataframes
|
||||
self.assertTrue(len(dataframes) == self.num_samples)
|
||||
|
||||
@@ -5,6 +5,7 @@ import random
|
||||
import unittest
|
||||
import sys
|
||||
|
||||
import ray
|
||||
from ray import tune
|
||||
from ray.tune.schedulers import PopulationBasedTraining
|
||||
|
||||
@@ -23,13 +24,32 @@ class MockTrainable(tune.Trainable):
|
||||
def save_checkpoint(self, tmp_checkpoint_dir):
|
||||
checkpoint_path = os.path.join(tmp_checkpoint_dir, "model.mock")
|
||||
with open(checkpoint_path, "wb") as fp:
|
||||
pickle.dump((self.a, self.b), fp)
|
||||
pickle.dump((self.a, self.b, self.iter), fp)
|
||||
return tmp_checkpoint_dir
|
||||
|
||||
def load_checkpoint(self, tmp_checkpoint_dir):
|
||||
checkpoint_path = os.path.join(tmp_checkpoint_dir, "model.mock")
|
||||
with open(checkpoint_path, "rb") as fp:
|
||||
self.a, self.b = pickle.load(fp)
|
||||
self.a, self.b, self.iter = pickle.load(fp)
|
||||
|
||||
|
||||
def MockTrainingFunc(config, checkpoint_dir=None):
|
||||
iter = 0
|
||||
a = config["a"]
|
||||
b = config["b"]
|
||||
|
||||
if checkpoint_dir:
|
||||
checkpoint_path = os.path.join(checkpoint_dir, "model.mock")
|
||||
with open(checkpoint_path, "rb") as fp:
|
||||
a, b, iter = pickle.load(fp)
|
||||
|
||||
while True:
|
||||
iter += 1
|
||||
with tune.checkpoint_dir(step=iter) as checkpoint_dir:
|
||||
checkpoint_path = os.path.join(checkpoint_dir, "model.mock")
|
||||
with open(checkpoint_path, "wb") as fp:
|
||||
pickle.dump((a, b, iter), fp)
|
||||
tune.report(mean_accuracy=(a - iter) * b)
|
||||
|
||||
|
||||
class MockParam(object):
|
||||
@@ -44,6 +64,12 @@ class MockParam(object):
|
||||
|
||||
|
||||
class PopulationBasedTrainingResumeTest(unittest.TestCase):
|
||||
def setUp(self):
|
||||
ray.init()
|
||||
|
||||
def tearDown(self):
|
||||
ray.shutdown()
|
||||
|
||||
def testPermutationContinuation(self):
|
||||
"""
|
||||
Tests continuation of runs after permutation.
|
||||
@@ -74,7 +100,6 @@ class PopulationBasedTrainingResumeTest(unittest.TestCase):
|
||||
},
|
||||
fail_fast=True,
|
||||
num_samples=20,
|
||||
global_checkpoint_period=1,
|
||||
checkpoint_freq=1,
|
||||
checkpoint_at_end=True,
|
||||
keep_checkpoints_num=1,
|
||||
@@ -83,6 +108,33 @@ class PopulationBasedTrainingResumeTest(unittest.TestCase):
|
||||
name="testPermutationContinuation",
|
||||
stop={"training_iteration": 5})
|
||||
|
||||
def testPermutationContinuationFunc(self):
|
||||
scheduler = PopulationBasedTraining(
|
||||
time_attr="training_iteration",
|
||||
metric="mean_accuracy",
|
||||
mode="max",
|
||||
perturbation_interval=1,
|
||||
log_config=True,
|
||||
hyperparam_mutations={"c": lambda: 1})
|
||||
param_a = MockParam([10, 20, 30, 40])
|
||||
param_b = MockParam([1.2, 0.9, 1.1, 0.8])
|
||||
random.seed(100)
|
||||
np.random.seed(1000)
|
||||
tune.run(
|
||||
MockTrainingFunc,
|
||||
config={
|
||||
"a": tune.sample_from(lambda _: param_a()),
|
||||
"b": tune.sample_from(lambda _: param_b()),
|
||||
"c": 1
|
||||
},
|
||||
fail_fast=True,
|
||||
num_samples=4,
|
||||
keep_checkpoints_num=1,
|
||||
checkpoint_score_attr="min-training_iteration",
|
||||
scheduler=scheduler,
|
||||
name="testPermutationContinuationFunc",
|
||||
stop={"training_iteration": 3})
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import pytest
|
||||
|
||||
Reference in New Issue
Block a user