mirror of
https://github.com/wassname/ray.git
synced 2026-07-02 14:14:56 +08:00
[tune] MLFlow Logger (#5438)
This commit is contained in:
@@ -0,0 +1,49 @@
|
||||
#!/usr/bin/env python
|
||||
"""Simple MLFLow Logger example.
|
||||
|
||||
This uses a simple MLFlow logger. One limitation of this is that there is
|
||||
no artifact support; to save artifacts with Tune and MLFlow, you will need to
|
||||
start a MLFlow run inside the Trainable function/class.
|
||||
|
||||
"""
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import mlflow
|
||||
from mlflow.tracking import MlflowClient
|
||||
import time
|
||||
import random
|
||||
|
||||
from ray import tune
|
||||
from ray.tune.logger import MLFLowLogger, DEFAULT_LOGGERS
|
||||
|
||||
|
||||
def easy_objective(config):
|
||||
for i in range(20):
|
||||
result = dict(
|
||||
timesteps_total=i,
|
||||
mean_loss=(config["height"] - 14)**2 - abs(config["width"] - 3))
|
||||
tune.track.log(**result)
|
||||
time.sleep(0.02)
|
||||
tune.track.log(done=True)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
client = MlflowClient()
|
||||
experiment_id = client.create_experiment("test")
|
||||
|
||||
trials = tune.run(
|
||||
easy_objective,
|
||||
name="mlflow",
|
||||
num_samples=5,
|
||||
loggers=DEFAULT_LOGGERS + (MLFLowLogger, ),
|
||||
config={
|
||||
"mlflow_experiment_id": experiment_id,
|
||||
"width": tune.sample_from(
|
||||
lambda spec: 10 + int(90 * random.random())),
|
||||
"height": tune.sample_from(lambda spec: int(100 * random.random()))
|
||||
})
|
||||
|
||||
df = mlflow.search_runs([experiment_id])
|
||||
print(df)
|
||||
@@ -7,7 +7,17 @@ from ray.tune import track
|
||||
|
||||
|
||||
class TuneReporterCallback(keras.callbacks.Callback):
|
||||
"""Tune Callback for Keras."""
|
||||
|
||||
def __init__(self, reporter=None, freq="batch", logs={}):
|
||||
"""Initializer.
|
||||
|
||||
Args:
|
||||
reporter (StatusReporter|tune.track.log|None): Tune object for
|
||||
returning results.
|
||||
freq (str): Sets the frequency of reporting intermediate results.
|
||||
One of ["batch", "epoch"].
|
||||
"""
|
||||
self.reporter = reporter or track.log
|
||||
self.iteration = 0
|
||||
if freq not in ["batch", "epoch"]:
|
||||
|
||||
@@ -72,6 +72,34 @@ class NoopLogger(Logger):
|
||||
pass
|
||||
|
||||
|
||||
class MLFLowLogger(Logger):
|
||||
"""MLFlow logger.
|
||||
|
||||
Requires the experiment configuration to have a MLFlow Experiment ID
|
||||
or manually set the proper environment variables.
|
||||
|
||||
"""
|
||||
|
||||
def _init(self):
|
||||
from mlflow.tracking import MlflowClient
|
||||
client = MlflowClient()
|
||||
run = client.create_run(self.config.get("mlflow_experiment_id"))
|
||||
self._run_id = run.info.run_id
|
||||
for key, value in self.config.items():
|
||||
client.log_param(self._run_id, key, value)
|
||||
self.client = client
|
||||
|
||||
def on_result(self, result):
|
||||
for key, value in result.items():
|
||||
if not isinstance(value, float):
|
||||
continue
|
||||
self.client.log_metric(
|
||||
self._run_id, key, value, step=result.get(TRAINING_ITERATION))
|
||||
|
||||
def close(self):
|
||||
self.client.set_terminated(self._run_id)
|
||||
|
||||
|
||||
class JsonLogger(Logger):
|
||||
def _init(self):
|
||||
self.update_config(self.config)
|
||||
|
||||
Reference in New Issue
Block a user