mirror of
https://github.com/wassname/ray.git
synced 2026-06-27 21:23:10 +08:00
[tune] extend PTL template (GPU, typing fixes, tensorboard) (#9451)
Co-authored-by: Kai Fricke <kai@anyscale.com>
This commit is contained in:
@@ -13,8 +13,10 @@ import os
|
||||
|
||||
# __import_tune_begin__
|
||||
import shutil
|
||||
from functools import partial
|
||||
from tempfile import mkdtemp
|
||||
from pytorch_lightning.callbacks import Callback
|
||||
from pytorch_lightning.loggers import TensorBoardLogger
|
||||
from pytorch_lightning.utilities.cloud_io import load as pl_load
|
||||
from ray import tune
|
||||
from ray.tune import CLIReporter
|
||||
@@ -74,7 +76,7 @@ class LightningMNISTClassifier(pl.LightningModule):
|
||||
loss = self.cross_entropy_loss(logits, y)
|
||||
accuracy = self.accuracy(logits, y)
|
||||
|
||||
logs = {"train_loss": loss, "train_accuracy": accuracy}
|
||||
logs = {"ptl/train_loss": loss, "ptl/train_accuracy": accuracy}
|
||||
return {"loss": loss, "log": logs}
|
||||
|
||||
def validation_step(self, val_batch, batch_idx):
|
||||
@@ -88,12 +90,12 @@ class LightningMNISTClassifier(pl.LightningModule):
|
||||
def validation_epoch_end(self, outputs):
|
||||
avg_loss = torch.stack([x["val_loss"] for x in outputs]).mean()
|
||||
avg_acc = torch.stack([x["val_accuracy"] for x in outputs]).mean()
|
||||
tensorboard_logs = {"val_loss": avg_loss, "val_accuracy": avg_acc}
|
||||
logs = {"ptl/val_loss": avg_loss, "ptl/val_accuracy": avg_acc}
|
||||
|
||||
return {
|
||||
"avg_val_loss": avg_loss,
|
||||
"avg_val_accuracy": avg_acc,
|
||||
"log": tensorboard_logs
|
||||
"log": logs
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
@@ -133,16 +135,19 @@ def train_mnist(config):
|
||||
class TuneReportCallback(Callback):
|
||||
def on_validation_end(self, trainer, pl_module):
|
||||
tune.report(
|
||||
loss=trainer.callback_metrics["avg_val_loss"],
|
||||
mean_accuracy=trainer.callback_metrics["avg_val_accuracy"])
|
||||
loss=trainer.callback_metrics["avg_val_loss"].item(),
|
||||
mean_accuracy=trainer.callback_metrics["avg_val_accuracy"].item())
|
||||
# __tune_callback_end__
|
||||
|
||||
|
||||
# __tune_train_begin__
|
||||
def train_mnist_tune(config):
|
||||
model = LightningMNISTClassifier(config, config["data_dir"])
|
||||
def train_mnist_tune(config, data_dir=None, num_epochs=10, num_gpus=0):
|
||||
model = LightningMNISTClassifier(config, data_dir)
|
||||
trainer = pl.Trainer(
|
||||
max_epochs=10,
|
||||
max_epochs=num_epochs,
|
||||
gpus=num_gpus,
|
||||
logger=TensorBoardLogger(
|
||||
save_dir=tune.get_trial_dir(), name="", version="."),
|
||||
progress_bar_refresh_rate=0,
|
||||
callbacks=[TuneReportCallback()])
|
||||
|
||||
@@ -160,9 +165,17 @@ class CheckpointCallback(Callback):
|
||||
|
||||
|
||||
# __tune_train_checkpoint_begin__
|
||||
def train_mnist_tune_checkpoint(config, checkpoint=None):
|
||||
def train_mnist_tune_checkpoint(
|
||||
config,
|
||||
checkpoint=None,
|
||||
data_dir=None,
|
||||
num_epochs=10,
|
||||
num_gpus=0):
|
||||
trainer = pl.Trainer(
|
||||
max_epochs=10,
|
||||
max_epochs=num_epochs,
|
||||
gpus=num_gpus,
|
||||
logger=TensorBoardLogger(
|
||||
save_dir=tune.get_trial_dir(), name="", version="."),
|
||||
progress_bar_refresh_rate=0,
|
||||
callbacks=[CheckpointCallback(),
|
||||
TuneReportCallback()])
|
||||
@@ -178,54 +191,64 @@ def train_mnist_tune_checkpoint(config, checkpoint=None):
|
||||
trainer.current_epoch = ckpt["epoch"]
|
||||
else:
|
||||
model = LightningMNISTClassifier(
|
||||
config=config, data_dir=config["data_dir"])
|
||||
config=config, data_dir=data_dir)
|
||||
|
||||
trainer.fit(model)
|
||||
# __tune_train_checkpoint_end__
|
||||
|
||||
|
||||
# __tune_asha_begin__
|
||||
def tune_mnist_asha(num_samples=10, max_num_epochs=10):
|
||||
def tune_mnist_asha(num_samples=10, num_epochs=10, gpus_per_trial=0):
|
||||
data_dir = mkdtemp(prefix="mnist_data_")
|
||||
LightningMNISTClassifier.download_data(data_dir)
|
||||
|
||||
config = {
|
||||
"layer_1_size": tune.choice([32, 64, 128]),
|
||||
"layer_2_size": tune.choice([64, 128, 256]),
|
||||
"lr": tune.loguniform(1e-4, 1e-1),
|
||||
"batch_size": tune.choice([32, 64, 128]),
|
||||
"data_dir": data_dir
|
||||
}
|
||||
|
||||
scheduler = ASHAScheduler(
|
||||
metric="loss",
|
||||
mode="min",
|
||||
max_t=max_num_epochs,
|
||||
max_t=num_epochs,
|
||||
grace_period=1,
|
||||
reduction_factor=2)
|
||||
|
||||
reporter = CLIReporter(
|
||||
parameter_columns=["layer_1_size", "layer_2_size", "lr", "batch_size"],
|
||||
metric_columns=["loss", "mean_accuracy", "training_iteration"])
|
||||
|
||||
tune.run(
|
||||
train_mnist_tune,
|
||||
resources_per_trial={"cpu": 1},
|
||||
partial(
|
||||
train_mnist_tune,
|
||||
data_dir=data_dir,
|
||||
num_epochs=num_epochs,
|
||||
num_gpus=gpus_per_trial),
|
||||
resources_per_trial={"cpu": 1, "gpu": gpus_per_trial},
|
||||
config=config,
|
||||
num_samples=num_samples,
|
||||
scheduler=scheduler,
|
||||
progress_reporter=reporter)
|
||||
progress_reporter=reporter,
|
||||
name="tune_mnist_asha")
|
||||
|
||||
shutil.rmtree(data_dir)
|
||||
# __tune_asha_end__
|
||||
|
||||
|
||||
# __tune_pbt_begin__
|
||||
def tune_mnist_pbt():
|
||||
def tune_mnist_pbt(num_samples=10, num_epochs=10, gpus_per_trial=0):
|
||||
data_dir = mkdtemp(prefix="mnist_data_")
|
||||
LightningMNISTClassifier.download_data(data_dir)
|
||||
|
||||
config = {
|
||||
"layer_1_size": tune.choice([32, 64, 128]),
|
||||
"layer_2_size": tune.choice([64, 128, 256]),
|
||||
"lr": 1e-3,
|
||||
"batch_size": 64,
|
||||
"data_dir": data_dir
|
||||
}
|
||||
|
||||
scheduler = PopulationBasedTraining(
|
||||
time_attr="training_iteration",
|
||||
metric="loss",
|
||||
@@ -235,16 +258,24 @@ def tune_mnist_pbt():
|
||||
"lr": lambda: tune.loguniform(1e-4, 1e-1).func(None),
|
||||
"batch_size": [32, 64, 128]
|
||||
})
|
||||
|
||||
reporter = CLIReporter(
|
||||
parameter_columns=["layer_1_size", "layer_2_size", "lr", "batch_size"],
|
||||
metric_columns=["loss", "mean_accuracy", "training_iteration"])
|
||||
|
||||
tune.run(
|
||||
train_mnist_tune_checkpoint,
|
||||
resources_per_trial={"cpu": 1},
|
||||
partial(
|
||||
train_mnist_tune_checkpoint,
|
||||
data_dir=data_dir,
|
||||
num_epochs=num_epochs,
|
||||
num_gpus=gpus_per_trial),
|
||||
resources_per_trial={"cpu": 1, "gpu": gpus_per_trial},
|
||||
config=config,
|
||||
num_samples=10,
|
||||
num_samples=num_samples,
|
||||
scheduler=scheduler,
|
||||
progress_reporter=reporter)
|
||||
progress_reporter=reporter,
|
||||
name="tune_mnist_pbt")
|
||||
|
||||
shutil.rmtree(data_dir)
|
||||
# __tune_pbt_end__
|
||||
|
||||
@@ -258,7 +289,10 @@ if __name__ == "__main__":
|
||||
args, _ = parser.parse_known_args()
|
||||
|
||||
if args.smoke_test:
|
||||
tune_mnist_asha(1, 1)
|
||||
tune_mnist_asha(num_samples=1, num_epochs=1, gpus_per_trial=0)
|
||||
tune_mnist_pbt(num_samples=1, num_epochs=1, gpus_per_trial=0)
|
||||
else:
|
||||
tune_mnist_asha() # ASHA scheduler
|
||||
tune_mnist_pbt() # population based training
|
||||
# ASHA scheduler
|
||||
tune_mnist_asha(num_samples=10, num_epochs=10, gpus_per_trial=0)
|
||||
# Population based training
|
||||
tune_mnist_pbt(num_samples=10, num_epochs=10, gpus_per_trial=0)
|
||||
|
||||
@@ -1,4 +1,6 @@
|
||||
import logging
|
||||
import random
|
||||
|
||||
import numpy as np
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -56,13 +58,13 @@ def loguniform(min_bound, max_bound, base=10):
|
||||
|
||||
|
||||
def choice(*args, **kwargs):
|
||||
"""Wraps tune.sample_from around ``np.random.choice``.
|
||||
"""Wraps tune.sample_from around ``random.choice``.
|
||||
|
||||
``tune.choice(10)`` is equivalent to
|
||||
``tune.sample_from(lambda _: np.random.choice(10))``
|
||||
``tune.choice([1, 2])`` is equivalent to
|
||||
``tune.sample_from(lambda _: random.choice([1, 2]))``
|
||||
|
||||
"""
|
||||
return sample_from(lambda _: np.random.choice(*args, **kwargs))
|
||||
return sample_from(lambda _: random.choice(*args, **kwargs))
|
||||
|
||||
|
||||
def randint(*args, **kwargs):
|
||||
|
||||
Reference in New Issue
Block a user