mirror of
https://github.com/wassname/Open-Assistant.git
synced 2026-06-27 16:10:30 +08:00
Initial implementation of the inference system (#869)
* very primitive implementation of inference * re-worked with security in mind * removed polling from clients * switched workers to websockets * implemented back and forth chats
This commit is contained in:
@@ -0,0 +1,35 @@
|
||||
# OpenAssitant Inference
|
||||
|
||||
Preliminary implementation of the inference engine for OpenAssistant.
|
||||
|
||||
## Development (you'll need multiple terminals)
|
||||
|
||||
Run a redis container (or use the one of the general docker compose file):
|
||||
|
||||
```bash
|
||||
docker run --rm -it -p 6379:6379 redis
|
||||
```
|
||||
|
||||
Run the inference server:
|
||||
|
||||
```bash
|
||||
cd server
|
||||
pip install -r requirements.txt
|
||||
uvicorn main:app --reload
|
||||
```
|
||||
|
||||
Run one (or more) workers:
|
||||
|
||||
```bash
|
||||
cd worker
|
||||
pip install -r requirements.txt
|
||||
python __main__.py
|
||||
```
|
||||
|
||||
Run the client:
|
||||
|
||||
```bash
|
||||
cd text-client
|
||||
pip install -r requirements.txt
|
||||
python __main__.py
|
||||
```
|
||||
@@ -0,0 +1,10 @@
|
||||
# OpenAssistant Inference Server
|
||||
|
||||
Workers communicate with the `/work` endpoint via Websocket. They provide their
|
||||
configuration and if a task is available, the server returns it. The worker then
|
||||
performs the task and returns the result in a streaming fashion to the server,
|
||||
also via websocket.
|
||||
|
||||
Clients first call `/chat` to make a new chat, then add to that via
|
||||
`/chat/<id>/message`. The response is a SSE event source, which will send tokens
|
||||
as they are available.
|
||||
@@ -0,0 +1,193 @@
|
||||
import asyncio
|
||||
import enum
|
||||
import uuid
|
||||
|
||||
import fastapi
|
||||
import pydantic
|
||||
import redis.asyncio as redis
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from loguru import logger
|
||||
from oasst_shared.schemas import inference, protocol
|
||||
from sse_starlette.sse import EventSourceResponse
|
||||
|
||||
app = fastapi.FastAPI()
|
||||
|
||||
# Allow CORS
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=["*"],
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
|
||||
class Settings(pydantic.BaseSettings):
|
||||
redis_host: str = "localhost"
|
||||
redis_port: int = 6379
|
||||
redis_db: int = 0
|
||||
|
||||
sse_retry_timeout: int = 15000
|
||||
|
||||
|
||||
settings = Settings()
|
||||
|
||||
# create async redis client
|
||||
redisClient = redis.Redis(
|
||||
host=settings.redis_host, port=settings.redis_port, db=settings.redis_db, decode_responses=True
|
||||
)
|
||||
|
||||
|
||||
class CreateChatRequest(pydantic.BaseModel):
|
||||
pass
|
||||
|
||||
|
||||
class CreateChatResponse(pydantic.BaseModel):
|
||||
id: str
|
||||
|
||||
|
||||
class MessageRequest(pydantic.BaseModel):
|
||||
message: str = pydantic.Field(..., repr=False)
|
||||
model_name: str = "distilgpt2"
|
||||
max_new_tokens: int = 100
|
||||
|
||||
def compatible_with(self, worker_config: inference.WorkerConfig) -> bool:
|
||||
return self.model_name == worker_config.model_name
|
||||
|
||||
|
||||
class TokenResponseEvent(pydantic.BaseModel):
|
||||
token: str
|
||||
|
||||
|
||||
class MessageRequestState(str, enum.Enum):
|
||||
pending = "pending"
|
||||
in_progress = "in_progress"
|
||||
complete = "complete"
|
||||
|
||||
|
||||
class DbChatEntry(pydantic.BaseModel):
|
||||
id: str = pydantic.Field(default_factory=lambda: str(uuid.uuid4()))
|
||||
conversation: protocol.Conversation = pydantic.Field(default_factory=protocol.Conversation)
|
||||
pending_message_request: MessageRequest | None = None
|
||||
message_request_state: MessageRequestState | None = None
|
||||
|
||||
|
||||
# TODO: make real database
|
||||
CHATS: dict[str, DbChatEntry] = {}
|
||||
|
||||
|
||||
@app.post("/chat")
|
||||
async def create_chat(request: CreateChatRequest) -> CreateChatResponse:
|
||||
"""Allows a client to create a new chat."""
|
||||
logger.info(f"Received {request}")
|
||||
chat = DbChatEntry()
|
||||
CHATS[chat.id] = chat
|
||||
return CreateChatResponse(id=chat.id)
|
||||
|
||||
|
||||
@app.get("/chat/{id}")
|
||||
async def get_chat(id: str) -> protocol.Conversation:
|
||||
"""Allows a client to get the current state of a chat."""
|
||||
return CHATS[id].conversation
|
||||
|
||||
|
||||
@app.post("/chat/{id}/message")
|
||||
async def create_message(id: str, message_request: MessageRequest, fastapi_request: fastapi.Request):
|
||||
"""Allows the client to stream the results of a request."""
|
||||
|
||||
chat = CHATS[id]
|
||||
if not chat.conversation.is_prompter_turn:
|
||||
raise fastapi.HTTPException(status_code=400, detail="Not your turn")
|
||||
if chat.pending_message_request is not None:
|
||||
raise fastapi.HTTPException(status_code=400, detail="Already pending")
|
||||
|
||||
chat.conversation.messages.append(
|
||||
protocol.ConversationMessage(
|
||||
text=message_request.message,
|
||||
is_assistant=False,
|
||||
)
|
||||
)
|
||||
|
||||
chat.pending_message_request = message_request
|
||||
chat.message_request_state = MessageRequestState.pending
|
||||
|
||||
async def event_generator():
|
||||
result_data = []
|
||||
|
||||
try:
|
||||
while True:
|
||||
if await fastapi_request.is_disconnected():
|
||||
logger.warning("Client disconnected")
|
||||
break
|
||||
|
||||
item = await redisClient.blpop(chat.id, 1)
|
||||
if item is None:
|
||||
continue
|
||||
|
||||
_, response_packet_str = item
|
||||
response_packet = inference.WorkResponsePacket.parse_raw(response_packet_str)
|
||||
result_data.append(response_packet)
|
||||
|
||||
if response_packet.is_end:
|
||||
break
|
||||
|
||||
yield {
|
||||
"retry": settings.sse_retry_timeout,
|
||||
"data": TokenResponseEvent(token=response_packet.token).json(),
|
||||
}
|
||||
logger.info(f"Finished streaming {chat.id} {len(result_data)=}")
|
||||
except Exception:
|
||||
logger.exception(f"Error streaming {chat.id}")
|
||||
|
||||
chat.conversation.messages.append(
|
||||
protocol.ConversationMessage(
|
||||
text="".join([d.token for d in result_data[:-1]]),
|
||||
is_assistant=True,
|
||||
)
|
||||
)
|
||||
chat.pending_message_request = None
|
||||
|
||||
return EventSourceResponse(event_generator())
|
||||
|
||||
|
||||
@app.websocket("/work")
|
||||
async def work(websocket: fastapi.WebSocket):
|
||||
await websocket.accept()
|
||||
worker_config = inference.WorkerConfig.parse_raw(await websocket.receive_text())
|
||||
while True:
|
||||
# find a pending task that matches the worker's config
|
||||
# could also be implemented using task queues
|
||||
# but general compatibility matching is tricky
|
||||
for chat in CHATS.values():
|
||||
if (request := chat.pending_message_request) is not None:
|
||||
if chat.message_request_state == MessageRequestState.pending:
|
||||
if request.compatible_with(worker_config):
|
||||
break
|
||||
else:
|
||||
logger.debug("No pending tasks")
|
||||
await asyncio.sleep(1)
|
||||
continue
|
||||
|
||||
chat.message_request_state = MessageRequestState.in_progress
|
||||
|
||||
work_request = inference.WorkRequest(
|
||||
conversation=chat.conversation,
|
||||
model_name=request.model_name,
|
||||
max_new_tokens=request.max_new_tokens,
|
||||
)
|
||||
|
||||
logger.info(f"Created {work_request}")
|
||||
try:
|
||||
await websocket.send_text(work_request.json())
|
||||
while True:
|
||||
# maybe unnecessary to parse and re-serialize
|
||||
# could just pass the raw string and mark end via empty string
|
||||
response_packet = inference.WorkResponsePacket.parse_raw(await websocket.receive_text())
|
||||
await redisClient.rpush(chat.id, response_packet.json())
|
||||
if response_packet.is_end:
|
||||
break
|
||||
except fastapi.WebSocketException:
|
||||
# TODO: handle this better
|
||||
logger.exception(f"Websocket closed during handling of {chat.id}")
|
||||
|
||||
chat.message_request_state = MessageRequestState.complete
|
||||
@@ -0,0 +1,6 @@
|
||||
fastapi[all]
|
||||
loguru
|
||||
pydantic
|
||||
redis
|
||||
sse-starlette
|
||||
websockets
|
||||
@@ -0,0 +1,40 @@
|
||||
"""Simple REPL frontend."""
|
||||
|
||||
import json
|
||||
|
||||
import requests
|
||||
import sseclient
|
||||
import typer
|
||||
|
||||
app = typer.Typer()
|
||||
|
||||
|
||||
@app.command()
|
||||
def main(backend_url: str = "http://127.0.0.1:8000"):
|
||||
"""Simple REPL client."""
|
||||
chat_id = requests.post(f"{backend_url}/chat", json={}).json()["id"]
|
||||
while True:
|
||||
message = typer.prompt("User").strip()
|
||||
|
||||
# wait for stream to be ready
|
||||
# could implement a queue position indicator
|
||||
# could be implemented with long polling
|
||||
# but server load needs to be considered
|
||||
response = requests.post(
|
||||
f"{backend_url}/chat/{chat_id}/message",
|
||||
json={"message": message},
|
||||
stream=True,
|
||||
headers={"Accept": "text/event-stream"},
|
||||
)
|
||||
response.raise_for_status()
|
||||
|
||||
client = sseclient.SSEClient(response)
|
||||
print("Assistant: ", end="", flush=True)
|
||||
for event in client.events():
|
||||
data = json.loads(event.data)
|
||||
print(data["token"], end="", flush=True)
|
||||
print()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
app()
|
||||
@@ -0,0 +1,3 @@
|
||||
requests
|
||||
sseclient-py
|
||||
typer
|
||||
@@ -0,0 +1,79 @@
|
||||
import re
|
||||
import time
|
||||
|
||||
import rel
|
||||
import torch
|
||||
import typer
|
||||
import websocket
|
||||
from loguru import logger
|
||||
from oasst_shared.schemas import inference, protocol
|
||||
from transformers import pipeline
|
||||
|
||||
app = typer.Typer()
|
||||
|
||||
|
||||
@app.command()
|
||||
def main(
|
||||
backend_url: str = "ws://localhost:8000",
|
||||
model_name: str = "distilgpt2",
|
||||
):
|
||||
pipe = pipeline("text-generation", model=model_name)
|
||||
|
||||
def on_open(ws: websocket.WebSocket):
|
||||
worker_config = inference.WorkerConfig(model_name=model_name)
|
||||
ws.send(worker_config.json())
|
||||
|
||||
def on_message(ws: websocket.WebSocket, message: str):
|
||||
# TODO: what if this comes in, but one is already in progress?
|
||||
# also need to think of enabling batching
|
||||
work_request = inference.WorkRequest.parse_raw(message)
|
||||
|
||||
def _prepare_message(message: protocol.ConversationMessage) -> str:
|
||||
prefix = "Assistant: " if message.is_assistant else "User: "
|
||||
return prefix + message.text
|
||||
|
||||
# construct prompt
|
||||
messages = [_prepare_message(message) for message in work_request.conversation.messages]
|
||||
|
||||
prompt = "\n".join(messages) + "\nAssistant:"
|
||||
|
||||
# TODO: replace this with incremental generation
|
||||
torch.manual_seed(work_request.seed)
|
||||
model_output = pipe(prompt, max_new_tokens=work_request.max_new_tokens, do_sample=True, return_full_text=False)[
|
||||
0
|
||||
]["generated_text"]
|
||||
model_output = model_output.strip()
|
||||
|
||||
# fake streaming
|
||||
split_idcs = [m.start() for m in re.finditer(r"([\w:]+)", model_output)]
|
||||
pieces = [model_output[a:b] for a, b in zip([0] + split_idcs, split_idcs + [None])]
|
||||
for piece in pieces:
|
||||
if not piece:
|
||||
continue
|
||||
if piece.strip() in ("User:", "Assistant:"):
|
||||
break
|
||||
ws.send(inference.WorkResponsePacket(token=piece).json())
|
||||
time.sleep(0.1)
|
||||
ws.send(inference.WorkResponsePacket(is_end=True).json())
|
||||
|
||||
def on_error(ws: websocket.WebSocket, error: Exception):
|
||||
logger.error(f"Connection error: {error}")
|
||||
|
||||
def on_close(ws: websocket.WebSocket, close_status_code: int, close_msg: str):
|
||||
logger.warning(f"Connection closed: {close_status_code=} {close_msg=}")
|
||||
|
||||
ws = websocket.WebSocketApp(
|
||||
f"{backend_url}/work",
|
||||
on_message=on_message,
|
||||
on_error=on_error,
|
||||
on_close=on_close,
|
||||
on_open=on_open,
|
||||
)
|
||||
|
||||
ws.run_forever(dispatcher=rel, reconnect=5)
|
||||
rel.signal(2, rel.abort)
|
||||
rel.dispatch()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
app()
|
||||
@@ -0,0 +1,6 @@
|
||||
loguru
|
||||
rel
|
||||
torch
|
||||
transformers
|
||||
typer
|
||||
websocket-client
|
||||
@@ -0,0 +1,21 @@
|
||||
import random
|
||||
|
||||
import pydantic
|
||||
|
||||
from . import protocol
|
||||
|
||||
|
||||
class WorkerConfig(pydantic.BaseModel):
|
||||
model_name: str = "distilgpt2"
|
||||
|
||||
|
||||
class WorkRequest(pydantic.BaseModel):
|
||||
conversation: protocol.Conversation = pydantic.Field(..., repr=False)
|
||||
model_name: str = "distilgpt2"
|
||||
max_new_tokens: int = 100
|
||||
seed: int = pydantic.Field(default_factory=lambda: random.randint(0, 2**32 - 1))
|
||||
|
||||
|
||||
class WorkResponsePacket(pydantic.BaseModel):
|
||||
token: str | None = None
|
||||
is_end: bool = False
|
||||
@@ -64,6 +64,18 @@ class Conversation(BaseModel):
|
||||
|
||||
messages: list[ConversationMessage] = []
|
||||
|
||||
def __len__(self):
|
||||
return len(self.messages)
|
||||
|
||||
@property
|
||||
def is_prompter_turn(self) -> bool:
|
||||
if len(self) == 0:
|
||||
return True
|
||||
last_message = self.messages[-1]
|
||||
if last_message.is_assistant:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
class Message(ConversationMessage):
|
||||
parent_id: Optional[UUID] = None
|
||||
|
||||
Reference in New Issue
Block a user