diff --git a/python/ray/serve/master.py b/python/ray/serve/master.py index 87a6f099f..3e3f5cf55 100644 --- a/python/ray/serve/master.py +++ b/python/ray/serve/master.py @@ -219,20 +219,30 @@ class ServeMaster: return expand( self.route_table.list_service(include_headless=True).values()) + def get_all_routes(self): + return expand(self.route_table.list_service().keys()) + + def get_all_backends(self): + return self.backend_table.list_backends() + async def set_traffic(self, endpoint_name, traffic_policy_dictionary): - assert endpoint_name in expand( - self.route_table.list_service(include_headless=True).values()) + assert endpoint_name in self.get_all_endpoints(), \ + "Attempted to assign traffic for an endpoint '{}'" \ + " that is not registered.".format(endpoint_name) assert isinstance(traffic_policy_dictionary, dict), "Traffic policy must be dictionary" prob = 0 + existing_backends = set(self.get_all_backends()) for backend, weight in traffic_policy_dictionary.items(): prob += weight - assert (backend in self.backend_table.list_backends() - ), "backend {} is not registered".format(backend) + assert backend in existing_backends, \ + "Attempted to assign traffic to a backend '{}' that " \ + "is not registered.".format(backend) + assert np.isclose( prob, 1, atol=0.02 - ), "weights must sum to 1, currently it sums to {}".format(prob) + ), "weights must sum to 1, currently they sum to {}".format(prob) self.policy_table.register_traffic_policy(endpoint_name, traffic_policy_dictionary) @@ -242,6 +252,12 @@ class ServeMaster: traffic_policy_dictionary) async def create_endpoint(self, route, endpoint_name, methods): + err_prefix = "Cannot create endpoint. " + assert route not in self.get_all_routes(), \ + "{} Route '{}' is already registered.".format(err_prefix, route) + assert endpoint_name not in self.get_all_endpoints(), \ + "{} Endpoint '{}' is already registered.".format(err_prefix, + endpoint_name) self.route_table.register_service( route, endpoint_name, methods=methods) [http_proxy] = self.get_http_proxy() @@ -255,6 +271,10 @@ class ServeMaster: backend_config_dict = dict(backend_config) backend_worker = create_backend_worker(func_or_class) + assert backend_tag not in self.get_all_backends(), \ + "Cannot create backend '{}' because a backend" \ + "with that name already exists.".format(backend_tag) + # Save creator which starts replicas. self.backend_table.register_backend(backend_tag, backend_worker) diff --git a/python/ray/serve/tests/test_api.py b/python/ray/serve/tests/test_api.py index dda46bfc5..ffb599d60 100644 --- a/python/ray/serve/tests/test_api.py +++ b/python/ray/serve/tests/test_api.py @@ -53,6 +53,38 @@ def test_no_route(serve_instance): assert result == 1 +def test_reject_duplicate_backend_tag(serve_instance): + backend_name = "foo" + serve.create_backend(lambda foo: foo, backend_name) + with pytest.raises(AssertionError): + serve.create_backend(lambda foo: foo, backend_name) + + +def test_reject_duplicate_route(serve_instance): + route = "/foo" + serve.create_endpoint("bar", route=route) + with pytest.raises(AssertionError): + serve.create_endpoint("foo", route=route) + + +def test_reject_duplicate_endpoint(serve_instance): + endpoint_name = "foo" + serve.create_endpoint(endpoint_name, route="/ok") + with pytest.raises(AssertionError): + serve.create_endpoint(endpoint_name, route="/different") + + +def test_set_traffic_missing_data(serve_instance): + endpoint_name = "foobar" + backend_name = "foo_backend" + serve.create_endpoint(endpoint_name) + serve.create_backend(lambda: 5, backend_name) + with pytest.raises(AssertionError): + serve.set_traffic(endpoint_name, {"nonexistant_backend": 1.0}) + with pytest.raises(AssertionError): + serve.set_traffic("nonexistant_endpoint_name", {backend_name: 1.0}) + + def test_scaling_replicas(serve_instance): class Counter: def __init__(self): @@ -105,10 +137,10 @@ def test_batching(serve_instance): batch_size = serve.context.batch_size return [self.count] * batch_size - serve.create_endpoint("counter1", "/increment") + serve.create_endpoint("counter1", "/increment2") # Keep checking the routing table until /increment is populated - while "/increment" not in requests.get( + while "/increment2" not in requests.get( "http://127.0.0.1:8000/-/routes").json(): time.sleep(0.2)