[tune] Make metrics parameter optional in pytorch lightning integration (#11402)

This commit is contained in:
Kai Fricke
2020-10-15 01:50:34 +01:00
committed by GitHub
parent 34191107a3
commit f7120d2a18
2 changed files with 34 additions and 14 deletions
@@ -162,7 +162,7 @@ class TuneReportCallback(TuneCallback):
"""
def __init__(self,
metrics: Union[str, List[str], Dict[str, str]],
metrics: Union[None, str, List[str], Dict[str, str]] = None,
on: Union[str, List[str]] = "validation_end"):
super(TuneReportCallback, self).__init__(on)
if isinstance(metrics, str):
@@ -173,13 +173,19 @@ class TuneReportCallback(TuneCallback):
# Don't report if just doing initial validation sanity checks.
if trainer.running_sanity_check:
return
report_dict = {}
for key in self._metrics:
if isinstance(self._metrics, dict):
metric = self._metrics[key]
else:
metric = key
report_dict[key] = trainer.callback_metrics[metric].item()
if not self._metrics:
report_dict = {
k: v.item()
for k, v in trainer.callback_metrics.items()
}
else:
report_dict = {}
for key in self._metrics:
if isinstance(self._metrics, dict):
metric = self._metrics[key]
else:
metric = key
report_dict[key] = trainer.callback_metrics[metric].item()
tune.report(**report_dict)
@@ -253,7 +259,7 @@ class TuneReportCheckpointCallback(TuneCallback):
"""
def __init__(self,
metrics: Union[str, List[str], Dict[str, str]],
metrics: Union[None, str, List[str], Dict[str, str]] = None,
filename: str = "checkpoint",
on: Union[str, List[str]] = "validation_end"):
super(TuneReportCheckpointCallback, self).__init__(on)
@@ -66,9 +66,22 @@ class PyTorchLightningIntegrationTest(unittest.TestCase):
def tearDown(self):
pass
def testReportCallback(self):
def testReportCallbackUnnamed(self):
def train(config):
module = _MockModule(10, 20)
module = _MockModule(10., 20.)
trainer = pl.Trainer(
max_epochs=1,
callbacks=[TuneReportCallback(on="validation_end")])
trainer.fit(module)
analysis = tune.run(train, stop={TRAINING_ITERATION: 1})
self.assertEqual(analysis.trials[0].last_result["avg_val_loss"],
10. * 1.1)
def testReportCallbackNamed(self):
def train(config):
module = _MockModule(10., 20.)
trainer = pl.Trainer(
max_epochs=1,
callbacks=[
@@ -81,14 +94,15 @@ class PyTorchLightningIntegrationTest(unittest.TestCase):
analysis = tune.run(train, stop={TRAINING_ITERATION: 1})
self.assertEqual(analysis.trials[0].last_result["tune_loss"], 10 * 1.1)
self.assertEqual(analysis.trials[0].last_result["tune_loss"],
10. * 1.1)
def testCheckpointCallback(self):
tmpdir = tempfile.mkdtemp()
self.addCleanup(lambda: shutil.rmtree(tmpdir))
def train(config):
module = _MockModule(10, 20)
module = _MockModule(10., 20.)
trainer = pl.Trainer(
max_epochs=1,
callbacks=[
@@ -115,7 +129,7 @@ class PyTorchLightningIntegrationTest(unittest.TestCase):
self.addCleanup(lambda: shutil.rmtree(tmpdir))
def train(config):
module = _MockModule(10, 20)
module = _MockModule(10., 20.)
trainer = pl.Trainer(
max_epochs=1,
callbacks=[