mirror of
https://github.com/wassname/ray.git
synced 2026-06-28 17:02:43 +08:00
[tune] Test example checkpointing (#4728)
This commit is contained in:
@@ -28,8 +28,8 @@ class MyTrainableClass(Trainable):
|
||||
|
||||
def _train(self):
|
||||
self.timestep += 1
|
||||
v = np.tanh(float(self.timestep) / self.config["width"])
|
||||
v *= self.config["height"]
|
||||
v = np.tanh(float(self.timestep) / self.config.get("width", 1))
|
||||
v *= self.config.get("height", 1)
|
||||
|
||||
# Here we use `episode_reward_mean`, but you can also report other
|
||||
# objectives such as loss or accuracy.
|
||||
|
||||
@@ -28,8 +28,8 @@ class MyTrainableClass(Trainable):
|
||||
|
||||
def _train(self):
|
||||
self.timestep += 1
|
||||
v = np.tanh(float(self.timestep) / self.config["width"])
|
||||
v *= self.config["height"]
|
||||
v = np.tanh(float(self.timestep) / self.config.get("width", 1))
|
||||
v *= self.config.get("height", 1)
|
||||
|
||||
# Here we use `episode_reward_mean`, but you can also report other
|
||||
# objectives such as loss or accuracy.
|
||||
|
||||
@@ -36,8 +36,8 @@ class MyTrainableClass(Trainable):
|
||||
|
||||
def _train(self):
|
||||
self.timestep += 1
|
||||
v = np.tanh(float(self.timestep) / self.config["width"])
|
||||
v *= self.config["height"]
|
||||
v = np.tanh(float(self.timestep) / self.config.get("width", 1))
|
||||
v *= self.config.get("height", 1)
|
||||
|
||||
# Here we use `episode_reward_mean`, but you can also report other
|
||||
# objectives such as loss or accuracy.
|
||||
|
||||
@@ -12,6 +12,10 @@ from torchvision import datasets, transforms
|
||||
|
||||
from ray.tune import Trainable
|
||||
|
||||
# Change these values if you want the training to run quicker or slower.
|
||||
EPOCH_SIZE = 512
|
||||
TEST_SIZE = 256
|
||||
|
||||
# Training settings
|
||||
parser = argparse.ArgumentParser(description="PyTorch MNIST Example")
|
||||
parser.add_argument(
|
||||
@@ -85,7 +89,7 @@ class Net(nn.Module):
|
||||
|
||||
class TrainMNIST(Trainable):
|
||||
def _setup(self, config):
|
||||
args = config.pop("args")
|
||||
args = config.pop("args", parser.parse_args([]))
|
||||
vars(args).update(config)
|
||||
args.cuda = not args.no_cuda and torch.cuda.is_available()
|
||||
|
||||
@@ -98,7 +102,7 @@ class TrainMNIST(Trainable):
|
||||
datasets.MNIST(
|
||||
"~/data",
|
||||
train=True,
|
||||
download=False,
|
||||
download=True,
|
||||
transform=transforms.Compose([
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize((0.1307, ), (0.3081, ))
|
||||
@@ -129,6 +133,8 @@ class TrainMNIST(Trainable):
|
||||
def _train_iteration(self):
|
||||
self.model.train()
|
||||
for batch_idx, (data, target) in enumerate(self.train_loader):
|
||||
if batch_idx * len(data) > EPOCH_SIZE:
|
||||
return
|
||||
if self.args.cuda:
|
||||
data, target = data.cuda(), target.cuda()
|
||||
self.optimizer.zero_grad()
|
||||
@@ -142,7 +148,9 @@ class TrainMNIST(Trainable):
|
||||
test_loss = 0
|
||||
correct = 0
|
||||
with torch.no_grad():
|
||||
for data, target in self.test_loader:
|
||||
for batch_idx, (data, target) in enumerate(self.test_loader):
|
||||
if batch_idx * len(data) > TEST_SIZE:
|
||||
break
|
||||
if self.args.cuda:
|
||||
data, target = data.cuda(), target.cuda()
|
||||
output = self.model(data)
|
||||
|
||||
@@ -19,7 +19,7 @@ import tensorflow as tf
|
||||
from tensorflow.python.keras.datasets import cifar10
|
||||
from tensorflow.python.keras.layers import Input, Dense, Dropout, Flatten
|
||||
from tensorflow.python.keras.layers import Convolution2D, MaxPooling2D
|
||||
from tensorflow.python.keras.models import Model
|
||||
from tensorflow.python.keras.models import Model, load_model
|
||||
from tensorflow.python.keras.preprocessing.image import ImageDataGenerator
|
||||
|
||||
import ray
|
||||
@@ -28,6 +28,7 @@ from ray.tune import Trainable
|
||||
from ray.tune.schedulers import PopulationBasedTraining
|
||||
|
||||
num_classes = 10
|
||||
NUM_SAMPLES = 128
|
||||
|
||||
|
||||
class Cifar10Model(Trainable):
|
||||
@@ -98,7 +99,7 @@ class Cifar10Model(Trainable):
|
||||
y = MaxPooling2D(pool_size=2, strides=2, padding="same")(y)
|
||||
|
||||
y = Flatten()(y)
|
||||
y = Dropout(self.config["dropout"])(y)
|
||||
y = Dropout(self.config.get("dropout", 0.5))(y)
|
||||
y = Dense(
|
||||
units=10, activation="softmax", kernel_initializer="he_normal")(y)
|
||||
|
||||
@@ -111,7 +112,8 @@ class Cifar10Model(Trainable):
|
||||
model = self._build_model(x_train.shape[1:])
|
||||
|
||||
opt = tf.keras.optimizers.Adadelta(
|
||||
lr=self.config["lr"], decay=self.config["decay"])
|
||||
lr=self.config.get("lr", 1e-4),
|
||||
decay=self.config.get("decay", 1e-4))
|
||||
model.compile(
|
||||
loss="categorical_crossentropy",
|
||||
optimizer=opt,
|
||||
@@ -120,7 +122,9 @@ class Cifar10Model(Trainable):
|
||||
|
||||
def _train(self):
|
||||
x_train, y_train = self.train_data
|
||||
x_train, y_train = x_train[:NUM_SAMPLES], y_train[:NUM_SAMPLES]
|
||||
x_test, y_test = self.test_data
|
||||
x_test, y_test = x_test[:NUM_SAMPLES], y_test[:NUM_SAMPLES]
|
||||
|
||||
aug_gen = ImageDataGenerator(
|
||||
# set input mean to 0 over the dataset
|
||||
@@ -146,12 +150,11 @@ class Cifar10Model(Trainable):
|
||||
)
|
||||
|
||||
aug_gen.fit(x_train)
|
||||
gen = aug_gen.flow(
|
||||
x_train, y_train, batch_size=self.config["batch_size"])
|
||||
batch_size = self.config.get("batch_size", 64)
|
||||
gen = aug_gen.flow(x_train, y_train, batch_size=batch_size)
|
||||
self.model.fit_generator(
|
||||
generator=gen,
|
||||
steps_per_epoch=50000 // self.config["batch_size"],
|
||||
epochs=self.config["epochs"],
|
||||
epochs=self.config.get("epochs", 1),
|
||||
validation_data=None)
|
||||
|
||||
# loss, accuracy
|
||||
@@ -160,11 +163,13 @@ class Cifar10Model(Trainable):
|
||||
|
||||
def _save(self, checkpoint_dir):
|
||||
file_path = checkpoint_dir + "/model"
|
||||
self.model.save_weights(file_path)
|
||||
self.model.save(file_path)
|
||||
return file_path
|
||||
|
||||
def _restore(self, path):
|
||||
self.model.load_weights(path)
|
||||
# See https://stackoverflow.com/a/42763323
|
||||
del self.model
|
||||
self.model = load_model(path)
|
||||
|
||||
def _stop(self):
|
||||
# If need, save your model when exit.
|
||||
|
||||
@@ -27,11 +27,12 @@ from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import argparse
|
||||
import os
|
||||
import time
|
||||
|
||||
import ray
|
||||
from ray import tune
|
||||
from ray.tune import grid_search, Trainable, sample_from
|
||||
from ray.tune import Trainable, sample_from
|
||||
from ray.tune.schedulers import HyperBandScheduler
|
||||
from tensorflow.examples.tutorials.mnist import input_data
|
||||
|
||||
@@ -148,7 +149,7 @@ class TrainMNIST(Trainable):
|
||||
self.x = tf.placeholder(tf.float32, [None, 784])
|
||||
self.y_ = tf.placeholder(tf.float32, [None, 10])
|
||||
|
||||
activation_fn = getattr(tf.nn, config["activation"])
|
||||
activation_fn = getattr(tf.nn, config.get("activation", "relu"))
|
||||
|
||||
# Build the graph for the deep net
|
||||
y_conv, self.keep_prob = setupCNN(self.x)
|
||||
@@ -160,7 +161,7 @@ class TrainMNIST(Trainable):
|
||||
|
||||
with tf.name_scope("adam_optimizer"):
|
||||
train_step = tf.train.AdamOptimizer(
|
||||
config["learning_rate"]).minimize(cross_entropy)
|
||||
config.get("learning_rate", 1e-4)).minimize(cross_entropy)
|
||||
|
||||
self.train_step = train_step
|
||||
|
||||
@@ -172,8 +173,7 @@ class TrainMNIST(Trainable):
|
||||
|
||||
self.sess = tf.Session()
|
||||
self.sess.run(tf.global_variables_initializer())
|
||||
self.iterations = 0
|
||||
self.saver = tf.train.Saver()
|
||||
self.saver = tf.train.Saver(save_relative_paths=True)
|
||||
|
||||
def _train(self):
|
||||
for i in range(10):
|
||||
@@ -194,18 +194,14 @@ class TrainMNIST(Trainable):
|
||||
self.y_: batch[1],
|
||||
self.keep_prob: 1.0
|
||||
})
|
||||
|
||||
self.iterations += 1
|
||||
return {"mean_accuracy": train_accuracy}
|
||||
|
||||
def _save(self, checkpoint_dir):
|
||||
prefix = self.saver.save(
|
||||
self.sess, checkpoint_dir + "/save", global_step=self.iterations)
|
||||
return {"prefix": prefix}
|
||||
path = self.saver.save(self.sess, os.path.join(checkpoint_dir, "save"))
|
||||
return path
|
||||
|
||||
def _restore(self, ckpt_data):
|
||||
prefix = ckpt_data["prefix"]
|
||||
return self.saver.restore(self.sess, prefix)
|
||||
def _restore(self, checkpoint_path):
|
||||
self.saver.restore(self.sess, checkpoint_path)
|
||||
|
||||
|
||||
# !!! Example of using the ray.tune Python API !!!
|
||||
@@ -222,14 +218,14 @@ if __name__ == "__main__":
|
||||
"config": {
|
||||
"learning_rate": sample_from(
|
||||
lambda spec: 10**np.random.uniform(-5, -3)),
|
||||
"activation": grid_search(["relu", "elu", "tanh"]),
|
||||
"activation": "relu",
|
||||
},
|
||||
"num_samples": 10,
|
||||
}
|
||||
|
||||
if args.smoke_test:
|
||||
mnist_spec["stop"]["training_iteration"] = 20
|
||||
mnist_spec["num_samples"] = 2
|
||||
mnist_spec["num_samples"] = 1
|
||||
|
||||
ray.init()
|
||||
hyperband = HyperBandScheduler(
|
||||
|
||||
@@ -10,7 +10,7 @@ import unittest
|
||||
|
||||
import ray
|
||||
from ray import tune
|
||||
from ray.tune.util import recursive_fnmatch
|
||||
from ray.tune.util import recursive_fnmatch, validate_save_restore
|
||||
from ray.rllib import _register_all
|
||||
|
||||
|
||||
@@ -53,6 +53,49 @@ class TuneRestoreTest(unittest.TestCase):
|
||||
)
|
||||
|
||||
|
||||
class TuneExampleTest(unittest.TestCase):
|
||||
def setUp(self):
|
||||
ray.init()
|
||||
|
||||
def tearDown(self):
|
||||
ray.shutdown()
|
||||
_register_all()
|
||||
|
||||
def testTensorFlowMNIST(self):
|
||||
from ray.tune.examples.tune_mnist_ray_hyperband import TrainMNIST
|
||||
validate_save_restore(TrainMNIST)
|
||||
validate_save_restore(TrainMNIST, use_object_store=True)
|
||||
|
||||
def testPBTKeras(self):
|
||||
from ray.tune.examples.pbt_tune_cifar10_with_keras import Cifar10Model
|
||||
from tensorflow.python.keras.datasets import cifar10
|
||||
cifar10.load_data()
|
||||
validate_save_restore(Cifar10Model)
|
||||
validate_save_restore(Cifar10Model, use_object_store=True)
|
||||
|
||||
def testPyTorchMNIST(self):
|
||||
from ray.tune.examples.mnist_pytorch_trainable import TrainMNIST
|
||||
from torchvision import datasets
|
||||
datasets.MNIST("~/data", train=True, download=True)
|
||||
validate_save_restore(TrainMNIST)
|
||||
validate_save_restore(TrainMNIST, use_object_store=True)
|
||||
|
||||
def testLogging(self):
|
||||
from ray.tune.examples.logging_example import MyTrainableClass
|
||||
validate_save_restore(MyTrainableClass)
|
||||
validate_save_restore(MyTrainableClass, use_object_store=True)
|
||||
|
||||
def testHyperbandExample(self):
|
||||
from ray.tune.examples.hyperband_example import MyTrainableClass
|
||||
validate_save_restore(MyTrainableClass)
|
||||
validate_save_restore(MyTrainableClass, use_object_store=True)
|
||||
|
||||
def testAsyncHyperbandExample(self):
|
||||
from ray.tune.examples.async_hyperband_example import MyTrainableClass
|
||||
validate_save_restore(MyTrainableClass)
|
||||
validate_save_restore(MyTrainableClass, use_object_store=True)
|
||||
|
||||
|
||||
class AutoInitTest(unittest.TestCase):
|
||||
def testTuneRestore(self):
|
||||
self.assertFalse(ray.is_initialized())
|
||||
|
||||
@@ -238,7 +238,7 @@ class Trainable(object):
|
||||
checkpoint_dir (str): Optional dir to place the checkpoint.
|
||||
|
||||
Returns:
|
||||
Checkpoint path that may be passed to restore().
|
||||
Checkpoint path or prefix that may be passed to restore().
|
||||
"""
|
||||
|
||||
checkpoint_dir = os.path.join(checkpoint_dir or self.logdir,
|
||||
@@ -248,16 +248,11 @@ class Trainable(object):
|
||||
checkpoint = self._save(checkpoint_dir)
|
||||
saved_as_dict = False
|
||||
if isinstance(checkpoint, string_types):
|
||||
if (not checkpoint.startswith(checkpoint_dir)
|
||||
or checkpoint == checkpoint_dir):
|
||||
if not checkpoint.startswith(checkpoint_dir):
|
||||
raise ValueError(
|
||||
"The returned checkpoint path must be within the "
|
||||
"given checkpoint dir {}: {}".format(
|
||||
checkpoint_dir, checkpoint))
|
||||
if not os.path.exists(checkpoint):
|
||||
raise ValueError(
|
||||
"The returned checkpoint path does not exist: {}".format(
|
||||
checkpoint))
|
||||
checkpoint_path = checkpoint
|
||||
elif isinstance(checkpoint, dict):
|
||||
saved_as_dict = True
|
||||
@@ -265,9 +260,9 @@ class Trainable(object):
|
||||
with open(checkpoint_path, "wb") as f:
|
||||
pickle.dump(checkpoint, f)
|
||||
else:
|
||||
raise ValueError(
|
||||
"`_save` must return a dict or string type: {}".format(
|
||||
str(type(checkpoint))))
|
||||
raise ValueError("Returned unexpected type {}. "
|
||||
"Expected str or dict.".format(type(checkpoint)))
|
||||
|
||||
with open(checkpoint_path + ".tune_metadata", "wb") as f:
|
||||
pickle.dump({
|
||||
"experiment_id": self._experiment_id,
|
||||
@@ -288,25 +283,25 @@ class Trainable(object):
|
||||
"""
|
||||
|
||||
tmpdir = tempfile.mkdtemp("save_to_object", dir=self.logdir)
|
||||
checkpoint_prefix = self.save(tmpdir)
|
||||
checkpoint_path = self.save(tmpdir)
|
||||
|
||||
# Save all files in subtree.
|
||||
data = {}
|
||||
base_dir = os.path.dirname(checkpoint_prefix)
|
||||
for path in os.listdir(base_dir):
|
||||
path = os.path.join(base_dir, path)
|
||||
if path.startswith(checkpoint_prefix):
|
||||
for basedir, _, file_names in os.walk(tmpdir):
|
||||
for file_name in file_names:
|
||||
path = os.path.join(basedir, file_name)
|
||||
|
||||
with open(path, "rb") as f:
|
||||
data[os.path.basename(path)] = f.read()
|
||||
data[os.path.relpath(path, tmpdir)] = f.read()
|
||||
|
||||
out = io.BytesIO()
|
||||
data_dict = pickle.dumps({
|
||||
"checkpoint_name": os.path.basename(checkpoint_prefix),
|
||||
"checkpoint_name": os.path.relpath(checkpoint_path, tmpdir),
|
||||
"data": data,
|
||||
})
|
||||
if len(data_dict) > 10e6: # getting pretty large
|
||||
logger.info("Checkpoint size is {} bytes".format(len(data_dict)))
|
||||
out.write(data_dict)
|
||||
|
||||
shutil.rmtree(tmpdir)
|
||||
return out.getvalue()
|
||||
|
||||
@@ -318,7 +313,6 @@ class Trainable(object):
|
||||
Subclasses should override ``_restore()`` instead to restore state.
|
||||
This method restores additional metadata saved with the checkpoint.
|
||||
"""
|
||||
|
||||
with open(checkpoint_path + ".tune_metadata", "rb") as f:
|
||||
metadata = pickle.load(f)
|
||||
self._experiment_id = metadata["experiment_id"]
|
||||
@@ -330,6 +324,7 @@ class Trainable(object):
|
||||
if saved_as_dict:
|
||||
with open(checkpoint_path, "rb") as loaded_state:
|
||||
checkpoint_dict = pickle.load(loaded_state)
|
||||
checkpoint_dict.update(tune_checkpoint_path=checkpoint_path)
|
||||
self._restore(checkpoint_dict)
|
||||
else:
|
||||
self._restore(checkpoint_path)
|
||||
@@ -343,14 +338,18 @@ class Trainable(object):
|
||||
|
||||
These checkpoints are returned from calls to save_to_object().
|
||||
"""
|
||||
|
||||
info = pickle.loads(obj)
|
||||
data = info["data"]
|
||||
tmpdir = tempfile.mkdtemp("restore_from_object", dir=self.logdir)
|
||||
checkpoint_path = os.path.join(tmpdir, info["checkpoint_name"])
|
||||
|
||||
for file_name, file_contents in data.items():
|
||||
with open(os.path.join(tmpdir, file_name), "wb") as f:
|
||||
for relpath_name, file_contents in data.items():
|
||||
path = os.path.join(tmpdir, relpath_name)
|
||||
|
||||
# This may be a subdirectory, hence not just using tmpdir
|
||||
if not os.path.exists(os.path.dirname(path)):
|
||||
os.makedirs(os.path.dirname(path))
|
||||
with open(path, "wb") as f:
|
||||
f.write(file_contents)
|
||||
|
||||
self.restore(checkpoint_path)
|
||||
@@ -412,7 +411,7 @@ class Trainable(object):
|
||||
|
||||
Returns:
|
||||
checkpoint (str | dict): If string, the return value is
|
||||
expected to be the checkpoint path that will be passed to
|
||||
expected to be the checkpoint path or prefix to be passed to
|
||||
`_restore()`. If dict, the return value will be automatically
|
||||
serialized by Tune and passed to `_restore()`.
|
||||
|
||||
|
||||
@@ -62,7 +62,7 @@ def run(run_or_experiment,
|
||||
verbose=2,
|
||||
resume=False,
|
||||
queue_trials=False,
|
||||
reuse_actors=True,
|
||||
reuse_actors=False,
|
||||
trial_executor=None,
|
||||
raise_on_failed_trial=True,
|
||||
return_trials=True,
|
||||
|
||||
@@ -223,6 +223,46 @@ def recursive_fnmatch(dirpath, pattern):
|
||||
return matches
|
||||
|
||||
|
||||
def validate_save_restore(trainable_cls, config=None, use_object_store=False):
|
||||
"""Helper method to check if your Trainable class will resume correctly.
|
||||
|
||||
Args:
|
||||
trainable_cls: Trainable class for evaluation.
|
||||
config (dict): Config to pass to Trainable when testing.
|
||||
use_object_store (bool): Whether to save and restore to Ray's object
|
||||
store. Recommended to set this to True if planning to use
|
||||
algorithms that pause training (i.e., PBT, HyperBand).
|
||||
"""
|
||||
assert ray.is_initialized(), "Need Ray to be initialized."
|
||||
remote_cls = ray.remote(trainable_cls)
|
||||
trainable_1 = remote_cls.remote(config=config)
|
||||
trainable_2 = remote_cls.remote(config=config)
|
||||
|
||||
from ray.tune.result import TRAINING_ITERATION
|
||||
|
||||
for _ in range(3):
|
||||
res = ray.get(trainable_1.train.remote())
|
||||
|
||||
assert res.get(TRAINING_ITERATION), (
|
||||
"Validation will not pass because it requires `training_iteration` "
|
||||
"to be returned.")
|
||||
|
||||
if use_object_store:
|
||||
restore_check = trainable_2.restore_from_object.remote(
|
||||
trainable_1.save_to_object.remote())
|
||||
ray.get(restore_check)
|
||||
else:
|
||||
restore_check = ray.get(
|
||||
trainable_2.restore.remote(trainable_1.save.remote()))
|
||||
|
||||
res = ray.get(trainable_2.train.remote())
|
||||
assert res[TRAINING_ITERATION] == 4
|
||||
|
||||
res = ray.get(trainable_2.train.remote())
|
||||
assert res[TRAINING_ITERATION] == 5
|
||||
return True
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
ray.init()
|
||||
X = pin_in_object_store("hello")
|
||||
|
||||
Reference in New Issue
Block a user