From f87a4aa45dfb6b00c2074af867dd5352c27a2e28 Mon Sep 17 00:00:00 2001 From: Amog Kamsetty Date: Fri, 14 Aug 2020 17:52:30 -0700 Subject: [PATCH] [Tune] Pbt Function API (#9958) * adding function convnet example * add unit test * update test * update example * wip * move error from experiment to tune * wip * Fix checkpoint deletion * updating code * adding smoke test * updating pbt guide * formatting * fix build * add best checkpoint analysis util * update test * add comments * remove class api * fix example * add setup and teardown to tests * formatting * Update python/ray/tune/tests/test_trial_scheduler_pbt.py Co-authored-by: Kai Fricke Co-authored-by: Richard Liaw --- .../_tutorials/tune-advanced-tutorial.rst | 49 ++++--- python/ray/tune/BUILD | 9 ++ .../ray/tune/analysis/experiment_analysis.py | 16 +++ python/ray/tune/checkpoint_manager.py | 2 +- .../ray/tune/examples/pbt_convnet_example.py | 5 +- .../examples/pbt_convnet_function_example.py | 129 ++++++++++++++++++ .../ray/tune/tests/test_checkpoint_manager.py | 31 +++++ .../tune/tests/test_experiment_analysis.py | 7 + .../tune/tests/test_trial_scheduler_pbt.py | 58 +++++++- 9 files changed, 274 insertions(+), 32 deletions(-) create mode 100644 python/ray/tune/examples/pbt_convnet_function_example.py diff --git a/doc/source/tune/_tutorials/tune-advanced-tutorial.rst b/doc/source/tune/_tutorials/tune-advanced-tutorial.rst index ad62f59d1..e74061fca 100644 --- a/doc/source/tune/_tutorials/tune-advanced-tutorial.rst +++ b/doc/source/tune/_tutorials/tune-advanced-tutorial.rst @@ -14,9 +14,8 @@ hyperparameters and allocate resources to promising models. Let's walk through h :local: :backlinks: none - -Trainable API with Population Based Training --------------------------------------------- +Function API with Population Based Training +------------------------------------------- PBT takes its inspiration from genetic algorithms where each member of the population can exploit information from the remainder of the population. For example, a worker might @@ -31,23 +30,24 @@ This means that PBT can quickly exploit good hyperparameters, can dedicate more promising models and, crucially, can adapt the hyperparameter values throughout training, leading to automatic learning of the best configurations. -First, we define a Trainable that wraps a ConvNet model. +First we define a training function that trains a ConvNet model using SGD. -.. literalinclude:: /../../python/ray/tune/examples/pbt_convnet_example.py - :language: python - :start-after: __trainable_begin__ - :end-before: __trainable_end__ +.. literalinclude:: /../../python/ray/tune/examples/pbt_convnet_function_example.py + :language: python + :start-after: __train_begin__ + :end-before: __train_end__ The example reuses some of the functions in ray/tune/examples/mnist_pytorch.py, and is also a good demo for how to decouple the tuning logic and original training code. -Here, we also override ``reset_config``. This method is optional but can be implemented to speed -up algorithms such as PBT, and to allow performance optimizations such as running experiments -with ``reuse_actors=True``. +Here, we also need to take in a ``checkpoint_dir`` arg since checkpointing is required for the exploitation process in PBT. +We have to both load in the checkpoint if one is provided, and periodically save our +model state in a checkpoint- in this case every 5 iterations. With SGD, there's no need to checkpoint the optimizer +since it does not depend on previous states, but this is necessary with other optimizers like Adam. Then, we define a PBT scheduler: -.. literalinclude:: /../../python/ray/tune/examples/pbt_convnet_example.py +.. literalinclude:: /../../python/ray/tune/examples/pbt_convnet_function_example.py :language: python :start-after: __pbt_begin__ :end-before: __pbt_end__ @@ -67,7 +67,7 @@ Some of the most important parameters are: Now we can kick off the tuning process by invoking tune.run: -.. literalinclude:: /../../python/ray/tune/examples/pbt_convnet_example.py +.. literalinclude:: /../../python/ray/tune/examples/pbt_convnet_function_example.py :language: python :start-after: __tune_begin__ :end-before: __tune_end__ @@ -77,19 +77,19 @@ During the training, we can constantly check the status of the models from conso .. code-block:: bash == Status == - Memory usage on this node: 11.6/16.0 GiB - PopulationBasedTraining: 5 checkpoints, 4 perturbs - Resources requested: 0/16 CPUs, 0/0 GPUs, 0.0/3.96 GiB heap, 0.0/1.37 GiB objects + Memory usage on this node: 11.2/16.0 GiB + PopulationBasedTraining: 12 checkpoints, 5 perturbs + Resources requested: 0/16 CPUs, 0/0 GPUs, 0.0/4.83 GiB heap, 0.0/1.66 GiB objects Result logdir: /Users/foo/ray_results/pbt_test Number of trials: 4 (4 TERMINATED) - +------------------------------+------------+-------+-----------+------------+----------+--------+------------------+ - | Trial name | status | loc | lr | momentum | acc | iter | total time (s) | - |------------------------------+------------+-------+-----------+------------+----------+--------+------------------| - | PytorchTrainable_ba982_00000 | TERMINATED | | 0.0457501 | 0.99 | 0.6375 | 25 | 5.35712 | - | PytorchTrainable_ba982_00001 | TERMINATED | | 0.175808 | 0.0667043 | 0.909375 | 29 | 6.18802 | - | PytorchTrainable_ba982_00002 | TERMINATED | | 0.21097 | 0.99 | 0.040625 | 29 | 6.19634 | - | PytorchTrainable_ba982_00003 | TERMINATED | | 0.0571876 | 0.852088 | 0.96875 | 30 | 6.37298 | - +------------------------------+------------+-------+-----------+------------+----------+--------+------------------+ + +---------------------------+------------+-------+-----------+------------+----------+--------+------------------+ + | Trial name | status | loc | lr | momentum | acc | iter | total time (s) | + |---------------------------+------------+-------+-----------+------------+----------+--------+------------------| + | train_convnet_b2732_00000 | TERMINATED | | 0.221776 | 0.608416 | 0.95625 | 59 | 13.0862 | + | train_convnet_b2732_00001 | TERMINATED | | 0.0734679 | 0.1484 | 0.934375 | 59 | 13.1084 | + | train_convnet_b2732_00002 | TERMINATED | | 0.0376862 | 0.8 | 0.971875 | 46 | 10.2909 | + | train_convnet_b2732_00003 | TERMINATED | | 0.0471078 | 0.8 | 0.95 | 51 | 11.3355 | + +---------------------------+------------+-------+-----------+------------+----------+--------+------------------+ In {LOG_DIR}/{MY_EXPERIMENT_NAME}/, all mutations are logged in pbt_global.txt and individual policy perturbations are recorded in pbt_policy_{i}.txt. Tune logs: @@ -145,7 +145,6 @@ thus just use the same ``Trainable`` for the replay run. scheduler=replay, stop={"training_iteration": 100}) - DCGAN with Trainable and PBT ---------------------------- diff --git a/python/ray/tune/BUILD b/python/ray/tune/BUILD index 742997fc3..99b10b280 100644 --- a/python/ray/tune/BUILD +++ b/python/ray/tune/BUILD @@ -446,6 +446,15 @@ py_test( args = ["--smoke-test"] ) +py_test( + name = "pbt_convnet_function_example", + size = "medium", + srcs = ["examples/pbt_convnet_function_example.py"], + deps = [":tune_lib"], + tags = ["exclusive", "example"], + args = ["--smoke-test"] +) + py_test( name = "pbt_example", size = "medium", diff --git a/python/ray/tune/analysis/experiment_analysis.py b/python/ray/tune/analysis/experiment_analysis.py index c2780a036..c504f462a 100644 --- a/python/ray/tune/analysis/experiment_analysis.py +++ b/python/ray/tune/analysis/experiment_analysis.py @@ -167,6 +167,22 @@ class Analysis: else: raise ValueError("trial should be a string or a Trial instance.") + def get_best_checkpoint(self, trial, metric=TRAINING_ITERATION): + """Gets best persistent checkpoint path of provided trial. + + Args: + trial (Trial): The log directory of a trial, or a trial instance. + metric (str): key of trial info to return, e.g. "mean_accuracy". + "training_iteration" is used by default. + + Returns: + Path for best checkpoint of trial determined by metric + """ + + return max( + self.get_trial_checkpoints_paths(trial, metric), + key=lambda x: x[1])[0] + def _retrieve_rows(self, metric=None, mode=None): assert mode is None or mode in ["max", "min"] rows = {} diff --git a/python/ray/tune/checkpoint_manager.py b/python/ray/tune/checkpoint_manager.py index 4d76dadcf..df50eb185 100644 --- a/python/ray/tune/checkpoint_manager.py +++ b/python/ray/tune/checkpoint_manager.py @@ -137,7 +137,7 @@ class CheckpointManager: self._membership.remove(worst) # Don't delete the newest checkpoint. It will be deleted on the # next on_checkpoint() call since it isn't in self._membership. - if worst != checkpoint: + if worst.value != checkpoint.value: self.delete(worst) def best_checkpoints(self): diff --git a/python/ray/tune/examples/pbt_convnet_example.py b/python/ray/tune/examples/pbt_convnet_example.py index 656899966..c9f455ccc 100644 --- a/python/ray/tune/examples/pbt_convnet_example.py +++ b/python/ray/tune/examples/pbt_convnet_example.py @@ -132,10 +132,9 @@ if __name__ == "__main__": # __tune_end__ best_trial = analysis.get_best_trial("mean_accuracy") - best_checkpoint = max( - analysis.get_trial_checkpoints_paths(best_trial, "mean_accuracy")) + best_checkpoint = analysis.get_best_checkpoint(best_trial, metric="mean_accuracy") restored_trainable = PytorchTrainable() - restored_trainable.restore(best_checkpoint[0]) + restored_trainable.restore(best_checkpoint) best_model = restored_trainable.model # Note that test only runs on a small random set of the test data, thus the # accuracy may be different from metrics shown in tuning process. diff --git a/python/ray/tune/examples/pbt_convnet_function_example.py b/python/ray/tune/examples/pbt_convnet_function_example.py new file mode 100644 index 000000000..758148a29 --- /dev/null +++ b/python/ray/tune/examples/pbt_convnet_function_example.py @@ -0,0 +1,129 @@ +#!/usr/bin/env python + +# __tutorial_imports_begin__ +import argparse +import os +import numpy as np +import torch +import torch.optim as optim +from torchvision import datasets +from ray.tune.examples.mnist_pytorch import train, test, ConvNet,\ + get_data_loaders + +import ray +from ray import tune +from ray.tune.schedulers import PopulationBasedTraining +from ray.tune.trial import ExportFormat + +# __tutorial_imports_end__ + + +# __train_begin__ +def train_convnet(config, checkpoint_dir=None): + # Create our data loaders, model, and optmizer. + step = 0 + train_loader, test_loader = get_data_loaders() + model = ConvNet() + optimizer = optim.SGD( + model.parameters(), + lr=config.get("lr", 0.01), + momentum=config.get("momentum", 0.9)) + + # If checkpoint_dir is not None, then we are resuming from a checkpoint. + # Load model state and iteration step from checkpoint. + if checkpoint_dir: + print("Loading from checkpoint.") + path = os.path.join(checkpoint_dir, "checkpoint") + checkpoint = torch.load(path) + model.load_state_dict(checkpoint["model_state_dict"]) + step = checkpoint["step"] + + while True: + train(model, optimizer, train_loader) + acc = test(model, test_loader) + if step % 5 == 0: + # Every 5 steps, checkpoint our current state. + # First get the checkpoint directory from tune. + with tune.checkpoint_dir(step=step) as checkpoint_dir: + # Then create a checkpoint file in this directory. + path = os.path.join(checkpoint_dir, "checkpoint") + # Save state to checkpoint file. + # No need to save optimizer for SGD. + torch.save({ + "step": step, + "model_state_dict": model.state_dict(), + "mean_accuracy": acc + }, path) + step += 1 + tune.report(mean_accuracy=acc) + + +# __train_end__ + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument( + "--smoke-test", action="store_true", help="Finish quickly for testing") + args, _ = parser.parse_known_args() + + ray.init() + datasets.MNIST("~/data", train=True, download=True) + + # __pbt_begin__ + scheduler = PopulationBasedTraining( + time_attr="training_iteration", + metric="mean_accuracy", + mode="max", + perturbation_interval=5, + hyperparam_mutations={ + # distribution for resampling + "lr": lambda: np.random.uniform(0.0001, 1), + # allow perturbations within this set of categorical values + "momentum": [0.8, 0.9, 0.99], + }) + + # __pbt_end__ + + # __tune_begin__ + class CustomStopper(tune.Stopper): + def __init__(self): + self.should_stop = False + + def __call__(self, trial_id, result): + max_iter = 5 if args.smoke_test else 100 + if not self.should_stop and result["mean_accuracy"] > 0.96: + self.should_stop = True + return self.should_stop or result["training_iteration"] >= max_iter + + def stop_all(self): + return self.should_stop + + stopper = CustomStopper() + + analysis = tune.run( + train_convnet, + name="pbt_test", + scheduler=scheduler, + verbose=1, + stop=stopper, + export_formats=[ExportFormat.MODEL], + checkpoint_score_attr="mean_accuracy", + keep_checkpoints_num=4, + num_samples=4, + config={ + "lr": tune.uniform(0.001, 1), + "momentum": tune.uniform(0.001, 1), + }) + # __tune_end__ + + best_trial = analysis.get_best_trial("mean_accuracy") + best_checkpoint_path = analysis.get_best_checkpoint( + best_trial, metric="mean_accuracy") + best_model = ConvNet() + best_checkpoint = torch.load( + os.path.join(best_checkpoint_path, "checkpoint")) + best_model.load_state_dict(best_checkpoint["model_state_dict"]) + # Note that test only runs on a small random set of the test data, thus the + # accuracy may be different from metrics shown in tuning process. + test_acc = test(best_model, get_data_loaders()[1]) + print("best model accuracy: ", test_acc) diff --git a/python/ray/tune/tests/test_checkpoint_manager.py b/python/ray/tune/tests/test_checkpoint_manager.py index d171f5974..216341f12 100644 --- a/python/ray/tune/tests/test_checkpoint_manager.py +++ b/python/ray/tune/tests/test_checkpoint_manager.py @@ -1,6 +1,8 @@ # coding: utf-8 +import os import random import sys +import tempfile import unittest from unittest.mock import patch @@ -128,6 +130,35 @@ class CheckpointManagerTest(unittest.TestCase): self.assertEqual(newest, checkpoints[1]) self.assertEqual(checkpoint_manager.best_checkpoints(), []) + def testSameCheckpoint(self): + checkpoint_manager = CheckpointManager( + 1, "i", delete_fn=lambda c: os.remove(c.value)) + + tmpfiles = [] + for i in range(3): + tmpfile = tempfile.mktemp() + with open(tmpfile, "wt") as fp: + fp.write("") + tmpfiles.append(tmpfile) + + checkpoints = [ + Checkpoint(Checkpoint.PERSISTENT, tmpfiles[0], + self.mock_result(5)), + Checkpoint(Checkpoint.PERSISTENT, tmpfiles[1], + self.mock_result(10)), + Checkpoint(Checkpoint.PERSISTENT, tmpfiles[2], + self.mock_result(0)), + Checkpoint(Checkpoint.PERSISTENT, tmpfiles[1], + self.mock_result(20)) + ] + for checkpoint in checkpoints: + checkpoint_manager.on_checkpoint(checkpoint) + self.assertTrue(os.path.exists(checkpoint.value)) + + for tmpfile in tmpfiles: + if os.path.exists(tmpfile): + os.remove(tmpfile) + if __name__ == "__main__": import pytest diff --git a/python/ray/tune/tests/test_experiment_analysis.py b/python/ray/tune/tests/test_experiment_analysis.py index 1b6f6a286..5b1f4c7bb 100644 --- a/python/ray/tune/tests/test_experiment_analysis.py +++ b/python/ray/tune/tests/test_experiment_analysis.py @@ -126,6 +126,13 @@ class ExperimentAnalysisSuite(unittest.TestCase): assert paths[0][0] == expected_path assert paths[0][1] == best_trial.metric_analysis[self.metric]["last"] + def testGetBestCheckpoint(self): + best_trial = self.ea.get_best_trial(self.metric) + checkpoints_metrics = self.ea.get_trial_checkpoints_paths(best_trial) + expected_path = max(checkpoints_metrics, key=lambda x: x[1])[0] + best_checkpoint = self.ea.get_best_checkpoint(best_trial, self.metric) + assert expected_path == best_checkpoint + def testAllDataframes(self): dataframes = self.ea.trial_dataframes self.assertTrue(len(dataframes) == self.num_samples) diff --git a/python/ray/tune/tests/test_trial_scheduler_pbt.py b/python/ray/tune/tests/test_trial_scheduler_pbt.py index 5886c3a6a..523d6e778 100644 --- a/python/ray/tune/tests/test_trial_scheduler_pbt.py +++ b/python/ray/tune/tests/test_trial_scheduler_pbt.py @@ -5,6 +5,7 @@ import random import unittest import sys +import ray from ray import tune from ray.tune.schedulers import PopulationBasedTraining @@ -23,13 +24,32 @@ class MockTrainable(tune.Trainable): def save_checkpoint(self, tmp_checkpoint_dir): checkpoint_path = os.path.join(tmp_checkpoint_dir, "model.mock") with open(checkpoint_path, "wb") as fp: - pickle.dump((self.a, self.b), fp) + pickle.dump((self.a, self.b, self.iter), fp) return tmp_checkpoint_dir def load_checkpoint(self, tmp_checkpoint_dir): checkpoint_path = os.path.join(tmp_checkpoint_dir, "model.mock") with open(checkpoint_path, "rb") as fp: - self.a, self.b = pickle.load(fp) + self.a, self.b, self.iter = pickle.load(fp) + + +def MockTrainingFunc(config, checkpoint_dir=None): + iter = 0 + a = config["a"] + b = config["b"] + + if checkpoint_dir: + checkpoint_path = os.path.join(checkpoint_dir, "model.mock") + with open(checkpoint_path, "rb") as fp: + a, b, iter = pickle.load(fp) + + while True: + iter += 1 + with tune.checkpoint_dir(step=iter) as checkpoint_dir: + checkpoint_path = os.path.join(checkpoint_dir, "model.mock") + with open(checkpoint_path, "wb") as fp: + pickle.dump((a, b, iter), fp) + tune.report(mean_accuracy=(a - iter) * b) class MockParam(object): @@ -44,6 +64,12 @@ class MockParam(object): class PopulationBasedTrainingResumeTest(unittest.TestCase): + def setUp(self): + ray.init() + + def tearDown(self): + ray.shutdown() + def testPermutationContinuation(self): """ Tests continuation of runs after permutation. @@ -74,7 +100,6 @@ class PopulationBasedTrainingResumeTest(unittest.TestCase): }, fail_fast=True, num_samples=20, - global_checkpoint_period=1, checkpoint_freq=1, checkpoint_at_end=True, keep_checkpoints_num=1, @@ -83,6 +108,33 @@ class PopulationBasedTrainingResumeTest(unittest.TestCase): name="testPermutationContinuation", stop={"training_iteration": 5}) + def testPermutationContinuationFunc(self): + scheduler = PopulationBasedTraining( + time_attr="training_iteration", + metric="mean_accuracy", + mode="max", + perturbation_interval=1, + log_config=True, + hyperparam_mutations={"c": lambda: 1}) + param_a = MockParam([10, 20, 30, 40]) + param_b = MockParam([1.2, 0.9, 1.1, 0.8]) + random.seed(100) + np.random.seed(1000) + tune.run( + MockTrainingFunc, + config={ + "a": tune.sample_from(lambda _: param_a()), + "b": tune.sample_from(lambda _: param_b()), + "c": 1 + }, + fail_fast=True, + num_samples=4, + keep_checkpoints_num=1, + checkpoint_score_attr="min-training_iteration", + scheduler=scheduler, + name="testPermutationContinuationFunc", + stop={"training_iteration": 3}) + if __name__ == "__main__": import pytest