From 34bda32054316dcb8c10c1acfdc2ec66ffab2dca Mon Sep 17 00:00:00 2001 From: Edward Oakes Date: Fri, 4 Sep 2020 19:11:58 -0500 Subject: [PATCH] [tune/serve] Fix tune/serve integration script broken by serve API change (#10586) --- .../tune-serve-integration-mnist.py | 23 ++++++++++++------- 1 file changed, 15 insertions(+), 8 deletions(-) diff --git a/doc/source/tune/_tutorials/tune-serve-integration-mnist.py b/doc/source/tune/_tutorials/tune-serve-integration-mnist.py index 11f42f472..04690129a 100644 --- a/doc/source/tune/_tutorials/tune-serve-integration-mnist.py +++ b/doc/source/tune/_tutorials/tune-serve-integration-mnist.py @@ -98,6 +98,7 @@ import torch.nn as nn import torch.nn.functional as F import torch.optim as optim from ray import tune, serve +from ray.serve.exceptions import RayServeException from ray.tune import CLIReporter from ray.tune.schedulers import ASHAScheduler @@ -449,28 +450,34 @@ def serve_new_model(model_dir, checkpoint, config, metrics, day, gpu=False): checkpoint_path = _move_checkpoint_to_model_dir(model_dir, checkpoint, config, metrics) - serve.init() + try: + # Try to connect to an existing cluster. + client = serve.connect() + except RayServeException: + # If this is the first run, need to start the cluster. + client = serve.start(detached=True) + backend_name = "mnist:day_{}".format(day) - serve.create_backend(backend_name, MNISTBackend, checkpoint_path, config, - metrics, gpu) + client.create_backend(backend_name, MNISTBackend, checkpoint_path, config, + metrics, gpu) - if "mnist" not in serve.list_endpoints(): + if "mnist" not in client.list_endpoints(): # First time we serve a model - create endpoint - serve.create_endpoint( + client.create_endpoint( "mnist", backend=backend_name, route="/mnist", methods=["POST"]) else: # The endpoint already exists, route all traffic to the new model # Here you could also implement an incremental rollout, where only # a part of the traffic is sent to the new backend and the # rest is sent to the existing backends. - serve.set_traffic("mnist", {backend_name: 1.0}) + client.set_traffic("mnist", {backend_name: 1.0}) # Delete previous existing backends - for existing_backend in serve.list_backends(): + for existing_backend in client.list_backends(): if existing_backend.startswith("mnist:day") and \ existing_backend != backend_name: - serve.delete_backend(existing_backend) + client.delete_backend(existing_backend) return True