[Tune] Pbt Function API (#9958)

* adding function convnet example

* add unit test

* update test

* update example

* wip

* move error from experiment to tune

* wip

* Fix checkpoint deletion

* updating code

* adding smoke test

* updating pbt guide

* formatting

* fix build

* add best checkpoint analysis util

* update test

* add comments

* remove class api

* fix example

* add setup and teardown to tests

* formatting

* Update python/ray/tune/tests/test_trial_scheduler_pbt.py

Co-authored-by: Kai Fricke <kai@anyscale.com>
Co-authored-by: Richard Liaw <rliaw@berkeley.edu>
This commit is contained in:
Amog Kamsetty
2020-08-14 17:52:30 -07:00
committed by GitHub
parent fba5906ce3
commit f87a4aa45d
9 changed files with 274 additions and 32 deletions
+9
View File
@@ -446,6 +446,15 @@ py_test(
args = ["--smoke-test"]
)
py_test(
name = "pbt_convnet_function_example",
size = "medium",
srcs = ["examples/pbt_convnet_function_example.py"],
deps = [":tune_lib"],
tags = ["exclusive", "example"],
args = ["--smoke-test"]
)
py_test(
name = "pbt_example",
size = "medium",
@@ -167,6 +167,22 @@ class Analysis:
else:
raise ValueError("trial should be a string or a Trial instance.")
def get_best_checkpoint(self, trial, metric=TRAINING_ITERATION):
"""Gets best persistent checkpoint path of provided trial.
Args:
trial (Trial): The log directory of a trial, or a trial instance.
metric (str): key of trial info to return, e.g. "mean_accuracy".
"training_iteration" is used by default.
Returns:
Path for best checkpoint of trial determined by metric
"""
return max(
self.get_trial_checkpoints_paths(trial, metric),
key=lambda x: x[1])[0]
def _retrieve_rows(self, metric=None, mode=None):
assert mode is None or mode in ["max", "min"]
rows = {}
+1 -1
View File
@@ -137,7 +137,7 @@ class CheckpointManager:
self._membership.remove(worst)
# Don't delete the newest checkpoint. It will be deleted on the
# next on_checkpoint() call since it isn't in self._membership.
if worst != checkpoint:
if worst.value != checkpoint.value:
self.delete(worst)
def best_checkpoints(self):
@@ -132,10 +132,9 @@ if __name__ == "__main__":
# __tune_end__
best_trial = analysis.get_best_trial("mean_accuracy")
best_checkpoint = max(
analysis.get_trial_checkpoints_paths(best_trial, "mean_accuracy"))
best_checkpoint = analysis.get_best_checkpoint(best_trial, metric="mean_accuracy")
restored_trainable = PytorchTrainable()
restored_trainable.restore(best_checkpoint[0])
restored_trainable.restore(best_checkpoint)
best_model = restored_trainable.model
# Note that test only runs on a small random set of the test data, thus the
# accuracy may be different from metrics shown in tuning process.
@@ -0,0 +1,129 @@
#!/usr/bin/env python
# __tutorial_imports_begin__
import argparse
import os
import numpy as np
import torch
import torch.optim as optim
from torchvision import datasets
from ray.tune.examples.mnist_pytorch import train, test, ConvNet,\
get_data_loaders
import ray
from ray import tune
from ray.tune.schedulers import PopulationBasedTraining
from ray.tune.trial import ExportFormat
# __tutorial_imports_end__
# __train_begin__
def train_convnet(config, checkpoint_dir=None):
# Create our data loaders, model, and optmizer.
step = 0
train_loader, test_loader = get_data_loaders()
model = ConvNet()
optimizer = optim.SGD(
model.parameters(),
lr=config.get("lr", 0.01),
momentum=config.get("momentum", 0.9))
# If checkpoint_dir is not None, then we are resuming from a checkpoint.
# Load model state and iteration step from checkpoint.
if checkpoint_dir:
print("Loading from checkpoint.")
path = os.path.join(checkpoint_dir, "checkpoint")
checkpoint = torch.load(path)
model.load_state_dict(checkpoint["model_state_dict"])
step = checkpoint["step"]
while True:
train(model, optimizer, train_loader)
acc = test(model, test_loader)
if step % 5 == 0:
# Every 5 steps, checkpoint our current state.
# First get the checkpoint directory from tune.
with tune.checkpoint_dir(step=step) as checkpoint_dir:
# Then create a checkpoint file in this directory.
path = os.path.join(checkpoint_dir, "checkpoint")
# Save state to checkpoint file.
# No need to save optimizer for SGD.
torch.save({
"step": step,
"model_state_dict": model.state_dict(),
"mean_accuracy": acc
}, path)
step += 1
tune.report(mean_accuracy=acc)
# __train_end__
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--smoke-test", action="store_true", help="Finish quickly for testing")
args, _ = parser.parse_known_args()
ray.init()
datasets.MNIST("~/data", train=True, download=True)
# __pbt_begin__
scheduler = PopulationBasedTraining(
time_attr="training_iteration",
metric="mean_accuracy",
mode="max",
perturbation_interval=5,
hyperparam_mutations={
# distribution for resampling
"lr": lambda: np.random.uniform(0.0001, 1),
# allow perturbations within this set of categorical values
"momentum": [0.8, 0.9, 0.99],
})
# __pbt_end__
# __tune_begin__
class CustomStopper(tune.Stopper):
def __init__(self):
self.should_stop = False
def __call__(self, trial_id, result):
max_iter = 5 if args.smoke_test else 100
if not self.should_stop and result["mean_accuracy"] > 0.96:
self.should_stop = True
return self.should_stop or result["training_iteration"] >= max_iter
def stop_all(self):
return self.should_stop
stopper = CustomStopper()
analysis = tune.run(
train_convnet,
name="pbt_test",
scheduler=scheduler,
verbose=1,
stop=stopper,
export_formats=[ExportFormat.MODEL],
checkpoint_score_attr="mean_accuracy",
keep_checkpoints_num=4,
num_samples=4,
config={
"lr": tune.uniform(0.001, 1),
"momentum": tune.uniform(0.001, 1),
})
# __tune_end__
best_trial = analysis.get_best_trial("mean_accuracy")
best_checkpoint_path = analysis.get_best_checkpoint(
best_trial, metric="mean_accuracy")
best_model = ConvNet()
best_checkpoint = torch.load(
os.path.join(best_checkpoint_path, "checkpoint"))
best_model.load_state_dict(best_checkpoint["model_state_dict"])
# Note that test only runs on a small random set of the test data, thus the
# accuracy may be different from metrics shown in tuning process.
test_acc = test(best_model, get_data_loaders()[1])
print("best model accuracy: ", test_acc)
@@ -1,6 +1,8 @@
# coding: utf-8
import os
import random
import sys
import tempfile
import unittest
from unittest.mock import patch
@@ -128,6 +130,35 @@ class CheckpointManagerTest(unittest.TestCase):
self.assertEqual(newest, checkpoints[1])
self.assertEqual(checkpoint_manager.best_checkpoints(), [])
def testSameCheckpoint(self):
checkpoint_manager = CheckpointManager(
1, "i", delete_fn=lambda c: os.remove(c.value))
tmpfiles = []
for i in range(3):
tmpfile = tempfile.mktemp()
with open(tmpfile, "wt") as fp:
fp.write("")
tmpfiles.append(tmpfile)
checkpoints = [
Checkpoint(Checkpoint.PERSISTENT, tmpfiles[0],
self.mock_result(5)),
Checkpoint(Checkpoint.PERSISTENT, tmpfiles[1],
self.mock_result(10)),
Checkpoint(Checkpoint.PERSISTENT, tmpfiles[2],
self.mock_result(0)),
Checkpoint(Checkpoint.PERSISTENT, tmpfiles[1],
self.mock_result(20))
]
for checkpoint in checkpoints:
checkpoint_manager.on_checkpoint(checkpoint)
self.assertTrue(os.path.exists(checkpoint.value))
for tmpfile in tmpfiles:
if os.path.exists(tmpfile):
os.remove(tmpfile)
if __name__ == "__main__":
import pytest
@@ -126,6 +126,13 @@ class ExperimentAnalysisSuite(unittest.TestCase):
assert paths[0][0] == expected_path
assert paths[0][1] == best_trial.metric_analysis[self.metric]["last"]
def testGetBestCheckpoint(self):
best_trial = self.ea.get_best_trial(self.metric)
checkpoints_metrics = self.ea.get_trial_checkpoints_paths(best_trial)
expected_path = max(checkpoints_metrics, key=lambda x: x[1])[0]
best_checkpoint = self.ea.get_best_checkpoint(best_trial, self.metric)
assert expected_path == best_checkpoint
def testAllDataframes(self):
dataframes = self.ea.trial_dataframes
self.assertTrue(len(dataframes) == self.num_samples)
@@ -5,6 +5,7 @@ import random
import unittest
import sys
import ray
from ray import tune
from ray.tune.schedulers import PopulationBasedTraining
@@ -23,13 +24,32 @@ class MockTrainable(tune.Trainable):
def save_checkpoint(self, tmp_checkpoint_dir):
checkpoint_path = os.path.join(tmp_checkpoint_dir, "model.mock")
with open(checkpoint_path, "wb") as fp:
pickle.dump((self.a, self.b), fp)
pickle.dump((self.a, self.b, self.iter), fp)
return tmp_checkpoint_dir
def load_checkpoint(self, tmp_checkpoint_dir):
checkpoint_path = os.path.join(tmp_checkpoint_dir, "model.mock")
with open(checkpoint_path, "rb") as fp:
self.a, self.b = pickle.load(fp)
self.a, self.b, self.iter = pickle.load(fp)
def MockTrainingFunc(config, checkpoint_dir=None):
iter = 0
a = config["a"]
b = config["b"]
if checkpoint_dir:
checkpoint_path = os.path.join(checkpoint_dir, "model.mock")
with open(checkpoint_path, "rb") as fp:
a, b, iter = pickle.load(fp)
while True:
iter += 1
with tune.checkpoint_dir(step=iter) as checkpoint_dir:
checkpoint_path = os.path.join(checkpoint_dir, "model.mock")
with open(checkpoint_path, "wb") as fp:
pickle.dump((a, b, iter), fp)
tune.report(mean_accuracy=(a - iter) * b)
class MockParam(object):
@@ -44,6 +64,12 @@ class MockParam(object):
class PopulationBasedTrainingResumeTest(unittest.TestCase):
def setUp(self):
ray.init()
def tearDown(self):
ray.shutdown()
def testPermutationContinuation(self):
"""
Tests continuation of runs after permutation.
@@ -74,7 +100,6 @@ class PopulationBasedTrainingResumeTest(unittest.TestCase):
},
fail_fast=True,
num_samples=20,
global_checkpoint_period=1,
checkpoint_freq=1,
checkpoint_at_end=True,
keep_checkpoints_num=1,
@@ -83,6 +108,33 @@ class PopulationBasedTrainingResumeTest(unittest.TestCase):
name="testPermutationContinuation",
stop={"training_iteration": 5})
def testPermutationContinuationFunc(self):
scheduler = PopulationBasedTraining(
time_attr="training_iteration",
metric="mean_accuracy",
mode="max",
perturbation_interval=1,
log_config=True,
hyperparam_mutations={"c": lambda: 1})
param_a = MockParam([10, 20, 30, 40])
param_b = MockParam([1.2, 0.9, 1.1, 0.8])
random.seed(100)
np.random.seed(1000)
tune.run(
MockTrainingFunc,
config={
"a": tune.sample_from(lambda _: param_a()),
"b": tune.sample_from(lambda _: param_b()),
"c": 1
},
fail_fast=True,
num_samples=4,
keep_checkpoints_num=1,
checkpoint_score_attr="min-training_iteration",
scheduler=scheduler,
name="testPermutationContinuationFunc",
stop={"training_iteration": 3})
if __name__ == "__main__":
import pytest