[Tune, Ray SGD] Update PTL integrations (#11271)

This commit is contained in:
Amog Kamsetty
2020-10-08 13:43:07 -07:00
committed by GitHub
parent a6f91664c1
commit 1027bfd4b8
4 changed files with 25 additions and 24 deletions
@@ -77,27 +77,23 @@ class LightningMNISTClassifier(pl.LightningModule):
loss = self.cross_entropy_loss(logits, y)
accuracy = self.accuracy(logits, y)
logs = {"ptl/train_loss": loss, "ptl/train_accuracy": accuracy}
return {"loss": loss, "log": logs}
self.log("ptl/train_loss", loss)
self.log("ptl/train_accuracy", accuracy)
return loss
def validation_step(self, val_batch, batch_idx):
x, y = val_batch
logits = self.forward(x)
loss = self.cross_entropy_loss(logits, y)
accuracy = self.accuracy(logits, y)
return {"val_loss": loss, "val_accuracy": accuracy}
def validation_epoch_end(self, outputs):
avg_loss = torch.stack([x["val_loss"] for x in outputs]).mean()
avg_acc = torch.stack([x["val_accuracy"] for x in outputs]).mean()
logs = {"ptl/val_loss": avg_loss, "ptl/val_accuracy": avg_acc}
self.log("ptl/val_loss", avg_loss)
self.log("ptl/val_accuracy", avg_acc)
return {
"val_loss": avg_loss,
"val_accuracy": avg_acc,
"log": logs
}
@staticmethod
def download_data(data_dir):
@@ -144,12 +140,11 @@ def train_mnist_tune(config, data_dir=None, num_epochs=10, num_gpus=0):
callbacks=[
TuneReportCallback(
{
"loss": "val_loss",
"mean_accuracy": "val_accuracy"
"loss": "ptl/val_loss",
"mean_accuracy": "ptl/val_accuracy"
},
on="validation_end")
])
trainer.fit(model)
# __tune_train_end__
@@ -169,8 +164,8 @@ def train_mnist_tune_checkpoint(config,
callbacks=[
TuneReportCheckpointCallback(
metrics={
"loss": "val_loss",
"mean_accuracy": "val_accuracy"
"loss": "ptl/val_loss",
"mean_accuracy": "ptl/val_accuracy"
},
filename="checkpoint",
on="validation_end")
@@ -170,6 +170,9 @@ class TuneReportCallback(TuneCallback):
self._metrics = metrics
def _handle(self, trainer: Trainer, pl_module: LightningModule):
# Don't report if just doing initial validation sanity checks.
if trainer.running_sanity_check:
return
report_dict = {}
for key in self._metrics:
if isinstance(self._metrics, dict):
@@ -206,6 +209,8 @@ class _TuneCheckpointCallback(TuneCallback):
self._filename = filename
def _handle(self, trainer: Trainer, pl_module: LightningModule):
if trainer.running_sanity_check:
return
with tune.checkpoint_dir(step=trainer.global_step) as checkpoint_dir:
trainer.save_checkpoint(
os.path.join(checkpoint_dir, self._filename))
@@ -46,8 +46,8 @@ class _MockModule(pl.LightningModule):
def validation_epoch_end(self, outputs):
avg_val_loss = torch.stack([x["val_loss"] for x in outputs]).mean()
avg_val_acc = torch.stack([x["val_acc"] for x in outputs]).mean()
return {"avg_val_loss": avg_val_loss, "avg_val_acc": avg_val_acc}
self.log("avg_val_loss", avg_val_loss)
self.log("avg_val_acc", avg_val_acc)
def configure_optimizers(self):
return None
+9 -8
View File
@@ -5,6 +5,7 @@ import torch
from pytorch_lightning.core.step_result import Result
from pytorch_lightning.overrides.data_parallel import \
LightningDistributedDataParallel
from pytorch_lightning.utilities.model_utils import is_overridden
from pytorch_lightning.trainer.model_hooks import TrainerModelHooksMixin
from pytorch_lightning.trainer.optimizers import TrainerOptimizersMixin
import pytorch_lightning as ptl
@@ -39,8 +40,8 @@ class LightningOperator(TrainingOperator, TrainerModelHooksMixin,
assert len(models) == 1
model = models[0]
assert isinstance(model, ptl.LightningModule)
# This will default to LightningDistributedDataParallel.
model = model.configure_ddp(model=model, device_ids=device_ids)
model = LightningDistributedDataParallel(
model, device_ids=device_ids, find_unused_parameters=True)
return [model]
@property
@@ -110,7 +111,7 @@ class LightningOperator(TrainingOperator, TrainerModelHooksMixin,
# Call model.setup.
ptl_module.setup("fit")
if not self.is_overridden("configure_optimizers", ptl_module):
if not is_overridden("configure_optimizers", ptl_module):
raise MisconfigurationException(
"No `configure_optimizers()` method defined.")
@@ -232,7 +233,7 @@ class LightningOperator(TrainingOperator, TrainerModelHooksMixin,
break
processed_outputs = None
if self.is_overridden("training_epoch_end", model):
if is_overridden("training_epoch_end", model):
raw_outputs = [eo["raw_output"] for eo in epoch_outputs]
processed_outputs = model.training_epoch_end(raw_outputs)
@@ -316,7 +317,7 @@ class LightningOperator(TrainingOperator, TrainerModelHooksMixin,
# allow any mode to define training_step_end
# do something will all the dp outputs (like softmax)
if self.is_overridden("training_step_end", model):
if is_overridden("training_step_end", model):
output = model.training_step_end(output)
# Extract loss from output if dictionary.
@@ -397,7 +398,7 @@ class LightningOperator(TrainingOperator, TrainerModelHooksMixin,
val_outputs.append(batch_output)
processed_outputs = None
if self.is_overridden("validation_epoch_end", model):
if is_overridden("validation_epoch_end", model):
raw_outputs = [vo["raw_output"] for vo in val_outputs]
processed_outputs = model.training_epoch_end(raw_outputs)
@@ -440,7 +441,7 @@ class LightningOperator(TrainingOperator, TrainerModelHooksMixin,
def validate_batch(self, batch, batch_info):
model = self.get_model()
batch_idx = batch_info["batch_idx"]
if self.is_overridden("on_validation_batch_start", model):
if is_overridden("on_validation_batch_start", model):
model.on_validation_batch_start(
batch=batch, batch_idx=batch_idx, dataloader_idx=0)
args = [batch, batch_idx]
@@ -462,7 +463,7 @@ class LightningOperator(TrainingOperator, TrainerModelHooksMixin,
raise ValueError("EvalResult objects are not supported. Please "
"return a dictionary instead.")
if self.is_overridden("on_validation_step_end", model):
if is_overridden("on_validation_step_end", model):
output = model.validation_step_end(output)
if self.is_function_implemented("on_validation_batch_end", model):