Files
ray/python/ray/tune/examples/pbt_convnet_function_example.py
T
Amog Kamsetty f87a4aa45d [Tune] Pbt Function API (#9958)
* adding function convnet example

* add unit test

* update test

* update example

* wip

* move error from experiment to tune

* wip

* Fix checkpoint deletion

* updating code

* adding smoke test

* updating pbt guide

* formatting

* fix build

* add best checkpoint analysis util

* update test

* add comments

* remove class api

* fix example

* add setup and teardown to tests

* formatting

* Update python/ray/tune/tests/test_trial_scheduler_pbt.py

Co-authored-by: Kai Fricke <kai@anyscale.com>
Co-authored-by: Richard Liaw <rliaw@berkeley.edu>
2020-08-14 17:52:30 -07:00

130 lines
4.1 KiB
Python

#!/usr/bin/env python
# __tutorial_imports_begin__
import argparse
import os
import numpy as np
import torch
import torch.optim as optim
from torchvision import datasets
from ray.tune.examples.mnist_pytorch import train, test, ConvNet,\
get_data_loaders
import ray
from ray import tune
from ray.tune.schedulers import PopulationBasedTraining
from ray.tune.trial import ExportFormat
# __tutorial_imports_end__
# __train_begin__
def train_convnet(config, checkpoint_dir=None):
# Create our data loaders, model, and optmizer.
step = 0
train_loader, test_loader = get_data_loaders()
model = ConvNet()
optimizer = optim.SGD(
model.parameters(),
lr=config.get("lr", 0.01),
momentum=config.get("momentum", 0.9))
# If checkpoint_dir is not None, then we are resuming from a checkpoint.
# Load model state and iteration step from checkpoint.
if checkpoint_dir:
print("Loading from checkpoint.")
path = os.path.join(checkpoint_dir, "checkpoint")
checkpoint = torch.load(path)
model.load_state_dict(checkpoint["model_state_dict"])
step = checkpoint["step"]
while True:
train(model, optimizer, train_loader)
acc = test(model, test_loader)
if step % 5 == 0:
# Every 5 steps, checkpoint our current state.
# First get the checkpoint directory from tune.
with tune.checkpoint_dir(step=step) as checkpoint_dir:
# Then create a checkpoint file in this directory.
path = os.path.join(checkpoint_dir, "checkpoint")
# Save state to checkpoint file.
# No need to save optimizer for SGD.
torch.save({
"step": step,
"model_state_dict": model.state_dict(),
"mean_accuracy": acc
}, path)
step += 1
tune.report(mean_accuracy=acc)
# __train_end__
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--smoke-test", action="store_true", help="Finish quickly for testing")
args, _ = parser.parse_known_args()
ray.init()
datasets.MNIST("~/data", train=True, download=True)
# __pbt_begin__
scheduler = PopulationBasedTraining(
time_attr="training_iteration",
metric="mean_accuracy",
mode="max",
perturbation_interval=5,
hyperparam_mutations={
# distribution for resampling
"lr": lambda: np.random.uniform(0.0001, 1),
# allow perturbations within this set of categorical values
"momentum": [0.8, 0.9, 0.99],
})
# __pbt_end__
# __tune_begin__
class CustomStopper(tune.Stopper):
def __init__(self):
self.should_stop = False
def __call__(self, trial_id, result):
max_iter = 5 if args.smoke_test else 100
if not self.should_stop and result["mean_accuracy"] > 0.96:
self.should_stop = True
return self.should_stop or result["training_iteration"] >= max_iter
def stop_all(self):
return self.should_stop
stopper = CustomStopper()
analysis = tune.run(
train_convnet,
name="pbt_test",
scheduler=scheduler,
verbose=1,
stop=stopper,
export_formats=[ExportFormat.MODEL],
checkpoint_score_attr="mean_accuracy",
keep_checkpoints_num=4,
num_samples=4,
config={
"lr": tune.uniform(0.001, 1),
"momentum": tune.uniform(0.001, 1),
})
# __tune_end__
best_trial = analysis.get_best_trial("mean_accuracy")
best_checkpoint_path = analysis.get_best_checkpoint(
best_trial, metric="mean_accuracy")
best_model = ConvNet()
best_checkpoint = torch.load(
os.path.join(best_checkpoint_path, "checkpoint"))
best_model.load_state_dict(best_checkpoint["model_state_dict"])
# Note that test only runs on a small random set of the test data, thus the
# accuracy may be different from metrics shown in tuning process.
test_acc = test(best_model, get_data_loaders()[1])
print("best model accuracy: ", test_acc)