[tune] extend PTL template (GPU, typing fixes, tensorboard) (#9451)

Co-authored-by: Kai Fricke <kai@anyscale.com>
This commit is contained in:
krfricke
2020-07-15 19:30:20 +02:00
committed by GitHub
parent aa8928fac2
commit 5a40299d42
3 changed files with 119 additions and 39 deletions
@@ -13,8 +13,10 @@ import os
# __import_tune_begin__
import shutil
from functools import partial
from tempfile import mkdtemp
from pytorch_lightning.callbacks import Callback
from pytorch_lightning.loggers import TensorBoardLogger
from pytorch_lightning.utilities.cloud_io import load as pl_load
from ray import tune
from ray.tune import CLIReporter
@@ -74,7 +76,7 @@ class LightningMNISTClassifier(pl.LightningModule):
loss = self.cross_entropy_loss(logits, y)
accuracy = self.accuracy(logits, y)
logs = {"train_loss": loss, "train_accuracy": accuracy}
logs = {"ptl/train_loss": loss, "ptl/train_accuracy": accuracy}
return {"loss": loss, "log": logs}
def validation_step(self, val_batch, batch_idx):
@@ -88,12 +90,12 @@ class LightningMNISTClassifier(pl.LightningModule):
def validation_epoch_end(self, outputs):
avg_loss = torch.stack([x["val_loss"] for x in outputs]).mean()
avg_acc = torch.stack([x["val_accuracy"] for x in outputs]).mean()
tensorboard_logs = {"val_loss": avg_loss, "val_accuracy": avg_acc}
logs = {"ptl/val_loss": avg_loss, "ptl/val_accuracy": avg_acc}
return {
"avg_val_loss": avg_loss,
"avg_val_accuracy": avg_acc,
"log": tensorboard_logs
"log": logs
}
@staticmethod
@@ -133,16 +135,19 @@ def train_mnist(config):
class TuneReportCallback(Callback):
def on_validation_end(self, trainer, pl_module):
tune.report(
loss=trainer.callback_metrics["avg_val_loss"],
mean_accuracy=trainer.callback_metrics["avg_val_accuracy"])
loss=trainer.callback_metrics["avg_val_loss"].item(),
mean_accuracy=trainer.callback_metrics["avg_val_accuracy"].item())
# __tune_callback_end__
# __tune_train_begin__
def train_mnist_tune(config):
model = LightningMNISTClassifier(config, config["data_dir"])
def train_mnist_tune(config, data_dir=None, num_epochs=10, num_gpus=0):
model = LightningMNISTClassifier(config, data_dir)
trainer = pl.Trainer(
max_epochs=10,
max_epochs=num_epochs,
gpus=num_gpus,
logger=TensorBoardLogger(
save_dir=tune.get_trial_dir(), name="", version="."),
progress_bar_refresh_rate=0,
callbacks=[TuneReportCallback()])
@@ -160,9 +165,17 @@ class CheckpointCallback(Callback):
# __tune_train_checkpoint_begin__
def train_mnist_tune_checkpoint(config, checkpoint=None):
def train_mnist_tune_checkpoint(
config,
checkpoint=None,
data_dir=None,
num_epochs=10,
num_gpus=0):
trainer = pl.Trainer(
max_epochs=10,
max_epochs=num_epochs,
gpus=num_gpus,
logger=TensorBoardLogger(
save_dir=tune.get_trial_dir(), name="", version="."),
progress_bar_refresh_rate=0,
callbacks=[CheckpointCallback(),
TuneReportCallback()])
@@ -178,54 +191,64 @@ def train_mnist_tune_checkpoint(config, checkpoint=None):
trainer.current_epoch = ckpt["epoch"]
else:
model = LightningMNISTClassifier(
config=config, data_dir=config["data_dir"])
config=config, data_dir=data_dir)
trainer.fit(model)
# __tune_train_checkpoint_end__
# __tune_asha_begin__
def tune_mnist_asha(num_samples=10, max_num_epochs=10):
def tune_mnist_asha(num_samples=10, num_epochs=10, gpus_per_trial=0):
data_dir = mkdtemp(prefix="mnist_data_")
LightningMNISTClassifier.download_data(data_dir)
config = {
"layer_1_size": tune.choice([32, 64, 128]),
"layer_2_size": tune.choice([64, 128, 256]),
"lr": tune.loguniform(1e-4, 1e-1),
"batch_size": tune.choice([32, 64, 128]),
"data_dir": data_dir
}
scheduler = ASHAScheduler(
metric="loss",
mode="min",
max_t=max_num_epochs,
max_t=num_epochs,
grace_period=1,
reduction_factor=2)
reporter = CLIReporter(
parameter_columns=["layer_1_size", "layer_2_size", "lr", "batch_size"],
metric_columns=["loss", "mean_accuracy", "training_iteration"])
tune.run(
train_mnist_tune,
resources_per_trial={"cpu": 1},
partial(
train_mnist_tune,
data_dir=data_dir,
num_epochs=num_epochs,
num_gpus=gpus_per_trial),
resources_per_trial={"cpu": 1, "gpu": gpus_per_trial},
config=config,
num_samples=num_samples,
scheduler=scheduler,
progress_reporter=reporter)
progress_reporter=reporter,
name="tune_mnist_asha")
shutil.rmtree(data_dir)
# __tune_asha_end__
# __tune_pbt_begin__
def tune_mnist_pbt():
def tune_mnist_pbt(num_samples=10, num_epochs=10, gpus_per_trial=0):
data_dir = mkdtemp(prefix="mnist_data_")
LightningMNISTClassifier.download_data(data_dir)
config = {
"layer_1_size": tune.choice([32, 64, 128]),
"layer_2_size": tune.choice([64, 128, 256]),
"lr": 1e-3,
"batch_size": 64,
"data_dir": data_dir
}
scheduler = PopulationBasedTraining(
time_attr="training_iteration",
metric="loss",
@@ -235,16 +258,24 @@ def tune_mnist_pbt():
"lr": lambda: tune.loguniform(1e-4, 1e-1).func(None),
"batch_size": [32, 64, 128]
})
reporter = CLIReporter(
parameter_columns=["layer_1_size", "layer_2_size", "lr", "batch_size"],
metric_columns=["loss", "mean_accuracy", "training_iteration"])
tune.run(
train_mnist_tune_checkpoint,
resources_per_trial={"cpu": 1},
partial(
train_mnist_tune_checkpoint,
data_dir=data_dir,
num_epochs=num_epochs,
num_gpus=gpus_per_trial),
resources_per_trial={"cpu": 1, "gpu": gpus_per_trial},
config=config,
num_samples=10,
num_samples=num_samples,
scheduler=scheduler,
progress_reporter=reporter)
progress_reporter=reporter,
name="tune_mnist_pbt")
shutil.rmtree(data_dir)
# __tune_pbt_end__
@@ -258,7 +289,10 @@ if __name__ == "__main__":
args, _ = parser.parse_known_args()
if args.smoke_test:
tune_mnist_asha(1, 1)
tune_mnist_asha(num_samples=1, num_epochs=1, gpus_per_trial=0)
tune_mnist_pbt(num_samples=1, num_epochs=1, gpus_per_trial=0)
else:
tune_mnist_asha() # ASHA scheduler
tune_mnist_pbt() # population based training
# ASHA scheduler
tune_mnist_asha(num_samples=10, num_epochs=10, gpus_per_trial=0)
# Population based training
tune_mnist_pbt(num_samples=10, num_epochs=10, gpus_per_trial=0)
+6 -4
View File
@@ -1,4 +1,6 @@
import logging
import random
import numpy as np
logger = logging.getLogger(__name__)
@@ -56,13 +58,13 @@ def loguniform(min_bound, max_bound, base=10):
def choice(*args, **kwargs):
"""Wraps tune.sample_from around ``np.random.choice``.
"""Wraps tune.sample_from around ``random.choice``.
``tune.choice(10)`` is equivalent to
``tune.sample_from(lambda _: np.random.choice(10))``
``tune.choice([1, 2])`` is equivalent to
``tune.sample_from(lambda _: random.choice([1, 2]))``
"""
return sample_from(lambda _: np.random.choice(*args, **kwargs))
return sample_from(lambda _: random.choice(*args, **kwargs))
def randint(*args, **kwargs):