[tune] Test example checkpointing (#4728)

This commit is contained in:
Richard Liaw
2019-07-10 01:58:26 -07:00
committed by GitHub
parent e55c8ca165
commit 0b540ab492
14 changed files with 157 additions and 68 deletions
@@ -28,8 +28,8 @@ class MyTrainableClass(Trainable):
def _train(self):
self.timestep += 1
v = np.tanh(float(self.timestep) / self.config["width"])
v *= self.config["height"]
v = np.tanh(float(self.timestep) / self.config.get("width", 1))
v *= self.config.get("height", 1)
# Here we use `episode_reward_mean`, but you can also report other
# objectives such as loss or accuracy.
@@ -28,8 +28,8 @@ class MyTrainableClass(Trainable):
def _train(self):
self.timestep += 1
v = np.tanh(float(self.timestep) / self.config["width"])
v *= self.config["height"]
v = np.tanh(float(self.timestep) / self.config.get("width", 1))
v *= self.config.get("height", 1)
# Here we use `episode_reward_mean`, but you can also report other
# objectives such as loss or accuracy.
+2 -2
View File
@@ -36,8 +36,8 @@ class MyTrainableClass(Trainable):
def _train(self):
self.timestep += 1
v = np.tanh(float(self.timestep) / self.config["width"])
v *= self.config["height"]
v = np.tanh(float(self.timestep) / self.config.get("width", 1))
v *= self.config.get("height", 1)
# Here we use `episode_reward_mean`, but you can also report other
# objectives such as loss or accuracy.
@@ -12,6 +12,10 @@ from torchvision import datasets, transforms
from ray.tune import Trainable
# Change these values if you want the training to run quicker or slower.
EPOCH_SIZE = 512
TEST_SIZE = 256
# Training settings
parser = argparse.ArgumentParser(description="PyTorch MNIST Example")
parser.add_argument(
@@ -85,7 +89,7 @@ class Net(nn.Module):
class TrainMNIST(Trainable):
def _setup(self, config):
args = config.pop("args")
args = config.pop("args", parser.parse_args([]))
vars(args).update(config)
args.cuda = not args.no_cuda and torch.cuda.is_available()
@@ -98,7 +102,7 @@ class TrainMNIST(Trainable):
datasets.MNIST(
"~/data",
train=True,
download=False,
download=True,
transform=transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307, ), (0.3081, ))
@@ -129,6 +133,8 @@ class TrainMNIST(Trainable):
def _train_iteration(self):
self.model.train()
for batch_idx, (data, target) in enumerate(self.train_loader):
if batch_idx * len(data) > EPOCH_SIZE:
return
if self.args.cuda:
data, target = data.cuda(), target.cuda()
self.optimizer.zero_grad()
@@ -142,7 +148,9 @@ class TrainMNIST(Trainable):
test_loss = 0
correct = 0
with torch.no_grad():
for data, target in self.test_loader:
for batch_idx, (data, target) in enumerate(self.test_loader):
if batch_idx * len(data) > TEST_SIZE:
break
if self.args.cuda:
data, target = data.cuda(), target.cuda()
output = self.model(data)
@@ -19,7 +19,7 @@ import tensorflow as tf
from tensorflow.python.keras.datasets import cifar10
from tensorflow.python.keras.layers import Input, Dense, Dropout, Flatten
from tensorflow.python.keras.layers import Convolution2D, MaxPooling2D
from tensorflow.python.keras.models import Model
from tensorflow.python.keras.models import Model, load_model
from tensorflow.python.keras.preprocessing.image import ImageDataGenerator
import ray
@@ -28,6 +28,7 @@ from ray.tune import Trainable
from ray.tune.schedulers import PopulationBasedTraining
num_classes = 10
NUM_SAMPLES = 128
class Cifar10Model(Trainable):
@@ -98,7 +99,7 @@ class Cifar10Model(Trainable):
y = MaxPooling2D(pool_size=2, strides=2, padding="same")(y)
y = Flatten()(y)
y = Dropout(self.config["dropout"])(y)
y = Dropout(self.config.get("dropout", 0.5))(y)
y = Dense(
units=10, activation="softmax", kernel_initializer="he_normal")(y)
@@ -111,7 +112,8 @@ class Cifar10Model(Trainable):
model = self._build_model(x_train.shape[1:])
opt = tf.keras.optimizers.Adadelta(
lr=self.config["lr"], decay=self.config["decay"])
lr=self.config.get("lr", 1e-4),
decay=self.config.get("decay", 1e-4))
model.compile(
loss="categorical_crossentropy",
optimizer=opt,
@@ -120,7 +122,9 @@ class Cifar10Model(Trainable):
def _train(self):
x_train, y_train = self.train_data
x_train, y_train = x_train[:NUM_SAMPLES], y_train[:NUM_SAMPLES]
x_test, y_test = self.test_data
x_test, y_test = x_test[:NUM_SAMPLES], y_test[:NUM_SAMPLES]
aug_gen = ImageDataGenerator(
# set input mean to 0 over the dataset
@@ -146,12 +150,11 @@ class Cifar10Model(Trainable):
)
aug_gen.fit(x_train)
gen = aug_gen.flow(
x_train, y_train, batch_size=self.config["batch_size"])
batch_size = self.config.get("batch_size", 64)
gen = aug_gen.flow(x_train, y_train, batch_size=batch_size)
self.model.fit_generator(
generator=gen,
steps_per_epoch=50000 // self.config["batch_size"],
epochs=self.config["epochs"],
epochs=self.config.get("epochs", 1),
validation_data=None)
# loss, accuracy
@@ -160,11 +163,13 @@ class Cifar10Model(Trainable):
def _save(self, checkpoint_dir):
file_path = checkpoint_dir + "/model"
self.model.save_weights(file_path)
self.model.save(file_path)
return file_path
def _restore(self, path):
self.model.load_weights(path)
# See https://stackoverflow.com/a/42763323
del self.model
self.model = load_model(path)
def _stop(self):
# If need, save your model when exit.
@@ -27,11 +27,12 @@ from __future__ import division
from __future__ import print_function
import argparse
import os
import time
import ray
from ray import tune
from ray.tune import grid_search, Trainable, sample_from
from ray.tune import Trainable, sample_from
from ray.tune.schedulers import HyperBandScheduler
from tensorflow.examples.tutorials.mnist import input_data
@@ -148,7 +149,7 @@ class TrainMNIST(Trainable):
self.x = tf.placeholder(tf.float32, [None, 784])
self.y_ = tf.placeholder(tf.float32, [None, 10])
activation_fn = getattr(tf.nn, config["activation"])
activation_fn = getattr(tf.nn, config.get("activation", "relu"))
# Build the graph for the deep net
y_conv, self.keep_prob = setupCNN(self.x)
@@ -160,7 +161,7 @@ class TrainMNIST(Trainable):
with tf.name_scope("adam_optimizer"):
train_step = tf.train.AdamOptimizer(
config["learning_rate"]).minimize(cross_entropy)
config.get("learning_rate", 1e-4)).minimize(cross_entropy)
self.train_step = train_step
@@ -172,8 +173,7 @@ class TrainMNIST(Trainable):
self.sess = tf.Session()
self.sess.run(tf.global_variables_initializer())
self.iterations = 0
self.saver = tf.train.Saver()
self.saver = tf.train.Saver(save_relative_paths=True)
def _train(self):
for i in range(10):
@@ -194,18 +194,14 @@ class TrainMNIST(Trainable):
self.y_: batch[1],
self.keep_prob: 1.0
})
self.iterations += 1
return {"mean_accuracy": train_accuracy}
def _save(self, checkpoint_dir):
prefix = self.saver.save(
self.sess, checkpoint_dir + "/save", global_step=self.iterations)
return {"prefix": prefix}
path = self.saver.save(self.sess, os.path.join(checkpoint_dir, "save"))
return path
def _restore(self, ckpt_data):
prefix = ckpt_data["prefix"]
return self.saver.restore(self.sess, prefix)
def _restore(self, checkpoint_path):
self.saver.restore(self.sess, checkpoint_path)
# !!! Example of using the ray.tune Python API !!!
@@ -222,14 +218,14 @@ if __name__ == "__main__":
"config": {
"learning_rate": sample_from(
lambda spec: 10**np.random.uniform(-5, -3)),
"activation": grid_search(["relu", "elu", "tanh"]),
"activation": "relu",
},
"num_samples": 10,
}
if args.smoke_test:
mnist_spec["stop"]["training_iteration"] = 20
mnist_spec["num_samples"] = 2
mnist_spec["num_samples"] = 1
ray.init()
hyperband = HyperBandScheduler(
+44 -1
View File
@@ -10,7 +10,7 @@ import unittest
import ray
from ray import tune
from ray.tune.util import recursive_fnmatch
from ray.tune.util import recursive_fnmatch, validate_save_restore
from ray.rllib import _register_all
@@ -53,6 +53,49 @@ class TuneRestoreTest(unittest.TestCase):
)
class TuneExampleTest(unittest.TestCase):
def setUp(self):
ray.init()
def tearDown(self):
ray.shutdown()
_register_all()
def testTensorFlowMNIST(self):
from ray.tune.examples.tune_mnist_ray_hyperband import TrainMNIST
validate_save_restore(TrainMNIST)
validate_save_restore(TrainMNIST, use_object_store=True)
def testPBTKeras(self):
from ray.tune.examples.pbt_tune_cifar10_with_keras import Cifar10Model
from tensorflow.python.keras.datasets import cifar10
cifar10.load_data()
validate_save_restore(Cifar10Model)
validate_save_restore(Cifar10Model, use_object_store=True)
def testPyTorchMNIST(self):
from ray.tune.examples.mnist_pytorch_trainable import TrainMNIST
from torchvision import datasets
datasets.MNIST("~/data", train=True, download=True)
validate_save_restore(TrainMNIST)
validate_save_restore(TrainMNIST, use_object_store=True)
def testLogging(self):
from ray.tune.examples.logging_example import MyTrainableClass
validate_save_restore(MyTrainableClass)
validate_save_restore(MyTrainableClass, use_object_store=True)
def testHyperbandExample(self):
from ray.tune.examples.hyperband_example import MyTrainableClass
validate_save_restore(MyTrainableClass)
validate_save_restore(MyTrainableClass, use_object_store=True)
def testAsyncHyperbandExample(self):
from ray.tune.examples.async_hyperband_example import MyTrainableClass
validate_save_restore(MyTrainableClass)
validate_save_restore(MyTrainableClass, use_object_store=True)
class AutoInitTest(unittest.TestCase):
def testTuneRestore(self):
self.assertFalse(ray.is_initialized())
+22 -23
View File
@@ -238,7 +238,7 @@ class Trainable(object):
checkpoint_dir (str): Optional dir to place the checkpoint.
Returns:
Checkpoint path that may be passed to restore().
Checkpoint path or prefix that may be passed to restore().
"""
checkpoint_dir = os.path.join(checkpoint_dir or self.logdir,
@@ -248,16 +248,11 @@ class Trainable(object):
checkpoint = self._save(checkpoint_dir)
saved_as_dict = False
if isinstance(checkpoint, string_types):
if (not checkpoint.startswith(checkpoint_dir)
or checkpoint == checkpoint_dir):
if not checkpoint.startswith(checkpoint_dir):
raise ValueError(
"The returned checkpoint path must be within the "
"given checkpoint dir {}: {}".format(
checkpoint_dir, checkpoint))
if not os.path.exists(checkpoint):
raise ValueError(
"The returned checkpoint path does not exist: {}".format(
checkpoint))
checkpoint_path = checkpoint
elif isinstance(checkpoint, dict):
saved_as_dict = True
@@ -265,9 +260,9 @@ class Trainable(object):
with open(checkpoint_path, "wb") as f:
pickle.dump(checkpoint, f)
else:
raise ValueError(
"`_save` must return a dict or string type: {}".format(
str(type(checkpoint))))
raise ValueError("Returned unexpected type {}. "
"Expected str or dict.".format(type(checkpoint)))
with open(checkpoint_path + ".tune_metadata", "wb") as f:
pickle.dump({
"experiment_id": self._experiment_id,
@@ -288,25 +283,25 @@ class Trainable(object):
"""
tmpdir = tempfile.mkdtemp("save_to_object", dir=self.logdir)
checkpoint_prefix = self.save(tmpdir)
checkpoint_path = self.save(tmpdir)
# Save all files in subtree.
data = {}
base_dir = os.path.dirname(checkpoint_prefix)
for path in os.listdir(base_dir):
path = os.path.join(base_dir, path)
if path.startswith(checkpoint_prefix):
for basedir, _, file_names in os.walk(tmpdir):
for file_name in file_names:
path = os.path.join(basedir, file_name)
with open(path, "rb") as f:
data[os.path.basename(path)] = f.read()
data[os.path.relpath(path, tmpdir)] = f.read()
out = io.BytesIO()
data_dict = pickle.dumps({
"checkpoint_name": os.path.basename(checkpoint_prefix),
"checkpoint_name": os.path.relpath(checkpoint_path, tmpdir),
"data": data,
})
if len(data_dict) > 10e6: # getting pretty large
logger.info("Checkpoint size is {} bytes".format(len(data_dict)))
out.write(data_dict)
shutil.rmtree(tmpdir)
return out.getvalue()
@@ -318,7 +313,6 @@ class Trainable(object):
Subclasses should override ``_restore()`` instead to restore state.
This method restores additional metadata saved with the checkpoint.
"""
with open(checkpoint_path + ".tune_metadata", "rb") as f:
metadata = pickle.load(f)
self._experiment_id = metadata["experiment_id"]
@@ -330,6 +324,7 @@ class Trainable(object):
if saved_as_dict:
with open(checkpoint_path, "rb") as loaded_state:
checkpoint_dict = pickle.load(loaded_state)
checkpoint_dict.update(tune_checkpoint_path=checkpoint_path)
self._restore(checkpoint_dict)
else:
self._restore(checkpoint_path)
@@ -343,14 +338,18 @@ class Trainable(object):
These checkpoints are returned from calls to save_to_object().
"""
info = pickle.loads(obj)
data = info["data"]
tmpdir = tempfile.mkdtemp("restore_from_object", dir=self.logdir)
checkpoint_path = os.path.join(tmpdir, info["checkpoint_name"])
for file_name, file_contents in data.items():
with open(os.path.join(tmpdir, file_name), "wb") as f:
for relpath_name, file_contents in data.items():
path = os.path.join(tmpdir, relpath_name)
# This may be a subdirectory, hence not just using tmpdir
if not os.path.exists(os.path.dirname(path)):
os.makedirs(os.path.dirname(path))
with open(path, "wb") as f:
f.write(file_contents)
self.restore(checkpoint_path)
@@ -412,7 +411,7 @@ class Trainable(object):
Returns:
checkpoint (str | dict): If string, the return value is
expected to be the checkpoint path that will be passed to
expected to be the checkpoint path or prefix to be passed to
`_restore()`. If dict, the return value will be automatically
serialized by Tune and passed to `_restore()`.
+1 -1
View File
@@ -62,7 +62,7 @@ def run(run_or_experiment,
verbose=2,
resume=False,
queue_trials=False,
reuse_actors=True,
reuse_actors=False,
trial_executor=None,
raise_on_failed_trial=True,
return_trials=True,
+40
View File
@@ -223,6 +223,46 @@ def recursive_fnmatch(dirpath, pattern):
return matches
def validate_save_restore(trainable_cls, config=None, use_object_store=False):
"""Helper method to check if your Trainable class will resume correctly.
Args:
trainable_cls: Trainable class for evaluation.
config (dict): Config to pass to Trainable when testing.
use_object_store (bool): Whether to save and restore to Ray's object
store. Recommended to set this to True if planning to use
algorithms that pause training (i.e., PBT, HyperBand).
"""
assert ray.is_initialized(), "Need Ray to be initialized."
remote_cls = ray.remote(trainable_cls)
trainable_1 = remote_cls.remote(config=config)
trainable_2 = remote_cls.remote(config=config)
from ray.tune.result import TRAINING_ITERATION
for _ in range(3):
res = ray.get(trainable_1.train.remote())
assert res.get(TRAINING_ITERATION), (
"Validation will not pass because it requires `training_iteration` "
"to be returned.")
if use_object_store:
restore_check = trainable_2.restore_from_object.remote(
trainable_1.save_to_object.remote())
ray.get(restore_check)
else:
restore_check = ray.get(
trainable_2.restore.remote(trainable_1.save.remote()))
res = ray.get(trainable_2.train.remote())
assert res[TRAINING_ITERATION] == 4
res = ray.get(trainable_2.train.remote())
assert res[TRAINING_ITERATION] == 5
return True
if __name__ == "__main__":
ray.init()
X = pin_in_object_store("hello")