From 48212079f436971ce71e27f56fe00a4fc4ed762b Mon Sep 17 00:00:00 2001 From: Yannic Kilcher Date: Sat, 11 Feb 2023 01:31:25 +0100 Subject: [PATCH] unified queueing --- inference/server/main.py | 133 ++++++++++-------- .../oasst_inference_server/chat_repository.py | 53 ++++--- .../oasst_inference_server/interface.py | 5 +- .../server/oasst_inference_server/queueing.py | 27 ++++ .../server/oasst_inference_server/settings.py | 2 + inference/worker/utils.py | 2 +- .../oasst_shared/schemas/inference.py | 4 + 7 files changed, 144 insertions(+), 82 deletions(-) create mode 100644 inference/server/oasst_inference_server/queueing.py diff --git a/inference/server/main.py b/inference/server/main.py index 072bcbde..988ce05b 100644 --- a/inference/server/main.py +++ b/inference/server/main.py @@ -12,7 +12,7 @@ import websockets.exceptions from fastapi import Depends from fastapi.middleware.cors import CORSMiddleware from loguru import logger -from oasst_inference_server import interface +from oasst_inference_server import interface, queueing from oasst_inference_server.chat_repository import ChatRepository from oasst_inference_server.database import db_engine from oasst_inference_server.settings import settings @@ -40,7 +40,7 @@ app.add_middleware( # create async redis client -redisClient = redis.Redis( +redis_client = redis.Redis( host=settings.redis_host, port=settings.redis_port, db=settings.redis_db, decode_responses=True ) @@ -55,6 +55,12 @@ def create_chat_repository(session: sqlmodel.Session = Depends(create_session)): return repository +@contextlib.contextmanager +def manual_chat_repository(): + with contextlib.contextmanager(create_session)() as session: + yield create_chat_repository(session) + + if settings.update_alembic: @app.on_event("startup") @@ -105,9 +111,9 @@ async def get_chat(id: str, chat_repository: ChatRepository = Depends(create_cha return chat -@app.post("/chat/{id}/message") +@app.post("/chat/{chat_id}/message") async def create_message( - id: str, + chat_id: str, message_request: interface.MessageRequest, fastapi_request: fastapi.Request, chat_repository: ChatRepository = Depends(create_chat_repository), @@ -115,21 +121,24 @@ async def create_message( """Allows the client to stream the results of a request.""" try: - chat_repository.add_prompter_message(id=id, message_request=message_request) + chat_repository.add_prompter_message(chat_id=chat_id, message_request=message_request) + queue = queueing.work_queue(redis_client, message_request.worker_compat_hash) + logger.debug(f"Adding {chat_id} to {queue.queue_id}") + await queue.enqueue(chat_id) + logger.debug(f"Added message to {queue.queue_id} for {chat_id}") except Exception: logger.exception("Error adding prompter message") return fastapi.Response(status_code=500) - async def event_generator(id): + async def event_generator(chat_id): + queue = queueing.chat_queue(redis_client, chat_id) result_data = [] - try: while True: if await fastapi_request.is_disconnected(): logger.warning("Client disconnected") return - - item = await redisClient.blpop(id, 1) + item = await queue.dequeue() if item is None: continue @@ -144,25 +153,26 @@ async def create_message( "retry": settings.sse_retry_timeout, "data": interface.TokenResponseEvent(token=response_packet.token).json(), } - logger.info(f"Finished streaming {id} {len(result_data)=}") + logger.info(f"Finished streaming {chat_id} {len(result_data)=}") except Exception: - logger.exception(f"Error streaming {id}") + logger.exception(f"Error streaming {chat_id}") raise try: - with contextlib.contextmanager(create_session)() as session: - chat_repository = create_chat_repository(session) - chat_repository.add_assistant_message(id=id, text=response_packet.generated_text.text) + with manual_chat_repository() as chat_repository: + chat_repository.add_assistant_message(chat_id=chat_id, text=response_packet.generated_text.text) except Exception: logger.exception("Error adding assistant message") - return EventSourceResponse(event_generator(id)) + return EventSourceResponse(event_generator(chat_id)) @app.websocket("/work") -async def work(websocket: fastapi.WebSocket, chat_repository: ChatRepository = Depends(create_chat_repository)): +async def work(websocket: fastapi.WebSocket): await websocket.accept() worker_config = inference.WorkerConfig.parse_raw(await websocket.receive_text()) + queue_id = f"work:{worker_config.compat_hash}" + work_queue = queueing.RedisQueue(redis_client, queue_id) try: while True: if websocket.client_state == fastapi.websockets.WebSocketState.DISCONNECTED: @@ -171,55 +181,62 @@ async def work(websocket: fastapi.WebSocket, chat_repository: ChatRepository = D # find a pending task that matches the worker's config # could also be implemented using task queues # but general compatibility matching is tricky - for chat in chat_repository.get_pending_chats(): - request = chat.pending_message_request - if request.compatible_with(worker_config): - break - else: + item = await work_queue.dequeue() + if item is None: await asyncio.sleep(1) continue + else: + _, chat_id = item - chat_repository.set_chat_state(chat.id, interface.MessageRequestState.in_progress) + with manual_chat_repository() as chat_repository: + chat = chat_repository.get_chat_by_id(chat_id) + request = chat.pending_message_request - work_request = inference.WorkRequest( - conversation=chat.conversation, - model_name=request.model_name, - max_new_tokens=request.max_new_tokens, - ) + chat_repository.set_chat_state(chat.id, interface.MessageRequestState.in_progress) - logger.info(f"Created {work_request=}") - try: - await websocket.send_text(work_request.json()) - except websockets.exceptions.ConnectionClosedError: - logger.warning("Worker disconnected") - websocket.close() - chat_repository.set_chat_state(chat.id, interface.MessageRequestState.pending) - break + work_request = inference.WorkRequest( + conversation=chat.conversation, + model_name=request.model_name, + max_new_tokens=request.max_new_tokens, + ) - logger.debug(f"Sent {work_request=} to worker.") - - try: - in_progress = False - while True: - # maybe unnecessary to parse and re-serialize - # could just pass the raw string and mark end via empty string - response_packet = inference.WorkResponsePacket.parse_raw(await websocket.receive_text()) - in_progress = True - await redisClient.rpush(chat.id, response_packet.json()) - if response_packet.is_end: - logger.debug(f"Received {response_packet=} from worker. Ending.") - break - except fastapi.WebSocketException: - # TODO: handle this better - logger.exception(f"Websocket closed during handling of {chat.id}") - if in_progress: - logger.warning(f"Aborting {chat.id=}") - chat_repository.set_chat_state(chat.id, interface.MessageRequestState.aborted_by_worker) - else: - logger.warning(f"Marking {chat.id=} as pending since no work was done.") + logger.info(f"Created {work_request=}") + try: + await websocket.send_text(work_request.json()) + except websockets.exceptions.ConnectionClosedError: + logger.warning("Worker disconnected") + websocket.close() chat_repository.set_chat_state(chat.id, interface.MessageRequestState.pending) - raise + await work_queue.enqueue(chat.id) + break - chat_repository.set_chat_state(chat.id, interface.MessageRequestState.complete) + logger.debug(f"Sent {work_request=} to worker.") + + chat_queue = queueing.chat_queue(redis_client, chat.id) + + try: + in_progress = False + while True: + # maybe unnecessary to parse and re-serialize + # could just pass the raw string and mark end via empty string + response_packet = inference.WorkResponsePacket.parse_raw(await websocket.receive_text()) + in_progress = True + await chat_queue.enqueue(response_packet.json()) + if response_packet.is_end: + logger.debug(f"Received {response_packet=} from worker. Ending.") + break + except fastapi.WebSocketException: + # TODO: handle this better + logger.exception(f"Websocket closed during handling of {chat.id}") + if in_progress: + logger.warning(f"Aborting {chat.id=}") + chat_repository.set_chat_state(chat.id, interface.MessageRequestState.aborted_by_worker) + else: + logger.warning(f"Marking {chat.id=} as pending since no work was done.") + chat_repository.set_chat_state(chat.id, interface.MessageRequestState.pending) + await work_queue.enqueue(chat.id) + raise + + chat_repository.set_chat_state(chat.id, interface.MessageRequestState.complete) except fastapi.WebSocketException: logger.exception("Websocket closed") diff --git a/inference/server/oasst_inference_server/chat_repository.py b/inference/server/oasst_inference_server/chat_repository.py index 52cb9543..185b86f1 100644 --- a/inference/server/oasst_inference_server/chat_repository.py +++ b/inference/server/oasst_inference_server/chat_repository.py @@ -7,8 +7,16 @@ from sqlalchemy.sql.operators import is_not class ChatRepository: - def __init__(self, session: sqlmodel.Session) -> None: + def __init__(self, session: sqlmodel.Session, do_commit=True) -> None: self.session = session + self.do_commit = do_commit + + def as_no_commit(self) -> "ChatRepository": + return ChatRepository(self.session, do_commit=False) + + def maybe_commit(self) -> None: + if self.do_commit: + self.session.commit() def get_chats(self) -> list[models.DbChatEntry]: return self.session.exec(sqlmodel.select(models.DbChatEntry)).all() @@ -25,22 +33,25 @@ class ChatRepository: chats = self.get_chats() return [chat.to_list_entry() for chat in chats] - def get_chat_by_id(self, id: str) -> models.DbChatEntry: - chat = self.session.exec(sqlmodel.select(models.DbChatEntry).where(models.DbChatEntry.id == id)).one() + def get_chat_by_id(self, chat_id: str, for_update=False) -> models.DbChatEntry: + query = sqlmodel.select(models.DbChatEntry).where(models.DbChatEntry.id == chat_id) + if for_update: + query = query.with_for_update() + chat = self.session.exec(query).one() return chat - def get_chat_entry_by_id(self, id: str) -> interface.ChatEntry: - return self.get_chat_by_id(id).to_entry() + def get_chat_entry_by_id(self, chat_id: str) -> interface.ChatEntry: + return self.get_chat_by_id(chat_id).to_entry() def create_chat(self) -> models.DbChatEntry: chat = models.DbChatEntry() self.session.add(chat) - self.session.commit() + self.maybe_commit() return chat - def add_prompter_message(self, id: str, message_request: interface.MessageRequest) -> None: - logger.info(f"Adding prompter message {message_request} to chat {id}") - chat = self.get_chat_by_id(id) + def add_prompter_message(self, chat_id: str, message_request: interface.MessageRequest) -> None: + logger.info(f"Adding prompter message {message_request} to chat {chat_id}") + chat = self.get_chat_by_id(chat_id, for_update=True) if not chat.conversation.is_prompter_turn: raise fastapi.HTTPException(status_code=400, detail="Not your turn") if chat.pending_message_request is not None: @@ -55,12 +66,12 @@ class ChatRepository: chat.pending_message_request = message_request chat.message_request_state = interface.MessageRequestState.pending - self.session.commit() - logger.debug(f"Added prompter message {message_request} to chat {id}") + self.maybe_commit() + logger.debug(f"Added prompter message {message_request} to chat {chat_id}") - def add_assistant_message(self, id: str, text: str) -> None: - logger.info(f"Adding assistant message {text} to chat {id}") - chat = self.get_chat_by_id(id) + def add_assistant_message(self, chat_id: str, text: str) -> None: + logger.info(f"Adding assistant message {text} to chat {chat_id}") + chat = self.get_chat_by_id(chat_id, for_update=True) chat.conversation.messages.append( protocol.ConversationMessage( text=text, @@ -68,12 +79,12 @@ class ChatRepository: ) ) chat.pending_message_request = None - self.session.commit() - logger.debug(f"Added assistant message {text} to chat {id}") + self.maybe_commit() + logger.debug(f"Added assistant message {text} to chat {chat_id}") - def set_chat_state(self, id: str, state: interface.MessageRequestState) -> None: - logger.info(f"Setting chat {id} state to {state}") - chat = self.get_chat_by_id(id) + def set_chat_state(self, chat_id: str, state: interface.MessageRequestState) -> None: + logger.info(f"Setting chat {chat_id} state to {state}") + chat = self.get_chat_by_id(chat_id, for_update=True) chat.message_request_state = state - self.session.commit() - logger.debug(f"Set chat {id} state to {state}") + self.maybe_commit() + logger.debug(f"Set chat {chat_id} state to {state}") diff --git a/inference/server/oasst_inference_server/interface.py b/inference/server/oasst_inference_server/interface.py index 7fecffa1..5c5d8fb7 100644 --- a/inference/server/oasst_inference_server/interface.py +++ b/inference/server/oasst_inference_server/interface.py @@ -9,8 +9,9 @@ class MessageRequest(pydantic.BaseModel): model_name: str = "distilgpt2" max_new_tokens: int = 100 - def compatible_with(self, worker_config: inference.WorkerConfig) -> bool: - return self.model_name == worker_config.model_name + @property + def worker_compat_hash(self) -> str: + return f"{self.model_name}" class TokenResponseEvent(pydantic.BaseModel): diff --git a/inference/server/oasst_inference_server/queueing.py b/inference/server/oasst_inference_server/queueing.py new file mode 100644 index 00000000..5e323378 --- /dev/null +++ b/inference/server/oasst_inference_server/queueing.py @@ -0,0 +1,27 @@ +import redis.asyncio as redis +from oasst_inference_server.settings import settings + + +class RedisQueue: + def __init__(self, redis_client: redis.Redis, queue_id: str) -> None: + self.redis_client = redis_client + self.queue_id = queue_id + + async def enqueue(self, value: str) -> None: + return await self.redis_client.rpush(self.queue_id, value) + + async def dequeue(self, block: bool = True, timeout: int = 1) -> str: + if block: + return await self.redis_client.blpop(self.queue_id, timeout=timeout) + else: + return await self.redis_client.lpop(self.queue_id) + + +def chat_queue(redis_client: redis.Redis, chat_id: str) -> RedisQueue: + return RedisQueue(redis_client, f"chat:{chat_id}") + + +def work_queue(redis_client: redis.Redis, worker_compat_hash: str) -> RedisQueue: + if worker_compat_hash not in settings.allowed_worker_compat_hashes: + raise ValueError(f"Worker compat hash {worker_compat_hash} not allowed") + return RedisQueue(redis_client, f"work:{worker_compat_hash}") diff --git a/inference/server/oasst_inference_server/settings.py b/inference/server/oasst_inference_server/settings.py index e0a4d914..5f79fb55 100644 --- a/inference/server/oasst_inference_server/settings.py +++ b/inference/server/oasst_inference_server/settings.py @@ -8,6 +8,8 @@ class Settings(pydantic.BaseSettings): redis_port: int = 6379 redis_db: int = 0 + allowed_worker_compat_hashes: list[str] = ["distilgpt2"] + sse_retry_timeout: int = 15000 update_alembic: bool = True alembic_retries: int = 5 diff --git a/inference/worker/utils.py b/inference/worker/utils.py index 414b6958..fe08fb96 100644 --- a/inference/worker/utils.py +++ b/inference/worker/utils.py @@ -55,7 +55,7 @@ def wait_for_inference_server(inference_server_url: str, timeout: int = 600): if time.time() > time_limit: raise sleep_duration = random.uniform(0, 10) - logger.warning(f"Inference server not ready. Retrying in {sleep_duration} seconds") + logger.warning(f"Inference server not ready. Retrying in {sleep_duration:.2f} seconds") time.sleep(sleep_duration) else: logger.info("Inference server is ready") diff --git a/oasst-shared/oasst_shared/schemas/inference.py b/oasst-shared/oasst_shared/schemas/inference.py index 1bb89a42..f8a94fc1 100644 --- a/oasst-shared/oasst_shared/schemas/inference.py +++ b/oasst-shared/oasst_shared/schemas/inference.py @@ -9,6 +9,10 @@ from . import protocol class WorkerConfig(pydantic.BaseModel): model_name: str = "distilgpt2" + @property + def compat_hash(self) -> str: + return f"{self.model_name}" + class WorkRequest(pydantic.BaseModel): conversation: protocol.Conversation = pydantic.Field(..., repr=False)