mirror of
https://github.com/wassname/Open-Assistant.git
synced 2026-06-26 16:00:18 +08:00
unified queueing
This commit is contained in:
+75
-58
@@ -12,7 +12,7 @@ import websockets.exceptions
|
||||
from fastapi import Depends
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from loguru import logger
|
||||
from oasst_inference_server import interface
|
||||
from oasst_inference_server import interface, queueing
|
||||
from oasst_inference_server.chat_repository import ChatRepository
|
||||
from oasst_inference_server.database import db_engine
|
||||
from oasst_inference_server.settings import settings
|
||||
@@ -40,7 +40,7 @@ app.add_middleware(
|
||||
|
||||
|
||||
# create async redis client
|
||||
redisClient = redis.Redis(
|
||||
redis_client = redis.Redis(
|
||||
host=settings.redis_host, port=settings.redis_port, db=settings.redis_db, decode_responses=True
|
||||
)
|
||||
|
||||
@@ -55,6 +55,12 @@ def create_chat_repository(session: sqlmodel.Session = Depends(create_session)):
|
||||
return repository
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def manual_chat_repository():
|
||||
with contextlib.contextmanager(create_session)() as session:
|
||||
yield create_chat_repository(session)
|
||||
|
||||
|
||||
if settings.update_alembic:
|
||||
|
||||
@app.on_event("startup")
|
||||
@@ -105,9 +111,9 @@ async def get_chat(id: str, chat_repository: ChatRepository = Depends(create_cha
|
||||
return chat
|
||||
|
||||
|
||||
@app.post("/chat/{id}/message")
|
||||
@app.post("/chat/{chat_id}/message")
|
||||
async def create_message(
|
||||
id: str,
|
||||
chat_id: str,
|
||||
message_request: interface.MessageRequest,
|
||||
fastapi_request: fastapi.Request,
|
||||
chat_repository: ChatRepository = Depends(create_chat_repository),
|
||||
@@ -115,21 +121,24 @@ async def create_message(
|
||||
"""Allows the client to stream the results of a request."""
|
||||
|
||||
try:
|
||||
chat_repository.add_prompter_message(id=id, message_request=message_request)
|
||||
chat_repository.add_prompter_message(chat_id=chat_id, message_request=message_request)
|
||||
queue = queueing.work_queue(redis_client, message_request.worker_compat_hash)
|
||||
logger.debug(f"Adding {chat_id} to {queue.queue_id}")
|
||||
await queue.enqueue(chat_id)
|
||||
logger.debug(f"Added message to {queue.queue_id} for {chat_id}")
|
||||
except Exception:
|
||||
logger.exception("Error adding prompter message")
|
||||
return fastapi.Response(status_code=500)
|
||||
|
||||
async def event_generator(id):
|
||||
async def event_generator(chat_id):
|
||||
queue = queueing.chat_queue(redis_client, chat_id)
|
||||
result_data = []
|
||||
|
||||
try:
|
||||
while True:
|
||||
if await fastapi_request.is_disconnected():
|
||||
logger.warning("Client disconnected")
|
||||
return
|
||||
|
||||
item = await redisClient.blpop(id, 1)
|
||||
item = await queue.dequeue()
|
||||
if item is None:
|
||||
continue
|
||||
|
||||
@@ -144,25 +153,26 @@ async def create_message(
|
||||
"retry": settings.sse_retry_timeout,
|
||||
"data": interface.TokenResponseEvent(token=response_packet.token).json(),
|
||||
}
|
||||
logger.info(f"Finished streaming {id} {len(result_data)=}")
|
||||
logger.info(f"Finished streaming {chat_id} {len(result_data)=}")
|
||||
except Exception:
|
||||
logger.exception(f"Error streaming {id}")
|
||||
logger.exception(f"Error streaming {chat_id}")
|
||||
raise
|
||||
|
||||
try:
|
||||
with contextlib.contextmanager(create_session)() as session:
|
||||
chat_repository = create_chat_repository(session)
|
||||
chat_repository.add_assistant_message(id=id, text=response_packet.generated_text.text)
|
||||
with manual_chat_repository() as chat_repository:
|
||||
chat_repository.add_assistant_message(chat_id=chat_id, text=response_packet.generated_text.text)
|
||||
except Exception:
|
||||
logger.exception("Error adding assistant message")
|
||||
|
||||
return EventSourceResponse(event_generator(id))
|
||||
return EventSourceResponse(event_generator(chat_id))
|
||||
|
||||
|
||||
@app.websocket("/work")
|
||||
async def work(websocket: fastapi.WebSocket, chat_repository: ChatRepository = Depends(create_chat_repository)):
|
||||
async def work(websocket: fastapi.WebSocket):
|
||||
await websocket.accept()
|
||||
worker_config = inference.WorkerConfig.parse_raw(await websocket.receive_text())
|
||||
queue_id = f"work:{worker_config.compat_hash}"
|
||||
work_queue = queueing.RedisQueue(redis_client, queue_id)
|
||||
try:
|
||||
while True:
|
||||
if websocket.client_state == fastapi.websockets.WebSocketState.DISCONNECTED:
|
||||
@@ -171,55 +181,62 @@ async def work(websocket: fastapi.WebSocket, chat_repository: ChatRepository = D
|
||||
# find a pending task that matches the worker's config
|
||||
# could also be implemented using task queues
|
||||
# but general compatibility matching is tricky
|
||||
for chat in chat_repository.get_pending_chats():
|
||||
request = chat.pending_message_request
|
||||
if request.compatible_with(worker_config):
|
||||
break
|
||||
else:
|
||||
item = await work_queue.dequeue()
|
||||
if item is None:
|
||||
await asyncio.sleep(1)
|
||||
continue
|
||||
else:
|
||||
_, chat_id = item
|
||||
|
||||
chat_repository.set_chat_state(chat.id, interface.MessageRequestState.in_progress)
|
||||
with manual_chat_repository() as chat_repository:
|
||||
chat = chat_repository.get_chat_by_id(chat_id)
|
||||
request = chat.pending_message_request
|
||||
|
||||
work_request = inference.WorkRequest(
|
||||
conversation=chat.conversation,
|
||||
model_name=request.model_name,
|
||||
max_new_tokens=request.max_new_tokens,
|
||||
)
|
||||
chat_repository.set_chat_state(chat.id, interface.MessageRequestState.in_progress)
|
||||
|
||||
logger.info(f"Created {work_request=}")
|
||||
try:
|
||||
await websocket.send_text(work_request.json())
|
||||
except websockets.exceptions.ConnectionClosedError:
|
||||
logger.warning("Worker disconnected")
|
||||
websocket.close()
|
||||
chat_repository.set_chat_state(chat.id, interface.MessageRequestState.pending)
|
||||
break
|
||||
work_request = inference.WorkRequest(
|
||||
conversation=chat.conversation,
|
||||
model_name=request.model_name,
|
||||
max_new_tokens=request.max_new_tokens,
|
||||
)
|
||||
|
||||
logger.debug(f"Sent {work_request=} to worker.")
|
||||
|
||||
try:
|
||||
in_progress = False
|
||||
while True:
|
||||
# maybe unnecessary to parse and re-serialize
|
||||
# could just pass the raw string and mark end via empty string
|
||||
response_packet = inference.WorkResponsePacket.parse_raw(await websocket.receive_text())
|
||||
in_progress = True
|
||||
await redisClient.rpush(chat.id, response_packet.json())
|
||||
if response_packet.is_end:
|
||||
logger.debug(f"Received {response_packet=} from worker. Ending.")
|
||||
break
|
||||
except fastapi.WebSocketException:
|
||||
# TODO: handle this better
|
||||
logger.exception(f"Websocket closed during handling of {chat.id}")
|
||||
if in_progress:
|
||||
logger.warning(f"Aborting {chat.id=}")
|
||||
chat_repository.set_chat_state(chat.id, interface.MessageRequestState.aborted_by_worker)
|
||||
else:
|
||||
logger.warning(f"Marking {chat.id=} as pending since no work was done.")
|
||||
logger.info(f"Created {work_request=}")
|
||||
try:
|
||||
await websocket.send_text(work_request.json())
|
||||
except websockets.exceptions.ConnectionClosedError:
|
||||
logger.warning("Worker disconnected")
|
||||
websocket.close()
|
||||
chat_repository.set_chat_state(chat.id, interface.MessageRequestState.pending)
|
||||
raise
|
||||
await work_queue.enqueue(chat.id)
|
||||
break
|
||||
|
||||
chat_repository.set_chat_state(chat.id, interface.MessageRequestState.complete)
|
||||
logger.debug(f"Sent {work_request=} to worker.")
|
||||
|
||||
chat_queue = queueing.chat_queue(redis_client, chat.id)
|
||||
|
||||
try:
|
||||
in_progress = False
|
||||
while True:
|
||||
# maybe unnecessary to parse and re-serialize
|
||||
# could just pass the raw string and mark end via empty string
|
||||
response_packet = inference.WorkResponsePacket.parse_raw(await websocket.receive_text())
|
||||
in_progress = True
|
||||
await chat_queue.enqueue(response_packet.json())
|
||||
if response_packet.is_end:
|
||||
logger.debug(f"Received {response_packet=} from worker. Ending.")
|
||||
break
|
||||
except fastapi.WebSocketException:
|
||||
# TODO: handle this better
|
||||
logger.exception(f"Websocket closed during handling of {chat.id}")
|
||||
if in_progress:
|
||||
logger.warning(f"Aborting {chat.id=}")
|
||||
chat_repository.set_chat_state(chat.id, interface.MessageRequestState.aborted_by_worker)
|
||||
else:
|
||||
logger.warning(f"Marking {chat.id=} as pending since no work was done.")
|
||||
chat_repository.set_chat_state(chat.id, interface.MessageRequestState.pending)
|
||||
await work_queue.enqueue(chat.id)
|
||||
raise
|
||||
|
||||
chat_repository.set_chat_state(chat.id, interface.MessageRequestState.complete)
|
||||
except fastapi.WebSocketException:
|
||||
logger.exception("Websocket closed")
|
||||
|
||||
@@ -7,8 +7,16 @@ from sqlalchemy.sql.operators import is_not
|
||||
|
||||
|
||||
class ChatRepository:
|
||||
def __init__(self, session: sqlmodel.Session) -> None:
|
||||
def __init__(self, session: sqlmodel.Session, do_commit=True) -> None:
|
||||
self.session = session
|
||||
self.do_commit = do_commit
|
||||
|
||||
def as_no_commit(self) -> "ChatRepository":
|
||||
return ChatRepository(self.session, do_commit=False)
|
||||
|
||||
def maybe_commit(self) -> None:
|
||||
if self.do_commit:
|
||||
self.session.commit()
|
||||
|
||||
def get_chats(self) -> list[models.DbChatEntry]:
|
||||
return self.session.exec(sqlmodel.select(models.DbChatEntry)).all()
|
||||
@@ -25,22 +33,25 @@ class ChatRepository:
|
||||
chats = self.get_chats()
|
||||
return [chat.to_list_entry() for chat in chats]
|
||||
|
||||
def get_chat_by_id(self, id: str) -> models.DbChatEntry:
|
||||
chat = self.session.exec(sqlmodel.select(models.DbChatEntry).where(models.DbChatEntry.id == id)).one()
|
||||
def get_chat_by_id(self, chat_id: str, for_update=False) -> models.DbChatEntry:
|
||||
query = sqlmodel.select(models.DbChatEntry).where(models.DbChatEntry.id == chat_id)
|
||||
if for_update:
|
||||
query = query.with_for_update()
|
||||
chat = self.session.exec(query).one()
|
||||
return chat
|
||||
|
||||
def get_chat_entry_by_id(self, id: str) -> interface.ChatEntry:
|
||||
return self.get_chat_by_id(id).to_entry()
|
||||
def get_chat_entry_by_id(self, chat_id: str) -> interface.ChatEntry:
|
||||
return self.get_chat_by_id(chat_id).to_entry()
|
||||
|
||||
def create_chat(self) -> models.DbChatEntry:
|
||||
chat = models.DbChatEntry()
|
||||
self.session.add(chat)
|
||||
self.session.commit()
|
||||
self.maybe_commit()
|
||||
return chat
|
||||
|
||||
def add_prompter_message(self, id: str, message_request: interface.MessageRequest) -> None:
|
||||
logger.info(f"Adding prompter message {message_request} to chat {id}")
|
||||
chat = self.get_chat_by_id(id)
|
||||
def add_prompter_message(self, chat_id: str, message_request: interface.MessageRequest) -> None:
|
||||
logger.info(f"Adding prompter message {message_request} to chat {chat_id}")
|
||||
chat = self.get_chat_by_id(chat_id, for_update=True)
|
||||
if not chat.conversation.is_prompter_turn:
|
||||
raise fastapi.HTTPException(status_code=400, detail="Not your turn")
|
||||
if chat.pending_message_request is not None:
|
||||
@@ -55,12 +66,12 @@ class ChatRepository:
|
||||
|
||||
chat.pending_message_request = message_request
|
||||
chat.message_request_state = interface.MessageRequestState.pending
|
||||
self.session.commit()
|
||||
logger.debug(f"Added prompter message {message_request} to chat {id}")
|
||||
self.maybe_commit()
|
||||
logger.debug(f"Added prompter message {message_request} to chat {chat_id}")
|
||||
|
||||
def add_assistant_message(self, id: str, text: str) -> None:
|
||||
logger.info(f"Adding assistant message {text} to chat {id}")
|
||||
chat = self.get_chat_by_id(id)
|
||||
def add_assistant_message(self, chat_id: str, text: str) -> None:
|
||||
logger.info(f"Adding assistant message {text} to chat {chat_id}")
|
||||
chat = self.get_chat_by_id(chat_id, for_update=True)
|
||||
chat.conversation.messages.append(
|
||||
protocol.ConversationMessage(
|
||||
text=text,
|
||||
@@ -68,12 +79,12 @@ class ChatRepository:
|
||||
)
|
||||
)
|
||||
chat.pending_message_request = None
|
||||
self.session.commit()
|
||||
logger.debug(f"Added assistant message {text} to chat {id}")
|
||||
self.maybe_commit()
|
||||
logger.debug(f"Added assistant message {text} to chat {chat_id}")
|
||||
|
||||
def set_chat_state(self, id: str, state: interface.MessageRequestState) -> None:
|
||||
logger.info(f"Setting chat {id} state to {state}")
|
||||
chat = self.get_chat_by_id(id)
|
||||
def set_chat_state(self, chat_id: str, state: interface.MessageRequestState) -> None:
|
||||
logger.info(f"Setting chat {chat_id} state to {state}")
|
||||
chat = self.get_chat_by_id(chat_id, for_update=True)
|
||||
chat.message_request_state = state
|
||||
self.session.commit()
|
||||
logger.debug(f"Set chat {id} state to {state}")
|
||||
self.maybe_commit()
|
||||
logger.debug(f"Set chat {chat_id} state to {state}")
|
||||
|
||||
@@ -9,8 +9,9 @@ class MessageRequest(pydantic.BaseModel):
|
||||
model_name: str = "distilgpt2"
|
||||
max_new_tokens: int = 100
|
||||
|
||||
def compatible_with(self, worker_config: inference.WorkerConfig) -> bool:
|
||||
return self.model_name == worker_config.model_name
|
||||
@property
|
||||
def worker_compat_hash(self) -> str:
|
||||
return f"{self.model_name}"
|
||||
|
||||
|
||||
class TokenResponseEvent(pydantic.BaseModel):
|
||||
|
||||
@@ -0,0 +1,27 @@
|
||||
import redis.asyncio as redis
|
||||
from oasst_inference_server.settings import settings
|
||||
|
||||
|
||||
class RedisQueue:
|
||||
def __init__(self, redis_client: redis.Redis, queue_id: str) -> None:
|
||||
self.redis_client = redis_client
|
||||
self.queue_id = queue_id
|
||||
|
||||
async def enqueue(self, value: str) -> None:
|
||||
return await self.redis_client.rpush(self.queue_id, value)
|
||||
|
||||
async def dequeue(self, block: bool = True, timeout: int = 1) -> str:
|
||||
if block:
|
||||
return await self.redis_client.blpop(self.queue_id, timeout=timeout)
|
||||
else:
|
||||
return await self.redis_client.lpop(self.queue_id)
|
||||
|
||||
|
||||
def chat_queue(redis_client: redis.Redis, chat_id: str) -> RedisQueue:
|
||||
return RedisQueue(redis_client, f"chat:{chat_id}")
|
||||
|
||||
|
||||
def work_queue(redis_client: redis.Redis, worker_compat_hash: str) -> RedisQueue:
|
||||
if worker_compat_hash not in settings.allowed_worker_compat_hashes:
|
||||
raise ValueError(f"Worker compat hash {worker_compat_hash} not allowed")
|
||||
return RedisQueue(redis_client, f"work:{worker_compat_hash}")
|
||||
@@ -8,6 +8,8 @@ class Settings(pydantic.BaseSettings):
|
||||
redis_port: int = 6379
|
||||
redis_db: int = 0
|
||||
|
||||
allowed_worker_compat_hashes: list[str] = ["distilgpt2"]
|
||||
|
||||
sse_retry_timeout: int = 15000
|
||||
update_alembic: bool = True
|
||||
alembic_retries: int = 5
|
||||
|
||||
@@ -55,7 +55,7 @@ def wait_for_inference_server(inference_server_url: str, timeout: int = 600):
|
||||
if time.time() > time_limit:
|
||||
raise
|
||||
sleep_duration = random.uniform(0, 10)
|
||||
logger.warning(f"Inference server not ready. Retrying in {sleep_duration} seconds")
|
||||
logger.warning(f"Inference server not ready. Retrying in {sleep_duration:.2f} seconds")
|
||||
time.sleep(sleep_duration)
|
||||
else:
|
||||
logger.info("Inference server is ready")
|
||||
|
||||
@@ -9,6 +9,10 @@ from . import protocol
|
||||
class WorkerConfig(pydantic.BaseModel):
|
||||
model_name: str = "distilgpt2"
|
||||
|
||||
@property
|
||||
def compat_hash(self) -> str:
|
||||
return f"{self.model_name}"
|
||||
|
||||
|
||||
class WorkRequest(pydantic.BaseModel):
|
||||
conversation: protocol.Conversation = pydantic.Field(..., repr=False)
|
||||
|
||||
Reference in New Issue
Block a user