From 5d3c9c8861ee658b7931bb374603f86987c17433 Mon Sep 17 00:00:00 2001 From: Amog Kamsetty Date: Sat, 19 Dec 2020 00:40:02 -0800 Subject: [PATCH] [Tune] Mlflow Integration (#12840) Co-authored-by: Kai Fricke Co-authored-by: Richard Liaw --- doc/source/conf.py | 1 + doc/source/images/mlflow.png | Bin 0 -> 3851 bytes doc/source/tune/_tutorials/overview.rst | 12 +- doc/source/tune/_tutorials/tune-mlflow.rst | 47 +++ doc/source/tune/api_docs/logging.rst | 17 +- doc/source/tune/examples/index.rst | 1 + .../tune/examples/mlflow_ptl_example.rst | 6 + python/ray/tune/BUILD | 32 +- python/ray/tune/examples/mlflow_example.py | 96 ++++- python/ray/tune/examples/mlflow_ptl.py | 93 +++++ python/ray/tune/function_runner.py | 9 +- python/ray/tune/integration/mlflow.py | 366 ++++++++++++++++++ python/ray/tune/integration/wandb.py | 5 +- python/ray/tune/logger.py | 38 +- python/ray/tune/tests/test_dependency.py | 1 + .../ray/tune/tests/test_integration_mlflow.py | 306 +++++++++++++++ python/ray/tune/tests/test_sample.py | 2 +- python/requirements_tune.txt | 1 + 18 files changed, 958 insertions(+), 75 deletions(-) create mode 100644 doc/source/images/mlflow.png create mode 100644 doc/source/tune/_tutorials/tune-mlflow.rst create mode 100644 doc/source/tune/examples/mlflow_ptl_example.rst create mode 100644 python/ray/tune/examples/mlflow_ptl.py create mode 100644 python/ray/tune/integration/mlflow.py create mode 100644 python/ray/tune/tests/test_integration_mlflow.py diff --git a/doc/source/conf.py b/doc/source/conf.py index c69f73760..13f075b29 100644 --- a/doc/source/conf.py +++ b/doc/source/conf.py @@ -41,6 +41,7 @@ MOCK_MODULES = [ "horovod", "horovod.ray", "kubernetes", + "mlflow", "mxnet", "mxnet.model", "psutil", diff --git a/doc/source/images/mlflow.png b/doc/source/images/mlflow.png new file mode 100644 index 0000000000000000000000000000000000000000..03b96d5daf39ad485a8d4cd614abbc9af21b8518 GIT binary patch literal 3851 zcmV+m5A^VfP)gW07J;^0FK}Qkl+A}-}d(Q0E^$~=;%SH=Y)iW+}zxKeSH9w;p6oDsn_$Vsj2w= z{&l_V?fCuj`u;J+>$KYR)9U%m=lG@5^7Q=v@$vCDH#cIl=}xZbIi%)|#O{%ikwryC zz~1#)&h8_g0;Wv z5(R3zE=?h230s>$LjaSuX%gCWyY2hG)nyBwGa5(7HD42 z&dW)qSga(Y^7)`u-S>nJ6xG=%j>53GMi_=sT$${9RtJdoc@%7D-k?(3_q+}c)$>T@ zc`n%(`<~drp<9uNE*69CzGrq|sEs3$=($w4_dU16LbI%6x>(uw+zt!PzN|bJLVdpP z$sHETTBg0K-Y(7qEz`xwd(~USX{=}3tLkl{A$rqM92b*H9FO-sy+gveLUkA=gJEO& z;n2H!+e1fj^P-E+z9)5AbOHinj#=d8C zuozAt6P zz-W%Vr}vLA0sq4mzKr4&zmEK|yev^<_`)>}V8XI_aFco~mEKln{A}S&HCn~Gd2lDp z{#Xtk&Q!zlZDd)n?UGmE=q;!#)eXos1(pR{x>^&owlm{r3s@AITYbrsE7fQfE$WwgE5U3#6MnvMq#CVaXt9d)x%jx! zk?J<2^(Vpnu+9!;N1 zbh@n>KU)~79+x?afWGp1nO5J>=VG&9eE-Iggqx9QV2w;xhDw?S_6;aEe zDw!5h`t5aQsweJL)1YSH(|Q0C)kEzTF1=I}JwelYkuOy>Qr$J|chN9SJZ)HntQ>n< zuVAKnpx?x$qnhXmx(FHe|LKp)CrnhU+c0d1Nur}gt>V$A^$KRHN6u7Nh*wS1t+s`+ zqkxHOcU4VYI-1t=o6ZYHsvEx|1QYT|K}>Sekk<3n?gb;&wJO6YhUp{>OIA`Y5<(s#!#1SXAZ!SNrM%jijA@JQL0uRzcH+js0sC$C{Z>Ej3ac`V>UHFzZ9Ap2lL08@Lon^XK^gZUp< zHzu~*-BSY1g&W=09K&I!(`mQcn*C>;)A^X|#i03^MuGBhyBzT4eM!MBy~4?NbwKG2 ze;ki+p&Cm-%<*3jswj17z240?LUT&@-DmBETQM~beR514^I1$oF%gx(O)}OIEZ>l}(}<^zcF;dd2qLZ_Rf)V4O#yexGF7 ztz>% zm&U5lcXZKK1F8O^<*`)yvH1GSd&7t~(hVJ`MyqJairVWEZVWsx`=$+ZL3AS<@A4!XA>lCJElb~mf5LpfK{{I)aEI5O2Gd`70vt{#^9;_vFOpJZe^i4^w%x(TTBn5y|mRW4xT$l4VZ z*!sRUZtwnR6~>2kM04_ie{#uhnw@Hxy=ASiNhOc#njRziF*009HH{OyQCJMP`e_@} z15IJ+IaEr!sn`VT_eNpyo~ASYh(v!UT0#kFcz~XnYP5>6IVm(W)g^$NN1VF5g)>dA z+q_ubl3Ttiu*h}8SX$2D-d(P&>yuJSC#Gv;r5ci=a_S)k-3G=D7h(F& zc#+g=V~aST$|C9ul4v`a5qp99t}dr%fCw=S`UZvwSq^p)_E}b{u{V}30WJI(R#;7Z z>NB$!Q>$oh2fN&=XsMu~0SrqB zrTPP674epld<|F}pvr}01ELuWqX{n>d# z8L^4BGQQsCmKc6lAqkRNjV4Jwj+#*3O`cUlBhEkzFK?{K7ATxi(_)wo(0oA)q7T>?hV zsjAlJ{x?XiD|#_Bah?i1S)K`(i4)?Eq^}`Oq@x(*g=(~lBki~)sIp;h;X_Kbeo_1p zX)XY@P+5#c>%{$7iaIgQskb*HZyX49Zm1rBRkZZh1G9y0VVXWYvv&theUKQ^dk}&= zzBS~%TdzfsZ0klT*T@^ebX#7icCm_8Y89DVNCFsp*;sAk$;4nnXVPjs$!n}r0n`O~ z%3|p@v5sZ@p zfxr?nr7I0Z*iSt}>Vm9%LW5A}g=*LpW3h@gcr5Hg;(^#JPKjB>|HsN5NJyS0r^npe zYZ-((FI2l&MViP|M2ApqTozvlbSN|pNn_69#52|ZeQDnA? zU3~yUv(?rV2)PnpiGU)GD$XxT-G*Ywp&3KgsHt!OD zVsMK##3j1ze2e(8ZxZU9Pz~#fwm!=A4NzBmuhMjrEnIAD?#jDi3@{vRQkRG33{XC( zW)t4bd{8}gu!@i^QS;iQ+Qz&xK(n1Od$+r0J3@5x+}kS~g*qoxqg6ETL9{e{ zBTI{C`VzU>b`(y|&GzDXk< zV{<0ct&oOyMfHIxy3Ndl;uL9|Zo^a_sCJ1JJLUjp4e((ei+hX7^t_+2 zR(^q8@~R%rEB72}tK+n?k_W2623h*mK$V4M!~gmlDY-(+M9opK4RvpTC<8ThoOTyr z4ybmriWimuhNDq7wbBW-JtqaJ77c^qg6I`s4yblH-Uvf@`#Q@|kz;tQeQjxvLSZ1c z98tVN{npkTP~CE{iV(oq53Q>k+fzmcv%vQ@>(T|hf1P1nGzV0>#EPc61O?T^s60Yb z8)m+{NM{#LnWi!9*UkaeXcf(EsGSWc$D!Vx>|c$PJ$p6b^?AR41+3v|gW1yPaybW_z`uSlOW1*AJYA*~S$LuPE$85X2x+jfRk19cnYd6#~XP(1(1e3`ZHH|3a6uRJ)Qz|mMq5` zc{~A44CcI5+#N7olN|e*tGfhv)!r)Z5ZIr`A6c|RU|QJHoYUh8;0wD)Qz99#29ij^YI(sk$rjk>VCHHPp|AmUc}ZPLTKVtf}@VzHbo|1JgeHu~it( zcQ%x?c>D(_b+jAh5$2})<1hdlhJfiLL-oIh1FkijgoAUR`q(nWqbI^>bI||* literal 0 HcmV?d00001 diff --git a/doc/source/tune/_tutorials/overview.rst b/doc/source/tune/_tutorials/overview.rst index dfbd986b0..0517c2f0a 100644 --- a/doc/source/tune/_tutorials/overview.rst +++ b/doc/source/tune/_tutorials/overview.rst @@ -70,6 +70,11 @@ Take a look at any of the below tutorials to get started with Tune. :figure: /images/wandb_logo.png :description: :doc:`Track your experiment process with the Weights & Biases tools ` +.. customgalleryitem:: + :tooltip: Use MLFlow with Ray Tune. + :figure: /images/mlflow.png + :description: :doc:`Log and track your hyperparameter sweep with MLFlow Tracking & AutoLogging ` + .. raw:: html @@ -81,12 +86,13 @@ Take a look at any of the below tutorials to get started with Tune. tune-tutorial.rst tune-advanced-tutorial.rst - tune-lifecycle.rst tune-distributed.rst - tune-sklearn.rst + tune-lifecycle.rst + tune-mlflow.rst tune-pytorch-cifar.rst tune-pytorch-lightning.rst tune-serve-integration-mnist.rst + tune-sklearn.rst tune-xgboost.rst tune-wandb.rst @@ -156,4 +162,4 @@ Check out: .. _tune-faq: -.. include:: _faq.rst \ No newline at end of file +.. include:: _faq.rst diff --git a/doc/source/tune/_tutorials/tune-mlflow.rst b/doc/source/tune/_tutorials/tune-mlflow.rst new file mode 100644 index 000000000..6b5519e3f --- /dev/null +++ b/doc/source/tune/_tutorials/tune-mlflow.rst @@ -0,0 +1,47 @@ +.. _tune-mlflow: + +Using MLFlow with Tune +====================== + +`MLFlow `_ is an open source platform to manage the ML lifecycle, including experimentation, +reproducibility, deployment, and a central model registry. It currently offers four components, including +MLFlow Tracking to record and query experiments, including code, data, config, and results. + +.. image:: /images/mlflow.png + :height: 80px + :alt: MLflow + :align: center + :target: https://www.mlflow.org/ + +Ray Tune currently offers two lightweight integrations for MLFlow Tracking. +One is the :ref:`MLFlowLoggerCallback `, which automatically logs +metrics reported to Tune to the MLFlow Tracking API. + +The other one is the :ref:`@mlflow_mixin ` decorator, which can be +used with the function API. It automatically +initializes the MLFlow API with Tune's training information and creates a run for each Tune trial. +Then within your training function, you can just use the +MLFlow like you would normally do, e.g. using ``mlflow.log_metrics()`` or even ``mlflow.autolog()`` +to log to your training process. + +Please :doc:`see here ` for a full example on how you can use either the +MLFlowLoggerCallback or the mlflow_mixin. + +MLFlow AutoLogging +------------------ +You can also check out :doc:`here ` for an example on how you can leverage MLflow +autologging, in this case with Pytorch Lightning + +MLFlow Logger API +----------------- +.. _tune-mlflow-logger: + +.. autoclass:: ray.tune.integration.mlflow.MLFlowLoggerCallback + :noindex: + +MLFlow Mixin API +---------------- +.. _tune-mlflow-mixin: + +.. autofunction:: ray.tune.integration.mlflow.mlflow_mixin + :noindex: diff --git a/doc/source/tune/api_docs/logging.rst b/doc/source/tune/api_docs/logging.rst index 240ec97af..7e8784829 100644 --- a/doc/source/tune/api_docs/logging.rst +++ b/doc/source/tune/api_docs/logging.rst @@ -70,11 +70,7 @@ An example of creating a custom logger can be found in :doc:`/tune/examples/logg Trainable Logging ----------------- -By default, Tune only logs the *training result dictionaries* from your Trainable. However, you may want to visualize the model weights, model graph, or use a custom logging library that requires multi-process logging. For example, you may want to do this if: - - * you're using `Weights and Biases `_ - * you're using `MLFlow `__ - * you're trying to log images to Tensorboard. +By default, Tune only logs the *training result dictionaries* from your Trainable. However, you may want to visualize the model weights, model graph, or use a custom logging library that requires multi-process logging. For example, you may want to do this if you're trying to log images to Tensorboard. You can do this in the trainable, as shown below: @@ -163,12 +159,17 @@ CSVLogger .. autoclass:: ray.tune.logger.CSVLoggerCallback -MLFLowLogger +MLFlowLogger ------------ -Tune also provides a default logger for `MLFlow `_. You can install MLFlow via ``pip install mlflow``. An example can be found in :doc:`/tune/examples/mlflow_example`. Note that this currently does not include artifact logging support. For this, you can use the native MLFlow APIs inside your Trainable definition. +Tune also provides a default logger for `MLFlow `_. You can install MLFlow via ``pip install mlflow``. +You can see the :doc:`tutorial here `. -.. autoclass:: ray.tune.logger.MLFLowLogger +WandbLogger +----------- + +Tune also provides a default logger for `Weights & Biases `_. You can install Wandb via ``pip install wandb``. +You can see the :doc:`tutorial here ` .. _logger-interface: diff --git a/doc/source/tune/examples/index.rst b/doc/source/tune/examples/index.rst index 54852d550..4f3594cf6 100644 --- a/doc/source/tune/examples/index.rst +++ b/doc/source/tune/examples/index.rst @@ -88,6 +88,7 @@ Wandb, MLFlow - :ref:`Tutorial ` for using `wandb `__ with Ray Tune - :doc:`/tune/examples/wandb_example`: Example for using `Weights and Biases `__ with Ray Tune. - :doc:`/tune/examples/mlflow_example`: Example for using `MLFlow `__ with Ray Tune. +- :doc:`/tune/examples/mlflow_ptl_example`: Example for using `MLFlow `__ and `Pytorch Lightning `_ with Ray Tune. Tensorflow/Keras ~~~~~~~~~~~~~~~~ diff --git a/doc/source/tune/examples/mlflow_ptl_example.rst b/doc/source/tune/examples/mlflow_ptl_example.rst new file mode 100644 index 000000000..73c07499b --- /dev/null +++ b/doc/source/tune/examples/mlflow_ptl_example.rst @@ -0,0 +1,6 @@ +:orphan: + +mlflow_ptl_example +~~~~~~~~~~~~~~~~~~ + +.. literalinclude:: /../../python/ray/tune/examples/mlflow_ptl.py diff --git a/python/ray/tune/BUILD b/python/ray/tune/BUILD index 70fd218d5..3b5757eb3 100644 --- a/python/ray/tune/BUILD +++ b/python/ray/tune/BUILD @@ -125,6 +125,14 @@ py_test( tags = ["exclusive"], ) +py_test( + name = "test_integration_mlflow", + size = "small", + srcs = ["tests/test_integration_mlflow.py"], + deps = [":tune_lib"], + tags = ["exclusive"] +) + py_test( name = "test_logger", size = "small", @@ -473,15 +481,23 @@ py_test( args = ["--smoke-test"] ) -# Commenting out for now because it is not idempotent -# py_test( -# name = "mlflow_example", -# size = "medium", -# srcs = ["examples/mlflow_example.py"], -# deps = [":tune_lib"], -# tags = ["exclusive", "example"] -# ) +py_test( + name = "mlflow_example", + size = "medium", + srcs = ["examples/mlflow_example.py"], + deps = [":tune_lib"], + tags = ["exclusive", "example"] +) +# Comment out for now until we sort out our dependencies. +#py_test( +# name = "mlflow_ptl", +# size = "medium", +# srcs = ["examples/mlflow_ptl.py"], +# deps = [":tune_lib"], +# tags = ["exclusive", "example", "py37", "pytorch"], +# args = ["--smoke-test"] +#) py_test( name = "mnist_pytorch", size = "small", diff --git a/python/ray/tune/examples/mlflow_example.py b/python/ray/tune/examples/mlflow_example.py index 875c7837b..e0f290b29 100644 --- a/python/ray/tune/examples/mlflow_example.py +++ b/python/ray/tune/examples/mlflow_example.py @@ -1,17 +1,14 @@ #!/usr/bin/env python -"""Simple MLFLow Logger example. - -This uses a simple MLFlow logger. One limitation of this is that there is -no artifact support; to save artifacts with Tune and MLFlow, you will need to -start a MLFlow run inside the Trainable function/class. - +"""Examples using MLFlowLoggerCallback and mlflow_mixin. """ -import mlflow -from mlflow.tracking import MlflowClient +import os +import tempfile import time +import mlflow + from ray import tune -from ray.tune.logger import MLFLowLogger, DEFAULT_LOGGERS +from ray.tune.integration.mlflow import MLFlowLoggerCallback, mlflow_mixin def evaluation_fn(step, width, height): @@ -25,27 +22,84 @@ def easy_objective(config): for step in range(config.get("steps", 100)): # Iterative training function - can be any arbitrary training procedure intermediate_score = evaluation_fn(step, width, height) - # Feed the score back back to Tune. + # Feed the score back to Tune. tune.report(iterations=step, mean_loss=intermediate_score) time.sleep(0.1) -if __name__ == "__main__": - client = MlflowClient() - experiment_id = client.create_experiment("test") - - trials = tune.run( +def tune_function(mlflow_tracking_uri, finish_fast=False): + tune.run( easy_objective, name="mlflow", num_samples=5, - loggers=DEFAULT_LOGGERS + (MLFLowLogger, ), + callbacks=[ + MLFlowLoggerCallback( + tracking_uri=mlflow_tracking_uri, + experiment_name="example", + save_artifact=True) + ], config={ - "logger_config": { - "mlflow_experiment_id": experiment_id, - }, "width": tune.randint(10, 100), "height": tune.randint(0, 100), + "steps": 5 if finish_fast else 100, }) - df = mlflow.search_runs([experiment_id]) - print(df) + +@mlflow_mixin +def decorated_easy_objective(config): + # Hyperparameters + width, height = config["width"], config["height"] + + for step in range(config.get("steps", 100)): + # Iterative training function - can be any arbitrary training procedure + intermediate_score = evaluation_fn(step, width, height) + # Log the metrics to mlflow + mlflow.log_metrics(dict(mean_loss=intermediate_score), step=step) + # Feed the score back to Tune. + tune.report(iterations=step, mean_loss=intermediate_score) + time.sleep(0.1) + + +def tune_decorated(mlflow_tracking_uri, finish_fast=False): + # Set the experiment, or create a new one if does not exist yet. + mlflow.set_tracking_uri(mlflow_tracking_uri) + mlflow.set_experiment(experiment_name="mixin_example") + tune.run( + decorated_easy_objective, + name="mlflow", + num_samples=5, + config={ + "width": tune.randint(10, 100), + "height": tune.randint(0, 100), + "steps": 5 if finish_fast else 100, + "mlflow": { + "experiment_name": "mixin_example", + "tracking_uri": mlflow.get_tracking_uri() + } + }) + + +if __name__ == "__main__": + import argparse + + parser = argparse.ArgumentParser() + parser.add_argument( + "--smoke-test", action="store_true", help="Finish quickly for testing") + args, _ = parser.parse_known_args() + + if args.smoke_test: + mlflow_tracking_uri = os.path.join(tempfile.gettempdir(), "mlruns") + else: + mlflow_tracking_uri = None + + tune_function(mlflow_tracking_uri, finish_fast=args.smoke_test) + if not args.smoke_test: + df = mlflow.search_runs( + [mlflow.get_experiment_by_name("example").experiment_id]) + print(df) + + tune_decorated(mlflow_tracking_uri, finish_fast=args.smoke_test) + if not args.smoke_test: + df = mlflow.search_runs( + [mlflow.get_experiment_by_name("mixin_example").experiment_id]) + print(df) diff --git a/python/ray/tune/examples/mlflow_ptl.py b/python/ray/tune/examples/mlflow_ptl.py new file mode 100644 index 000000000..5957a7e34 --- /dev/null +++ b/python/ray/tune/examples/mlflow_ptl.py @@ -0,0 +1,93 @@ +"""An example showing how to use Pytorch Lightning training, Ray Tune +HPO, and MLFlow autologging all together.""" +import os +import tempfile + +import pytorch_lightning as pl +from pl_bolts.datamodules import MNISTDataModule + +import mlflow + +from ray import tune +from ray.tune.integration.mlflow import mlflow_mixin +from ray.tune.integration.pytorch_lightning import TuneReportCallback +from ray.tune.examples.mnist_ptl_mini import LightningMNISTClassifier + + +@mlflow_mixin +def train_mnist_tune(config, data_dir=None, num_epochs=10, num_gpus=0): + model = LightningMNISTClassifier(config, data_dir) + dm = MNISTDataModule( + data_dir=data_dir, num_workers=1, batch_size=config["batch_size"]) + metrics = {"loss": "ptl/val_loss", "acc": "ptl/val_accuracy"} + mlflow.pytorch.autolog() + trainer = pl.Trainer( + max_epochs=num_epochs, + gpus=num_gpus, + progress_bar_refresh_rate=0, + callbacks=[TuneReportCallback(metrics, on="validation_end")]) + trainer.fit(model, dm) + + +def tune_mnist(num_samples=10, + num_epochs=10, + gpus_per_trial=0, + tracking_uri=None): + data_dir = os.path.join(tempfile.gettempdir(), "mnist_data_") + # Download data + MNISTDataModule(data_dir=data_dir).prepare_data() + + # Set the MLFlow experiment, or create it if it does not exist. + mlflow.set_tracking_uri(tracking_uri) + mlflow.set_experiment("ptl_autologging_test") + + config = { + "layer_1": tune.choice([32, 64, 128]), + "layer_2": tune.choice([64, 128, 256]), + "lr": tune.loguniform(1e-4, 1e-1), + "batch_size": tune.choice([32, 64, 128]), + "mlflow": { + "experiment_name": "ptl_autologging_test", + "tracking_uri": mlflow.get_tracking_uri() + }, + "data_dir": os.path.join(tempfile.gettempdir(), "mnist_data_"), + "num_epochs": num_epochs + } + + trainable = tune.with_parameters( + train_mnist_tune, + data_dir=data_dir, + num_epochs=num_epochs, + num_gpus=gpus_per_trial) + + analysis = tune.run( + trainable, + resources_per_trial={ + "cpu": 1, + "gpu": gpus_per_trial + }, + metric="loss", + mode="min", + config=config, + num_samples=num_samples, + name="tune_mnist") + + print("Best hyperparameters found were: ", analysis.best_config) + + +if __name__ == "__main__": + import argparse + + parser = argparse.ArgumentParser() + parser.add_argument( + "--smoke-test", action="store_true", help="Finish quickly for testing") + args, _ = parser.parse_known_args() + + if args.smoke_test: + tune_mnist( + num_samples=1, + num_epochs=1, + gpus_per_trial=0, + tracking_uri=os.path.join(tempfile.gettempdir(), "mlruns")) + else: + tune_mnist(num_samples=10, num_epochs=10, gpus_per_trial=0) diff --git a/python/ray/tune/function_runner.py b/python/ray/tune/function_runner.py index 79e0f5da9..9da6b2601 100644 --- a/python/ray/tune/function_runner.py +++ b/python/ray/tune/function_runner.py @@ -509,8 +509,9 @@ class FunctionRunner(Trainable): try: err_tb_str = self._error_queue.get( block=block, timeout=ERROR_FETCH_TIMEOUT) - raise TuneError(("Trial raised an exception. Traceback:\n{}" - .format(err_tb_str))) + raise TuneError( + ("Trial raised an exception. Traceback:\n{}".format(err_tb_str) + )) except queue.Empty: pass @@ -649,6 +650,10 @@ def with_parameters(fn, **kwargs): def _inner(config): inner(config, checkpoint_dir=None) + if hasattr(fn, "__mixins__"): + _inner.__mixins__ = fn.__mixins__ return _inner + if hasattr(fn, "__mixins__"): + inner.__mixins__ = fn.__mixins__ return inner diff --git a/python/ray/tune/integration/mlflow.py b/python/ray/tune/integration/mlflow.py new file mode 100644 index 000000000..1d1e01d48 --- /dev/null +++ b/python/ray/tune/integration/mlflow.py @@ -0,0 +1,366 @@ +import os +from typing import Dict, Callable, Optional +import logging + +from ray.tune.trainable import Trainable +from ray.tune.logger import Logger, LoggerCallback +from ray.tune.result import TRAINING_ITERATION +from ray.tune.trial import Trial + +logger = logging.getLogger(__name__) + + +def _import_mlflow(): + try: + import mlflow + except ImportError: + mlflow = None + return mlflow + + +class MLFlowLoggerCallback(LoggerCallback): + """MLFlow Logger to automatically log Tune results and config to MLFlow. + + MLFlow (https://mlflow.org) Tracking is an open source library for + recording and querying experiments. This Ray Tune ``LoggerCallback`` + sends information (config parameters, training results & metrics, + and artifacts) to MLFlow for automatic experiment tracking. + + Args: + tracking_uri (str): The tracking URI for where to manage experiments + and runs. This can either be a local file path or a remote server. + This arg gets passed directly to mlflow.tracking.MlflowClient + initialization. When using Tune in a multi-node setting, make sure + to set this to a remote server and not a local file path. + registry_uri (str): The registry URI that gets passed directly to + mlflow.tracking.MlflowClient initialization. + experiment_name (str): The experiment name to use for this Tune run. + If None is passed in here, the Logger will automatically then + check the MLFLOW_EXPERIMENT_NAME and then the MLFLOW_EXPERIMENT_ID + environment variables to determine the experiment name. + If the experiment with the name already exists with MlFlow, + it will be reused. If not, a new experiment will be created with + that name. + save_artifact (bool): If set to True, automatically save the entire + contents of the Tune local_dir as an artifact to the + corresponding run in MlFlow. + + Example: + + .. code-block:: python + + from ray.tune.integration.mlflow import MLFlowLoggerCallback + tune.run( + train_fn, + config={ + # define search space here + "parameter_1": tune.choice([1, 2, 3]), + "parameter_2": tune.choice([4, 5, 6]), + }, + callbacks=[MLFlowLoggerCallback( + experiment_name="experiment1", + save_artifact=True)]) + + """ + + def __init__(self, + tracking_uri: Optional[str] = None, + registry_uri: Optional[str] = None, + experiment_name: Optional[str] = None, + save_artifact: bool = False): + + mlflow = _import_mlflow() + if mlflow is None: + raise RuntimeError("MLFlow has not been installed. Please `pip " + "install mlflow` to use the MLFlowLogger.") + + from mlflow.tracking import MlflowClient + self.client = MlflowClient( + tracking_uri=tracking_uri, registry_uri=registry_uri) + + if experiment_name is None: + # If no name is passed in, then check env vars. + # First check if experiment_name env var is set. + experiment_name = os.environ.get("MLFLOW_EXPERIMENT_NAME") + + if experiment_name is not None: + # First check if experiment with name exists. + experiment = self.client.get_experiment_by_name(experiment_name) + if experiment is not None: + # If it already exists then get the id. + experiment_id = experiment.experiment_id + else: + # If it does not exist, create the experiment. + experiment_id = self.client.create_experiment( + name=experiment_name) + else: + # No experiment_name is passed in and name env var is not set. + # Now check the experiment id env var. + experiment_id = os.environ.get("MLFLOW_EXPERIMENT_ID") + # Confirm that an experiment with this id exists. + if experiment_id is None or self.client.get_experiment( + experiment_id) is None: + raise ValueError("No experiment_name passed, " + "MLFLOW_EXPERIMENT_NAME env var is not " + "set, and MLFLOW_EXPERIMENT_ID either " + "is not set or does not exist. Please " + "set one of these to use the " + "MLFlowLoggerCallback.") + + # At this point, experiment_id should be set. + self.experiment_id = experiment_id + self.save_artifact = save_artifact + + self._trial_runs = {} + + def log_trial_start(self, trial: "Trial"): + # Create run if not already exists. + if trial not in self._trial_runs: + run = self.client.create_run( + experiment_id=self.experiment_id, + tags={"trial_name": str(trial)}) + self._trial_runs[trial] = run.info.run_id + + run_id = self._trial_runs[trial] + + # Log the config parameters. + config = trial.config + + for key, value in config.items(): + self.client.log_param(run_id=run_id, key=key, value=value) + + def log_trial_result(self, iteration: int, trial: "Trial", result: Dict): + run_id = self._trial_runs[trial] + for key, value in result.items(): + try: + value = float(value) + except (ValueError, TypeError): + logger.debug("Cannot log key {} with value {} since the " + "value cannot be converted to float.".format( + key, value)) + continue + self.client.log_metric( + run_id=run_id, key=key, value=value, step=iteration) + + def log_trial_end(self, trial: "Trial", failed: bool = False): + run_id = self._trial_runs[trial] + + # Log the artifact if set_artifact is set to True. + if self.save_artifact: + self.client.log_artifacts(run_id, local_dir=trial.logdir) + + # Stop the run once trial finishes. + status = "FINISHED" if not failed else "FAILED" + self.client.set_terminated(run_id=run_id, status=status) + + +class MLFlowLogger(Logger): + """MLFlow logger using the deprecated Logger API. + + Requires the experiment configuration to have a MLFlow Experiment ID + or manually set the proper environment variables. + """ + + _experiment_logger_cls = MLFlowLoggerCallback + + def _init(self): + mlflow = _import_mlflow() + logger_config = self.config.pop("logger_config", {}) + tracking_uri = logger_config.get("mlflow_tracking_uri") + registry_uri = logger_config.get("mlflow_registry_uri") + + experiment_id = logger_config.get("mlflow_experiment_id") + if experiment_id is None or not mlflow.get_experiment(experiment_id): + raise ValueError( + "You must provide a valid `mlflow_experiment_id` " + "in your `logger_config` dict in the `config` " + "dict passed to `tune.run`. " + "Are you sure you passed in a `experiment_id` and " + "the experiment exists?") + else: + experiment_name = mlflow.get_experiment(experiment_id).name + + self._trial_experiment_logger = self._experiment_logger_cls( + tracking_uri, registry_uri, experiment_name) + + self._trial_experiment_logger.log_trial_start(self.trial) + + def on_result(self, result: Dict): + self._trial_experiment_logger.log_trial_result( + iteration=result.get(TRAINING_ITERATION), + trial=self.trial, + result=result) + + def close(self): + self._trial_experiment_logger.log_trial_end( + trial=self.trial, failed=False) + del self._trial_experiment_logger + + +def mlflow_mixin(func: Callable): + """mlflow_mixin + + MLFlow (https://mlflow.org) Tracking is an open source library for + recording and querying experiments. This Ray Tune Trainable mixin helps + initialize the MLflow API for use with the ``Trainable`` class or the + ``@mlflow_mixin`` function API. This mixin automatically configures MLFlow + and creates a run in the same process as each Tune trial. You can then + use the mlflow API inside the your training function and it will + automatically get reported to the correct run. + + For basic usage, just prepend your training function with the + ``@mlflow_mixin`` decorator: + + .. code-block:: python + + from ray.tune.integration.mlflow import mlflow_mixin + + @mlflow_mixin + def train_fn(config): + ... + mlflow.log_metric(...) + + You can also use MlFlow's autologging feature if using a training + framework like Pytorch Lightning, XGBoost, etc. More information can be + found here (https://mlflow.org/docs/latest/tracking.html#automatic + -logging). + + .. code-block:: python + + from ray.tune.integration.mlflow import mlflow_mixin + + @mlflow_mixin + def train_fn(config): + mlflow.autolog() + xgboost_results = xgb.train(config, ...) + + The MlFlow configuration is done by passing a ``mlflow`` key to + the ``config`` parameter of ``tune.run()`` (see example below). + + The content of the ``mlflow`` config entry is used to + configure MlFlow. Here are the keys you can pass in to this config entry: + + Args: + tracking_uri (str): The tracking URI for MLflow tracking. If using + Tune in a multi-node setting, make sure to use a remote server for + tracking. + experiment_id (str): The id of an already created MLflow experiment. + All logs from all trials in ``tune.run`` will be reported to this + experiment. If this is not provided or the experiment with this + id does not exist, you must provide an``experiment_name``. This + parameter takes precedence over ``experiment_name``. + experiment_name (str): The name of an already existing MLflow + experiment. All logs from all trials in ``tune.run`` will be + reported to this experiment. If this is not provided, you must + provide a valid ``experiment_id``. + + Example: + + .. code-block:: python + + from ray import tune + from ray.tune.integration.mlflow import mlflow_mixin + + import mlflow + + # Create the MlFlow expriment. + mlflow.create_experiment("my_experiment") + + @mlflow_mixin + def train_fn(config): + for i in range(10): + loss = self.config["a"] + self.config["b"] + mlflow.log_metric(key="loss", value=loss}) + tune.report(loss=loss, done=True) + + tune.run( + train_fn, + config={ + # define search space here + "a": tune.choice([1, 2, 3]), + "b": tune.choice([4, 5, 6]), + # mlflow configuration + "mlflow": { + "experiment_name": "my_experiment", + "tracking_uri": mlflow.get_tracking_uri() + } + }) + """ + if _import_mlflow() is None: + raise RuntimeError("MLFlow has not been installed. Please `pip " + "install mlflow` to use the mlflow_mixin.") + if hasattr(func, "__mixins__"): + func.__mixins__ = func.__mixins__ + (MLFlowTrainableMixin, ) + else: + func.__mixins__ = (MLFlowTrainableMixin, ) + return func + + +class MLFlowTrainableMixin: + def __init__(self, config: Dict, *args, **kwargs): + self._mlflow = _import_mlflow() + + if not isinstance(self, Trainable): + raise ValueError( + "The `MLFlowTrainableMixin` can only be used as a mixin " + "for `tune.Trainable` classes. Please make sure your " + "class inherits from both. For example: " + "`class YourTrainable(MLFlowTrainableMixin)`.") + + super().__init__(config, *args, **kwargs) + _config = config.copy() + try: + mlflow_config = _config.pop("mlflow").copy() + except KeyError as e: + raise ValueError( + "MLFlow mixin specified but no configuration has been passed. " + "Make sure to include a `mlflow` key in your `config` dict " + "containing at least a `tracking_uri` and either " + "`experiment_name` or `experiment_id` specification.") from e + + tracking_uri = mlflow_config.pop("tracking_uri", None) + if tracking_uri is None: + raise ValueError("MLFlow mixin specified but no " + "tracking_uri has been " + "passed in. Make sure to include a `mlflow` " + "key in your `config` dict containing at " + "least a `tracking_uri`") + self._mlflow.set_tracking_uri(tracking_uri) + + # First see if experiment_id is passed in. + experiment_id = mlflow_config.pop("experiment_id", None) + if experiment_id is None or self._mlflow.get_experiment( + experiment_id) is None: + logger.debug("Either no experiment_id is passed in, or the " + "experiment with the given id does not exist. " + "Checking experiment_name") + # Check for name. + experiment_name = mlflow_config.pop("experiment_name", None) + if experiment_name is None: + raise ValueError( + "MLFlow mixin specified but no " + "experiment_name or experiment_id has been " + "passed in. Make sure to include a `mlflow` " + "key in your `config` dict containing at " + "least a `experiment_name` or `experiment_id` " + "specification.") + experiment = self._mlflow.get_experiment_by_name(experiment_name) + if experiment is not None: + # Experiment with this name exists. + experiment_id = experiment.experiment_id + else: + raise ValueError("No experiment with the given " + "name: {} or id: {} currently exists. Make " + "sure to first start the MLFlow experiment " + "before calling tune.run.".format( + experiment_name, experiment_id)) + + self.experiment_id = experiment_id + + run_name = self.trial_name + "_" + self.trial_id + run_name = run_name.replace("/", "_") + self._mlflow.start_run( + experiment_id=self.experiment_id, run_name=run_name) + + def stop(self): + self._mlflow.end_run() diff --git a/python/ray/tune/integration/wandb.py b/python/ray/tune/integration/wandb.py index 82327fffc..4cd2cdee8 100644 --- a/python/ray/tune/integration/wandb.py +++ b/python/ray/tune/integration/wandb.py @@ -139,7 +139,10 @@ def wandb_mixin(func: Callable): }) """ - func.__mixins__ = (WandbTrainableMixin, ) + if hasattr(func, "__mixins__"): + func.__mixins__ = func.__mixins__ + (WandbTrainableMixin, ) + else: + func.__mixins__ = (WandbTrainableMixin, ) return func diff --git a/python/ray/tune/logger.py b/python/ray/tune/logger.py index b4ff76bae..3029f2000 100644 --- a/python/ray/tune/logger.py +++ b/python/ray/tune/logger.py @@ -77,37 +77,6 @@ class NoopLogger(Logger): pass -class MLFLowLogger(Logger): - """MLFlow logger. - - Requires the experiment configuration to have a MLFlow Experiment ID - or manually set the proper environment variables. - - """ - - def _init(self): - logger_config = self.config.get("logger_config", {}) - from mlflow.tracking import MlflowClient - client = MlflowClient( - tracking_uri=logger_config.get("mlflow_tracking_uri"), - registry_uri=logger_config.get("mlflow_registry_uri")) - run = client.create_run(logger_config.get("mlflow_experiment_id")) - self._run_id = run.info.run_id - for key, value in self.config.items(): - client.log_param(self._run_id, key, value) - self.client = client - - def on_result(self, result: Dict): - for key, value in result.items(): - if not isinstance(value, float): - continue - self.client.log_metric( - self._run_id, key, value, step=result.get(TRAINING_ITERATION)) - - def close(self): - self.client.set_terminated(self._run_id) - - class JsonLogger(Logger): """Logs trial results in json format. @@ -734,6 +703,13 @@ class TBXLoggerCallback(LoggerCallback): "in the hyperparameter values.") +# Maintain backwards compatibility. +from ray.tune.integration.mlflow import MLFlowLogger as _MLFlowLogger # noqa: E402, E501 +MLFlowLogger = _MLFlowLogger +# The capital L is a typo, but needs to remain for backwards compatibility. +MLFLowLogger = _MLFlowLogger + + def pretty_print(result): result = result.copy() result.update(config=None) # drop config from pretty print diff --git a/python/ray/tune/tests/test_dependency.py b/python/ray/tune/tests/test_dependency.py index c2e2f2c7c..d626b0e36 100644 --- a/python/ray/tune/tests/test_dependency.py +++ b/python/ray/tune/tests/test_dependency.py @@ -23,3 +23,4 @@ if __name__ == "__main__": } }) assert "ray.rllib" not in sys.modules, "RLlib should not be imported" + assert "mlflow" not in sys.modules, "MLFlow should not be imported" diff --git a/python/ray/tune/tests/test_integration_mlflow.py b/python/ray/tune/tests/test_integration_mlflow.py new file mode 100644 index 000000000..6613e0229 --- /dev/null +++ b/python/ray/tune/tests/test_integration_mlflow.py @@ -0,0 +1,306 @@ +import os +import unittest +from collections import namedtuple +from unittest.mock import patch + +from ray.tune.function_runner import wrap_function +from ray.tune.integration.mlflow import MLFlowLoggerCallback, MLFlowLogger, \ + mlflow_mixin, MLFlowTrainableMixin + + +class MockTrial( + namedtuple("MockTrial", + ["config", "trial_name", "trial_id", "logdir"])): + def __hash__(self): + return hash(self.trial_id) + + def __str__(self): + return self.trial_name + + +MockRunInfo = namedtuple("MockRunInfo", ["run_id"]) + + +class MockRun: + def __init__(self, run_id, tags=None): + self.run_id = run_id + self.tags = tags + self.info = MockRunInfo(run_id) + self.params = [] + self.metrics = [] + self.artifacts = [] + + def log_param(self, key, value): + self.params.append({key: value}) + + def log_metric(self, key, value): + self.metrics.append({key: value}) + + def log_artifact(self, artifact): + self.artifacts.append(artifact) + + def set_terminated(self, status): + self.terminated = True + self.status = status + + +MockExperiment = namedtuple("MockExperiment", ["name", "experiment_id"]) + + +class MockMlflowClient: + def __init__(self, tracking_uri=None, registry_uri=None): + self.tracking_uri = tracking_uri + self.registry_uri = registry_uri + self.experiments = [MockExperiment("existing_experiment", 0)] + self.runs = {0: []} + self.active_run = None + + def set_tracking_uri(self, tracking_uri): + self.tracking_uri = tracking_uri + + def get_experiment_by_name(self, name): + try: + index = self.experiment_names.index(name) + return self.experiments[index] + except ValueError: + return None + + def get_experiment(self, experiment_id): + experiment_id = int(experiment_id) + try: + return self.experiments[experiment_id] + except IndexError: + return None + + def create_experiment(self, name): + experiment_id = len(self.experiments) + self.experiments.append(MockExperiment(name, experiment_id)) + self.runs[experiment_id] = [] + return experiment_id + + def create_run(self, experiment_id, tags=None): + experiment_runs = self.runs[experiment_id] + run_id = (experiment_id, len(experiment_runs)) + run = MockRun(run_id=run_id, tags=tags) + experiment_runs.append(run) + return run + + def start_run(self, experiment_id, run_name): + # Creates new run and sets it as active. + run = self.create_run(experiment_id) + self.active_run = run + + def get_mock_run(self, run_id): + return self.runs[run_id[0]][run_id[1]] + + def log_param(self, run_id, key, value): + run = self.get_mock_run(run_id) + run.log_param(key, value) + + def log_metric(self, run_id, key, value, step): + run = self.get_mock_run(run_id) + run.log_metric(key, value) + + def log_artifacts(self, run_id, local_dir): + run = self.get_mock_run(run_id) + run.log_artifact(local_dir) + + def set_terminated(self, run_id, status): + run = self.get_mock_run(run_id) + run.set_terminated(status) + + @property + def experiment_names(self): + return [e.name for e in self.experiments] + + +def clear_env_vars(): + if "MLFLOW_EXPERIMENT_NAME" in os.environ: + del os.environ["MLFLOW_EXPERIMENT_NAME"] + if "MLFLOW_EXPERIMENT_ID" in os.environ: + del os.environ["MLFLOW_EXPERIMENT_ID"] + + +class MLFlowTest(unittest.TestCase): + @patch("mlflow.tracking.MlflowClient", MockMlflowClient) + def testMlFlowLoggerCallbackConfig(self): + # Explicitly pass in all args. + logger = MLFlowLoggerCallback( + tracking_uri="test1", + registry_uri="test2", + experiment_name="test_exp") + self.assertEqual(logger.client.tracking_uri, "test1") + self.assertEqual(logger.client.registry_uri, "test2") + self.assertListEqual(logger.client.experiment_names, + ["existing_experiment", "test_exp"]) + self.assertEqual(logger.experiment_id, 1) + + # Check if client recognizes already existing experiment. + logger = MLFlowLoggerCallback(experiment_name="existing_experiment") + self.assertListEqual(logger.client.experiment_names, + ["existing_experiment"]) + self.assertEqual(logger.experiment_id, 0) + + # Pass in experiment name as env var. + clear_env_vars() + os.environ["MLFLOW_EXPERIMENT_NAME"] = "test_exp" + logger = MLFlowLoggerCallback() + self.assertListEqual(logger.client.experiment_names, + ["existing_experiment", "test_exp"]) + self.assertEqual(logger.experiment_id, 1) + + # Pass in existing experiment name as env var. + clear_env_vars() + os.environ["MLFLOW_EXPERIMENT_NAME"] = "existing_experiment" + logger = MLFlowLoggerCallback() + self.assertListEqual(logger.client.experiment_names, + ["existing_experiment"]) + self.assertEqual(logger.experiment_id, 0) + + # Pass in existing experiment id as env var. + clear_env_vars() + os.environ["MLFLOW_EXPERIMENT_ID"] = "0" + logger = MLFlowLoggerCallback() + self.assertListEqual(logger.client.experiment_names, + ["existing_experiment"]) + self.assertEqual(logger.experiment_id, "0") + + # Pass in non existing experiment id as env var. + clear_env_vars() + os.environ["MLFLOW_EXPERIMENT_ID"] = "500" + with self.assertRaises(ValueError): + logger = MLFlowLoggerCallback() + + # Experiment name env var should take precedence over id env var. + clear_env_vars() + os.environ["MLFLOW_EXPERIMENT_NAME"] = "test_exp" + os.environ["MLFLOW_EXPERIMENT_ID"] = "0" + logger = MLFlowLoggerCallback() + self.assertListEqual(logger.client.experiment_names, + ["existing_experiment", "test_exp"]) + self.assertEqual(logger.experiment_id, 1) + + @patch("mlflow.tracking.MlflowClient", MockMlflowClient) + def testMlFlowLoggerLogging(self): + clear_env_vars() + trial_config = {"par1": 4, "par2": 9.} + trial = MockTrial(trial_config, "trial1", 0, "artifact") + + logger = MLFlowLoggerCallback( + experiment_name="test1", save_artifact=True) + + # Check if run is created. + logger.on_trial_start(iteration=0, trials=[], trial=trial) + # New run should be created for this trial with correct tag. + mock_run = logger.client.runs[1][0] + self.assertDictEqual(mock_run.tags, {"trial_name": "trial1"}) + self.assertTupleEqual(mock_run.run_id, (1, 0)) + self.assertTupleEqual(logger._trial_runs[trial], mock_run.run_id) + # Params should be logged. + self.assertListEqual(mock_run.params, [{"par1": 4}, {"par2": 9}]) + + # When same trial is started again, new run should not be created. + logger.on_trial_start(iteration=0, trials=[], trial=trial) + self.assertEqual(len(logger.client.runs[1]), 1) + + # Check metrics are logged properly. + result = {"metric1": 0.8, "metric2": 1, "metric3": None} + logger.on_trial_result(0, [], trial, result) + mock_run = logger.client.runs[1][0] + # metric3 is not logged since it cannot be converted to float. + self.assertListEqual(mock_run.metrics, [{ + "metric1": 0.8 + }, { + "metric2": 1.0 + }]) + + # Check that artifact is logged on termination. + logger.on_trial_complete(0, [], trial) + mock_run = logger.client.runs[1][0] + self.assertListEqual(mock_run.artifacts, ["artifact"]) + self.assertTrue(mock_run.terminated) + self.assertEqual(mock_run.status, "FINISHED") + + @patch("mlflow.tracking.MlflowClient", MockMlflowClient) + def testMlFlowLegacyLoggerConfig(self): + mlflow = MockMlflowClient() + with patch.dict("sys.modules", mlflow=mlflow): + clear_env_vars() + trial_config = {"par1": 4, "par2": 9.} + trial = MockTrial(trial_config, "trial1", 0, "artifact") + + # No experiment_id is passed in config, should raise an error. + with self.assertRaises(ValueError): + logger = MLFlowLogger(trial_config, "/tmp", trial) + + trial_config.update({ + "logger_config": { + "mlflow_tracking_uri": "test_tracking_uri", + "mlflow_experiment_id": 0 + } + }) + trial = MockTrial(trial_config, "trial2", 1, "artifact") + logger = MLFlowLogger(trial_config, "/tmp", trial) + experiment_logger = logger._trial_experiment_logger + client = experiment_logger.client + self.assertEqual(client.tracking_uri, "test_tracking_uri") + # Check to make sure that a run was created on experiment_id 0. + self.assertEqual(len(client.runs[0]), 1) + mock_run = client.runs[0][0] + self.assertDictEqual(mock_run.tags, {"trial_name": "trial2"}) + self.assertListEqual(mock_run.params, [{"par1": 4}, {"par2": 9}]) + + @patch("ray.tune.integration.mlflow._import_mlflow", + lambda: MockMlflowClient()) + def testMlFlowMixinConfig(self): + clear_env_vars() + trial_config = {"par1": 4, "par2": 9.} + + @mlflow_mixin + def train_fn(config): + return 1 + + train_fn.__mixins__ = (MLFlowTrainableMixin, ) + + # No MLFlow config passed in. + with self.assertRaises(ValueError): + wrapped = wrap_function(train_fn)(trial_config) + + trial_config.update({"mlflow": {}}) + # No tracking uri or experiment_id/name passed in. + with self.assertRaises(ValueError): + wrapped = wrap_function(train_fn)(trial_config) + + # Invalid experiment-id + trial_config["mlflow"].update({"experiment_id": "500"}) + # No tracking uri or experiment_id/name passed in. + with self.assertRaises(ValueError): + wrapped = wrap_function(train_fn)(trial_config) + + trial_config["mlflow"].update({ + "tracking_uri": "test_tracking_uri", + "experiment_name": "existing_experiment" + }) + wrapped = wrap_function(train_fn)(trial_config) + client = wrapped._mlflow + self.assertEqual(client.tracking_uri, "test_tracking_uri") + self.assertTupleEqual(client.active_run.run_id, (0, 0)) + + with patch("ray.tune.integration.mlflow._import_mlflow", + lambda: client): + train_fn.__mixins__ = (MLFlowTrainableMixin, ) + wrapped = wrap_function(train_fn)(trial_config) + client = wrapped._mlflow + self.assertTupleEqual(client.active_run.run_id, (0, 1)) + + # Set to experiment that does not already exist. + # New experiment should be created. + trial_config["mlflow"]["experiment_name"] = "new_experiment" + with self.assertRaises(ValueError): + wrapped = wrap_function(train_fn)(trial_config) + + +if __name__ == "__main__": + import pytest + import sys + sys.exit(pytest.main(["-v", __file__])) diff --git a/python/ray/tune/tests/test_sample.py b/python/ray/tune/tests/test_sample.py index 921e0c9ca..8a06be5d0 100644 --- a/python/ray/tune/tests/test_sample.py +++ b/python/ray/tune/tests/test_sample.py @@ -972,4 +972,4 @@ class SearchSpaceTest(unittest.TestCase): if __name__ == "__main__": import pytest import sys - sys.exit(pytest.main(["-v", __file__])) + sys.exit(pytest.main(["-v", __file__] + sys.argv[1:])) diff --git a/python/requirements_tune.txt b/python/requirements_tune.txt index d68d3b3d3..9be5ee118 100644 --- a/python/requirements_tune.txt +++ b/python/requirements_tune.txt @@ -3,6 +3,7 @@ bayesian-optimization ConfigSpace==0.4.10 dragonfly-opt gluoncv +gorilla # Need this because bug in mlflow. Should be fixed in v1.12.2 gym[atari] GPy h5py