[Tune] Mlflow Integration (#12840)

Co-authored-by: Kai Fricke <krfricke@users.noreply.github.com>
Co-authored-by: Richard Liaw <rliaw@berkeley.edu>
This commit is contained in:
Amog Kamsetty
2020-12-19 00:40:02 -08:00
committed by GitHub
parent 5d987f5988
commit 5d3c9c8861
18 changed files with 958 additions and 75 deletions
+24 -8
View File
@@ -125,6 +125,14 @@ py_test(
tags = ["exclusive"],
)
py_test(
name = "test_integration_mlflow",
size = "small",
srcs = ["tests/test_integration_mlflow.py"],
deps = [":tune_lib"],
tags = ["exclusive"]
)
py_test(
name = "test_logger",
size = "small",
@@ -473,15 +481,23 @@ py_test(
args = ["--smoke-test"]
)
# Commenting out for now because it is not idempotent
# py_test(
# name = "mlflow_example",
# size = "medium",
# srcs = ["examples/mlflow_example.py"],
# deps = [":tune_lib"],
# tags = ["exclusive", "example"]
# )
py_test(
name = "mlflow_example",
size = "medium",
srcs = ["examples/mlflow_example.py"],
deps = [":tune_lib"],
tags = ["exclusive", "example"]
)
# Comment out for now until we sort out our dependencies.
#py_test(
# name = "mlflow_ptl",
# size = "medium",
# srcs = ["examples/mlflow_ptl.py"],
# deps = [":tune_lib"],
# tags = ["exclusive", "example", "py37", "pytorch"],
# args = ["--smoke-test"]
#)
py_test(
name = "mnist_pytorch",
size = "small",
+75 -21
View File
@@ -1,17 +1,14 @@
#!/usr/bin/env python
"""Simple MLFLow Logger example.
This uses a simple MLFlow logger. One limitation of this is that there is
no artifact support; to save artifacts with Tune and MLFlow, you will need to
start a MLFlow run inside the Trainable function/class.
"""Examples using MLFlowLoggerCallback and mlflow_mixin.
"""
import mlflow
from mlflow.tracking import MlflowClient
import os
import tempfile
import time
import mlflow
from ray import tune
from ray.tune.logger import MLFLowLogger, DEFAULT_LOGGERS
from ray.tune.integration.mlflow import MLFlowLoggerCallback, mlflow_mixin
def evaluation_fn(step, width, height):
@@ -25,27 +22,84 @@ def easy_objective(config):
for step in range(config.get("steps", 100)):
# Iterative training function - can be any arbitrary training procedure
intermediate_score = evaluation_fn(step, width, height)
# Feed the score back back to Tune.
# Feed the score back to Tune.
tune.report(iterations=step, mean_loss=intermediate_score)
time.sleep(0.1)
if __name__ == "__main__":
client = MlflowClient()
experiment_id = client.create_experiment("test")
trials = tune.run(
def tune_function(mlflow_tracking_uri, finish_fast=False):
tune.run(
easy_objective,
name="mlflow",
num_samples=5,
loggers=DEFAULT_LOGGERS + (MLFLowLogger, ),
callbacks=[
MLFlowLoggerCallback(
tracking_uri=mlflow_tracking_uri,
experiment_name="example",
save_artifact=True)
],
config={
"logger_config": {
"mlflow_experiment_id": experiment_id,
},
"width": tune.randint(10, 100),
"height": tune.randint(0, 100),
"steps": 5 if finish_fast else 100,
})
df = mlflow.search_runs([experiment_id])
print(df)
@mlflow_mixin
def decorated_easy_objective(config):
# Hyperparameters
width, height = config["width"], config["height"]
for step in range(config.get("steps", 100)):
# Iterative training function - can be any arbitrary training procedure
intermediate_score = evaluation_fn(step, width, height)
# Log the metrics to mlflow
mlflow.log_metrics(dict(mean_loss=intermediate_score), step=step)
# Feed the score back to Tune.
tune.report(iterations=step, mean_loss=intermediate_score)
time.sleep(0.1)
def tune_decorated(mlflow_tracking_uri, finish_fast=False):
# Set the experiment, or create a new one if does not exist yet.
mlflow.set_tracking_uri(mlflow_tracking_uri)
mlflow.set_experiment(experiment_name="mixin_example")
tune.run(
decorated_easy_objective,
name="mlflow",
num_samples=5,
config={
"width": tune.randint(10, 100),
"height": tune.randint(0, 100),
"steps": 5 if finish_fast else 100,
"mlflow": {
"experiment_name": "mixin_example",
"tracking_uri": mlflow.get_tracking_uri()
}
})
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser()
parser.add_argument(
"--smoke-test", action="store_true", help="Finish quickly for testing")
args, _ = parser.parse_known_args()
if args.smoke_test:
mlflow_tracking_uri = os.path.join(tempfile.gettempdir(), "mlruns")
else:
mlflow_tracking_uri = None
tune_function(mlflow_tracking_uri, finish_fast=args.smoke_test)
if not args.smoke_test:
df = mlflow.search_runs(
[mlflow.get_experiment_by_name("example").experiment_id])
print(df)
tune_decorated(mlflow_tracking_uri, finish_fast=args.smoke_test)
if not args.smoke_test:
df = mlflow.search_runs(
[mlflow.get_experiment_by_name("mixin_example").experiment_id])
print(df)
+93
View File
@@ -0,0 +1,93 @@
"""An example showing how to use Pytorch Lightning training, Ray Tune
HPO, and MLFlow autologging all together."""
import os
import tempfile
import pytorch_lightning as pl
from pl_bolts.datamodules import MNISTDataModule
import mlflow
from ray import tune
from ray.tune.integration.mlflow import mlflow_mixin
from ray.tune.integration.pytorch_lightning import TuneReportCallback
from ray.tune.examples.mnist_ptl_mini import LightningMNISTClassifier
@mlflow_mixin
def train_mnist_tune(config, data_dir=None, num_epochs=10, num_gpus=0):
model = LightningMNISTClassifier(config, data_dir)
dm = MNISTDataModule(
data_dir=data_dir, num_workers=1, batch_size=config["batch_size"])
metrics = {"loss": "ptl/val_loss", "acc": "ptl/val_accuracy"}
mlflow.pytorch.autolog()
trainer = pl.Trainer(
max_epochs=num_epochs,
gpus=num_gpus,
progress_bar_refresh_rate=0,
callbacks=[TuneReportCallback(metrics, on="validation_end")])
trainer.fit(model, dm)
def tune_mnist(num_samples=10,
num_epochs=10,
gpus_per_trial=0,
tracking_uri=None):
data_dir = os.path.join(tempfile.gettempdir(), "mnist_data_")
# Download data
MNISTDataModule(data_dir=data_dir).prepare_data()
# Set the MLFlow experiment, or create it if it does not exist.
mlflow.set_tracking_uri(tracking_uri)
mlflow.set_experiment("ptl_autologging_test")
config = {
"layer_1": tune.choice([32, 64, 128]),
"layer_2": tune.choice([64, 128, 256]),
"lr": tune.loguniform(1e-4, 1e-1),
"batch_size": tune.choice([32, 64, 128]),
"mlflow": {
"experiment_name": "ptl_autologging_test",
"tracking_uri": mlflow.get_tracking_uri()
},
"data_dir": os.path.join(tempfile.gettempdir(), "mnist_data_"),
"num_epochs": num_epochs
}
trainable = tune.with_parameters(
train_mnist_tune,
data_dir=data_dir,
num_epochs=num_epochs,
num_gpus=gpus_per_trial)
analysis = tune.run(
trainable,
resources_per_trial={
"cpu": 1,
"gpu": gpus_per_trial
},
metric="loss",
mode="min",
config=config,
num_samples=num_samples,
name="tune_mnist")
print("Best hyperparameters found were: ", analysis.best_config)
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser()
parser.add_argument(
"--smoke-test", action="store_true", help="Finish quickly for testing")
args, _ = parser.parse_known_args()
if args.smoke_test:
tune_mnist(
num_samples=1,
num_epochs=1,
gpus_per_trial=0,
tracking_uri=os.path.join(tempfile.gettempdir(), "mlruns"))
else:
tune_mnist(num_samples=10, num_epochs=10, gpus_per_trial=0)
+7 -2
View File
@@ -509,8 +509,9 @@ class FunctionRunner(Trainable):
try:
err_tb_str = self._error_queue.get(
block=block, timeout=ERROR_FETCH_TIMEOUT)
raise TuneError(("Trial raised an exception. Traceback:\n{}"
.format(err_tb_str)))
raise TuneError(
("Trial raised an exception. Traceback:\n{}".format(err_tb_str)
))
except queue.Empty:
pass
@@ -649,6 +650,10 @@ def with_parameters(fn, **kwargs):
def _inner(config):
inner(config, checkpoint_dir=None)
if hasattr(fn, "__mixins__"):
_inner.__mixins__ = fn.__mixins__
return _inner
if hasattr(fn, "__mixins__"):
inner.__mixins__ = fn.__mixins__
return inner
+366
View File
@@ -0,0 +1,366 @@
import os
from typing import Dict, Callable, Optional
import logging
from ray.tune.trainable import Trainable
from ray.tune.logger import Logger, LoggerCallback
from ray.tune.result import TRAINING_ITERATION
from ray.tune.trial import Trial
logger = logging.getLogger(__name__)
def _import_mlflow():
try:
import mlflow
except ImportError:
mlflow = None
return mlflow
class MLFlowLoggerCallback(LoggerCallback):
"""MLFlow Logger to automatically log Tune results and config to MLFlow.
MLFlow (https://mlflow.org) Tracking is an open source library for
recording and querying experiments. This Ray Tune ``LoggerCallback``
sends information (config parameters, training results & metrics,
and artifacts) to MLFlow for automatic experiment tracking.
Args:
tracking_uri (str): The tracking URI for where to manage experiments
and runs. This can either be a local file path or a remote server.
This arg gets passed directly to mlflow.tracking.MlflowClient
initialization. When using Tune in a multi-node setting, make sure
to set this to a remote server and not a local file path.
registry_uri (str): The registry URI that gets passed directly to
mlflow.tracking.MlflowClient initialization.
experiment_name (str): The experiment name to use for this Tune run.
If None is passed in here, the Logger will automatically then
check the MLFLOW_EXPERIMENT_NAME and then the MLFLOW_EXPERIMENT_ID
environment variables to determine the experiment name.
If the experiment with the name already exists with MlFlow,
it will be reused. If not, a new experiment will be created with
that name.
save_artifact (bool): If set to True, automatically save the entire
contents of the Tune local_dir as an artifact to the
corresponding run in MlFlow.
Example:
.. code-block:: python
from ray.tune.integration.mlflow import MLFlowLoggerCallback
tune.run(
train_fn,
config={
# define search space here
"parameter_1": tune.choice([1, 2, 3]),
"parameter_2": tune.choice([4, 5, 6]),
},
callbacks=[MLFlowLoggerCallback(
experiment_name="experiment1",
save_artifact=True)])
"""
def __init__(self,
tracking_uri: Optional[str] = None,
registry_uri: Optional[str] = None,
experiment_name: Optional[str] = None,
save_artifact: bool = False):
mlflow = _import_mlflow()
if mlflow is None:
raise RuntimeError("MLFlow has not been installed. Please `pip "
"install mlflow` to use the MLFlowLogger.")
from mlflow.tracking import MlflowClient
self.client = MlflowClient(
tracking_uri=tracking_uri, registry_uri=registry_uri)
if experiment_name is None:
# If no name is passed in, then check env vars.
# First check if experiment_name env var is set.
experiment_name = os.environ.get("MLFLOW_EXPERIMENT_NAME")
if experiment_name is not None:
# First check if experiment with name exists.
experiment = self.client.get_experiment_by_name(experiment_name)
if experiment is not None:
# If it already exists then get the id.
experiment_id = experiment.experiment_id
else:
# If it does not exist, create the experiment.
experiment_id = self.client.create_experiment(
name=experiment_name)
else:
# No experiment_name is passed in and name env var is not set.
# Now check the experiment id env var.
experiment_id = os.environ.get("MLFLOW_EXPERIMENT_ID")
# Confirm that an experiment with this id exists.
if experiment_id is None or self.client.get_experiment(
experiment_id) is None:
raise ValueError("No experiment_name passed, "
"MLFLOW_EXPERIMENT_NAME env var is not "
"set, and MLFLOW_EXPERIMENT_ID either "
"is not set or does not exist. Please "
"set one of these to use the "
"MLFlowLoggerCallback.")
# At this point, experiment_id should be set.
self.experiment_id = experiment_id
self.save_artifact = save_artifact
self._trial_runs = {}
def log_trial_start(self, trial: "Trial"):
# Create run if not already exists.
if trial not in self._trial_runs:
run = self.client.create_run(
experiment_id=self.experiment_id,
tags={"trial_name": str(trial)})
self._trial_runs[trial] = run.info.run_id
run_id = self._trial_runs[trial]
# Log the config parameters.
config = trial.config
for key, value in config.items():
self.client.log_param(run_id=run_id, key=key, value=value)
def log_trial_result(self, iteration: int, trial: "Trial", result: Dict):
run_id = self._trial_runs[trial]
for key, value in result.items():
try:
value = float(value)
except (ValueError, TypeError):
logger.debug("Cannot log key {} with value {} since the "
"value cannot be converted to float.".format(
key, value))
continue
self.client.log_metric(
run_id=run_id, key=key, value=value, step=iteration)
def log_trial_end(self, trial: "Trial", failed: bool = False):
run_id = self._trial_runs[trial]
# Log the artifact if set_artifact is set to True.
if self.save_artifact:
self.client.log_artifacts(run_id, local_dir=trial.logdir)
# Stop the run once trial finishes.
status = "FINISHED" if not failed else "FAILED"
self.client.set_terminated(run_id=run_id, status=status)
class MLFlowLogger(Logger):
"""MLFlow logger using the deprecated Logger API.
Requires the experiment configuration to have a MLFlow Experiment ID
or manually set the proper environment variables.
"""
_experiment_logger_cls = MLFlowLoggerCallback
def _init(self):
mlflow = _import_mlflow()
logger_config = self.config.pop("logger_config", {})
tracking_uri = logger_config.get("mlflow_tracking_uri")
registry_uri = logger_config.get("mlflow_registry_uri")
experiment_id = logger_config.get("mlflow_experiment_id")
if experiment_id is None or not mlflow.get_experiment(experiment_id):
raise ValueError(
"You must provide a valid `mlflow_experiment_id` "
"in your `logger_config` dict in the `config` "
"dict passed to `tune.run`. "
"Are you sure you passed in a `experiment_id` and "
"the experiment exists?")
else:
experiment_name = mlflow.get_experiment(experiment_id).name
self._trial_experiment_logger = self._experiment_logger_cls(
tracking_uri, registry_uri, experiment_name)
self._trial_experiment_logger.log_trial_start(self.trial)
def on_result(self, result: Dict):
self._trial_experiment_logger.log_trial_result(
iteration=result.get(TRAINING_ITERATION),
trial=self.trial,
result=result)
def close(self):
self._trial_experiment_logger.log_trial_end(
trial=self.trial, failed=False)
del self._trial_experiment_logger
def mlflow_mixin(func: Callable):
"""mlflow_mixin
MLFlow (https://mlflow.org) Tracking is an open source library for
recording and querying experiments. This Ray Tune Trainable mixin helps
initialize the MLflow API for use with the ``Trainable`` class or the
``@mlflow_mixin`` function API. This mixin automatically configures MLFlow
and creates a run in the same process as each Tune trial. You can then
use the mlflow API inside the your training function and it will
automatically get reported to the correct run.
For basic usage, just prepend your training function with the
``@mlflow_mixin`` decorator:
.. code-block:: python
from ray.tune.integration.mlflow import mlflow_mixin
@mlflow_mixin
def train_fn(config):
...
mlflow.log_metric(...)
You can also use MlFlow's autologging feature if using a training
framework like Pytorch Lightning, XGBoost, etc. More information can be
found here (https://mlflow.org/docs/latest/tracking.html#automatic
-logging).
.. code-block:: python
from ray.tune.integration.mlflow import mlflow_mixin
@mlflow_mixin
def train_fn(config):
mlflow.autolog()
xgboost_results = xgb.train(config, ...)
The MlFlow configuration is done by passing a ``mlflow`` key to
the ``config`` parameter of ``tune.run()`` (see example below).
The content of the ``mlflow`` config entry is used to
configure MlFlow. Here are the keys you can pass in to this config entry:
Args:
tracking_uri (str): The tracking URI for MLflow tracking. If using
Tune in a multi-node setting, make sure to use a remote server for
tracking.
experiment_id (str): The id of an already created MLflow experiment.
All logs from all trials in ``tune.run`` will be reported to this
experiment. If this is not provided or the experiment with this
id does not exist, you must provide an``experiment_name``. This
parameter takes precedence over ``experiment_name``.
experiment_name (str): The name of an already existing MLflow
experiment. All logs from all trials in ``tune.run`` will be
reported to this experiment. If this is not provided, you must
provide a valid ``experiment_id``.
Example:
.. code-block:: python
from ray import tune
from ray.tune.integration.mlflow import mlflow_mixin
import mlflow
# Create the MlFlow expriment.
mlflow.create_experiment("my_experiment")
@mlflow_mixin
def train_fn(config):
for i in range(10):
loss = self.config["a"] + self.config["b"]
mlflow.log_metric(key="loss", value=loss})
tune.report(loss=loss, done=True)
tune.run(
train_fn,
config={
# define search space here
"a": tune.choice([1, 2, 3]),
"b": tune.choice([4, 5, 6]),
# mlflow configuration
"mlflow": {
"experiment_name": "my_experiment",
"tracking_uri": mlflow.get_tracking_uri()
}
})
"""
if _import_mlflow() is None:
raise RuntimeError("MLFlow has not been installed. Please `pip "
"install mlflow` to use the mlflow_mixin.")
if hasattr(func, "__mixins__"):
func.__mixins__ = func.__mixins__ + (MLFlowTrainableMixin, )
else:
func.__mixins__ = (MLFlowTrainableMixin, )
return func
class MLFlowTrainableMixin:
def __init__(self, config: Dict, *args, **kwargs):
self._mlflow = _import_mlflow()
if not isinstance(self, Trainable):
raise ValueError(
"The `MLFlowTrainableMixin` can only be used as a mixin "
"for `tune.Trainable` classes. Please make sure your "
"class inherits from both. For example: "
"`class YourTrainable(MLFlowTrainableMixin)`.")
super().__init__(config, *args, **kwargs)
_config = config.copy()
try:
mlflow_config = _config.pop("mlflow").copy()
except KeyError as e:
raise ValueError(
"MLFlow mixin specified but no configuration has been passed. "
"Make sure to include a `mlflow` key in your `config` dict "
"containing at least a `tracking_uri` and either "
"`experiment_name` or `experiment_id` specification.") from e
tracking_uri = mlflow_config.pop("tracking_uri", None)
if tracking_uri is None:
raise ValueError("MLFlow mixin specified but no "
"tracking_uri has been "
"passed in. Make sure to include a `mlflow` "
"key in your `config` dict containing at "
"least a `tracking_uri`")
self._mlflow.set_tracking_uri(tracking_uri)
# First see if experiment_id is passed in.
experiment_id = mlflow_config.pop("experiment_id", None)
if experiment_id is None or self._mlflow.get_experiment(
experiment_id) is None:
logger.debug("Either no experiment_id is passed in, or the "
"experiment with the given id does not exist. "
"Checking experiment_name")
# Check for name.
experiment_name = mlflow_config.pop("experiment_name", None)
if experiment_name is None:
raise ValueError(
"MLFlow mixin specified but no "
"experiment_name or experiment_id has been "
"passed in. Make sure to include a `mlflow` "
"key in your `config` dict containing at "
"least a `experiment_name` or `experiment_id` "
"specification.")
experiment = self._mlflow.get_experiment_by_name(experiment_name)
if experiment is not None:
# Experiment with this name exists.
experiment_id = experiment.experiment_id
else:
raise ValueError("No experiment with the given "
"name: {} or id: {} currently exists. Make "
"sure to first start the MLFlow experiment "
"before calling tune.run.".format(
experiment_name, experiment_id))
self.experiment_id = experiment_id
run_name = self.trial_name + "_" + self.trial_id
run_name = run_name.replace("/", "_")
self._mlflow.start_run(
experiment_id=self.experiment_id, run_name=run_name)
def stop(self):
self._mlflow.end_run()
+4 -1
View File
@@ -139,7 +139,10 @@ def wandb_mixin(func: Callable):
})
"""
func.__mixins__ = (WandbTrainableMixin, )
if hasattr(func, "__mixins__"):
func.__mixins__ = func.__mixins__ + (WandbTrainableMixin, )
else:
func.__mixins__ = (WandbTrainableMixin, )
return func
+7 -31
View File
@@ -77,37 +77,6 @@ class NoopLogger(Logger):
pass
class MLFLowLogger(Logger):
"""MLFlow logger.
Requires the experiment configuration to have a MLFlow Experiment ID
or manually set the proper environment variables.
"""
def _init(self):
logger_config = self.config.get("logger_config", {})
from mlflow.tracking import MlflowClient
client = MlflowClient(
tracking_uri=logger_config.get("mlflow_tracking_uri"),
registry_uri=logger_config.get("mlflow_registry_uri"))
run = client.create_run(logger_config.get("mlflow_experiment_id"))
self._run_id = run.info.run_id
for key, value in self.config.items():
client.log_param(self._run_id, key, value)
self.client = client
def on_result(self, result: Dict):
for key, value in result.items():
if not isinstance(value, float):
continue
self.client.log_metric(
self._run_id, key, value, step=result.get(TRAINING_ITERATION))
def close(self):
self.client.set_terminated(self._run_id)
class JsonLogger(Logger):
"""Logs trial results in json format.
@@ -734,6 +703,13 @@ class TBXLoggerCallback(LoggerCallback):
"in the hyperparameter values.")
# Maintain backwards compatibility.
from ray.tune.integration.mlflow import MLFlowLogger as _MLFlowLogger # noqa: E402, E501
MLFlowLogger = _MLFlowLogger
# The capital L is a typo, but needs to remain for backwards compatibility.
MLFLowLogger = _MLFlowLogger
def pretty_print(result):
result = result.copy()
result.update(config=None) # drop config from pretty print
+1
View File
@@ -23,3 +23,4 @@ if __name__ == "__main__":
}
})
assert "ray.rllib" not in sys.modules, "RLlib should not be imported"
assert "mlflow" not in sys.modules, "MLFlow should not be imported"
@@ -0,0 +1,306 @@
import os
import unittest
from collections import namedtuple
from unittest.mock import patch
from ray.tune.function_runner import wrap_function
from ray.tune.integration.mlflow import MLFlowLoggerCallback, MLFlowLogger, \
mlflow_mixin, MLFlowTrainableMixin
class MockTrial(
namedtuple("MockTrial",
["config", "trial_name", "trial_id", "logdir"])):
def __hash__(self):
return hash(self.trial_id)
def __str__(self):
return self.trial_name
MockRunInfo = namedtuple("MockRunInfo", ["run_id"])
class MockRun:
def __init__(self, run_id, tags=None):
self.run_id = run_id
self.tags = tags
self.info = MockRunInfo(run_id)
self.params = []
self.metrics = []
self.artifacts = []
def log_param(self, key, value):
self.params.append({key: value})
def log_metric(self, key, value):
self.metrics.append({key: value})
def log_artifact(self, artifact):
self.artifacts.append(artifact)
def set_terminated(self, status):
self.terminated = True
self.status = status
MockExperiment = namedtuple("MockExperiment", ["name", "experiment_id"])
class MockMlflowClient:
def __init__(self, tracking_uri=None, registry_uri=None):
self.tracking_uri = tracking_uri
self.registry_uri = registry_uri
self.experiments = [MockExperiment("existing_experiment", 0)]
self.runs = {0: []}
self.active_run = None
def set_tracking_uri(self, tracking_uri):
self.tracking_uri = tracking_uri
def get_experiment_by_name(self, name):
try:
index = self.experiment_names.index(name)
return self.experiments[index]
except ValueError:
return None
def get_experiment(self, experiment_id):
experiment_id = int(experiment_id)
try:
return self.experiments[experiment_id]
except IndexError:
return None
def create_experiment(self, name):
experiment_id = len(self.experiments)
self.experiments.append(MockExperiment(name, experiment_id))
self.runs[experiment_id] = []
return experiment_id
def create_run(self, experiment_id, tags=None):
experiment_runs = self.runs[experiment_id]
run_id = (experiment_id, len(experiment_runs))
run = MockRun(run_id=run_id, tags=tags)
experiment_runs.append(run)
return run
def start_run(self, experiment_id, run_name):
# Creates new run and sets it as active.
run = self.create_run(experiment_id)
self.active_run = run
def get_mock_run(self, run_id):
return self.runs[run_id[0]][run_id[1]]
def log_param(self, run_id, key, value):
run = self.get_mock_run(run_id)
run.log_param(key, value)
def log_metric(self, run_id, key, value, step):
run = self.get_mock_run(run_id)
run.log_metric(key, value)
def log_artifacts(self, run_id, local_dir):
run = self.get_mock_run(run_id)
run.log_artifact(local_dir)
def set_terminated(self, run_id, status):
run = self.get_mock_run(run_id)
run.set_terminated(status)
@property
def experiment_names(self):
return [e.name for e in self.experiments]
def clear_env_vars():
if "MLFLOW_EXPERIMENT_NAME" in os.environ:
del os.environ["MLFLOW_EXPERIMENT_NAME"]
if "MLFLOW_EXPERIMENT_ID" in os.environ:
del os.environ["MLFLOW_EXPERIMENT_ID"]
class MLFlowTest(unittest.TestCase):
@patch("mlflow.tracking.MlflowClient", MockMlflowClient)
def testMlFlowLoggerCallbackConfig(self):
# Explicitly pass in all args.
logger = MLFlowLoggerCallback(
tracking_uri="test1",
registry_uri="test2",
experiment_name="test_exp")
self.assertEqual(logger.client.tracking_uri, "test1")
self.assertEqual(logger.client.registry_uri, "test2")
self.assertListEqual(logger.client.experiment_names,
["existing_experiment", "test_exp"])
self.assertEqual(logger.experiment_id, 1)
# Check if client recognizes already existing experiment.
logger = MLFlowLoggerCallback(experiment_name="existing_experiment")
self.assertListEqual(logger.client.experiment_names,
["existing_experiment"])
self.assertEqual(logger.experiment_id, 0)
# Pass in experiment name as env var.
clear_env_vars()
os.environ["MLFLOW_EXPERIMENT_NAME"] = "test_exp"
logger = MLFlowLoggerCallback()
self.assertListEqual(logger.client.experiment_names,
["existing_experiment", "test_exp"])
self.assertEqual(logger.experiment_id, 1)
# Pass in existing experiment name as env var.
clear_env_vars()
os.environ["MLFLOW_EXPERIMENT_NAME"] = "existing_experiment"
logger = MLFlowLoggerCallback()
self.assertListEqual(logger.client.experiment_names,
["existing_experiment"])
self.assertEqual(logger.experiment_id, 0)
# Pass in existing experiment id as env var.
clear_env_vars()
os.environ["MLFLOW_EXPERIMENT_ID"] = "0"
logger = MLFlowLoggerCallback()
self.assertListEqual(logger.client.experiment_names,
["existing_experiment"])
self.assertEqual(logger.experiment_id, "0")
# Pass in non existing experiment id as env var.
clear_env_vars()
os.environ["MLFLOW_EXPERIMENT_ID"] = "500"
with self.assertRaises(ValueError):
logger = MLFlowLoggerCallback()
# Experiment name env var should take precedence over id env var.
clear_env_vars()
os.environ["MLFLOW_EXPERIMENT_NAME"] = "test_exp"
os.environ["MLFLOW_EXPERIMENT_ID"] = "0"
logger = MLFlowLoggerCallback()
self.assertListEqual(logger.client.experiment_names,
["existing_experiment", "test_exp"])
self.assertEqual(logger.experiment_id, 1)
@patch("mlflow.tracking.MlflowClient", MockMlflowClient)
def testMlFlowLoggerLogging(self):
clear_env_vars()
trial_config = {"par1": 4, "par2": 9.}
trial = MockTrial(trial_config, "trial1", 0, "artifact")
logger = MLFlowLoggerCallback(
experiment_name="test1", save_artifact=True)
# Check if run is created.
logger.on_trial_start(iteration=0, trials=[], trial=trial)
# New run should be created for this trial with correct tag.
mock_run = logger.client.runs[1][0]
self.assertDictEqual(mock_run.tags, {"trial_name": "trial1"})
self.assertTupleEqual(mock_run.run_id, (1, 0))
self.assertTupleEqual(logger._trial_runs[trial], mock_run.run_id)
# Params should be logged.
self.assertListEqual(mock_run.params, [{"par1": 4}, {"par2": 9}])
# When same trial is started again, new run should not be created.
logger.on_trial_start(iteration=0, trials=[], trial=trial)
self.assertEqual(len(logger.client.runs[1]), 1)
# Check metrics are logged properly.
result = {"metric1": 0.8, "metric2": 1, "metric3": None}
logger.on_trial_result(0, [], trial, result)
mock_run = logger.client.runs[1][0]
# metric3 is not logged since it cannot be converted to float.
self.assertListEqual(mock_run.metrics, [{
"metric1": 0.8
}, {
"metric2": 1.0
}])
# Check that artifact is logged on termination.
logger.on_trial_complete(0, [], trial)
mock_run = logger.client.runs[1][0]
self.assertListEqual(mock_run.artifacts, ["artifact"])
self.assertTrue(mock_run.terminated)
self.assertEqual(mock_run.status, "FINISHED")
@patch("mlflow.tracking.MlflowClient", MockMlflowClient)
def testMlFlowLegacyLoggerConfig(self):
mlflow = MockMlflowClient()
with patch.dict("sys.modules", mlflow=mlflow):
clear_env_vars()
trial_config = {"par1": 4, "par2": 9.}
trial = MockTrial(trial_config, "trial1", 0, "artifact")
# No experiment_id is passed in config, should raise an error.
with self.assertRaises(ValueError):
logger = MLFlowLogger(trial_config, "/tmp", trial)
trial_config.update({
"logger_config": {
"mlflow_tracking_uri": "test_tracking_uri",
"mlflow_experiment_id": 0
}
})
trial = MockTrial(trial_config, "trial2", 1, "artifact")
logger = MLFlowLogger(trial_config, "/tmp", trial)
experiment_logger = logger._trial_experiment_logger
client = experiment_logger.client
self.assertEqual(client.tracking_uri, "test_tracking_uri")
# Check to make sure that a run was created on experiment_id 0.
self.assertEqual(len(client.runs[0]), 1)
mock_run = client.runs[0][0]
self.assertDictEqual(mock_run.tags, {"trial_name": "trial2"})
self.assertListEqual(mock_run.params, [{"par1": 4}, {"par2": 9}])
@patch("ray.tune.integration.mlflow._import_mlflow",
lambda: MockMlflowClient())
def testMlFlowMixinConfig(self):
clear_env_vars()
trial_config = {"par1": 4, "par2": 9.}
@mlflow_mixin
def train_fn(config):
return 1
train_fn.__mixins__ = (MLFlowTrainableMixin, )
# No MLFlow config passed in.
with self.assertRaises(ValueError):
wrapped = wrap_function(train_fn)(trial_config)
trial_config.update({"mlflow": {}})
# No tracking uri or experiment_id/name passed in.
with self.assertRaises(ValueError):
wrapped = wrap_function(train_fn)(trial_config)
# Invalid experiment-id
trial_config["mlflow"].update({"experiment_id": "500"})
# No tracking uri or experiment_id/name passed in.
with self.assertRaises(ValueError):
wrapped = wrap_function(train_fn)(trial_config)
trial_config["mlflow"].update({
"tracking_uri": "test_tracking_uri",
"experiment_name": "existing_experiment"
})
wrapped = wrap_function(train_fn)(trial_config)
client = wrapped._mlflow
self.assertEqual(client.tracking_uri, "test_tracking_uri")
self.assertTupleEqual(client.active_run.run_id, (0, 0))
with patch("ray.tune.integration.mlflow._import_mlflow",
lambda: client):
train_fn.__mixins__ = (MLFlowTrainableMixin, )
wrapped = wrap_function(train_fn)(trial_config)
client = wrapped._mlflow
self.assertTupleEqual(client.active_run.run_id, (0, 1))
# Set to experiment that does not already exist.
# New experiment should be created.
trial_config["mlflow"]["experiment_name"] = "new_experiment"
with self.assertRaises(ValueError):
wrapped = wrap_function(train_fn)(trial_config)
if __name__ == "__main__":
import pytest
import sys
sys.exit(pytest.main(["-v", __file__]))
+1 -1
View File
@@ -972,4 +972,4 @@ class SearchSpaceTest(unittest.TestCase):
if __name__ == "__main__":
import pytest
import sys
sys.exit(pytest.main(["-v", __file__]))
sys.exit(pytest.main(["-v", __file__] + sys.argv[1:]))
+1
View File
@@ -3,6 +3,7 @@ bayesian-optimization
ConfigSpace==0.4.10
dragonfly-opt
gluoncv
gorilla # Need this because bug in mlflow. Should be fixed in v1.12.2
gym[atari]
GPy
h5py