From f7120d2a18a18a425ede565e1b6700aea027b956 Mon Sep 17 00:00:00 2001 From: Kai Fricke Date: Thu, 15 Oct 2020 01:50:34 +0100 Subject: [PATCH] [tune] Make `metrics` parameter optional in pytorch lightning integration (#11402) --- .../ray/tune/integration/pytorch_lightning.py | 24 ++++++++++++------- .../test_integration_pytorch_lightning.py | 24 +++++++++++++++---- 2 files changed, 34 insertions(+), 14 deletions(-) diff --git a/python/ray/tune/integration/pytorch_lightning.py b/python/ray/tune/integration/pytorch_lightning.py index 21b9a60ee..866b669ae 100644 --- a/python/ray/tune/integration/pytorch_lightning.py +++ b/python/ray/tune/integration/pytorch_lightning.py @@ -162,7 +162,7 @@ class TuneReportCallback(TuneCallback): """ def __init__(self, - metrics: Union[str, List[str], Dict[str, str]], + metrics: Union[None, str, List[str], Dict[str, str]] = None, on: Union[str, List[str]] = "validation_end"): super(TuneReportCallback, self).__init__(on) if isinstance(metrics, str): @@ -173,13 +173,19 @@ class TuneReportCallback(TuneCallback): # Don't report if just doing initial validation sanity checks. if trainer.running_sanity_check: return - report_dict = {} - for key in self._metrics: - if isinstance(self._metrics, dict): - metric = self._metrics[key] - else: - metric = key - report_dict[key] = trainer.callback_metrics[metric].item() + if not self._metrics: + report_dict = { + k: v.item() + for k, v in trainer.callback_metrics.items() + } + else: + report_dict = {} + for key in self._metrics: + if isinstance(self._metrics, dict): + metric = self._metrics[key] + else: + metric = key + report_dict[key] = trainer.callback_metrics[metric].item() tune.report(**report_dict) @@ -253,7 +259,7 @@ class TuneReportCheckpointCallback(TuneCallback): """ def __init__(self, - metrics: Union[str, List[str], Dict[str, str]], + metrics: Union[None, str, List[str], Dict[str, str]] = None, filename: str = "checkpoint", on: Union[str, List[str]] = "validation_end"): super(TuneReportCheckpointCallback, self).__init__(on) diff --git a/python/ray/tune/tests/test_integration_pytorch_lightning.py b/python/ray/tune/tests/test_integration_pytorch_lightning.py index 14d88c30a..c1ec09f1d 100644 --- a/python/ray/tune/tests/test_integration_pytorch_lightning.py +++ b/python/ray/tune/tests/test_integration_pytorch_lightning.py @@ -66,9 +66,22 @@ class PyTorchLightningIntegrationTest(unittest.TestCase): def tearDown(self): pass - def testReportCallback(self): + def testReportCallbackUnnamed(self): def train(config): - module = _MockModule(10, 20) + module = _MockModule(10., 20.) + trainer = pl.Trainer( + max_epochs=1, + callbacks=[TuneReportCallback(on="validation_end")]) + trainer.fit(module) + + analysis = tune.run(train, stop={TRAINING_ITERATION: 1}) + + self.assertEqual(analysis.trials[0].last_result["avg_val_loss"], + 10. * 1.1) + + def testReportCallbackNamed(self): + def train(config): + module = _MockModule(10., 20.) trainer = pl.Trainer( max_epochs=1, callbacks=[ @@ -81,14 +94,15 @@ class PyTorchLightningIntegrationTest(unittest.TestCase): analysis = tune.run(train, stop={TRAINING_ITERATION: 1}) - self.assertEqual(analysis.trials[0].last_result["tune_loss"], 10 * 1.1) + self.assertEqual(analysis.trials[0].last_result["tune_loss"], + 10. * 1.1) def testCheckpointCallback(self): tmpdir = tempfile.mkdtemp() self.addCleanup(lambda: shutil.rmtree(tmpdir)) def train(config): - module = _MockModule(10, 20) + module = _MockModule(10., 20.) trainer = pl.Trainer( max_epochs=1, callbacks=[ @@ -115,7 +129,7 @@ class PyTorchLightningIntegrationTest(unittest.TestCase): self.addCleanup(lambda: shutil.rmtree(tmpdir)) def train(config): - module = _MockModule(10, 20) + module = _MockModule(10., 20.) trainer = pl.Trainer( max_epochs=1, callbacks=[