diff --git a/python/ray/serve/api.py b/python/ray/serve/api.py index b8456dfa3..353f04546 100644 --- a/python/ray/serve/api.py +++ b/python/ray/serve/api.py @@ -17,6 +17,7 @@ from ray.serve.exceptions import RayServeException, batch_annotation_not_found from ray.serve.backend_config import BackendConfig from ray.serve.policy import RoutePolicy from ray.serve.queues import Query +from ray.serve.request_params import RequestMetadata global_state = None @@ -124,6 +125,9 @@ def init( # Register serialization context once ray.register_custom_serializer(Query, Query.ray_serialize, Query.ray_deserialize) + ray.register_custom_serializer(RequestMetadata, + RequestMetadata.ray_serialize, + RequestMetadata.ray_deserialize) if kv_store_path is None: _, kv_store_path = mkstemp() diff --git a/python/ray/serve/queues.py b/python/ray/serve/queues.py index 674180e8e..f687318d5 100644 --- a/python/ray/serve/queues.py +++ b/python/ray/serve/queues.py @@ -2,7 +2,7 @@ import asyncio import copy from collections import defaultdict from typing import DefaultDict, List -import pickle +import ray.cloudpickle as pickle # Note on choosing blist instead of stdlib heapq # 1. pop operation should be O(1) (amortized) @@ -22,12 +22,13 @@ class Query: request_kwargs, request_context, request_slo_ms, - call_method="__call__"): + call_method="__call__", + async_future=None): self.request_args = request_args self.request_kwargs = request_kwargs self.request_context = request_context - self.async_future = asyncio.get_event_loop().create_future() + self.async_future = async_future # Service level objective in milliseconds. This is expected to be the # absolute time since unix epoch. @@ -41,14 +42,14 @@ class Query: # replica worker the async_future is still needed to retrieve the final # result. Therefore we need a way to pass the information to replica # worker without removing async_future. - clone = copy.copy(self) - clone.async_future = None - # We can't use cloudpickle due to a recursion issue - return pickle.dumps(clone) + clone = copy.copy(self).__dict__ + clone.pop("async_future") + return pickle.dumps(clone, protocol=5) @staticmethod def ray_deserialize(value): - return pickle.loads(value) + kwargs = pickle.loads(value) + return Query(**kwargs) # adding comparator fn for maintaining an # ascending order sorted list w.r.t request_slo_ms @@ -169,24 +170,25 @@ class CentralizedQueues: for backend_name, queue in self.buffer_queues.items() } - async def enqueue_request(self, request_in_object, *request_args, + async def enqueue_request(self, request_meta, *request_args, **request_kwargs): - service = request_in_object.service + service = request_meta.service logger.debug("Received a request for service {}".format(service)) # check if the slo specified is directly the # wall clock time - if request_in_object.absolute_slo_ms is not None: - request_slo_ms = request_in_object.absolute_slo_ms + if request_meta.absolute_slo_ms is not None: + request_slo_ms = request_meta.absolute_slo_ms else: - request_slo_ms = request_in_object.adjust_relative_slo_ms() - request_context = request_in_object.request_context + request_slo_ms = request_meta.adjust_relative_slo_ms() + request_context = request_meta.request_context query = Query( request_args, request_kwargs, request_context, request_slo_ms, - call_method=request_in_object.call_method) + call_method=request_meta.call_method, + async_future=asyncio.get_event_loop().create_future()) await self.service_queues[service].put(query) await self.flush() diff --git a/python/ray/serve/request_params.py b/python/ray/serve/request_params.py index be17ef5ce..0d5015f25 100644 --- a/python/ray/serve/request_params.py +++ b/python/ray/serve/request_params.py @@ -1,5 +1,6 @@ import time from ray.serve.constants import DEFAULT_LATENCY_SLO_MS +import ray.cloudpickle as pickle class RequestMetadata: @@ -37,3 +38,11 @@ class RequestMetadata: slo_ms = DEFAULT_LATENCY_SLO_MS current_time_ms = time.time() * 1000 return current_time_ms + slo_ms + + def ray_serialize(self): + return pickle.dumps(self.__dict__, protocol=5) + + @staticmethod + def ray_deserialize(value): + kwargs = pickle.loads(value) + return RequestMetadata(**kwargs) diff --git a/python/ray/serve/server.py b/python/ray/serve/server.py index c19067788..8e86871cc 100644 --- a/python/ray/serve/server.py +++ b/python/ray/serve/server.py @@ -146,13 +146,8 @@ class HTTPProxy: await error_sender(str(e), 400) return - # create objects necessary for enqueue - # enclosing http_body_bytes to list due to - # https://github.com/ray-project/ray/issues/6944 - # TODO(alind): remove list enclosing after issue is fixed - args = (scope, [http_body_bytes]) headers = {k.decode(): v.decode() for k, v in scope["headers"]} - request_in_object = RequestMetadata( + request_metadata = RequestMetadata( endpoint_name, TaskContext.Web, relative_slo_ms=relative_slo_ms, @@ -161,7 +156,8 @@ class HTTPProxy: try: result = await (self.serve_global_state.init_or_get_router() - .enqueue_request.remote(request_in_object, *args)) + .enqueue_request.remote(request_metadata, scope, + http_body_bytes)) await Response(result).send(scope, receive, send) except Exception as e: error_message = "Internal Error. Traceback: {}.".format(e) diff --git a/python/ray/serve/utils.py b/python/ray/serve/utils.py index 6baa907fb..567adfcd4 100644 --- a/python/ray/serve/utils.py +++ b/python/ray/serve/utils.py @@ -38,11 +38,7 @@ def parse_request_item(request_item): is_web_context = True asgi_scope, body_bytes = request_item.request_args - # http_body_bytes enclosed in list due to - # https://github.com/ray-project/ray/issues/6944 - # TODO(alind): remove list enclosing after issue is fixed - flask_request = build_flask_request(asgi_scope, - io.BytesIO(body_bytes[0])) + flask_request = build_flask_request(asgi_scope, io.BytesIO(body_bytes)) args = (flask_request, ) kwargs = {} else: