mirror of
https://github.com/wassname/ray.git
synced 2026-06-30 09:41:11 +08:00
[tune] Make metrics parameter optional in pytorch lightning integration (#11402)
This commit is contained in:
@@ -162,7 +162,7 @@ class TuneReportCallback(TuneCallback):
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
metrics: Union[str, List[str], Dict[str, str]],
|
||||
metrics: Union[None, str, List[str], Dict[str, str]] = None,
|
||||
on: Union[str, List[str]] = "validation_end"):
|
||||
super(TuneReportCallback, self).__init__(on)
|
||||
if isinstance(metrics, str):
|
||||
@@ -173,13 +173,19 @@ class TuneReportCallback(TuneCallback):
|
||||
# Don't report if just doing initial validation sanity checks.
|
||||
if trainer.running_sanity_check:
|
||||
return
|
||||
report_dict = {}
|
||||
for key in self._metrics:
|
||||
if isinstance(self._metrics, dict):
|
||||
metric = self._metrics[key]
|
||||
else:
|
||||
metric = key
|
||||
report_dict[key] = trainer.callback_metrics[metric].item()
|
||||
if not self._metrics:
|
||||
report_dict = {
|
||||
k: v.item()
|
||||
for k, v in trainer.callback_metrics.items()
|
||||
}
|
||||
else:
|
||||
report_dict = {}
|
||||
for key in self._metrics:
|
||||
if isinstance(self._metrics, dict):
|
||||
metric = self._metrics[key]
|
||||
else:
|
||||
metric = key
|
||||
report_dict[key] = trainer.callback_metrics[metric].item()
|
||||
tune.report(**report_dict)
|
||||
|
||||
|
||||
@@ -253,7 +259,7 @@ class TuneReportCheckpointCallback(TuneCallback):
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
metrics: Union[str, List[str], Dict[str, str]],
|
||||
metrics: Union[None, str, List[str], Dict[str, str]] = None,
|
||||
filename: str = "checkpoint",
|
||||
on: Union[str, List[str]] = "validation_end"):
|
||||
super(TuneReportCheckpointCallback, self).__init__(on)
|
||||
|
||||
@@ -66,9 +66,22 @@ class PyTorchLightningIntegrationTest(unittest.TestCase):
|
||||
def tearDown(self):
|
||||
pass
|
||||
|
||||
def testReportCallback(self):
|
||||
def testReportCallbackUnnamed(self):
|
||||
def train(config):
|
||||
module = _MockModule(10, 20)
|
||||
module = _MockModule(10., 20.)
|
||||
trainer = pl.Trainer(
|
||||
max_epochs=1,
|
||||
callbacks=[TuneReportCallback(on="validation_end")])
|
||||
trainer.fit(module)
|
||||
|
||||
analysis = tune.run(train, stop={TRAINING_ITERATION: 1})
|
||||
|
||||
self.assertEqual(analysis.trials[0].last_result["avg_val_loss"],
|
||||
10. * 1.1)
|
||||
|
||||
def testReportCallbackNamed(self):
|
||||
def train(config):
|
||||
module = _MockModule(10., 20.)
|
||||
trainer = pl.Trainer(
|
||||
max_epochs=1,
|
||||
callbacks=[
|
||||
@@ -81,14 +94,15 @@ class PyTorchLightningIntegrationTest(unittest.TestCase):
|
||||
|
||||
analysis = tune.run(train, stop={TRAINING_ITERATION: 1})
|
||||
|
||||
self.assertEqual(analysis.trials[0].last_result["tune_loss"], 10 * 1.1)
|
||||
self.assertEqual(analysis.trials[0].last_result["tune_loss"],
|
||||
10. * 1.1)
|
||||
|
||||
def testCheckpointCallback(self):
|
||||
tmpdir = tempfile.mkdtemp()
|
||||
self.addCleanup(lambda: shutil.rmtree(tmpdir))
|
||||
|
||||
def train(config):
|
||||
module = _MockModule(10, 20)
|
||||
module = _MockModule(10., 20.)
|
||||
trainer = pl.Trainer(
|
||||
max_epochs=1,
|
||||
callbacks=[
|
||||
@@ -115,7 +129,7 @@ class PyTorchLightningIntegrationTest(unittest.TestCase):
|
||||
self.addCleanup(lambda: shutil.rmtree(tmpdir))
|
||||
|
||||
def train(config):
|
||||
module = _MockModule(10, 20)
|
||||
module = _MockModule(10., 20.)
|
||||
trainer = pl.Trainer(
|
||||
max_epochs=1,
|
||||
callbacks=[
|
||||
|
||||
Reference in New Issue
Block a user