mirror of
https://github.com/wassname/attentive-neural-processes.git
synced 2026-06-27 16:29:15 +08:00
175 lines
6.2 KiB
Python
175 lines
6.2 KiB
Python
from pytorch_lightning.callbacks import EarlyStopping
|
|
from optuna.integration.pytorch_lightning import _check_pytorch_lightning_availability
|
|
from pathlib import Path
|
|
import optuna
|
|
import pytorch_lightning as pl
|
|
from matplotlib import pyplot as plt
|
|
import torch
|
|
from .dict_logger import DictLogger
|
|
from .utils import PyTorchLightningPruningCallback
|
|
from .plot import plot_from_loader
|
|
from .logger import logger
|
|
|
|
|
|
def main(
|
|
trial: optuna.Trial,
|
|
PL_MODEL_CLS: pl.LightningModule,
|
|
name: str,
|
|
MODEL_DIR: Path = Path("./lightning_logs"),
|
|
train=True,
|
|
prune=True,
|
|
PERCENT_TEST_EXAMPLES=0.5,
|
|
):
|
|
# PyTorch Lightning will try to restore model parameters from previous trials if checkpoint
|
|
# filenames match. Therefore, the filenames for each trial must be made unique.
|
|
|
|
checkpoint_callback = pl.callbacks.ModelCheckpoint(
|
|
MODEL_DIR / name / "version_{}".format(trial.number) / "chk",
|
|
monitor="val_loss",
|
|
mode="min",
|
|
)
|
|
|
|
# The default logger in PyTorch Lightning writes to event files to be consumed by
|
|
# TensorBoard. We create a simple logger instead that holds the log in memory so that the
|
|
# final accuracy can be obtained after optimization. When using the default logger, the
|
|
# final accuracy could be stored in an attribute of the `Trainer` instead.
|
|
logger = DictLogger(MODEL_DIR, name=name, version=trial.number)
|
|
# print("log_dir", logger.experiment.log_dir)
|
|
hparams = dict(**trial.params, **trial.user_attrs)
|
|
|
|
trainer = pl.Trainer(
|
|
logger=logger,
|
|
val_percent_check=PERCENT_TEST_EXAMPLES,
|
|
checkpoint_callback=checkpoint_callback,
|
|
max_epochs=hparams["max_nb_epochs"],
|
|
weights_summary='top',
|
|
gpus=-1 if torch.cuda.is_available() else None,
|
|
early_stop_callback=PyTorchLightningPruningCallback(trial, monitor="val_loss")
|
|
if prune
|
|
else EarlyStopping(
|
|
patience=hparams["patience"] * 2, monitor="val_loss", verbose=True
|
|
),
|
|
)
|
|
|
|
model = PL_MODEL_CLS(hparams)
|
|
if train:
|
|
trainer.fit(model)
|
|
return model, trainer
|
|
|
|
|
|
def objective(trial, PL_MODEL_CLS, name, user_attrs):
|
|
"""For optuna hparam opt."""
|
|
# see https://github.com/optuna/optuna/blob/cf6f02d/examples/pytorch_lightning_simple.py
|
|
trial = PL_MODEL_CLS.add_suggest(trial)
|
|
[trial.set_user_attr(k, v) for k, v in user_attrs.items()]
|
|
|
|
logger.debug(dict(number=trial.number, params=trial.params, user_attrs=trial.user_attrs))
|
|
|
|
model, trainer = main(trial, PL_MODEL_CLS=PL_MODEL_CLS, name=name)
|
|
|
|
# Load checkpoint
|
|
checkpoints = sorted(Path(trainer.checkpoint_callback.dirpath).glob("*.ckpt"))
|
|
if len(checkpoints):
|
|
checkpoint = checkpoints[-1]
|
|
device = next(model.parameters()).device
|
|
logger.info(f"Loading checkpoint {checkpoint}")
|
|
model = model.load_from_checkpoint(checkpoint).to(device)
|
|
|
|
trainer.test(model)
|
|
|
|
# also report to tensorboard & print
|
|
logger.info("logger.metrics %s", model.logger.metrics[-1:])
|
|
model.logger.experiment.add_hparams(trial.params, model.logger.metrics[-1])
|
|
model.logger.save()
|
|
|
|
return model.logger.metrics[-1]["agg_test_score"]
|
|
|
|
|
|
def add_number(trial: optuna.Trial, model_dir: Path):
|
|
# For manual experiment we will start at -1 and deincr by 1
|
|
versions = [int(s.stem.split("_")[-1]) for s in model_dir.glob("version_*")] + [-1]
|
|
trial.number = min(versions) - 1
|
|
# logger.debug("trial.number", trial.number)
|
|
return trial
|
|
|
|
|
|
def run_trial(
|
|
name: str,
|
|
PL_MODEL_CLS: pl.LightningModule,
|
|
params: dict = {},
|
|
user_attrs: dict = {},
|
|
MODEL_DIR: Path = Path("./lightning_logs"),
|
|
plot_from_loader=plot_from_loader,
|
|
number=None,
|
|
):
|
|
logger.info(f"now run `tensorboard --logdir {MODEL_DIR}`")
|
|
(MODEL_DIR / name).mkdir(parents=True, exist_ok=True)
|
|
|
|
if getattr(PL_MODEL_CLS, 'DEFAULT_ARGS', None):
|
|
# add default args
|
|
params = {**PL_MODEL_CLS.DEFAULT_ARGS, **params}
|
|
else:
|
|
logger.warning(f"No default args on {PL_MODEL_CLS}")
|
|
|
|
# Make trial
|
|
trial = optuna.trial.FixedTrial(params=params)
|
|
trial = PL_MODEL_CLS.add_suggest(trial)
|
|
|
|
if number is None:
|
|
trial = add_number(trial, MODEL_DIR / name)
|
|
else:
|
|
trial.number = number
|
|
|
|
|
|
# Add user attributes
|
|
[trial.set_user_attr(k, v) for k, v in user_attrs.items()]
|
|
model, trainer = main(
|
|
trial, PL_MODEL_CLS, name=name, MODEL_DIR=MODEL_DIR, train=False, prune=False
|
|
)
|
|
logger.info('trial number=%s name=%s, trial=%s params=%s attrs=%s', trial.number, trainer.logger.name, trial, trial.params, trial.user_attrs)
|
|
|
|
|
|
checkpoints = sorted(Path(trainer.checkpoint_callback.dirpath).glob("*.ckpt"))
|
|
if len(checkpoints)==0 or number is None:
|
|
try:
|
|
trainer.fit(model)
|
|
except KeyboardInterrupt:
|
|
logger.warning('KeyboardInterrupt, skipping rest of training')
|
|
pass
|
|
|
|
# Plot
|
|
loader = model.val_dataloader()
|
|
dset_test = loader.dataset
|
|
label_names = dset_test.label_names
|
|
plot_from_loader(model.val_dataloader(), model, i=670, title='overfit val 670')
|
|
plt.show()
|
|
plot_from_loader(model.train_dataloader(), model, i=670, title='overfit train 670')
|
|
plt.show()
|
|
plot_from_loader(model.test_dataloader(), model, i=670, title='overfit test 670')
|
|
plt.show()
|
|
|
|
# Load checkpoint
|
|
checkpoints = sorted(Path(trainer.checkpoint_callback.dirpath).glob("*.ckpt"))
|
|
if len(checkpoints):
|
|
checkpoint = checkpoints[-1]
|
|
device = next(model.parameters()).device
|
|
logger.info(f"Loading checkpoint {checkpoint}")
|
|
model = model.load_from_checkpoint(checkpoint).to(device)
|
|
|
|
# Plot
|
|
plot_from_loader(model.val_dataloader(), model, i=670, title='val 670')
|
|
plt.show()
|
|
plot_from_loader(model.train_dataloader(), model, i=670, title='train 670')
|
|
plt.show()
|
|
plot_from_loader(model.test_dataloader(), model, i=670, title='test 670')
|
|
plt.show()
|
|
else:
|
|
logger.warning('no checkpoints')
|
|
|
|
try:
|
|
trainer.test(model)
|
|
except KeyboardInterrupt:
|
|
logger.warning('KeyboardInterrupt, skipping rest of testing')
|
|
pass
|
|
return trial, trainer, model
|