[serve] Add serve.shutdown() (#8766)

This commit is contained in:
Edward Oakes
2020-06-23 13:42:03 -05:00
committed by GitHub
parent b6d425526d
commit c9010eb8ad
8 changed files with 77 additions and 13 deletions
+5 -5
View File
@@ -1,7 +1,5 @@
import ray
_local = {} # dict for local mode
def _internal_kv_initialized():
worker = ray.worker.global_worker
@@ -11,9 +9,7 @@ def _internal_kv_initialized():
def _internal_kv_get(key):
"""Fetch the value of a binary key."""
worker = ray.worker.global_worker
return worker.redis_client.hget(key, "value")
return ray.worker.global_worker.redis_client.hget(key, "value")
def _internal_kv_put(key, value, overwrite=False):
@@ -32,3 +28,7 @@ def _internal_kv_put(key, value, overwrite=False):
else:
updated = worker.redis_client.hsetnx(key, "value", value)
return updated == 0 # already exists
def _internal_kv_del(key):
return ray.worker.global_worker.redis_client.delete(key)
+2 -1
View File
@@ -1,7 +1,7 @@
from ray.serve.api import (
init, create_backend, delete_backend, create_endpoint, delete_endpoint,
set_traffic, get_handle, stat, update_backend_config, get_backend_config,
accept_batch, list_backends, list_endpoints) # noqa: E402
accept_batch, list_backends, list_endpoints, shutdown) # noqa: E402
__all__ = [
"init",
@@ -17,4 +17,5 @@ __all__ = [
"accept_batch",
"list_backends",
"list_endpoints",
"shutdown",
]
+12
View File
@@ -119,6 +119,18 @@ def init(name=None,
@_ensure_connected
def shutdown():
"""Completely shut down the connected Serve instance.
Shuts down all processes and deletes all state associated with the Serve
instance that's currently connected to (via serve.init).
"""
global master_actor
ray.get(master_actor.shutdown.remote())
ray.kill(master_actor, no_restart=True)
master_actor = None
def create_endpoint(endpoint_name,
*,
backend=None,
+11
View File
@@ -45,3 +45,14 @@ class RayInternalKVStore:
raise TypeError("key must be a string, got: {}.".format(type(key)))
return ray_kv._internal_kv_get(self._format_key(key))
def delete(self, key):
"""Delete the value associated with the given key from the store.
Args:
key (str)
"""
if not isinstance(key, str):
raise TypeError("key must be a string, got: {}.".format(type(key)))
return ray_kv._internal_kv_del(self._format_key(key))
+13 -5
View File
@@ -61,7 +61,7 @@ class ServeMaster:
# namespace child actors and checkpoints.
self.instance_name = instance_name
# Used to read/write checkpoints.
self.kv_store = RayInternalKVStore()
self.kv_store = RayInternalKVStore(namespace=instance_name)
# path -> (endpoint, methods).
self.routes = {}
# backend -> (backend_worker, backend_config, replica_config).
@@ -112,10 +112,7 @@ class ServeMaster:
# a checkpoint to the event loop. Other state-changing calls acquire
# this lock and will be blocked until recovering from the checkpoint
# finishes.
checkpoint_key = CHECKPOINT_KEY
if self.instance_name is not None:
checkpoint_key = "{}:{}".format(self.instance_name, checkpoint_key)
checkpoint = self.kv_store.get(checkpoint_key)
checkpoint = self.kv_store.get(CHECKPOINT_KEY)
if checkpoint is None:
logger.debug("No checkpoint found")
else:
@@ -712,3 +709,14 @@ class ServeMaster:
assert (backend_tag in self.backends
), "Backend {} is not registered.".format(backend_tag)
return self.backends[backend_tag][2]
async def shutdown(self):
"""Shuts down the serve instance completely."""
async with self.write_lock:
ray.kill(self.http_proxy, no_restart=True)
ray.kill(self.router, no_restart=True)
ray.kill(self.metric_exporter, no_restart=True)
for replica_dict in self.workers.values():
for replica in replica_dict.values():
ray.kill(replica, no_restart=True)
self.kv_store.delete(CHECKPOINT_KEY)
+2
View File
@@ -20,6 +20,8 @@ def _shared_serve_instance():
def serve_instance(_shared_serve_instance):
serve.init()
yield
# Re-init if necessary.
serve.init()
master = serve.api._get_master_actor()
# Clear all state between tests to avoid naming collisions.
for endpoint in ray.get(master.get_all_endpoints.remote()):
+31 -1
View File
@@ -6,8 +6,10 @@ import requests
import ray
from ray import serve
from ray.serve.utils import get_random_letters
from ray.test_utils import wait_for_condition
from ray.serve import constants
from ray.serve.exceptions import RayServeException
from ray.serve.utils import format_actor_name, get_random_letters
def test_e2e(serve_instance):
@@ -546,6 +548,34 @@ def test_create_infeasible_error(serve_instance):
assert len(replicas) == 0
def test_shutdown(serve_instance):
def f():
pass
instance_name = "shutdown"
serve.init(name=instance_name, http_port=8002)
serve.create_backend("backend", f)
serve.create_endpoint("endpoint", backend="backend")
serve.shutdown()
with pytest.raises(RayServeException, match="Please run serve.init"):
serve.list_backends()
def check_dead():
for actor_name in [
constants.SERVE_MASTER_NAME, constants.SERVE_PROXY_NAME,
constants.SERVE_ROUTER_NAME, constants.SERVE_METRIC_SINK_NAME
]:
try:
ray.get_actor(format_actor_name(actor_name, instance_name))
return False
except ValueError:
pass
return True
assert wait_for_condition(check_dead)
if __name__ == "__main__":
import sys
sys.exit(pytest.main(["-v", "-s", __file__]))