diff --git a/doc/source/tune/api_docs/integration.rst b/doc/source/tune/api_docs/integration.rst index 7eebe4fcb..54f8c9e30 100644 --- a/doc/source/tune/api_docs/integration.rst +++ b/doc/source/tune/api_docs/integration.rst @@ -32,6 +32,18 @@ Kubernetes (tune.integration.kubernetes) .. autofunction:: ray.tune.integration.kubernetes.NamespacedKubernetesSyncer +.. _tune-integration-mlflow: + +MLflow (tune.integration.mlflow) +-------------------------------- + +:ref:`See also here `. + +.. autoclass:: ray.tune.integration.mlflow.MLflowLoggerCallback + +.. autofunction:: ray.tune.integration.mlflow.mlflow_mixin + + .. _tune-integration-mxnet: MXNet (tune.integration.mxnet) diff --git a/python/ray/tune/integration/mlflow.py b/python/ray/tune/integration/mlflow.py index 77467b999..cbd3811d4 100644 --- a/python/ray/tune/integration/mlflow.py +++ b/python/ray/tune/integration/mlflow.py @@ -253,6 +253,11 @@ def mlflow_mixin(func: Callable): experiment. All logs from all trials in ``tune.run`` will be reported to this experiment. If this is not provided, you must provide a valid ``experiment_id``. + token (optional, str): A token to use for HTTP authentication when + logging to a remote tracking server. This is useful when you + want to log to a Databricks server, for example. This value will + be used to set the MLFLOW_TRACKING_TOKEN environment variable on + all the remote training processes. Example: @@ -327,6 +332,11 @@ class MLflowTrainableMixin: "least a `tracking_uri`") self._mlflow.set_tracking_uri(tracking_uri) + # Set the tracking token if one is passed in. + tracking_token = mlflow_config.pop("token", None) + if tracking_token is not None: + os.environ["MLFLOW_TRACKING_TOKEN"] = tracking_token + # First see if experiment_id is passed in. experiment_id = mlflow_config.pop("experiment_id", None) if experiment_id is None or self._mlflow.get_experiment(