From 1709dc03247fb72202b76703d5b8e18c837d7488 Mon Sep 17 00:00:00 2001 From: Yannic Kilcher Date: Sat, 21 Jan 2023 22:38:18 +0100 Subject: [PATCH] Initial implementation of the inference system (#869) * very primitive implementation of inference * re-worked with security in mind * removed polling from clients * switched workers to websockets * implemented back and forth chats --- inference/README.md | 35 ++++ inference/server/README.md | 10 + inference/server/main.py | 193 ++++++++++++++++++ inference/server/requirements.txt | 6 + inference/text-client/__main__.py | 40 ++++ inference/text-client/requirements.txt | 3 + inference/worker/__main__.py | 79 +++++++ inference/worker/requirements.txt | 6 + .../oasst_shared/schemas/inference.py | 21 ++ oasst-shared/oasst_shared/schemas/protocol.py | 12 ++ 10 files changed, 405 insertions(+) create mode 100644 inference/README.md create mode 100644 inference/server/README.md create mode 100644 inference/server/main.py create mode 100644 inference/server/requirements.txt create mode 100644 inference/text-client/__main__.py create mode 100644 inference/text-client/requirements.txt create mode 100644 inference/worker/__main__.py create mode 100644 inference/worker/requirements.txt create mode 100644 oasst-shared/oasst_shared/schemas/inference.py diff --git a/inference/README.md b/inference/README.md new file mode 100644 index 00000000..3dee94f9 --- /dev/null +++ b/inference/README.md @@ -0,0 +1,35 @@ +# OpenAssitant Inference + +Preliminary implementation of the inference engine for OpenAssistant. + +## Development (you'll need multiple terminals) + +Run a redis container (or use the one of the general docker compose file): + +```bash +docker run --rm -it -p 6379:6379 redis +``` + +Run the inference server: + +```bash +cd server +pip install -r requirements.txt +uvicorn main:app --reload +``` + +Run one (or more) workers: + +```bash +cd worker +pip install -r requirements.txt +python __main__.py +``` + +Run the client: + +```bash +cd text-client +pip install -r requirements.txt +python __main__.py +``` diff --git a/inference/server/README.md b/inference/server/README.md new file mode 100644 index 00000000..a235a7e6 --- /dev/null +++ b/inference/server/README.md @@ -0,0 +1,10 @@ +# OpenAssistant Inference Server + +Workers communicate with the `/work` endpoint via Websocket. They provide their +configuration and if a task is available, the server returns it. The worker then +performs the task and returns the result in a streaming fashion to the server, +also via websocket. + +Clients first call `/chat` to make a new chat, then add to that via +`/chat//message`. The response is a SSE event source, which will send tokens +as they are available. diff --git a/inference/server/main.py b/inference/server/main.py new file mode 100644 index 00000000..f3ec02b1 --- /dev/null +++ b/inference/server/main.py @@ -0,0 +1,193 @@ +import asyncio +import enum +import uuid + +import fastapi +import pydantic +import redis.asyncio as redis +from fastapi.middleware.cors import CORSMiddleware +from loguru import logger +from oasst_shared.schemas import inference, protocol +from sse_starlette.sse import EventSourceResponse + +app = fastapi.FastAPI() + +# Allow CORS +app.add_middleware( + CORSMiddleware, + allow_origins=["*"], + allow_credentials=True, + allow_methods=["*"], + allow_headers=["*"], +) + + +class Settings(pydantic.BaseSettings): + redis_host: str = "localhost" + redis_port: int = 6379 + redis_db: int = 0 + + sse_retry_timeout: int = 15000 + + +settings = Settings() + +# create async redis client +redisClient = redis.Redis( + host=settings.redis_host, port=settings.redis_port, db=settings.redis_db, decode_responses=True +) + + +class CreateChatRequest(pydantic.BaseModel): + pass + + +class CreateChatResponse(pydantic.BaseModel): + id: str + + +class MessageRequest(pydantic.BaseModel): + message: str = pydantic.Field(..., repr=False) + model_name: str = "distilgpt2" + max_new_tokens: int = 100 + + def compatible_with(self, worker_config: inference.WorkerConfig) -> bool: + return self.model_name == worker_config.model_name + + +class TokenResponseEvent(pydantic.BaseModel): + token: str + + +class MessageRequestState(str, enum.Enum): + pending = "pending" + in_progress = "in_progress" + complete = "complete" + + +class DbChatEntry(pydantic.BaseModel): + id: str = pydantic.Field(default_factory=lambda: str(uuid.uuid4())) + conversation: protocol.Conversation = pydantic.Field(default_factory=protocol.Conversation) + pending_message_request: MessageRequest | None = None + message_request_state: MessageRequestState | None = None + + +# TODO: make real database +CHATS: dict[str, DbChatEntry] = {} + + +@app.post("/chat") +async def create_chat(request: CreateChatRequest) -> CreateChatResponse: + """Allows a client to create a new chat.""" + logger.info(f"Received {request}") + chat = DbChatEntry() + CHATS[chat.id] = chat + return CreateChatResponse(id=chat.id) + + +@app.get("/chat/{id}") +async def get_chat(id: str) -> protocol.Conversation: + """Allows a client to get the current state of a chat.""" + return CHATS[id].conversation + + +@app.post("/chat/{id}/message") +async def create_message(id: str, message_request: MessageRequest, fastapi_request: fastapi.Request): + """Allows the client to stream the results of a request.""" + + chat = CHATS[id] + if not chat.conversation.is_prompter_turn: + raise fastapi.HTTPException(status_code=400, detail="Not your turn") + if chat.pending_message_request is not None: + raise fastapi.HTTPException(status_code=400, detail="Already pending") + + chat.conversation.messages.append( + protocol.ConversationMessage( + text=message_request.message, + is_assistant=False, + ) + ) + + chat.pending_message_request = message_request + chat.message_request_state = MessageRequestState.pending + + async def event_generator(): + result_data = [] + + try: + while True: + if await fastapi_request.is_disconnected(): + logger.warning("Client disconnected") + break + + item = await redisClient.blpop(chat.id, 1) + if item is None: + continue + + _, response_packet_str = item + response_packet = inference.WorkResponsePacket.parse_raw(response_packet_str) + result_data.append(response_packet) + + if response_packet.is_end: + break + + yield { + "retry": settings.sse_retry_timeout, + "data": TokenResponseEvent(token=response_packet.token).json(), + } + logger.info(f"Finished streaming {chat.id} {len(result_data)=}") + except Exception: + logger.exception(f"Error streaming {chat.id}") + + chat.conversation.messages.append( + protocol.ConversationMessage( + text="".join([d.token for d in result_data[:-1]]), + is_assistant=True, + ) + ) + chat.pending_message_request = None + + return EventSourceResponse(event_generator()) + + +@app.websocket("/work") +async def work(websocket: fastapi.WebSocket): + await websocket.accept() + worker_config = inference.WorkerConfig.parse_raw(await websocket.receive_text()) + while True: + # find a pending task that matches the worker's config + # could also be implemented using task queues + # but general compatibility matching is tricky + for chat in CHATS.values(): + if (request := chat.pending_message_request) is not None: + if chat.message_request_state == MessageRequestState.pending: + if request.compatible_with(worker_config): + break + else: + logger.debug("No pending tasks") + await asyncio.sleep(1) + continue + + chat.message_request_state = MessageRequestState.in_progress + + work_request = inference.WorkRequest( + conversation=chat.conversation, + model_name=request.model_name, + max_new_tokens=request.max_new_tokens, + ) + + logger.info(f"Created {work_request}") + try: + await websocket.send_text(work_request.json()) + while True: + # maybe unnecessary to parse and re-serialize + # could just pass the raw string and mark end via empty string + response_packet = inference.WorkResponsePacket.parse_raw(await websocket.receive_text()) + await redisClient.rpush(chat.id, response_packet.json()) + if response_packet.is_end: + break + except fastapi.WebSocketException: + # TODO: handle this better + logger.exception(f"Websocket closed during handling of {chat.id}") + + chat.message_request_state = MessageRequestState.complete diff --git a/inference/server/requirements.txt b/inference/server/requirements.txt new file mode 100644 index 00000000..e0a00339 --- /dev/null +++ b/inference/server/requirements.txt @@ -0,0 +1,6 @@ +fastapi[all] +loguru +pydantic +redis +sse-starlette +websockets diff --git a/inference/text-client/__main__.py b/inference/text-client/__main__.py new file mode 100644 index 00000000..bf1f8b02 --- /dev/null +++ b/inference/text-client/__main__.py @@ -0,0 +1,40 @@ +"""Simple REPL frontend.""" + +import json + +import requests +import sseclient +import typer + +app = typer.Typer() + + +@app.command() +def main(backend_url: str = "http://127.0.0.1:8000"): + """Simple REPL client.""" + chat_id = requests.post(f"{backend_url}/chat", json={}).json()["id"] + while True: + message = typer.prompt("User").strip() + + # wait for stream to be ready + # could implement a queue position indicator + # could be implemented with long polling + # but server load needs to be considered + response = requests.post( + f"{backend_url}/chat/{chat_id}/message", + json={"message": message}, + stream=True, + headers={"Accept": "text/event-stream"}, + ) + response.raise_for_status() + + client = sseclient.SSEClient(response) + print("Assistant: ", end="", flush=True) + for event in client.events(): + data = json.loads(event.data) + print(data["token"], end="", flush=True) + print() + + +if __name__ == "__main__": + app() diff --git a/inference/text-client/requirements.txt b/inference/text-client/requirements.txt new file mode 100644 index 00000000..8d7bff7d --- /dev/null +++ b/inference/text-client/requirements.txt @@ -0,0 +1,3 @@ +requests +sseclient-py +typer diff --git a/inference/worker/__main__.py b/inference/worker/__main__.py new file mode 100644 index 00000000..ad5e5cef --- /dev/null +++ b/inference/worker/__main__.py @@ -0,0 +1,79 @@ +import re +import time + +import rel +import torch +import typer +import websocket +from loguru import logger +from oasst_shared.schemas import inference, protocol +from transformers import pipeline + +app = typer.Typer() + + +@app.command() +def main( + backend_url: str = "ws://localhost:8000", + model_name: str = "distilgpt2", +): + pipe = pipeline("text-generation", model=model_name) + + def on_open(ws: websocket.WebSocket): + worker_config = inference.WorkerConfig(model_name=model_name) + ws.send(worker_config.json()) + + def on_message(ws: websocket.WebSocket, message: str): + # TODO: what if this comes in, but one is already in progress? + # also need to think of enabling batching + work_request = inference.WorkRequest.parse_raw(message) + + def _prepare_message(message: protocol.ConversationMessage) -> str: + prefix = "Assistant: " if message.is_assistant else "User: " + return prefix + message.text + + # construct prompt + messages = [_prepare_message(message) for message in work_request.conversation.messages] + + prompt = "\n".join(messages) + "\nAssistant:" + + # TODO: replace this with incremental generation + torch.manual_seed(work_request.seed) + model_output = pipe(prompt, max_new_tokens=work_request.max_new_tokens, do_sample=True, return_full_text=False)[ + 0 + ]["generated_text"] + model_output = model_output.strip() + + # fake streaming + split_idcs = [m.start() for m in re.finditer(r"([\w:]+)", model_output)] + pieces = [model_output[a:b] for a, b in zip([0] + split_idcs, split_idcs + [None])] + for piece in pieces: + if not piece: + continue + if piece.strip() in ("User:", "Assistant:"): + break + ws.send(inference.WorkResponsePacket(token=piece).json()) + time.sleep(0.1) + ws.send(inference.WorkResponsePacket(is_end=True).json()) + + def on_error(ws: websocket.WebSocket, error: Exception): + logger.error(f"Connection error: {error}") + + def on_close(ws: websocket.WebSocket, close_status_code: int, close_msg: str): + logger.warning(f"Connection closed: {close_status_code=} {close_msg=}") + + ws = websocket.WebSocketApp( + f"{backend_url}/work", + on_message=on_message, + on_error=on_error, + on_close=on_close, + on_open=on_open, + ) + + ws.run_forever(dispatcher=rel, reconnect=5) + rel.signal(2, rel.abort) + rel.dispatch() + + +if __name__ == "__main__": + app() diff --git a/inference/worker/requirements.txt b/inference/worker/requirements.txt new file mode 100644 index 00000000..c248c652 --- /dev/null +++ b/inference/worker/requirements.txt @@ -0,0 +1,6 @@ +loguru +rel +torch +transformers +typer +websocket-client diff --git a/oasst-shared/oasst_shared/schemas/inference.py b/oasst-shared/oasst_shared/schemas/inference.py new file mode 100644 index 00000000..0acb5014 --- /dev/null +++ b/oasst-shared/oasst_shared/schemas/inference.py @@ -0,0 +1,21 @@ +import random + +import pydantic + +from . import protocol + + +class WorkerConfig(pydantic.BaseModel): + model_name: str = "distilgpt2" + + +class WorkRequest(pydantic.BaseModel): + conversation: protocol.Conversation = pydantic.Field(..., repr=False) + model_name: str = "distilgpt2" + max_new_tokens: int = 100 + seed: int = pydantic.Field(default_factory=lambda: random.randint(0, 2**32 - 1)) + + +class WorkResponsePacket(pydantic.BaseModel): + token: str | None = None + is_end: bool = False diff --git a/oasst-shared/oasst_shared/schemas/protocol.py b/oasst-shared/oasst_shared/schemas/protocol.py index f2164e8f..20bbdf9b 100644 --- a/oasst-shared/oasst_shared/schemas/protocol.py +++ b/oasst-shared/oasst_shared/schemas/protocol.py @@ -64,6 +64,18 @@ class Conversation(BaseModel): messages: list[ConversationMessage] = [] + def __len__(self): + return len(self.messages) + + @property + def is_prompter_turn(self) -> bool: + if len(self) == 0: + return True + last_message = self.messages[-1] + if last_message.is_assistant: + return True + return False + class Message(ConversationMessage): parent_id: Optional[UUID] = None