Improve Serve API Input Validations (#8124)

* Add additional validation to endpoint and backend creation that ensures there are not duplicates created of either of these. In addition, adds additional validation to split_traffic to make sure both the endpoint and backends exist.

* Fix test to deal with removed serve.link

* Address PR feedback

Co-authored-by: Max Fitton <max@semprehealth.com>
This commit is contained in:
Max Fitton
2020-04-21 19:45:29 -07:00
committed by GitHub
parent 95e8ec8c47
commit c486b56c58
2 changed files with 59 additions and 7 deletions
+25 -5
View File
@@ -219,20 +219,30 @@ class ServeMaster:
return expand(
self.route_table.list_service(include_headless=True).values())
def get_all_routes(self):
return expand(self.route_table.list_service().keys())
def get_all_backends(self):
return self.backend_table.list_backends()
async def set_traffic(self, endpoint_name, traffic_policy_dictionary):
assert endpoint_name in expand(
self.route_table.list_service(include_headless=True).values())
assert endpoint_name in self.get_all_endpoints(), \
"Attempted to assign traffic for an endpoint '{}'" \
" that is not registered.".format(endpoint_name)
assert isinstance(traffic_policy_dictionary,
dict), "Traffic policy must be dictionary"
prob = 0
existing_backends = set(self.get_all_backends())
for backend, weight in traffic_policy_dictionary.items():
prob += weight
assert (backend in self.backend_table.list_backends()
), "backend {} is not registered".format(backend)
assert backend in existing_backends, \
"Attempted to assign traffic to a backend '{}' that " \
"is not registered.".format(backend)
assert np.isclose(
prob, 1, atol=0.02
), "weights must sum to 1, currently it sums to {}".format(prob)
), "weights must sum to 1, currently they sum to {}".format(prob)
self.policy_table.register_traffic_policy(endpoint_name,
traffic_policy_dictionary)
@@ -242,6 +252,12 @@ class ServeMaster:
traffic_policy_dictionary)
async def create_endpoint(self, route, endpoint_name, methods):
err_prefix = "Cannot create endpoint. "
assert route not in self.get_all_routes(), \
"{} Route '{}' is already registered.".format(err_prefix, route)
assert endpoint_name not in self.get_all_endpoints(), \
"{} Endpoint '{}' is already registered.".format(err_prefix,
endpoint_name)
self.route_table.register_service(
route, endpoint_name, methods=methods)
[http_proxy] = self.get_http_proxy()
@@ -255,6 +271,10 @@ class ServeMaster:
backend_config_dict = dict(backend_config)
backend_worker = create_backend_worker(func_or_class)
assert backend_tag not in self.get_all_backends(), \
"Cannot create backend '{}' because a backend" \
"with that name already exists.".format(backend_tag)
# Save creator which starts replicas.
self.backend_table.register_backend(backend_tag, backend_worker)
+34 -2
View File
@@ -53,6 +53,38 @@ def test_no_route(serve_instance):
assert result == 1
def test_reject_duplicate_backend_tag(serve_instance):
backend_name = "foo"
serve.create_backend(lambda foo: foo, backend_name)
with pytest.raises(AssertionError):
serve.create_backend(lambda foo: foo, backend_name)
def test_reject_duplicate_route(serve_instance):
route = "/foo"
serve.create_endpoint("bar", route=route)
with pytest.raises(AssertionError):
serve.create_endpoint("foo", route=route)
def test_reject_duplicate_endpoint(serve_instance):
endpoint_name = "foo"
serve.create_endpoint(endpoint_name, route="/ok")
with pytest.raises(AssertionError):
serve.create_endpoint(endpoint_name, route="/different")
def test_set_traffic_missing_data(serve_instance):
endpoint_name = "foobar"
backend_name = "foo_backend"
serve.create_endpoint(endpoint_name)
serve.create_backend(lambda: 5, backend_name)
with pytest.raises(AssertionError):
serve.set_traffic(endpoint_name, {"nonexistant_backend": 1.0})
with pytest.raises(AssertionError):
serve.set_traffic("nonexistant_endpoint_name", {backend_name: 1.0})
def test_scaling_replicas(serve_instance):
class Counter:
def __init__(self):
@@ -105,10 +137,10 @@ def test_batching(serve_instance):
batch_size = serve.context.batch_size
return [self.count] * batch_size
serve.create_endpoint("counter1", "/increment")
serve.create_endpoint("counter1", "/increment2")
# Keep checking the routing table until /increment is populated
while "/increment" not in requests.get(
while "/increment2" not in requests.get(
"http://127.0.0.1:8000/-/routes").json():
time.sleep(0.2)