mirror of
https://github.com/wassname/ray.git
synced 2026-06-27 23:08:32 +08:00
[Tune, Ray SGD] Update PTL integrations (#11271)
This commit is contained in:
@@ -77,27 +77,23 @@ class LightningMNISTClassifier(pl.LightningModule):
|
||||
loss = self.cross_entropy_loss(logits, y)
|
||||
accuracy = self.accuracy(logits, y)
|
||||
|
||||
logs = {"ptl/train_loss": loss, "ptl/train_accuracy": accuracy}
|
||||
return {"loss": loss, "log": logs}
|
||||
self.log("ptl/train_loss", loss)
|
||||
self.log("ptl/train_accuracy", accuracy)
|
||||
return loss
|
||||
|
||||
def validation_step(self, val_batch, batch_idx):
|
||||
x, y = val_batch
|
||||
logits = self.forward(x)
|
||||
loss = self.cross_entropy_loss(logits, y)
|
||||
accuracy = self.accuracy(logits, y)
|
||||
|
||||
return {"val_loss": loss, "val_accuracy": accuracy}
|
||||
|
||||
def validation_epoch_end(self, outputs):
|
||||
avg_loss = torch.stack([x["val_loss"] for x in outputs]).mean()
|
||||
avg_acc = torch.stack([x["val_accuracy"] for x in outputs]).mean()
|
||||
logs = {"ptl/val_loss": avg_loss, "ptl/val_accuracy": avg_acc}
|
||||
self.log("ptl/val_loss", avg_loss)
|
||||
self.log("ptl/val_accuracy", avg_acc)
|
||||
|
||||
return {
|
||||
"val_loss": avg_loss,
|
||||
"val_accuracy": avg_acc,
|
||||
"log": logs
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def download_data(data_dir):
|
||||
@@ -144,12 +140,11 @@ def train_mnist_tune(config, data_dir=None, num_epochs=10, num_gpus=0):
|
||||
callbacks=[
|
||||
TuneReportCallback(
|
||||
{
|
||||
"loss": "val_loss",
|
||||
"mean_accuracy": "val_accuracy"
|
||||
"loss": "ptl/val_loss",
|
||||
"mean_accuracy": "ptl/val_accuracy"
|
||||
},
|
||||
on="validation_end")
|
||||
])
|
||||
|
||||
trainer.fit(model)
|
||||
# __tune_train_end__
|
||||
|
||||
@@ -169,8 +164,8 @@ def train_mnist_tune_checkpoint(config,
|
||||
callbacks=[
|
||||
TuneReportCheckpointCallback(
|
||||
metrics={
|
||||
"loss": "val_loss",
|
||||
"mean_accuracy": "val_accuracy"
|
||||
"loss": "ptl/val_loss",
|
||||
"mean_accuracy": "ptl/val_accuracy"
|
||||
},
|
||||
filename="checkpoint",
|
||||
on="validation_end")
|
||||
|
||||
@@ -170,6 +170,9 @@ class TuneReportCallback(TuneCallback):
|
||||
self._metrics = metrics
|
||||
|
||||
def _handle(self, trainer: Trainer, pl_module: LightningModule):
|
||||
# Don't report if just doing initial validation sanity checks.
|
||||
if trainer.running_sanity_check:
|
||||
return
|
||||
report_dict = {}
|
||||
for key in self._metrics:
|
||||
if isinstance(self._metrics, dict):
|
||||
@@ -206,6 +209,8 @@ class _TuneCheckpointCallback(TuneCallback):
|
||||
self._filename = filename
|
||||
|
||||
def _handle(self, trainer: Trainer, pl_module: LightningModule):
|
||||
if trainer.running_sanity_check:
|
||||
return
|
||||
with tune.checkpoint_dir(step=trainer.global_step) as checkpoint_dir:
|
||||
trainer.save_checkpoint(
|
||||
os.path.join(checkpoint_dir, self._filename))
|
||||
|
||||
@@ -46,8 +46,8 @@ class _MockModule(pl.LightningModule):
|
||||
def validation_epoch_end(self, outputs):
|
||||
avg_val_loss = torch.stack([x["val_loss"] for x in outputs]).mean()
|
||||
avg_val_acc = torch.stack([x["val_acc"] for x in outputs]).mean()
|
||||
|
||||
return {"avg_val_loss": avg_val_loss, "avg_val_acc": avg_val_acc}
|
||||
self.log("avg_val_loss", avg_val_loss)
|
||||
self.log("avg_val_acc", avg_val_acc)
|
||||
|
||||
def configure_optimizers(self):
|
||||
return None
|
||||
|
||||
@@ -5,6 +5,7 @@ import torch
|
||||
from pytorch_lightning.core.step_result import Result
|
||||
from pytorch_lightning.overrides.data_parallel import \
|
||||
LightningDistributedDataParallel
|
||||
from pytorch_lightning.utilities.model_utils import is_overridden
|
||||
from pytorch_lightning.trainer.model_hooks import TrainerModelHooksMixin
|
||||
from pytorch_lightning.trainer.optimizers import TrainerOptimizersMixin
|
||||
import pytorch_lightning as ptl
|
||||
@@ -39,8 +40,8 @@ class LightningOperator(TrainingOperator, TrainerModelHooksMixin,
|
||||
assert len(models) == 1
|
||||
model = models[0]
|
||||
assert isinstance(model, ptl.LightningModule)
|
||||
# This will default to LightningDistributedDataParallel.
|
||||
model = model.configure_ddp(model=model, device_ids=device_ids)
|
||||
model = LightningDistributedDataParallel(
|
||||
model, device_ids=device_ids, find_unused_parameters=True)
|
||||
return [model]
|
||||
|
||||
@property
|
||||
@@ -110,7 +111,7 @@ class LightningOperator(TrainingOperator, TrainerModelHooksMixin,
|
||||
# Call model.setup.
|
||||
ptl_module.setup("fit")
|
||||
|
||||
if not self.is_overridden("configure_optimizers", ptl_module):
|
||||
if not is_overridden("configure_optimizers", ptl_module):
|
||||
raise MisconfigurationException(
|
||||
"No `configure_optimizers()` method defined.")
|
||||
|
||||
@@ -232,7 +233,7 @@ class LightningOperator(TrainingOperator, TrainerModelHooksMixin,
|
||||
break
|
||||
|
||||
processed_outputs = None
|
||||
if self.is_overridden("training_epoch_end", model):
|
||||
if is_overridden("training_epoch_end", model):
|
||||
raw_outputs = [eo["raw_output"] for eo in epoch_outputs]
|
||||
processed_outputs = model.training_epoch_end(raw_outputs)
|
||||
|
||||
@@ -316,7 +317,7 @@ class LightningOperator(TrainingOperator, TrainerModelHooksMixin,
|
||||
|
||||
# allow any mode to define training_step_end
|
||||
# do something will all the dp outputs (like softmax)
|
||||
if self.is_overridden("training_step_end", model):
|
||||
if is_overridden("training_step_end", model):
|
||||
output = model.training_step_end(output)
|
||||
|
||||
# Extract loss from output if dictionary.
|
||||
@@ -397,7 +398,7 @@ class LightningOperator(TrainingOperator, TrainerModelHooksMixin,
|
||||
val_outputs.append(batch_output)
|
||||
|
||||
processed_outputs = None
|
||||
if self.is_overridden("validation_epoch_end", model):
|
||||
if is_overridden("validation_epoch_end", model):
|
||||
raw_outputs = [vo["raw_output"] for vo in val_outputs]
|
||||
processed_outputs = model.training_epoch_end(raw_outputs)
|
||||
|
||||
@@ -440,7 +441,7 @@ class LightningOperator(TrainingOperator, TrainerModelHooksMixin,
|
||||
def validate_batch(self, batch, batch_info):
|
||||
model = self.get_model()
|
||||
batch_idx = batch_info["batch_idx"]
|
||||
if self.is_overridden("on_validation_batch_start", model):
|
||||
if is_overridden("on_validation_batch_start", model):
|
||||
model.on_validation_batch_start(
|
||||
batch=batch, batch_idx=batch_idx, dataloader_idx=0)
|
||||
args = [batch, batch_idx]
|
||||
@@ -462,7 +463,7 @@ class LightningOperator(TrainingOperator, TrainerModelHooksMixin,
|
||||
raise ValueError("EvalResult objects are not supported. Please "
|
||||
"return a dictionary instead.")
|
||||
|
||||
if self.is_overridden("on_validation_step_end", model):
|
||||
if is_overridden("on_validation_step_end", model):
|
||||
output = model.validation_step_end(output)
|
||||
|
||||
if self.is_function_implemented("on_validation_batch_end", model):
|
||||
|
||||
Reference in New Issue
Block a user