From 040344a41f8beef9c0a3ad4e82474326186b121f Mon Sep 17 00:00:00 2001 From: Yannic Kilcher Date: Thu, 26 Jan 2023 21:01:52 +0100 Subject: [PATCH] made inference server a bit more robust --- inference/server/main.py | 88 ++++++++++++++++++++++-------------- inference/worker/__main__.py | 8 +++- 2 files changed, 60 insertions(+), 36 deletions(-) diff --git a/inference/server/main.py b/inference/server/main.py index f3ec02b1..4cb5f659 100644 --- a/inference/server/main.py +++ b/inference/server/main.py @@ -5,6 +5,7 @@ import uuid import fastapi import pydantic import redis.asyncio as redis +import websockets.exceptions from fastapi.middleware.cors import CORSMiddleware from loguru import logger from oasst_shared.schemas import inference, protocol @@ -63,6 +64,7 @@ class MessageRequestState(str, enum.Enum): pending = "pending" in_progress = "in_progress" complete = "complete" + aborted_by_worker = "aborted_by_worker" class DbChatEntry(pydantic.BaseModel): @@ -154,40 +156,56 @@ async def create_message(id: str, message_request: MessageRequest, fastapi_reque async def work(websocket: fastapi.WebSocket): await websocket.accept() worker_config = inference.WorkerConfig.parse_raw(await websocket.receive_text()) - while True: - # find a pending task that matches the worker's config - # could also be implemented using task queues - # but general compatibility matching is tricky - for chat in CHATS.values(): - if (request := chat.pending_message_request) is not None: - if chat.message_request_state == MessageRequestState.pending: - if request.compatible_with(worker_config): + try: + while True: + print(websocket.client_state) + if websocket.client_state == fastapi.websockets.WebSocketState.DISCONNECTED: + logger.warning("Worker disconnected") + break + # find a pending task that matches the worker's config + # could also be implemented using task queues + # but general compatibility matching is tricky + for chat in CHATS.values(): + if (request := chat.pending_message_request) is not None: + if chat.message_request_state == MessageRequestState.pending: + if request.compatible_with(worker_config): + break + else: + logger.debug("No pending tasks") + await asyncio.sleep(1) + continue + + chat.message_request_state = MessageRequestState.in_progress + + work_request = inference.WorkRequest( + conversation=chat.conversation, + model_name=request.model_name, + max_new_tokens=request.max_new_tokens, + ) + + logger.info(f"Created {work_request}") + try: + await websocket.send_text(work_request.json()) + except websockets.exceptions.ConnectionClosedError: + logger.warning("Worker disconnected") + websocket.close() + chat.message_request_state = MessageRequestState.pending + break + + try: + while True: + # maybe unnecessary to parse and re-serialize + # could just pass the raw string and mark end via empty string + response_packet = inference.WorkResponsePacket.parse_raw(await websocket.receive_text()) + await redisClient.rpush(chat.id, response_packet.json()) + if response_packet.is_end: break - else: - logger.debug("No pending tasks") - await asyncio.sleep(1) - continue + except fastapi.WebSocketException: + # TODO: handle this better + logger.exception(f"Websocket closed during handling of {chat.id}") + chat.message_request_state = MessageRequestState.aborted_by_worker + raise - chat.message_request_state = MessageRequestState.in_progress - - work_request = inference.WorkRequest( - conversation=chat.conversation, - model_name=request.model_name, - max_new_tokens=request.max_new_tokens, - ) - - logger.info(f"Created {work_request}") - try: - await websocket.send_text(work_request.json()) - while True: - # maybe unnecessary to parse and re-serialize - # could just pass the raw string and mark end via empty string - response_packet = inference.WorkResponsePacket.parse_raw(await websocket.receive_text()) - await redisClient.rpush(chat.id, response_packet.json()) - if response_packet.is_end: - break - except fastapi.WebSocketException: - # TODO: handle this better - logger.exception(f"Websocket closed during handling of {chat.id}") - - chat.message_request_state = MessageRequestState.complete + chat.message_request_state = MessageRequestState.complete + except fastapi.WebSocketException: + logger.exception("Websocket closed") diff --git a/inference/worker/__main__.py b/inference/worker/__main__.py index c8c1a4c9..e5c15fb4 100644 --- a/inference/worker/__main__.py +++ b/inference/worker/__main__.py @@ -33,7 +33,13 @@ def main( # construct prompt messages = [_prepare_message(message) for message in work_request.conversation.messages] - prompt = "\n".join(messages) + "\nAssistant:" + prefix = ( + "The following is a conversation between a user and an assistant. " + "The assistant is helpful, creative, clever, and very friendly.\n" + "Assistant: Hello! How can I help you today?\n" + ) + + prompt = prefix + "\n".join(messages) + "\nAssistant:" # TODO: use the seed # torch.manual_seed(work_request.seed)