[Serve] Improve Serialization (#7688)

This commit is contained in:
Simon Mo
2020-03-29 14:57:19 -07:00
committed by GitHub
parent fc23f79f82
commit 353d7e107f
5 changed files with 34 additions and 27 deletions
+4
View File
@@ -17,6 +17,7 @@ from ray.serve.exceptions import RayServeException, batch_annotation_not_found
from ray.serve.backend_config import BackendConfig
from ray.serve.policy import RoutePolicy
from ray.serve.queues import Query
from ray.serve.request_params import RequestMetadata
global_state = None
@@ -124,6 +125,9 @@ def init(
# Register serialization context once
ray.register_custom_serializer(Query, Query.ray_serialize,
Query.ray_deserialize)
ray.register_custom_serializer(RequestMetadata,
RequestMetadata.ray_serialize,
RequestMetadata.ray_deserialize)
if kv_store_path is None:
_, kv_store_path = mkstemp()
+17 -15
View File
@@ -2,7 +2,7 @@ import asyncio
import copy
from collections import defaultdict
from typing import DefaultDict, List
import pickle
import ray.cloudpickle as pickle
# Note on choosing blist instead of stdlib heapq
# 1. pop operation should be O(1) (amortized)
@@ -22,12 +22,13 @@ class Query:
request_kwargs,
request_context,
request_slo_ms,
call_method="__call__"):
call_method="__call__",
async_future=None):
self.request_args = request_args
self.request_kwargs = request_kwargs
self.request_context = request_context
self.async_future = asyncio.get_event_loop().create_future()
self.async_future = async_future
# Service level objective in milliseconds. This is expected to be the
# absolute time since unix epoch.
@@ -41,14 +42,14 @@ class Query:
# replica worker the async_future is still needed to retrieve the final
# result. Therefore we need a way to pass the information to replica
# worker without removing async_future.
clone = copy.copy(self)
clone.async_future = None
# We can't use cloudpickle due to a recursion issue
return pickle.dumps(clone)
clone = copy.copy(self).__dict__
clone.pop("async_future")
return pickle.dumps(clone, protocol=5)
@staticmethod
def ray_deserialize(value):
return pickle.loads(value)
kwargs = pickle.loads(value)
return Query(**kwargs)
# adding comparator fn for maintaining an
# ascending order sorted list w.r.t request_slo_ms
@@ -169,24 +170,25 @@ class CentralizedQueues:
for backend_name, queue in self.buffer_queues.items()
}
async def enqueue_request(self, request_in_object, *request_args,
async def enqueue_request(self, request_meta, *request_args,
**request_kwargs):
service = request_in_object.service
service = request_meta.service
logger.debug("Received a request for service {}".format(service))
# check if the slo specified is directly the
# wall clock time
if request_in_object.absolute_slo_ms is not None:
request_slo_ms = request_in_object.absolute_slo_ms
if request_meta.absolute_slo_ms is not None:
request_slo_ms = request_meta.absolute_slo_ms
else:
request_slo_ms = request_in_object.adjust_relative_slo_ms()
request_context = request_in_object.request_context
request_slo_ms = request_meta.adjust_relative_slo_ms()
request_context = request_meta.request_context
query = Query(
request_args,
request_kwargs,
request_context,
request_slo_ms,
call_method=request_in_object.call_method)
call_method=request_meta.call_method,
async_future=asyncio.get_event_loop().create_future())
await self.service_queues[service].put(query)
await self.flush()
+9
View File
@@ -1,5 +1,6 @@
import time
from ray.serve.constants import DEFAULT_LATENCY_SLO_MS
import ray.cloudpickle as pickle
class RequestMetadata:
@@ -37,3 +38,11 @@ class RequestMetadata:
slo_ms = DEFAULT_LATENCY_SLO_MS
current_time_ms = time.time() * 1000
return current_time_ms + slo_ms
def ray_serialize(self):
return pickle.dumps(self.__dict__, protocol=5)
@staticmethod
def ray_deserialize(value):
kwargs = pickle.loads(value)
return RequestMetadata(**kwargs)
+3 -7
View File
@@ -146,13 +146,8 @@ class HTTPProxy:
await error_sender(str(e), 400)
return
# create objects necessary for enqueue
# enclosing http_body_bytes to list due to
# https://github.com/ray-project/ray/issues/6944
# TODO(alind): remove list enclosing after issue is fixed
args = (scope, [http_body_bytes])
headers = {k.decode(): v.decode() for k, v in scope["headers"]}
request_in_object = RequestMetadata(
request_metadata = RequestMetadata(
endpoint_name,
TaskContext.Web,
relative_slo_ms=relative_slo_ms,
@@ -161,7 +156,8 @@ class HTTPProxy:
try:
result = await (self.serve_global_state.init_or_get_router()
.enqueue_request.remote(request_in_object, *args))
.enqueue_request.remote(request_metadata, scope,
http_body_bytes))
await Response(result).send(scope, receive, send)
except Exception as e:
error_message = "Internal Error. Traceback: {}.".format(e)
+1 -5
View File
@@ -38,11 +38,7 @@ def parse_request_item(request_item):
is_web_context = True
asgi_scope, body_bytes = request_item.request_args
# http_body_bytes enclosed in list due to
# https://github.com/ray-project/ray/issues/6944
# TODO(alind): remove list enclosing after issue is fixed
flask_request = build_flask_request(asgi_scope,
io.BytesIO(body_bytes[0]))
flask_request = build_flask_request(asgi_scope, io.BytesIO(body_bytes))
args = (flask_request, )
kwargs = {}
else: