mirror of
https://github.com/wassname/ray.git
synced 2026-06-27 21:23:10 +08:00
[Tune] Rename MLFlow to MLflow (#13301)
This commit is contained in:
@@ -1,5 +1,5 @@
|
||||
#!/usr/bin/env python
|
||||
"""Examples using MLFlowLoggerCallback and mlflow_mixin.
|
||||
"""Examples using MLfowLoggerCallback and mlflow_mixin.
|
||||
"""
|
||||
import os
|
||||
import tempfile
|
||||
@@ -8,7 +8,7 @@ import time
|
||||
import mlflow
|
||||
|
||||
from ray import tune
|
||||
from ray.tune.integration.mlflow import MLFlowLoggerCallback, mlflow_mixin
|
||||
from ray.tune.integration.mlflow import MLflowLoggerCallback, mlflow_mixin
|
||||
|
||||
|
||||
def evaluation_fn(step, width, height):
|
||||
@@ -33,7 +33,7 @@ def tune_function(mlflow_tracking_uri, finish_fast=False):
|
||||
name="mlflow",
|
||||
num_samples=5,
|
||||
callbacks=[
|
||||
MLFlowLoggerCallback(
|
||||
MLflowLoggerCallback(
|
||||
tracking_uri=mlflow_tracking_uri,
|
||||
experiment_name="example",
|
||||
save_artifact=True)
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
"""An example showing how to use Pytorch Lightning training, Ray Tune
|
||||
HPO, and MLFlow autologging all together."""
|
||||
HPO, and MLflow autologging all together."""
|
||||
import os
|
||||
import tempfile
|
||||
|
||||
@@ -37,7 +37,7 @@ def tune_mnist(num_samples=10,
|
||||
# Download data
|
||||
MNISTDataModule(data_dir=data_dir).prepare_data()
|
||||
|
||||
# Set the MLFlow experiment, or create it if it does not exist.
|
||||
# Set the MLflow experiment, or create it if it does not exist.
|
||||
mlflow.set_tracking_uri(tracking_uri)
|
||||
mlflow.set_experiment("ptl_autologging_test")
|
||||
|
||||
|
||||
@@ -18,13 +18,13 @@ def _import_mlflow():
|
||||
return mlflow
|
||||
|
||||
|
||||
class MLFlowLoggerCallback(LoggerCallback):
|
||||
"""MLFlow Logger to automatically log Tune results and config to MLFlow.
|
||||
class MLflowLoggerCallback(LoggerCallback):
|
||||
"""MLflow Logger to automatically log Tune results and config to MLflow.
|
||||
|
||||
MLFlow (https://mlflow.org) Tracking is an open source library for
|
||||
MLflow (https://mlflow.org) Tracking is an open source library for
|
||||
recording and querying experiments. This Ray Tune ``LoggerCallback``
|
||||
sends information (config parameters, training results & metrics,
|
||||
and artifacts) to MLFlow for automatic experiment tracking.
|
||||
and artifacts) to MLflow for automatic experiment tracking.
|
||||
|
||||
Args:
|
||||
tracking_uri (str): The tracking URI for where to manage experiments
|
||||
@@ -49,7 +49,7 @@ class MLFlowLoggerCallback(LoggerCallback):
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
from ray.tune.integration.mlflow import MLFlowLoggerCallback
|
||||
from ray.tune.integration.mlflow import MLflowLoggerCallback
|
||||
tune.run(
|
||||
train_fn,
|
||||
config={
|
||||
@@ -57,7 +57,7 @@ class MLFlowLoggerCallback(LoggerCallback):
|
||||
"parameter_1": tune.choice([1, 2, 3]),
|
||||
"parameter_2": tune.choice([4, 5, 6]),
|
||||
},
|
||||
callbacks=[MLFlowLoggerCallback(
|
||||
callbacks=[MLflowLoggerCallback(
|
||||
experiment_name="experiment1",
|
||||
save_artifact=True)])
|
||||
|
||||
@@ -71,8 +71,8 @@ class MLFlowLoggerCallback(LoggerCallback):
|
||||
|
||||
mlflow = _import_mlflow()
|
||||
if mlflow is None:
|
||||
raise RuntimeError("MLFlow has not been installed. Please `pip "
|
||||
"install mlflow` to use the MLFlowLogger.")
|
||||
raise RuntimeError("MLflow has not been installed. Please `pip "
|
||||
"install mlflow` to use the MLflowLogger.")
|
||||
|
||||
from mlflow.tracking import MlflowClient
|
||||
self.client = MlflowClient(
|
||||
@@ -105,7 +105,7 @@ class MLFlowLoggerCallback(LoggerCallback):
|
||||
"set, and MLFLOW_EXPERIMENT_ID either "
|
||||
"is not set or does not exist. Please "
|
||||
"set one of these to use the "
|
||||
"MLFlowLoggerCallback.")
|
||||
"MLflowLoggerCallback.")
|
||||
|
||||
# At this point, experiment_id should be set.
|
||||
self.experiment_id = experiment_id
|
||||
@@ -154,14 +154,14 @@ class MLFlowLoggerCallback(LoggerCallback):
|
||||
self.client.set_terminated(run_id=run_id, status=status)
|
||||
|
||||
|
||||
class MLFlowLogger(Logger):
|
||||
"""MLFlow logger using the deprecated Logger API.
|
||||
class MLflowLogger(Logger):
|
||||
"""MLflow logger using the deprecated Logger API.
|
||||
|
||||
Requires the experiment configuration to have a MLFlow Experiment ID
|
||||
Requires the experiment configuration to have a MLflow Experiment ID
|
||||
or manually set the proper environment variables.
|
||||
"""
|
||||
|
||||
_experiment_logger_cls = MLFlowLoggerCallback
|
||||
_experiment_logger_cls = MLflowLoggerCallback
|
||||
|
||||
def _init(self):
|
||||
mlflow = _import_mlflow()
|
||||
@@ -200,10 +200,10 @@ class MLFlowLogger(Logger):
|
||||
def mlflow_mixin(func: Callable):
|
||||
"""mlflow_mixin
|
||||
|
||||
MLFlow (https://mlflow.org) Tracking is an open source library for
|
||||
MLflow (https://mlflow.org) Tracking is an open source library for
|
||||
recording and querying experiments. This Ray Tune Trainable mixin helps
|
||||
initialize the MLflow API for use with the ``Trainable`` class or the
|
||||
``@mlflow_mixin`` function API. This mixin automatically configures MLFlow
|
||||
``@mlflow_mixin`` function API. This mixin automatically configures MLflow
|
||||
and creates a run in the same process as each Tune trial. You can then
|
||||
use the mlflow API inside the your training function and it will
|
||||
automatically get reported to the correct run.
|
||||
@@ -287,25 +287,25 @@ def mlflow_mixin(func: Callable):
|
||||
})
|
||||
"""
|
||||
if _import_mlflow() is None:
|
||||
raise RuntimeError("MLFlow has not been installed. Please `pip "
|
||||
raise RuntimeError("MLflow has not been installed. Please `pip "
|
||||
"install mlflow` to use the mlflow_mixin.")
|
||||
if hasattr(func, "__mixins__"):
|
||||
func.__mixins__ = func.__mixins__ + (MLFlowTrainableMixin, )
|
||||
func.__mixins__ = func.__mixins__ + (MLflowTrainableMixin, )
|
||||
else:
|
||||
func.__mixins__ = (MLFlowTrainableMixin, )
|
||||
func.__mixins__ = (MLflowTrainableMixin, )
|
||||
return func
|
||||
|
||||
|
||||
class MLFlowTrainableMixin:
|
||||
class MLflowTrainableMixin:
|
||||
def __init__(self, config: Dict, *args, **kwargs):
|
||||
self._mlflow = _import_mlflow()
|
||||
|
||||
if not isinstance(self, Trainable):
|
||||
raise ValueError(
|
||||
"The `MLFlowTrainableMixin` can only be used as a mixin "
|
||||
"The `MLflowTrainableMixin` can only be used as a mixin "
|
||||
"for `tune.Trainable` classes. Please make sure your "
|
||||
"class inherits from both. For example: "
|
||||
"`class YourTrainable(MLFlowTrainableMixin)`.")
|
||||
"`class YourTrainable(MLflowTrainableMixin)`.")
|
||||
|
||||
super().__init__(config, *args, **kwargs)
|
||||
_config = config.copy()
|
||||
@@ -313,14 +313,14 @@ class MLFlowTrainableMixin:
|
||||
mlflow_config = _config.pop("mlflow").copy()
|
||||
except KeyError as e:
|
||||
raise ValueError(
|
||||
"MLFlow mixin specified but no configuration has been passed. "
|
||||
"MLflow mixin specified but no configuration has been passed. "
|
||||
"Make sure to include a `mlflow` key in your `config` dict "
|
||||
"containing at least a `tracking_uri` and either "
|
||||
"`experiment_name` or `experiment_id` specification.") from e
|
||||
|
||||
tracking_uri = mlflow_config.pop("tracking_uri", None)
|
||||
if tracking_uri is None:
|
||||
raise ValueError("MLFlow mixin specified but no "
|
||||
raise ValueError("MLflow mixin specified but no "
|
||||
"tracking_uri has been "
|
||||
"passed in. Make sure to include a `mlflow` "
|
||||
"key in your `config` dict containing at "
|
||||
@@ -338,7 +338,7 @@ class MLFlowTrainableMixin:
|
||||
experiment_name = mlflow_config.pop("experiment_name", None)
|
||||
if experiment_name is None:
|
||||
raise ValueError(
|
||||
"MLFlow mixin specified but no "
|
||||
"MLflow mixin specified but no "
|
||||
"experiment_name or experiment_id has been "
|
||||
"passed in. Make sure to include a `mlflow` "
|
||||
"key in your `config` dict containing at "
|
||||
@@ -351,7 +351,7 @@ class MLFlowTrainableMixin:
|
||||
else:
|
||||
raise ValueError("No experiment with the given "
|
||||
"name: {} or id: {} currently exists. Make "
|
||||
"sure to first start the MLFlow experiment "
|
||||
"sure to first start the MLflow experiment "
|
||||
"before calling tune.run.".format(
|
||||
experiment_name, experiment_id))
|
||||
|
||||
|
||||
@@ -704,10 +704,10 @@ class TBXLoggerCallback(LoggerCallback):
|
||||
|
||||
|
||||
# Maintain backwards compatibility.
|
||||
from ray.tune.integration.mlflow import MLFlowLogger as _MLFlowLogger # noqa: E402, E501
|
||||
MLFlowLogger = _MLFlowLogger
|
||||
from ray.tune.integration.mlflow import MLflowLogger as _MLflowLogger # noqa: E402, E501
|
||||
MLflowLogger = _MLflowLogger
|
||||
# The capital L is a typo, but needs to remain for backwards compatibility.
|
||||
MLFLowLogger = _MLFlowLogger
|
||||
MLFLowLogger = _MLflowLogger
|
||||
|
||||
|
||||
def pretty_print(result):
|
||||
|
||||
@@ -23,4 +23,4 @@ if __name__ == "__main__":
|
||||
}
|
||||
})
|
||||
assert "ray.rllib" not in sys.modules, "RLlib should not be imported"
|
||||
assert "mlflow" not in sys.modules, "MLFlow should not be imported"
|
||||
assert "mlflow" not in sys.modules, "MLflow should not be imported"
|
||||
|
||||
@@ -4,8 +4,8 @@ from collections import namedtuple
|
||||
from unittest.mock import patch
|
||||
|
||||
from ray.tune.function_runner import wrap_function
|
||||
from ray.tune.integration.mlflow import MLFlowLoggerCallback, MLFlowLogger, \
|
||||
mlflow_mixin, MLFlowTrainableMixin
|
||||
from ray.tune.integration.mlflow import MLflowLoggerCallback, MLflowLogger, \
|
||||
mlflow_mixin, MLflowTrainableMixin
|
||||
|
||||
|
||||
class MockTrial(
|
||||
@@ -121,11 +121,11 @@ def clear_env_vars():
|
||||
del os.environ["MLFLOW_EXPERIMENT_ID"]
|
||||
|
||||
|
||||
class MLFlowTest(unittest.TestCase):
|
||||
class MLflowTest(unittest.TestCase):
|
||||
@patch("mlflow.tracking.MlflowClient", MockMlflowClient)
|
||||
def testMlFlowLoggerCallbackConfig(self):
|
||||
# Explicitly pass in all args.
|
||||
logger = MLFlowLoggerCallback(
|
||||
logger = MLflowLoggerCallback(
|
||||
tracking_uri="test1",
|
||||
registry_uri="test2",
|
||||
experiment_name="test_exp")
|
||||
@@ -136,7 +136,7 @@ class MLFlowTest(unittest.TestCase):
|
||||
self.assertEqual(logger.experiment_id, 1)
|
||||
|
||||
# Check if client recognizes already existing experiment.
|
||||
logger = MLFlowLoggerCallback(experiment_name="existing_experiment")
|
||||
logger = MLflowLoggerCallback(experiment_name="existing_experiment")
|
||||
self.assertListEqual(logger.client.experiment_names,
|
||||
["existing_experiment"])
|
||||
self.assertEqual(logger.experiment_id, 0)
|
||||
@@ -144,7 +144,7 @@ class MLFlowTest(unittest.TestCase):
|
||||
# Pass in experiment name as env var.
|
||||
clear_env_vars()
|
||||
os.environ["MLFLOW_EXPERIMENT_NAME"] = "test_exp"
|
||||
logger = MLFlowLoggerCallback()
|
||||
logger = MLflowLoggerCallback()
|
||||
self.assertListEqual(logger.client.experiment_names,
|
||||
["existing_experiment", "test_exp"])
|
||||
self.assertEqual(logger.experiment_id, 1)
|
||||
@@ -152,7 +152,7 @@ class MLFlowTest(unittest.TestCase):
|
||||
# Pass in existing experiment name as env var.
|
||||
clear_env_vars()
|
||||
os.environ["MLFLOW_EXPERIMENT_NAME"] = "existing_experiment"
|
||||
logger = MLFlowLoggerCallback()
|
||||
logger = MLflowLoggerCallback()
|
||||
self.assertListEqual(logger.client.experiment_names,
|
||||
["existing_experiment"])
|
||||
self.assertEqual(logger.experiment_id, 0)
|
||||
@@ -160,7 +160,7 @@ class MLFlowTest(unittest.TestCase):
|
||||
# Pass in existing experiment id as env var.
|
||||
clear_env_vars()
|
||||
os.environ["MLFLOW_EXPERIMENT_ID"] = "0"
|
||||
logger = MLFlowLoggerCallback()
|
||||
logger = MLflowLoggerCallback()
|
||||
self.assertListEqual(logger.client.experiment_names,
|
||||
["existing_experiment"])
|
||||
self.assertEqual(logger.experiment_id, "0")
|
||||
@@ -169,13 +169,13 @@ class MLFlowTest(unittest.TestCase):
|
||||
clear_env_vars()
|
||||
os.environ["MLFLOW_EXPERIMENT_ID"] = "500"
|
||||
with self.assertRaises(ValueError):
|
||||
logger = MLFlowLoggerCallback()
|
||||
logger = MLflowLoggerCallback()
|
||||
|
||||
# Experiment name env var should take precedence over id env var.
|
||||
clear_env_vars()
|
||||
os.environ["MLFLOW_EXPERIMENT_NAME"] = "test_exp"
|
||||
os.environ["MLFLOW_EXPERIMENT_ID"] = "0"
|
||||
logger = MLFlowLoggerCallback()
|
||||
logger = MLflowLoggerCallback()
|
||||
self.assertListEqual(logger.client.experiment_names,
|
||||
["existing_experiment", "test_exp"])
|
||||
self.assertEqual(logger.experiment_id, 1)
|
||||
@@ -186,7 +186,7 @@ class MLFlowTest(unittest.TestCase):
|
||||
trial_config = {"par1": 4, "par2": 9.}
|
||||
trial = MockTrial(trial_config, "trial1", 0, "artifact")
|
||||
|
||||
logger = MLFlowLoggerCallback(
|
||||
logger = MLflowLoggerCallback(
|
||||
experiment_name="test1", save_artifact=True)
|
||||
|
||||
# Check if run is created.
|
||||
@@ -231,7 +231,7 @@ class MLFlowTest(unittest.TestCase):
|
||||
|
||||
# No experiment_id is passed in config, should raise an error.
|
||||
with self.assertRaises(ValueError):
|
||||
logger = MLFlowLogger(trial_config, "/tmp", trial)
|
||||
logger = MLflowLogger(trial_config, "/tmp", trial)
|
||||
|
||||
trial_config.update({
|
||||
"logger_config": {
|
||||
@@ -240,7 +240,7 @@ class MLFlowTest(unittest.TestCase):
|
||||
}
|
||||
})
|
||||
trial = MockTrial(trial_config, "trial2", 1, "artifact")
|
||||
logger = MLFlowLogger(trial_config, "/tmp", trial)
|
||||
logger = MLflowLogger(trial_config, "/tmp", trial)
|
||||
experiment_logger = logger._trial_experiment_logger
|
||||
client = experiment_logger.client
|
||||
self.assertEqual(client.tracking_uri, "test_tracking_uri")
|
||||
@@ -260,9 +260,9 @@ class MLFlowTest(unittest.TestCase):
|
||||
def train_fn(config):
|
||||
return 1
|
||||
|
||||
train_fn.__mixins__ = (MLFlowTrainableMixin, )
|
||||
train_fn.__mixins__ = (MLflowTrainableMixin, )
|
||||
|
||||
# No MLFlow config passed in.
|
||||
# No MLflow config passed in.
|
||||
with self.assertRaises(ValueError):
|
||||
wrapped = wrap_function(train_fn)(trial_config)
|
||||
|
||||
@@ -288,7 +288,7 @@ class MLFlowTest(unittest.TestCase):
|
||||
|
||||
with patch("ray.tune.integration.mlflow._import_mlflow",
|
||||
lambda: client):
|
||||
train_fn.__mixins__ = (MLFlowTrainableMixin, )
|
||||
train_fn.__mixins__ = (MLflowTrainableMixin, )
|
||||
wrapped = wrap_function(train_fn)(trial_config)
|
||||
client = wrapped._mlflow
|
||||
self.assertTupleEqual(client.active_run.run_id, (0, 1))
|
||||
|
||||
Reference in New Issue
Block a user