From a6138ca31f5e9b54d9ee30fda2fc34325a320760 Mon Sep 17 00:00:00 2001 From: Edward Oakes Date: Tue, 2 Feb 2021 09:44:01 -0600 Subject: [PATCH] [serve] Support batches for ImportedBackends (#13843) --- python/ray/serve/backends.py | 8 ++++++++ python/ray/serve/tests/test_imported_backend.py | 2 +- python/ray/serve/utils.py | 13 ++++++++----- 3 files changed, 17 insertions(+), 6 deletions(-) diff --git a/python/ray/serve/backends.py b/python/ray/serve/backends.py index 086755500..5f58ad2c9 100644 --- a/python/ray/serve/backends.py +++ b/python/ray/serve/backends.py @@ -1,3 +1,4 @@ +from ray import serve from ray.serve.utils import import_class @@ -26,6 +27,13 @@ class ImportedBackend: # proxy it manually. return self.wrapped.reconfigure(*args, **kwargs) + # We mark 'accept_batch' here just so this will always pass the + # check we make during create_backend(). Unfortunately this means + # that validation won't happen until the replica is created. + @serve.accept_batch + def __call__(self, *args, **kwargs): + return self.wrapped(*args, **kwargs) + def __getattr__(self, attr): """Proxy all other methods to the wrapper class.""" return getattr(self.wrapped, attr) diff --git a/python/ray/serve/tests/test_imported_backend.py b/python/ray/serve/tests/test_imported_backend.py index cc575dd94..99f08a04b 100644 --- a/python/ray/serve/tests/test_imported_backend.py +++ b/python/ray/serve/tests/test_imported_backend.py @@ -7,7 +7,7 @@ def test_imported_backend(serve_instance): client = serve_instance backend_class = ImportedBackend("ray.serve.utils.MockImportedBackend") - config = BackendConfig(user_config="config") + config = BackendConfig(user_config="config", max_batch_size=2) client.create_backend( "imported", backend_class, "input_arg", config=config) client.create_endpoint("imported", backend="imported") diff --git a/python/ray/serve/utils.py b/python/ray/serve/utils.py index a594b94dd..b4fdbf497 100644 --- a/python/ray/serve/utils.py +++ b/python/ray/serve/utils.py @@ -392,11 +392,14 @@ class MockImportedBackend: def reconfigure(self, config): self.config = config - def __call__(self, *args): - return {"arg": self.arg, "config": self.config} + def __call__(self, batch): + return [{ + "arg": self.arg, + "config": self.config + } for _ in range(len(batch))] - async def other_method(self, request): - return await request.body() + async def other_method(self, batch): + return [await request.body() for request in batch] def compute_iterable_delta(old: Iterable, @@ -406,7 +409,7 @@ def compute_iterable_delta(old: Iterable, Usage: >>> old = {"a", "b"} >>> new = {"a", "d"} - >>> compute_dict_delta(old, new) + >>> compute_iterable_delta(old, new) ({"d"}, {"b"}, {"a"}) """ old_keys, new_keys = set(old), set(new)