mirror of
https://github.com/wassname/ray.git
synced 2026-07-06 05:16:30 +08:00
[Tune] Mlflow Integration (#12840)
Co-authored-by: Kai Fricke <krfricke@users.noreply.github.com> Co-authored-by: Richard Liaw <rliaw@berkeley.edu>
This commit is contained in:
+24
-8
@@ -125,6 +125,14 @@ py_test(
|
||||
tags = ["exclusive"],
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "test_integration_mlflow",
|
||||
size = "small",
|
||||
srcs = ["tests/test_integration_mlflow.py"],
|
||||
deps = [":tune_lib"],
|
||||
tags = ["exclusive"]
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "test_logger",
|
||||
size = "small",
|
||||
@@ -473,15 +481,23 @@ py_test(
|
||||
args = ["--smoke-test"]
|
||||
)
|
||||
|
||||
# Commenting out for now because it is not idempotent
|
||||
# py_test(
|
||||
# name = "mlflow_example",
|
||||
# size = "medium",
|
||||
# srcs = ["examples/mlflow_example.py"],
|
||||
# deps = [":tune_lib"],
|
||||
# tags = ["exclusive", "example"]
|
||||
# )
|
||||
py_test(
|
||||
name = "mlflow_example",
|
||||
size = "medium",
|
||||
srcs = ["examples/mlflow_example.py"],
|
||||
deps = [":tune_lib"],
|
||||
tags = ["exclusive", "example"]
|
||||
)
|
||||
|
||||
# Comment out for now until we sort out our dependencies.
|
||||
#py_test(
|
||||
# name = "mlflow_ptl",
|
||||
# size = "medium",
|
||||
# srcs = ["examples/mlflow_ptl.py"],
|
||||
# deps = [":tune_lib"],
|
||||
# tags = ["exclusive", "example", "py37", "pytorch"],
|
||||
# args = ["--smoke-test"]
|
||||
#)
|
||||
py_test(
|
||||
name = "mnist_pytorch",
|
||||
size = "small",
|
||||
|
||||
@@ -1,17 +1,14 @@
|
||||
#!/usr/bin/env python
|
||||
"""Simple MLFLow Logger example.
|
||||
|
||||
This uses a simple MLFlow logger. One limitation of this is that there is
|
||||
no artifact support; to save artifacts with Tune and MLFlow, you will need to
|
||||
start a MLFlow run inside the Trainable function/class.
|
||||
|
||||
"""Examples using MLFlowLoggerCallback and mlflow_mixin.
|
||||
"""
|
||||
import mlflow
|
||||
from mlflow.tracking import MlflowClient
|
||||
import os
|
||||
import tempfile
|
||||
import time
|
||||
|
||||
import mlflow
|
||||
|
||||
from ray import tune
|
||||
from ray.tune.logger import MLFLowLogger, DEFAULT_LOGGERS
|
||||
from ray.tune.integration.mlflow import MLFlowLoggerCallback, mlflow_mixin
|
||||
|
||||
|
||||
def evaluation_fn(step, width, height):
|
||||
@@ -25,27 +22,84 @@ def easy_objective(config):
|
||||
for step in range(config.get("steps", 100)):
|
||||
# Iterative training function - can be any arbitrary training procedure
|
||||
intermediate_score = evaluation_fn(step, width, height)
|
||||
# Feed the score back back to Tune.
|
||||
# Feed the score back to Tune.
|
||||
tune.report(iterations=step, mean_loss=intermediate_score)
|
||||
time.sleep(0.1)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
client = MlflowClient()
|
||||
experiment_id = client.create_experiment("test")
|
||||
|
||||
trials = tune.run(
|
||||
def tune_function(mlflow_tracking_uri, finish_fast=False):
|
||||
tune.run(
|
||||
easy_objective,
|
||||
name="mlflow",
|
||||
num_samples=5,
|
||||
loggers=DEFAULT_LOGGERS + (MLFLowLogger, ),
|
||||
callbacks=[
|
||||
MLFlowLoggerCallback(
|
||||
tracking_uri=mlflow_tracking_uri,
|
||||
experiment_name="example",
|
||||
save_artifact=True)
|
||||
],
|
||||
config={
|
||||
"logger_config": {
|
||||
"mlflow_experiment_id": experiment_id,
|
||||
},
|
||||
"width": tune.randint(10, 100),
|
||||
"height": tune.randint(0, 100),
|
||||
"steps": 5 if finish_fast else 100,
|
||||
})
|
||||
|
||||
df = mlflow.search_runs([experiment_id])
|
||||
print(df)
|
||||
|
||||
@mlflow_mixin
|
||||
def decorated_easy_objective(config):
|
||||
# Hyperparameters
|
||||
width, height = config["width"], config["height"]
|
||||
|
||||
for step in range(config.get("steps", 100)):
|
||||
# Iterative training function - can be any arbitrary training procedure
|
||||
intermediate_score = evaluation_fn(step, width, height)
|
||||
# Log the metrics to mlflow
|
||||
mlflow.log_metrics(dict(mean_loss=intermediate_score), step=step)
|
||||
# Feed the score back to Tune.
|
||||
tune.report(iterations=step, mean_loss=intermediate_score)
|
||||
time.sleep(0.1)
|
||||
|
||||
|
||||
def tune_decorated(mlflow_tracking_uri, finish_fast=False):
|
||||
# Set the experiment, or create a new one if does not exist yet.
|
||||
mlflow.set_tracking_uri(mlflow_tracking_uri)
|
||||
mlflow.set_experiment(experiment_name="mixin_example")
|
||||
tune.run(
|
||||
decorated_easy_objective,
|
||||
name="mlflow",
|
||||
num_samples=5,
|
||||
config={
|
||||
"width": tune.randint(10, 100),
|
||||
"height": tune.randint(0, 100),
|
||||
"steps": 5 if finish_fast else 100,
|
||||
"mlflow": {
|
||||
"experiment_name": "mixin_example",
|
||||
"tracking_uri": mlflow.get_tracking_uri()
|
||||
}
|
||||
})
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import argparse
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--smoke-test", action="store_true", help="Finish quickly for testing")
|
||||
args, _ = parser.parse_known_args()
|
||||
|
||||
if args.smoke_test:
|
||||
mlflow_tracking_uri = os.path.join(tempfile.gettempdir(), "mlruns")
|
||||
else:
|
||||
mlflow_tracking_uri = None
|
||||
|
||||
tune_function(mlflow_tracking_uri, finish_fast=args.smoke_test)
|
||||
if not args.smoke_test:
|
||||
df = mlflow.search_runs(
|
||||
[mlflow.get_experiment_by_name("example").experiment_id])
|
||||
print(df)
|
||||
|
||||
tune_decorated(mlflow_tracking_uri, finish_fast=args.smoke_test)
|
||||
if not args.smoke_test:
|
||||
df = mlflow.search_runs(
|
||||
[mlflow.get_experiment_by_name("mixin_example").experiment_id])
|
||||
print(df)
|
||||
|
||||
@@ -0,0 +1,93 @@
|
||||
"""An example showing how to use Pytorch Lightning training, Ray Tune
|
||||
HPO, and MLFlow autologging all together."""
|
||||
import os
|
||||
import tempfile
|
||||
|
||||
import pytorch_lightning as pl
|
||||
from pl_bolts.datamodules import MNISTDataModule
|
||||
|
||||
import mlflow
|
||||
|
||||
from ray import tune
|
||||
from ray.tune.integration.mlflow import mlflow_mixin
|
||||
from ray.tune.integration.pytorch_lightning import TuneReportCallback
|
||||
from ray.tune.examples.mnist_ptl_mini import LightningMNISTClassifier
|
||||
|
||||
|
||||
@mlflow_mixin
|
||||
def train_mnist_tune(config, data_dir=None, num_epochs=10, num_gpus=0):
|
||||
model = LightningMNISTClassifier(config, data_dir)
|
||||
dm = MNISTDataModule(
|
||||
data_dir=data_dir, num_workers=1, batch_size=config["batch_size"])
|
||||
metrics = {"loss": "ptl/val_loss", "acc": "ptl/val_accuracy"}
|
||||
mlflow.pytorch.autolog()
|
||||
trainer = pl.Trainer(
|
||||
max_epochs=num_epochs,
|
||||
gpus=num_gpus,
|
||||
progress_bar_refresh_rate=0,
|
||||
callbacks=[TuneReportCallback(metrics, on="validation_end")])
|
||||
trainer.fit(model, dm)
|
||||
|
||||
|
||||
def tune_mnist(num_samples=10,
|
||||
num_epochs=10,
|
||||
gpus_per_trial=0,
|
||||
tracking_uri=None):
|
||||
data_dir = os.path.join(tempfile.gettempdir(), "mnist_data_")
|
||||
# Download data
|
||||
MNISTDataModule(data_dir=data_dir).prepare_data()
|
||||
|
||||
# Set the MLFlow experiment, or create it if it does not exist.
|
||||
mlflow.set_tracking_uri(tracking_uri)
|
||||
mlflow.set_experiment("ptl_autologging_test")
|
||||
|
||||
config = {
|
||||
"layer_1": tune.choice([32, 64, 128]),
|
||||
"layer_2": tune.choice([64, 128, 256]),
|
||||
"lr": tune.loguniform(1e-4, 1e-1),
|
||||
"batch_size": tune.choice([32, 64, 128]),
|
||||
"mlflow": {
|
||||
"experiment_name": "ptl_autologging_test",
|
||||
"tracking_uri": mlflow.get_tracking_uri()
|
||||
},
|
||||
"data_dir": os.path.join(tempfile.gettempdir(), "mnist_data_"),
|
||||
"num_epochs": num_epochs
|
||||
}
|
||||
|
||||
trainable = tune.with_parameters(
|
||||
train_mnist_tune,
|
||||
data_dir=data_dir,
|
||||
num_epochs=num_epochs,
|
||||
num_gpus=gpus_per_trial)
|
||||
|
||||
analysis = tune.run(
|
||||
trainable,
|
||||
resources_per_trial={
|
||||
"cpu": 1,
|
||||
"gpu": gpus_per_trial
|
||||
},
|
||||
metric="loss",
|
||||
mode="min",
|
||||
config=config,
|
||||
num_samples=num_samples,
|
||||
name="tune_mnist")
|
||||
|
||||
print("Best hyperparameters found were: ", analysis.best_config)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import argparse
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--smoke-test", action="store_true", help="Finish quickly for testing")
|
||||
args, _ = parser.parse_known_args()
|
||||
|
||||
if args.smoke_test:
|
||||
tune_mnist(
|
||||
num_samples=1,
|
||||
num_epochs=1,
|
||||
gpus_per_trial=0,
|
||||
tracking_uri=os.path.join(tempfile.gettempdir(), "mlruns"))
|
||||
else:
|
||||
tune_mnist(num_samples=10, num_epochs=10, gpus_per_trial=0)
|
||||
@@ -509,8 +509,9 @@ class FunctionRunner(Trainable):
|
||||
try:
|
||||
err_tb_str = self._error_queue.get(
|
||||
block=block, timeout=ERROR_FETCH_TIMEOUT)
|
||||
raise TuneError(("Trial raised an exception. Traceback:\n{}"
|
||||
.format(err_tb_str)))
|
||||
raise TuneError(
|
||||
("Trial raised an exception. Traceback:\n{}".format(err_tb_str)
|
||||
))
|
||||
except queue.Empty:
|
||||
pass
|
||||
|
||||
@@ -649,6 +650,10 @@ def with_parameters(fn, **kwargs):
|
||||
def _inner(config):
|
||||
inner(config, checkpoint_dir=None)
|
||||
|
||||
if hasattr(fn, "__mixins__"):
|
||||
_inner.__mixins__ = fn.__mixins__
|
||||
return _inner
|
||||
|
||||
if hasattr(fn, "__mixins__"):
|
||||
inner.__mixins__ = fn.__mixins__
|
||||
return inner
|
||||
|
||||
@@ -0,0 +1,366 @@
|
||||
import os
|
||||
from typing import Dict, Callable, Optional
|
||||
import logging
|
||||
|
||||
from ray.tune.trainable import Trainable
|
||||
from ray.tune.logger import Logger, LoggerCallback
|
||||
from ray.tune.result import TRAINING_ITERATION
|
||||
from ray.tune.trial import Trial
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _import_mlflow():
|
||||
try:
|
||||
import mlflow
|
||||
except ImportError:
|
||||
mlflow = None
|
||||
return mlflow
|
||||
|
||||
|
||||
class MLFlowLoggerCallback(LoggerCallback):
|
||||
"""MLFlow Logger to automatically log Tune results and config to MLFlow.
|
||||
|
||||
MLFlow (https://mlflow.org) Tracking is an open source library for
|
||||
recording and querying experiments. This Ray Tune ``LoggerCallback``
|
||||
sends information (config parameters, training results & metrics,
|
||||
and artifacts) to MLFlow for automatic experiment tracking.
|
||||
|
||||
Args:
|
||||
tracking_uri (str): The tracking URI for where to manage experiments
|
||||
and runs. This can either be a local file path or a remote server.
|
||||
This arg gets passed directly to mlflow.tracking.MlflowClient
|
||||
initialization. When using Tune in a multi-node setting, make sure
|
||||
to set this to a remote server and not a local file path.
|
||||
registry_uri (str): The registry URI that gets passed directly to
|
||||
mlflow.tracking.MlflowClient initialization.
|
||||
experiment_name (str): The experiment name to use for this Tune run.
|
||||
If None is passed in here, the Logger will automatically then
|
||||
check the MLFLOW_EXPERIMENT_NAME and then the MLFLOW_EXPERIMENT_ID
|
||||
environment variables to determine the experiment name.
|
||||
If the experiment with the name already exists with MlFlow,
|
||||
it will be reused. If not, a new experiment will be created with
|
||||
that name.
|
||||
save_artifact (bool): If set to True, automatically save the entire
|
||||
contents of the Tune local_dir as an artifact to the
|
||||
corresponding run in MlFlow.
|
||||
|
||||
Example:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
from ray.tune.integration.mlflow import MLFlowLoggerCallback
|
||||
tune.run(
|
||||
train_fn,
|
||||
config={
|
||||
# define search space here
|
||||
"parameter_1": tune.choice([1, 2, 3]),
|
||||
"parameter_2": tune.choice([4, 5, 6]),
|
||||
},
|
||||
callbacks=[MLFlowLoggerCallback(
|
||||
experiment_name="experiment1",
|
||||
save_artifact=True)])
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
tracking_uri: Optional[str] = None,
|
||||
registry_uri: Optional[str] = None,
|
||||
experiment_name: Optional[str] = None,
|
||||
save_artifact: bool = False):
|
||||
|
||||
mlflow = _import_mlflow()
|
||||
if mlflow is None:
|
||||
raise RuntimeError("MLFlow has not been installed. Please `pip "
|
||||
"install mlflow` to use the MLFlowLogger.")
|
||||
|
||||
from mlflow.tracking import MlflowClient
|
||||
self.client = MlflowClient(
|
||||
tracking_uri=tracking_uri, registry_uri=registry_uri)
|
||||
|
||||
if experiment_name is None:
|
||||
# If no name is passed in, then check env vars.
|
||||
# First check if experiment_name env var is set.
|
||||
experiment_name = os.environ.get("MLFLOW_EXPERIMENT_NAME")
|
||||
|
||||
if experiment_name is not None:
|
||||
# First check if experiment with name exists.
|
||||
experiment = self.client.get_experiment_by_name(experiment_name)
|
||||
if experiment is not None:
|
||||
# If it already exists then get the id.
|
||||
experiment_id = experiment.experiment_id
|
||||
else:
|
||||
# If it does not exist, create the experiment.
|
||||
experiment_id = self.client.create_experiment(
|
||||
name=experiment_name)
|
||||
else:
|
||||
# No experiment_name is passed in and name env var is not set.
|
||||
# Now check the experiment id env var.
|
||||
experiment_id = os.environ.get("MLFLOW_EXPERIMENT_ID")
|
||||
# Confirm that an experiment with this id exists.
|
||||
if experiment_id is None or self.client.get_experiment(
|
||||
experiment_id) is None:
|
||||
raise ValueError("No experiment_name passed, "
|
||||
"MLFLOW_EXPERIMENT_NAME env var is not "
|
||||
"set, and MLFLOW_EXPERIMENT_ID either "
|
||||
"is not set or does not exist. Please "
|
||||
"set one of these to use the "
|
||||
"MLFlowLoggerCallback.")
|
||||
|
||||
# At this point, experiment_id should be set.
|
||||
self.experiment_id = experiment_id
|
||||
self.save_artifact = save_artifact
|
||||
|
||||
self._trial_runs = {}
|
||||
|
||||
def log_trial_start(self, trial: "Trial"):
|
||||
# Create run if not already exists.
|
||||
if trial not in self._trial_runs:
|
||||
run = self.client.create_run(
|
||||
experiment_id=self.experiment_id,
|
||||
tags={"trial_name": str(trial)})
|
||||
self._trial_runs[trial] = run.info.run_id
|
||||
|
||||
run_id = self._trial_runs[trial]
|
||||
|
||||
# Log the config parameters.
|
||||
config = trial.config
|
||||
|
||||
for key, value in config.items():
|
||||
self.client.log_param(run_id=run_id, key=key, value=value)
|
||||
|
||||
def log_trial_result(self, iteration: int, trial: "Trial", result: Dict):
|
||||
run_id = self._trial_runs[trial]
|
||||
for key, value in result.items():
|
||||
try:
|
||||
value = float(value)
|
||||
except (ValueError, TypeError):
|
||||
logger.debug("Cannot log key {} with value {} since the "
|
||||
"value cannot be converted to float.".format(
|
||||
key, value))
|
||||
continue
|
||||
self.client.log_metric(
|
||||
run_id=run_id, key=key, value=value, step=iteration)
|
||||
|
||||
def log_trial_end(self, trial: "Trial", failed: bool = False):
|
||||
run_id = self._trial_runs[trial]
|
||||
|
||||
# Log the artifact if set_artifact is set to True.
|
||||
if self.save_artifact:
|
||||
self.client.log_artifacts(run_id, local_dir=trial.logdir)
|
||||
|
||||
# Stop the run once trial finishes.
|
||||
status = "FINISHED" if not failed else "FAILED"
|
||||
self.client.set_terminated(run_id=run_id, status=status)
|
||||
|
||||
|
||||
class MLFlowLogger(Logger):
|
||||
"""MLFlow logger using the deprecated Logger API.
|
||||
|
||||
Requires the experiment configuration to have a MLFlow Experiment ID
|
||||
or manually set the proper environment variables.
|
||||
"""
|
||||
|
||||
_experiment_logger_cls = MLFlowLoggerCallback
|
||||
|
||||
def _init(self):
|
||||
mlflow = _import_mlflow()
|
||||
logger_config = self.config.pop("logger_config", {})
|
||||
tracking_uri = logger_config.get("mlflow_tracking_uri")
|
||||
registry_uri = logger_config.get("mlflow_registry_uri")
|
||||
|
||||
experiment_id = logger_config.get("mlflow_experiment_id")
|
||||
if experiment_id is None or not mlflow.get_experiment(experiment_id):
|
||||
raise ValueError(
|
||||
"You must provide a valid `mlflow_experiment_id` "
|
||||
"in your `logger_config` dict in the `config` "
|
||||
"dict passed to `tune.run`. "
|
||||
"Are you sure you passed in a `experiment_id` and "
|
||||
"the experiment exists?")
|
||||
else:
|
||||
experiment_name = mlflow.get_experiment(experiment_id).name
|
||||
|
||||
self._trial_experiment_logger = self._experiment_logger_cls(
|
||||
tracking_uri, registry_uri, experiment_name)
|
||||
|
||||
self._trial_experiment_logger.log_trial_start(self.trial)
|
||||
|
||||
def on_result(self, result: Dict):
|
||||
self._trial_experiment_logger.log_trial_result(
|
||||
iteration=result.get(TRAINING_ITERATION),
|
||||
trial=self.trial,
|
||||
result=result)
|
||||
|
||||
def close(self):
|
||||
self._trial_experiment_logger.log_trial_end(
|
||||
trial=self.trial, failed=False)
|
||||
del self._trial_experiment_logger
|
||||
|
||||
|
||||
def mlflow_mixin(func: Callable):
|
||||
"""mlflow_mixin
|
||||
|
||||
MLFlow (https://mlflow.org) Tracking is an open source library for
|
||||
recording and querying experiments. This Ray Tune Trainable mixin helps
|
||||
initialize the MLflow API for use with the ``Trainable`` class or the
|
||||
``@mlflow_mixin`` function API. This mixin automatically configures MLFlow
|
||||
and creates a run in the same process as each Tune trial. You can then
|
||||
use the mlflow API inside the your training function and it will
|
||||
automatically get reported to the correct run.
|
||||
|
||||
For basic usage, just prepend your training function with the
|
||||
``@mlflow_mixin`` decorator:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
from ray.tune.integration.mlflow import mlflow_mixin
|
||||
|
||||
@mlflow_mixin
|
||||
def train_fn(config):
|
||||
...
|
||||
mlflow.log_metric(...)
|
||||
|
||||
You can also use MlFlow's autologging feature if using a training
|
||||
framework like Pytorch Lightning, XGBoost, etc. More information can be
|
||||
found here (https://mlflow.org/docs/latest/tracking.html#automatic
|
||||
-logging).
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
from ray.tune.integration.mlflow import mlflow_mixin
|
||||
|
||||
@mlflow_mixin
|
||||
def train_fn(config):
|
||||
mlflow.autolog()
|
||||
xgboost_results = xgb.train(config, ...)
|
||||
|
||||
The MlFlow configuration is done by passing a ``mlflow`` key to
|
||||
the ``config`` parameter of ``tune.run()`` (see example below).
|
||||
|
||||
The content of the ``mlflow`` config entry is used to
|
||||
configure MlFlow. Here are the keys you can pass in to this config entry:
|
||||
|
||||
Args:
|
||||
tracking_uri (str): The tracking URI for MLflow tracking. If using
|
||||
Tune in a multi-node setting, make sure to use a remote server for
|
||||
tracking.
|
||||
experiment_id (str): The id of an already created MLflow experiment.
|
||||
All logs from all trials in ``tune.run`` will be reported to this
|
||||
experiment. If this is not provided or the experiment with this
|
||||
id does not exist, you must provide an``experiment_name``. This
|
||||
parameter takes precedence over ``experiment_name``.
|
||||
experiment_name (str): The name of an already existing MLflow
|
||||
experiment. All logs from all trials in ``tune.run`` will be
|
||||
reported to this experiment. If this is not provided, you must
|
||||
provide a valid ``experiment_id``.
|
||||
|
||||
Example:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
from ray import tune
|
||||
from ray.tune.integration.mlflow import mlflow_mixin
|
||||
|
||||
import mlflow
|
||||
|
||||
# Create the MlFlow expriment.
|
||||
mlflow.create_experiment("my_experiment")
|
||||
|
||||
@mlflow_mixin
|
||||
def train_fn(config):
|
||||
for i in range(10):
|
||||
loss = self.config["a"] + self.config["b"]
|
||||
mlflow.log_metric(key="loss", value=loss})
|
||||
tune.report(loss=loss, done=True)
|
||||
|
||||
tune.run(
|
||||
train_fn,
|
||||
config={
|
||||
# define search space here
|
||||
"a": tune.choice([1, 2, 3]),
|
||||
"b": tune.choice([4, 5, 6]),
|
||||
# mlflow configuration
|
||||
"mlflow": {
|
||||
"experiment_name": "my_experiment",
|
||||
"tracking_uri": mlflow.get_tracking_uri()
|
||||
}
|
||||
})
|
||||
"""
|
||||
if _import_mlflow() is None:
|
||||
raise RuntimeError("MLFlow has not been installed. Please `pip "
|
||||
"install mlflow` to use the mlflow_mixin.")
|
||||
if hasattr(func, "__mixins__"):
|
||||
func.__mixins__ = func.__mixins__ + (MLFlowTrainableMixin, )
|
||||
else:
|
||||
func.__mixins__ = (MLFlowTrainableMixin, )
|
||||
return func
|
||||
|
||||
|
||||
class MLFlowTrainableMixin:
|
||||
def __init__(self, config: Dict, *args, **kwargs):
|
||||
self._mlflow = _import_mlflow()
|
||||
|
||||
if not isinstance(self, Trainable):
|
||||
raise ValueError(
|
||||
"The `MLFlowTrainableMixin` can only be used as a mixin "
|
||||
"for `tune.Trainable` classes. Please make sure your "
|
||||
"class inherits from both. For example: "
|
||||
"`class YourTrainable(MLFlowTrainableMixin)`.")
|
||||
|
||||
super().__init__(config, *args, **kwargs)
|
||||
_config = config.copy()
|
||||
try:
|
||||
mlflow_config = _config.pop("mlflow").copy()
|
||||
except KeyError as e:
|
||||
raise ValueError(
|
||||
"MLFlow mixin specified but no configuration has been passed. "
|
||||
"Make sure to include a `mlflow` key in your `config` dict "
|
||||
"containing at least a `tracking_uri` and either "
|
||||
"`experiment_name` or `experiment_id` specification.") from e
|
||||
|
||||
tracking_uri = mlflow_config.pop("tracking_uri", None)
|
||||
if tracking_uri is None:
|
||||
raise ValueError("MLFlow mixin specified but no "
|
||||
"tracking_uri has been "
|
||||
"passed in. Make sure to include a `mlflow` "
|
||||
"key in your `config` dict containing at "
|
||||
"least a `tracking_uri`")
|
||||
self._mlflow.set_tracking_uri(tracking_uri)
|
||||
|
||||
# First see if experiment_id is passed in.
|
||||
experiment_id = mlflow_config.pop("experiment_id", None)
|
||||
if experiment_id is None or self._mlflow.get_experiment(
|
||||
experiment_id) is None:
|
||||
logger.debug("Either no experiment_id is passed in, or the "
|
||||
"experiment with the given id does not exist. "
|
||||
"Checking experiment_name")
|
||||
# Check for name.
|
||||
experiment_name = mlflow_config.pop("experiment_name", None)
|
||||
if experiment_name is None:
|
||||
raise ValueError(
|
||||
"MLFlow mixin specified but no "
|
||||
"experiment_name or experiment_id has been "
|
||||
"passed in. Make sure to include a `mlflow` "
|
||||
"key in your `config` dict containing at "
|
||||
"least a `experiment_name` or `experiment_id` "
|
||||
"specification.")
|
||||
experiment = self._mlflow.get_experiment_by_name(experiment_name)
|
||||
if experiment is not None:
|
||||
# Experiment with this name exists.
|
||||
experiment_id = experiment.experiment_id
|
||||
else:
|
||||
raise ValueError("No experiment with the given "
|
||||
"name: {} or id: {} currently exists. Make "
|
||||
"sure to first start the MLFlow experiment "
|
||||
"before calling tune.run.".format(
|
||||
experiment_name, experiment_id))
|
||||
|
||||
self.experiment_id = experiment_id
|
||||
|
||||
run_name = self.trial_name + "_" + self.trial_id
|
||||
run_name = run_name.replace("/", "_")
|
||||
self._mlflow.start_run(
|
||||
experiment_id=self.experiment_id, run_name=run_name)
|
||||
|
||||
def stop(self):
|
||||
self._mlflow.end_run()
|
||||
@@ -139,7 +139,10 @@ def wandb_mixin(func: Callable):
|
||||
})
|
||||
|
||||
"""
|
||||
func.__mixins__ = (WandbTrainableMixin, )
|
||||
if hasattr(func, "__mixins__"):
|
||||
func.__mixins__ = func.__mixins__ + (WandbTrainableMixin, )
|
||||
else:
|
||||
func.__mixins__ = (WandbTrainableMixin, )
|
||||
return func
|
||||
|
||||
|
||||
|
||||
@@ -77,37 +77,6 @@ class NoopLogger(Logger):
|
||||
pass
|
||||
|
||||
|
||||
class MLFLowLogger(Logger):
|
||||
"""MLFlow logger.
|
||||
|
||||
Requires the experiment configuration to have a MLFlow Experiment ID
|
||||
or manually set the proper environment variables.
|
||||
|
||||
"""
|
||||
|
||||
def _init(self):
|
||||
logger_config = self.config.get("logger_config", {})
|
||||
from mlflow.tracking import MlflowClient
|
||||
client = MlflowClient(
|
||||
tracking_uri=logger_config.get("mlflow_tracking_uri"),
|
||||
registry_uri=logger_config.get("mlflow_registry_uri"))
|
||||
run = client.create_run(logger_config.get("mlflow_experiment_id"))
|
||||
self._run_id = run.info.run_id
|
||||
for key, value in self.config.items():
|
||||
client.log_param(self._run_id, key, value)
|
||||
self.client = client
|
||||
|
||||
def on_result(self, result: Dict):
|
||||
for key, value in result.items():
|
||||
if not isinstance(value, float):
|
||||
continue
|
||||
self.client.log_metric(
|
||||
self._run_id, key, value, step=result.get(TRAINING_ITERATION))
|
||||
|
||||
def close(self):
|
||||
self.client.set_terminated(self._run_id)
|
||||
|
||||
|
||||
class JsonLogger(Logger):
|
||||
"""Logs trial results in json format.
|
||||
|
||||
@@ -734,6 +703,13 @@ class TBXLoggerCallback(LoggerCallback):
|
||||
"in the hyperparameter values.")
|
||||
|
||||
|
||||
# Maintain backwards compatibility.
|
||||
from ray.tune.integration.mlflow import MLFlowLogger as _MLFlowLogger # noqa: E402, E501
|
||||
MLFlowLogger = _MLFlowLogger
|
||||
# The capital L is a typo, but needs to remain for backwards compatibility.
|
||||
MLFLowLogger = _MLFlowLogger
|
||||
|
||||
|
||||
def pretty_print(result):
|
||||
result = result.copy()
|
||||
result.update(config=None) # drop config from pretty print
|
||||
|
||||
@@ -23,3 +23,4 @@ if __name__ == "__main__":
|
||||
}
|
||||
})
|
||||
assert "ray.rllib" not in sys.modules, "RLlib should not be imported"
|
||||
assert "mlflow" not in sys.modules, "MLFlow should not be imported"
|
||||
|
||||
@@ -0,0 +1,306 @@
|
||||
import os
|
||||
import unittest
|
||||
from collections import namedtuple
|
||||
from unittest.mock import patch
|
||||
|
||||
from ray.tune.function_runner import wrap_function
|
||||
from ray.tune.integration.mlflow import MLFlowLoggerCallback, MLFlowLogger, \
|
||||
mlflow_mixin, MLFlowTrainableMixin
|
||||
|
||||
|
||||
class MockTrial(
|
||||
namedtuple("MockTrial",
|
||||
["config", "trial_name", "trial_id", "logdir"])):
|
||||
def __hash__(self):
|
||||
return hash(self.trial_id)
|
||||
|
||||
def __str__(self):
|
||||
return self.trial_name
|
||||
|
||||
|
||||
MockRunInfo = namedtuple("MockRunInfo", ["run_id"])
|
||||
|
||||
|
||||
class MockRun:
|
||||
def __init__(self, run_id, tags=None):
|
||||
self.run_id = run_id
|
||||
self.tags = tags
|
||||
self.info = MockRunInfo(run_id)
|
||||
self.params = []
|
||||
self.metrics = []
|
||||
self.artifacts = []
|
||||
|
||||
def log_param(self, key, value):
|
||||
self.params.append({key: value})
|
||||
|
||||
def log_metric(self, key, value):
|
||||
self.metrics.append({key: value})
|
||||
|
||||
def log_artifact(self, artifact):
|
||||
self.artifacts.append(artifact)
|
||||
|
||||
def set_terminated(self, status):
|
||||
self.terminated = True
|
||||
self.status = status
|
||||
|
||||
|
||||
MockExperiment = namedtuple("MockExperiment", ["name", "experiment_id"])
|
||||
|
||||
|
||||
class MockMlflowClient:
|
||||
def __init__(self, tracking_uri=None, registry_uri=None):
|
||||
self.tracking_uri = tracking_uri
|
||||
self.registry_uri = registry_uri
|
||||
self.experiments = [MockExperiment("existing_experiment", 0)]
|
||||
self.runs = {0: []}
|
||||
self.active_run = None
|
||||
|
||||
def set_tracking_uri(self, tracking_uri):
|
||||
self.tracking_uri = tracking_uri
|
||||
|
||||
def get_experiment_by_name(self, name):
|
||||
try:
|
||||
index = self.experiment_names.index(name)
|
||||
return self.experiments[index]
|
||||
except ValueError:
|
||||
return None
|
||||
|
||||
def get_experiment(self, experiment_id):
|
||||
experiment_id = int(experiment_id)
|
||||
try:
|
||||
return self.experiments[experiment_id]
|
||||
except IndexError:
|
||||
return None
|
||||
|
||||
def create_experiment(self, name):
|
||||
experiment_id = len(self.experiments)
|
||||
self.experiments.append(MockExperiment(name, experiment_id))
|
||||
self.runs[experiment_id] = []
|
||||
return experiment_id
|
||||
|
||||
def create_run(self, experiment_id, tags=None):
|
||||
experiment_runs = self.runs[experiment_id]
|
||||
run_id = (experiment_id, len(experiment_runs))
|
||||
run = MockRun(run_id=run_id, tags=tags)
|
||||
experiment_runs.append(run)
|
||||
return run
|
||||
|
||||
def start_run(self, experiment_id, run_name):
|
||||
# Creates new run and sets it as active.
|
||||
run = self.create_run(experiment_id)
|
||||
self.active_run = run
|
||||
|
||||
def get_mock_run(self, run_id):
|
||||
return self.runs[run_id[0]][run_id[1]]
|
||||
|
||||
def log_param(self, run_id, key, value):
|
||||
run = self.get_mock_run(run_id)
|
||||
run.log_param(key, value)
|
||||
|
||||
def log_metric(self, run_id, key, value, step):
|
||||
run = self.get_mock_run(run_id)
|
||||
run.log_metric(key, value)
|
||||
|
||||
def log_artifacts(self, run_id, local_dir):
|
||||
run = self.get_mock_run(run_id)
|
||||
run.log_artifact(local_dir)
|
||||
|
||||
def set_terminated(self, run_id, status):
|
||||
run = self.get_mock_run(run_id)
|
||||
run.set_terminated(status)
|
||||
|
||||
@property
|
||||
def experiment_names(self):
|
||||
return [e.name for e in self.experiments]
|
||||
|
||||
|
||||
def clear_env_vars():
|
||||
if "MLFLOW_EXPERIMENT_NAME" in os.environ:
|
||||
del os.environ["MLFLOW_EXPERIMENT_NAME"]
|
||||
if "MLFLOW_EXPERIMENT_ID" in os.environ:
|
||||
del os.environ["MLFLOW_EXPERIMENT_ID"]
|
||||
|
||||
|
||||
class MLFlowTest(unittest.TestCase):
|
||||
@patch("mlflow.tracking.MlflowClient", MockMlflowClient)
|
||||
def testMlFlowLoggerCallbackConfig(self):
|
||||
# Explicitly pass in all args.
|
||||
logger = MLFlowLoggerCallback(
|
||||
tracking_uri="test1",
|
||||
registry_uri="test2",
|
||||
experiment_name="test_exp")
|
||||
self.assertEqual(logger.client.tracking_uri, "test1")
|
||||
self.assertEqual(logger.client.registry_uri, "test2")
|
||||
self.assertListEqual(logger.client.experiment_names,
|
||||
["existing_experiment", "test_exp"])
|
||||
self.assertEqual(logger.experiment_id, 1)
|
||||
|
||||
# Check if client recognizes already existing experiment.
|
||||
logger = MLFlowLoggerCallback(experiment_name="existing_experiment")
|
||||
self.assertListEqual(logger.client.experiment_names,
|
||||
["existing_experiment"])
|
||||
self.assertEqual(logger.experiment_id, 0)
|
||||
|
||||
# Pass in experiment name as env var.
|
||||
clear_env_vars()
|
||||
os.environ["MLFLOW_EXPERIMENT_NAME"] = "test_exp"
|
||||
logger = MLFlowLoggerCallback()
|
||||
self.assertListEqual(logger.client.experiment_names,
|
||||
["existing_experiment", "test_exp"])
|
||||
self.assertEqual(logger.experiment_id, 1)
|
||||
|
||||
# Pass in existing experiment name as env var.
|
||||
clear_env_vars()
|
||||
os.environ["MLFLOW_EXPERIMENT_NAME"] = "existing_experiment"
|
||||
logger = MLFlowLoggerCallback()
|
||||
self.assertListEqual(logger.client.experiment_names,
|
||||
["existing_experiment"])
|
||||
self.assertEqual(logger.experiment_id, 0)
|
||||
|
||||
# Pass in existing experiment id as env var.
|
||||
clear_env_vars()
|
||||
os.environ["MLFLOW_EXPERIMENT_ID"] = "0"
|
||||
logger = MLFlowLoggerCallback()
|
||||
self.assertListEqual(logger.client.experiment_names,
|
||||
["existing_experiment"])
|
||||
self.assertEqual(logger.experiment_id, "0")
|
||||
|
||||
# Pass in non existing experiment id as env var.
|
||||
clear_env_vars()
|
||||
os.environ["MLFLOW_EXPERIMENT_ID"] = "500"
|
||||
with self.assertRaises(ValueError):
|
||||
logger = MLFlowLoggerCallback()
|
||||
|
||||
# Experiment name env var should take precedence over id env var.
|
||||
clear_env_vars()
|
||||
os.environ["MLFLOW_EXPERIMENT_NAME"] = "test_exp"
|
||||
os.environ["MLFLOW_EXPERIMENT_ID"] = "0"
|
||||
logger = MLFlowLoggerCallback()
|
||||
self.assertListEqual(logger.client.experiment_names,
|
||||
["existing_experiment", "test_exp"])
|
||||
self.assertEqual(logger.experiment_id, 1)
|
||||
|
||||
@patch("mlflow.tracking.MlflowClient", MockMlflowClient)
|
||||
def testMlFlowLoggerLogging(self):
|
||||
clear_env_vars()
|
||||
trial_config = {"par1": 4, "par2": 9.}
|
||||
trial = MockTrial(trial_config, "trial1", 0, "artifact")
|
||||
|
||||
logger = MLFlowLoggerCallback(
|
||||
experiment_name="test1", save_artifact=True)
|
||||
|
||||
# Check if run is created.
|
||||
logger.on_trial_start(iteration=0, trials=[], trial=trial)
|
||||
# New run should be created for this trial with correct tag.
|
||||
mock_run = logger.client.runs[1][0]
|
||||
self.assertDictEqual(mock_run.tags, {"trial_name": "trial1"})
|
||||
self.assertTupleEqual(mock_run.run_id, (1, 0))
|
||||
self.assertTupleEqual(logger._trial_runs[trial], mock_run.run_id)
|
||||
# Params should be logged.
|
||||
self.assertListEqual(mock_run.params, [{"par1": 4}, {"par2": 9}])
|
||||
|
||||
# When same trial is started again, new run should not be created.
|
||||
logger.on_trial_start(iteration=0, trials=[], trial=trial)
|
||||
self.assertEqual(len(logger.client.runs[1]), 1)
|
||||
|
||||
# Check metrics are logged properly.
|
||||
result = {"metric1": 0.8, "metric2": 1, "metric3": None}
|
||||
logger.on_trial_result(0, [], trial, result)
|
||||
mock_run = logger.client.runs[1][0]
|
||||
# metric3 is not logged since it cannot be converted to float.
|
||||
self.assertListEqual(mock_run.metrics, [{
|
||||
"metric1": 0.8
|
||||
}, {
|
||||
"metric2": 1.0
|
||||
}])
|
||||
|
||||
# Check that artifact is logged on termination.
|
||||
logger.on_trial_complete(0, [], trial)
|
||||
mock_run = logger.client.runs[1][0]
|
||||
self.assertListEqual(mock_run.artifacts, ["artifact"])
|
||||
self.assertTrue(mock_run.terminated)
|
||||
self.assertEqual(mock_run.status, "FINISHED")
|
||||
|
||||
@patch("mlflow.tracking.MlflowClient", MockMlflowClient)
|
||||
def testMlFlowLegacyLoggerConfig(self):
|
||||
mlflow = MockMlflowClient()
|
||||
with patch.dict("sys.modules", mlflow=mlflow):
|
||||
clear_env_vars()
|
||||
trial_config = {"par1": 4, "par2": 9.}
|
||||
trial = MockTrial(trial_config, "trial1", 0, "artifact")
|
||||
|
||||
# No experiment_id is passed in config, should raise an error.
|
||||
with self.assertRaises(ValueError):
|
||||
logger = MLFlowLogger(trial_config, "/tmp", trial)
|
||||
|
||||
trial_config.update({
|
||||
"logger_config": {
|
||||
"mlflow_tracking_uri": "test_tracking_uri",
|
||||
"mlflow_experiment_id": 0
|
||||
}
|
||||
})
|
||||
trial = MockTrial(trial_config, "trial2", 1, "artifact")
|
||||
logger = MLFlowLogger(trial_config, "/tmp", trial)
|
||||
experiment_logger = logger._trial_experiment_logger
|
||||
client = experiment_logger.client
|
||||
self.assertEqual(client.tracking_uri, "test_tracking_uri")
|
||||
# Check to make sure that a run was created on experiment_id 0.
|
||||
self.assertEqual(len(client.runs[0]), 1)
|
||||
mock_run = client.runs[0][0]
|
||||
self.assertDictEqual(mock_run.tags, {"trial_name": "trial2"})
|
||||
self.assertListEqual(mock_run.params, [{"par1": 4}, {"par2": 9}])
|
||||
|
||||
@patch("ray.tune.integration.mlflow._import_mlflow",
|
||||
lambda: MockMlflowClient())
|
||||
def testMlFlowMixinConfig(self):
|
||||
clear_env_vars()
|
||||
trial_config = {"par1": 4, "par2": 9.}
|
||||
|
||||
@mlflow_mixin
|
||||
def train_fn(config):
|
||||
return 1
|
||||
|
||||
train_fn.__mixins__ = (MLFlowTrainableMixin, )
|
||||
|
||||
# No MLFlow config passed in.
|
||||
with self.assertRaises(ValueError):
|
||||
wrapped = wrap_function(train_fn)(trial_config)
|
||||
|
||||
trial_config.update({"mlflow": {}})
|
||||
# No tracking uri or experiment_id/name passed in.
|
||||
with self.assertRaises(ValueError):
|
||||
wrapped = wrap_function(train_fn)(trial_config)
|
||||
|
||||
# Invalid experiment-id
|
||||
trial_config["mlflow"].update({"experiment_id": "500"})
|
||||
# No tracking uri or experiment_id/name passed in.
|
||||
with self.assertRaises(ValueError):
|
||||
wrapped = wrap_function(train_fn)(trial_config)
|
||||
|
||||
trial_config["mlflow"].update({
|
||||
"tracking_uri": "test_tracking_uri",
|
||||
"experiment_name": "existing_experiment"
|
||||
})
|
||||
wrapped = wrap_function(train_fn)(trial_config)
|
||||
client = wrapped._mlflow
|
||||
self.assertEqual(client.tracking_uri, "test_tracking_uri")
|
||||
self.assertTupleEqual(client.active_run.run_id, (0, 0))
|
||||
|
||||
with patch("ray.tune.integration.mlflow._import_mlflow",
|
||||
lambda: client):
|
||||
train_fn.__mixins__ = (MLFlowTrainableMixin, )
|
||||
wrapped = wrap_function(train_fn)(trial_config)
|
||||
client = wrapped._mlflow
|
||||
self.assertTupleEqual(client.active_run.run_id, (0, 1))
|
||||
|
||||
# Set to experiment that does not already exist.
|
||||
# New experiment should be created.
|
||||
trial_config["mlflow"]["experiment_name"] = "new_experiment"
|
||||
with self.assertRaises(ValueError):
|
||||
wrapped = wrap_function(train_fn)(trial_config)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import pytest
|
||||
import sys
|
||||
sys.exit(pytest.main(["-v", __file__]))
|
||||
@@ -972,4 +972,4 @@ class SearchSpaceTest(unittest.TestCase):
|
||||
if __name__ == "__main__":
|
||||
import pytest
|
||||
import sys
|
||||
sys.exit(pytest.main(["-v", __file__]))
|
||||
sys.exit(pytest.main(["-v", __file__] + sys.argv[1:]))
|
||||
|
||||
@@ -3,6 +3,7 @@ bayesian-optimization
|
||||
ConfigSpace==0.4.10
|
||||
dragonfly-opt
|
||||
gluoncv
|
||||
gorilla # Need this because bug in mlflow. Should be fixed in v1.12.2
|
||||
gym[atari]
|
||||
GPy
|
||||
h5py
|
||||
|
||||
Reference in New Issue
Block a user