Initial implementation of the inference system (#869)

* very primitive implementation of inference

* re-worked with security in mind

* removed polling from clients

* switched workers to websockets

* implemented back and forth chats
This commit is contained in:
Yannic Kilcher
2023-01-21 22:38:18 +01:00
committed by GitHub
parent cec49614c2
commit 1709dc0324
10 changed files with 405 additions and 0 deletions
+35
View File
@@ -0,0 +1,35 @@
# OpenAssitant Inference
Preliminary implementation of the inference engine for OpenAssistant.
## Development (you'll need multiple terminals)
Run a redis container (or use the one of the general docker compose file):
```bash
docker run --rm -it -p 6379:6379 redis
```
Run the inference server:
```bash
cd server
pip install -r requirements.txt
uvicorn main:app --reload
```
Run one (or more) workers:
```bash
cd worker
pip install -r requirements.txt
python __main__.py
```
Run the client:
```bash
cd text-client
pip install -r requirements.txt
python __main__.py
```
+10
View File
@@ -0,0 +1,10 @@
# OpenAssistant Inference Server
Workers communicate with the `/work` endpoint via Websocket. They provide their
configuration and if a task is available, the server returns it. The worker then
performs the task and returns the result in a streaming fashion to the server,
also via websocket.
Clients first call `/chat` to make a new chat, then add to that via
`/chat/<id>/message`. The response is a SSE event source, which will send tokens
as they are available.
+193
View File
@@ -0,0 +1,193 @@
import asyncio
import enum
import uuid
import fastapi
import pydantic
import redis.asyncio as redis
from fastapi.middleware.cors import CORSMiddleware
from loguru import logger
from oasst_shared.schemas import inference, protocol
from sse_starlette.sse import EventSourceResponse
app = fastapi.FastAPI()
# Allow CORS
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
class Settings(pydantic.BaseSettings):
redis_host: str = "localhost"
redis_port: int = 6379
redis_db: int = 0
sse_retry_timeout: int = 15000
settings = Settings()
# create async redis client
redisClient = redis.Redis(
host=settings.redis_host, port=settings.redis_port, db=settings.redis_db, decode_responses=True
)
class CreateChatRequest(pydantic.BaseModel):
pass
class CreateChatResponse(pydantic.BaseModel):
id: str
class MessageRequest(pydantic.BaseModel):
message: str = pydantic.Field(..., repr=False)
model_name: str = "distilgpt2"
max_new_tokens: int = 100
def compatible_with(self, worker_config: inference.WorkerConfig) -> bool:
return self.model_name == worker_config.model_name
class TokenResponseEvent(pydantic.BaseModel):
token: str
class MessageRequestState(str, enum.Enum):
pending = "pending"
in_progress = "in_progress"
complete = "complete"
class DbChatEntry(pydantic.BaseModel):
id: str = pydantic.Field(default_factory=lambda: str(uuid.uuid4()))
conversation: protocol.Conversation = pydantic.Field(default_factory=protocol.Conversation)
pending_message_request: MessageRequest | None = None
message_request_state: MessageRequestState | None = None
# TODO: make real database
CHATS: dict[str, DbChatEntry] = {}
@app.post("/chat")
async def create_chat(request: CreateChatRequest) -> CreateChatResponse:
"""Allows a client to create a new chat."""
logger.info(f"Received {request}")
chat = DbChatEntry()
CHATS[chat.id] = chat
return CreateChatResponse(id=chat.id)
@app.get("/chat/{id}")
async def get_chat(id: str) -> protocol.Conversation:
"""Allows a client to get the current state of a chat."""
return CHATS[id].conversation
@app.post("/chat/{id}/message")
async def create_message(id: str, message_request: MessageRequest, fastapi_request: fastapi.Request):
"""Allows the client to stream the results of a request."""
chat = CHATS[id]
if not chat.conversation.is_prompter_turn:
raise fastapi.HTTPException(status_code=400, detail="Not your turn")
if chat.pending_message_request is not None:
raise fastapi.HTTPException(status_code=400, detail="Already pending")
chat.conversation.messages.append(
protocol.ConversationMessage(
text=message_request.message,
is_assistant=False,
)
)
chat.pending_message_request = message_request
chat.message_request_state = MessageRequestState.pending
async def event_generator():
result_data = []
try:
while True:
if await fastapi_request.is_disconnected():
logger.warning("Client disconnected")
break
item = await redisClient.blpop(chat.id, 1)
if item is None:
continue
_, response_packet_str = item
response_packet = inference.WorkResponsePacket.parse_raw(response_packet_str)
result_data.append(response_packet)
if response_packet.is_end:
break
yield {
"retry": settings.sse_retry_timeout,
"data": TokenResponseEvent(token=response_packet.token).json(),
}
logger.info(f"Finished streaming {chat.id} {len(result_data)=}")
except Exception:
logger.exception(f"Error streaming {chat.id}")
chat.conversation.messages.append(
protocol.ConversationMessage(
text="".join([d.token for d in result_data[:-1]]),
is_assistant=True,
)
)
chat.pending_message_request = None
return EventSourceResponse(event_generator())
@app.websocket("/work")
async def work(websocket: fastapi.WebSocket):
await websocket.accept()
worker_config = inference.WorkerConfig.parse_raw(await websocket.receive_text())
while True:
# find a pending task that matches the worker's config
# could also be implemented using task queues
# but general compatibility matching is tricky
for chat in CHATS.values():
if (request := chat.pending_message_request) is not None:
if chat.message_request_state == MessageRequestState.pending:
if request.compatible_with(worker_config):
break
else:
logger.debug("No pending tasks")
await asyncio.sleep(1)
continue
chat.message_request_state = MessageRequestState.in_progress
work_request = inference.WorkRequest(
conversation=chat.conversation,
model_name=request.model_name,
max_new_tokens=request.max_new_tokens,
)
logger.info(f"Created {work_request}")
try:
await websocket.send_text(work_request.json())
while True:
# maybe unnecessary to parse and re-serialize
# could just pass the raw string and mark end via empty string
response_packet = inference.WorkResponsePacket.parse_raw(await websocket.receive_text())
await redisClient.rpush(chat.id, response_packet.json())
if response_packet.is_end:
break
except fastapi.WebSocketException:
# TODO: handle this better
logger.exception(f"Websocket closed during handling of {chat.id}")
chat.message_request_state = MessageRequestState.complete
+6
View File
@@ -0,0 +1,6 @@
fastapi[all]
loguru
pydantic
redis
sse-starlette
websockets
+40
View File
@@ -0,0 +1,40 @@
"""Simple REPL frontend."""
import json
import requests
import sseclient
import typer
app = typer.Typer()
@app.command()
def main(backend_url: str = "http://127.0.0.1:8000"):
"""Simple REPL client."""
chat_id = requests.post(f"{backend_url}/chat", json={}).json()["id"]
while True:
message = typer.prompt("User").strip()
# wait for stream to be ready
# could implement a queue position indicator
# could be implemented with long polling
# but server load needs to be considered
response = requests.post(
f"{backend_url}/chat/{chat_id}/message",
json={"message": message},
stream=True,
headers={"Accept": "text/event-stream"},
)
response.raise_for_status()
client = sseclient.SSEClient(response)
print("Assistant: ", end="", flush=True)
for event in client.events():
data = json.loads(event.data)
print(data["token"], end="", flush=True)
print()
if __name__ == "__main__":
app()
+3
View File
@@ -0,0 +1,3 @@
requests
sseclient-py
typer
+79
View File
@@ -0,0 +1,79 @@
import re
import time
import rel
import torch
import typer
import websocket
from loguru import logger
from oasst_shared.schemas import inference, protocol
from transformers import pipeline
app = typer.Typer()
@app.command()
def main(
backend_url: str = "ws://localhost:8000",
model_name: str = "distilgpt2",
):
pipe = pipeline("text-generation", model=model_name)
def on_open(ws: websocket.WebSocket):
worker_config = inference.WorkerConfig(model_name=model_name)
ws.send(worker_config.json())
def on_message(ws: websocket.WebSocket, message: str):
# TODO: what if this comes in, but one is already in progress?
# also need to think of enabling batching
work_request = inference.WorkRequest.parse_raw(message)
def _prepare_message(message: protocol.ConversationMessage) -> str:
prefix = "Assistant: " if message.is_assistant else "User: "
return prefix + message.text
# construct prompt
messages = [_prepare_message(message) for message in work_request.conversation.messages]
prompt = "\n".join(messages) + "\nAssistant:"
# TODO: replace this with incremental generation
torch.manual_seed(work_request.seed)
model_output = pipe(prompt, max_new_tokens=work_request.max_new_tokens, do_sample=True, return_full_text=False)[
0
]["generated_text"]
model_output = model_output.strip()
# fake streaming
split_idcs = [m.start() for m in re.finditer(r"([\w:]+)", model_output)]
pieces = [model_output[a:b] for a, b in zip([0] + split_idcs, split_idcs + [None])]
for piece in pieces:
if not piece:
continue
if piece.strip() in ("User:", "Assistant:"):
break
ws.send(inference.WorkResponsePacket(token=piece).json())
time.sleep(0.1)
ws.send(inference.WorkResponsePacket(is_end=True).json())
def on_error(ws: websocket.WebSocket, error: Exception):
logger.error(f"Connection error: {error}")
def on_close(ws: websocket.WebSocket, close_status_code: int, close_msg: str):
logger.warning(f"Connection closed: {close_status_code=} {close_msg=}")
ws = websocket.WebSocketApp(
f"{backend_url}/work",
on_message=on_message,
on_error=on_error,
on_close=on_close,
on_open=on_open,
)
ws.run_forever(dispatcher=rel, reconnect=5)
rel.signal(2, rel.abort)
rel.dispatch()
if __name__ == "__main__":
app()
+6
View File
@@ -0,0 +1,6 @@
loguru
rel
torch
transformers
typer
websocket-client
@@ -0,0 +1,21 @@
import random
import pydantic
from . import protocol
class WorkerConfig(pydantic.BaseModel):
model_name: str = "distilgpt2"
class WorkRequest(pydantic.BaseModel):
conversation: protocol.Conversation = pydantic.Field(..., repr=False)
model_name: str = "distilgpt2"
max_new_tokens: int = 100
seed: int = pydantic.Field(default_factory=lambda: random.randint(0, 2**32 - 1))
class WorkResponsePacket(pydantic.BaseModel):
token: str | None = None
is_end: bool = False
@@ -64,6 +64,18 @@ class Conversation(BaseModel):
messages: list[ConversationMessage] = []
def __len__(self):
return len(self.messages)
@property
def is_prompter_turn(self) -> bool:
if len(self) == 0:
return True
last_message = self.messages[-1]
if last_message.is_assistant:
return True
return False
class Message(ConversationMessage):
parent_id: Optional[UUID] = None