mirror of
https://github.com/wassname/ray.git
synced 2026-07-01 23:33:18 +08:00
Improve Serve API Input Validations (#8124)
* Add additional validation to endpoint and backend creation that ensures there are not duplicates created of either of these. In addition, adds additional validation to split_traffic to make sure both the endpoint and backends exist. * Fix test to deal with removed serve.link * Address PR feedback Co-authored-by: Max Fitton <max@semprehealth.com>
This commit is contained in:
@@ -219,20 +219,30 @@ class ServeMaster:
|
||||
return expand(
|
||||
self.route_table.list_service(include_headless=True).values())
|
||||
|
||||
def get_all_routes(self):
|
||||
return expand(self.route_table.list_service().keys())
|
||||
|
||||
def get_all_backends(self):
|
||||
return self.backend_table.list_backends()
|
||||
|
||||
async def set_traffic(self, endpoint_name, traffic_policy_dictionary):
|
||||
assert endpoint_name in expand(
|
||||
self.route_table.list_service(include_headless=True).values())
|
||||
assert endpoint_name in self.get_all_endpoints(), \
|
||||
"Attempted to assign traffic for an endpoint '{}'" \
|
||||
" that is not registered.".format(endpoint_name)
|
||||
|
||||
assert isinstance(traffic_policy_dictionary,
|
||||
dict), "Traffic policy must be dictionary"
|
||||
prob = 0
|
||||
existing_backends = set(self.get_all_backends())
|
||||
for backend, weight in traffic_policy_dictionary.items():
|
||||
prob += weight
|
||||
assert (backend in self.backend_table.list_backends()
|
||||
), "backend {} is not registered".format(backend)
|
||||
assert backend in existing_backends, \
|
||||
"Attempted to assign traffic to a backend '{}' that " \
|
||||
"is not registered.".format(backend)
|
||||
|
||||
assert np.isclose(
|
||||
prob, 1, atol=0.02
|
||||
), "weights must sum to 1, currently it sums to {}".format(prob)
|
||||
), "weights must sum to 1, currently they sum to {}".format(prob)
|
||||
|
||||
self.policy_table.register_traffic_policy(endpoint_name,
|
||||
traffic_policy_dictionary)
|
||||
@@ -242,6 +252,12 @@ class ServeMaster:
|
||||
traffic_policy_dictionary)
|
||||
|
||||
async def create_endpoint(self, route, endpoint_name, methods):
|
||||
err_prefix = "Cannot create endpoint. "
|
||||
assert route not in self.get_all_routes(), \
|
||||
"{} Route '{}' is already registered.".format(err_prefix, route)
|
||||
assert endpoint_name not in self.get_all_endpoints(), \
|
||||
"{} Endpoint '{}' is already registered.".format(err_prefix,
|
||||
endpoint_name)
|
||||
self.route_table.register_service(
|
||||
route, endpoint_name, methods=methods)
|
||||
[http_proxy] = self.get_http_proxy()
|
||||
@@ -255,6 +271,10 @@ class ServeMaster:
|
||||
backend_config_dict = dict(backend_config)
|
||||
backend_worker = create_backend_worker(func_or_class)
|
||||
|
||||
assert backend_tag not in self.get_all_backends(), \
|
||||
"Cannot create backend '{}' because a backend" \
|
||||
"with that name already exists.".format(backend_tag)
|
||||
|
||||
# Save creator which starts replicas.
|
||||
self.backend_table.register_backend(backend_tag, backend_worker)
|
||||
|
||||
|
||||
@@ -53,6 +53,38 @@ def test_no_route(serve_instance):
|
||||
assert result == 1
|
||||
|
||||
|
||||
def test_reject_duplicate_backend_tag(serve_instance):
|
||||
backend_name = "foo"
|
||||
serve.create_backend(lambda foo: foo, backend_name)
|
||||
with pytest.raises(AssertionError):
|
||||
serve.create_backend(lambda foo: foo, backend_name)
|
||||
|
||||
|
||||
def test_reject_duplicate_route(serve_instance):
|
||||
route = "/foo"
|
||||
serve.create_endpoint("bar", route=route)
|
||||
with pytest.raises(AssertionError):
|
||||
serve.create_endpoint("foo", route=route)
|
||||
|
||||
|
||||
def test_reject_duplicate_endpoint(serve_instance):
|
||||
endpoint_name = "foo"
|
||||
serve.create_endpoint(endpoint_name, route="/ok")
|
||||
with pytest.raises(AssertionError):
|
||||
serve.create_endpoint(endpoint_name, route="/different")
|
||||
|
||||
|
||||
def test_set_traffic_missing_data(serve_instance):
|
||||
endpoint_name = "foobar"
|
||||
backend_name = "foo_backend"
|
||||
serve.create_endpoint(endpoint_name)
|
||||
serve.create_backend(lambda: 5, backend_name)
|
||||
with pytest.raises(AssertionError):
|
||||
serve.set_traffic(endpoint_name, {"nonexistant_backend": 1.0})
|
||||
with pytest.raises(AssertionError):
|
||||
serve.set_traffic("nonexistant_endpoint_name", {backend_name: 1.0})
|
||||
|
||||
|
||||
def test_scaling_replicas(serve_instance):
|
||||
class Counter:
|
||||
def __init__(self):
|
||||
@@ -105,10 +137,10 @@ def test_batching(serve_instance):
|
||||
batch_size = serve.context.batch_size
|
||||
return [self.count] * batch_size
|
||||
|
||||
serve.create_endpoint("counter1", "/increment")
|
||||
serve.create_endpoint("counter1", "/increment2")
|
||||
|
||||
# Keep checking the routing table until /increment is populated
|
||||
while "/increment" not in requests.get(
|
||||
while "/increment2" not in requests.get(
|
||||
"http://127.0.0.1:8000/-/routes").json():
|
||||
time.sleep(0.2)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user