unified queueing

This commit is contained in:
Yannic Kilcher
2023-02-11 01:31:25 +01:00
parent 76f7af0dfd
commit 48212079f4
7 changed files with 144 additions and 82 deletions
+75 -58
View File
@@ -12,7 +12,7 @@ import websockets.exceptions
from fastapi import Depends from fastapi import Depends
from fastapi.middleware.cors import CORSMiddleware from fastapi.middleware.cors import CORSMiddleware
from loguru import logger from loguru import logger
from oasst_inference_server import interface from oasst_inference_server import interface, queueing
from oasst_inference_server.chat_repository import ChatRepository from oasst_inference_server.chat_repository import ChatRepository
from oasst_inference_server.database import db_engine from oasst_inference_server.database import db_engine
from oasst_inference_server.settings import settings from oasst_inference_server.settings import settings
@@ -40,7 +40,7 @@ app.add_middleware(
# create async redis client # create async redis client
redisClient = redis.Redis( redis_client = redis.Redis(
host=settings.redis_host, port=settings.redis_port, db=settings.redis_db, decode_responses=True host=settings.redis_host, port=settings.redis_port, db=settings.redis_db, decode_responses=True
) )
@@ -55,6 +55,12 @@ def create_chat_repository(session: sqlmodel.Session = Depends(create_session)):
return repository return repository
@contextlib.contextmanager
def manual_chat_repository():
with contextlib.contextmanager(create_session)() as session:
yield create_chat_repository(session)
if settings.update_alembic: if settings.update_alembic:
@app.on_event("startup") @app.on_event("startup")
@@ -105,9 +111,9 @@ async def get_chat(id: str, chat_repository: ChatRepository = Depends(create_cha
return chat return chat
@app.post("/chat/{id}/message") @app.post("/chat/{chat_id}/message")
async def create_message( async def create_message(
id: str, chat_id: str,
message_request: interface.MessageRequest, message_request: interface.MessageRequest,
fastapi_request: fastapi.Request, fastapi_request: fastapi.Request,
chat_repository: ChatRepository = Depends(create_chat_repository), chat_repository: ChatRepository = Depends(create_chat_repository),
@@ -115,21 +121,24 @@ async def create_message(
"""Allows the client to stream the results of a request.""" """Allows the client to stream the results of a request."""
try: try:
chat_repository.add_prompter_message(id=id, message_request=message_request) chat_repository.add_prompter_message(chat_id=chat_id, message_request=message_request)
queue = queueing.work_queue(redis_client, message_request.worker_compat_hash)
logger.debug(f"Adding {chat_id} to {queue.queue_id}")
await queue.enqueue(chat_id)
logger.debug(f"Added message to {queue.queue_id} for {chat_id}")
except Exception: except Exception:
logger.exception("Error adding prompter message") logger.exception("Error adding prompter message")
return fastapi.Response(status_code=500) return fastapi.Response(status_code=500)
async def event_generator(id): async def event_generator(chat_id):
queue = queueing.chat_queue(redis_client, chat_id)
result_data = [] result_data = []
try: try:
while True: while True:
if await fastapi_request.is_disconnected(): if await fastapi_request.is_disconnected():
logger.warning("Client disconnected") logger.warning("Client disconnected")
return return
item = await queue.dequeue()
item = await redisClient.blpop(id, 1)
if item is None: if item is None:
continue continue
@@ -144,25 +153,26 @@ async def create_message(
"retry": settings.sse_retry_timeout, "retry": settings.sse_retry_timeout,
"data": interface.TokenResponseEvent(token=response_packet.token).json(), "data": interface.TokenResponseEvent(token=response_packet.token).json(),
} }
logger.info(f"Finished streaming {id} {len(result_data)=}") logger.info(f"Finished streaming {chat_id} {len(result_data)=}")
except Exception: except Exception:
logger.exception(f"Error streaming {id}") logger.exception(f"Error streaming {chat_id}")
raise raise
try: try:
with contextlib.contextmanager(create_session)() as session: with manual_chat_repository() as chat_repository:
chat_repository = create_chat_repository(session) chat_repository.add_assistant_message(chat_id=chat_id, text=response_packet.generated_text.text)
chat_repository.add_assistant_message(id=id, text=response_packet.generated_text.text)
except Exception: except Exception:
logger.exception("Error adding assistant message") logger.exception("Error adding assistant message")
return EventSourceResponse(event_generator(id)) return EventSourceResponse(event_generator(chat_id))
@app.websocket("/work") @app.websocket("/work")
async def work(websocket: fastapi.WebSocket, chat_repository: ChatRepository = Depends(create_chat_repository)): async def work(websocket: fastapi.WebSocket):
await websocket.accept() await websocket.accept()
worker_config = inference.WorkerConfig.parse_raw(await websocket.receive_text()) worker_config = inference.WorkerConfig.parse_raw(await websocket.receive_text())
queue_id = f"work:{worker_config.compat_hash}"
work_queue = queueing.RedisQueue(redis_client, queue_id)
try: try:
while True: while True:
if websocket.client_state == fastapi.websockets.WebSocketState.DISCONNECTED: if websocket.client_state == fastapi.websockets.WebSocketState.DISCONNECTED:
@@ -171,55 +181,62 @@ async def work(websocket: fastapi.WebSocket, chat_repository: ChatRepository = D
# find a pending task that matches the worker's config # find a pending task that matches the worker's config
# could also be implemented using task queues # could also be implemented using task queues
# but general compatibility matching is tricky # but general compatibility matching is tricky
for chat in chat_repository.get_pending_chats(): item = await work_queue.dequeue()
request = chat.pending_message_request if item is None:
if request.compatible_with(worker_config):
break
else:
await asyncio.sleep(1) await asyncio.sleep(1)
continue continue
else:
_, chat_id = item
chat_repository.set_chat_state(chat.id, interface.MessageRequestState.in_progress) with manual_chat_repository() as chat_repository:
chat = chat_repository.get_chat_by_id(chat_id)
request = chat.pending_message_request
work_request = inference.WorkRequest( chat_repository.set_chat_state(chat.id, interface.MessageRequestState.in_progress)
conversation=chat.conversation,
model_name=request.model_name,
max_new_tokens=request.max_new_tokens,
)
logger.info(f"Created {work_request=}") work_request = inference.WorkRequest(
try: conversation=chat.conversation,
await websocket.send_text(work_request.json()) model_name=request.model_name,
except websockets.exceptions.ConnectionClosedError: max_new_tokens=request.max_new_tokens,
logger.warning("Worker disconnected") )
websocket.close()
chat_repository.set_chat_state(chat.id, interface.MessageRequestState.pending)
break
logger.debug(f"Sent {work_request=} to worker.") logger.info(f"Created {work_request=}")
try:
try: await websocket.send_text(work_request.json())
in_progress = False except websockets.exceptions.ConnectionClosedError:
while True: logger.warning("Worker disconnected")
# maybe unnecessary to parse and re-serialize websocket.close()
# could just pass the raw string and mark end via empty string
response_packet = inference.WorkResponsePacket.parse_raw(await websocket.receive_text())
in_progress = True
await redisClient.rpush(chat.id, response_packet.json())
if response_packet.is_end:
logger.debug(f"Received {response_packet=} from worker. Ending.")
break
except fastapi.WebSocketException:
# TODO: handle this better
logger.exception(f"Websocket closed during handling of {chat.id}")
if in_progress:
logger.warning(f"Aborting {chat.id=}")
chat_repository.set_chat_state(chat.id, interface.MessageRequestState.aborted_by_worker)
else:
logger.warning(f"Marking {chat.id=} as pending since no work was done.")
chat_repository.set_chat_state(chat.id, interface.MessageRequestState.pending) chat_repository.set_chat_state(chat.id, interface.MessageRequestState.pending)
raise await work_queue.enqueue(chat.id)
break
chat_repository.set_chat_state(chat.id, interface.MessageRequestState.complete) logger.debug(f"Sent {work_request=} to worker.")
chat_queue = queueing.chat_queue(redis_client, chat.id)
try:
in_progress = False
while True:
# maybe unnecessary to parse and re-serialize
# could just pass the raw string and mark end via empty string
response_packet = inference.WorkResponsePacket.parse_raw(await websocket.receive_text())
in_progress = True
await chat_queue.enqueue(response_packet.json())
if response_packet.is_end:
logger.debug(f"Received {response_packet=} from worker. Ending.")
break
except fastapi.WebSocketException:
# TODO: handle this better
logger.exception(f"Websocket closed during handling of {chat.id}")
if in_progress:
logger.warning(f"Aborting {chat.id=}")
chat_repository.set_chat_state(chat.id, interface.MessageRequestState.aborted_by_worker)
else:
logger.warning(f"Marking {chat.id=} as pending since no work was done.")
chat_repository.set_chat_state(chat.id, interface.MessageRequestState.pending)
await work_queue.enqueue(chat.id)
raise
chat_repository.set_chat_state(chat.id, interface.MessageRequestState.complete)
except fastapi.WebSocketException: except fastapi.WebSocketException:
logger.exception("Websocket closed") logger.exception("Websocket closed")
@@ -7,8 +7,16 @@ from sqlalchemy.sql.operators import is_not
class ChatRepository: class ChatRepository:
def __init__(self, session: sqlmodel.Session) -> None: def __init__(self, session: sqlmodel.Session, do_commit=True) -> None:
self.session = session self.session = session
self.do_commit = do_commit
def as_no_commit(self) -> "ChatRepository":
return ChatRepository(self.session, do_commit=False)
def maybe_commit(self) -> None:
if self.do_commit:
self.session.commit()
def get_chats(self) -> list[models.DbChatEntry]: def get_chats(self) -> list[models.DbChatEntry]:
return self.session.exec(sqlmodel.select(models.DbChatEntry)).all() return self.session.exec(sqlmodel.select(models.DbChatEntry)).all()
@@ -25,22 +33,25 @@ class ChatRepository:
chats = self.get_chats() chats = self.get_chats()
return [chat.to_list_entry() for chat in chats] return [chat.to_list_entry() for chat in chats]
def get_chat_by_id(self, id: str) -> models.DbChatEntry: def get_chat_by_id(self, chat_id: str, for_update=False) -> models.DbChatEntry:
chat = self.session.exec(sqlmodel.select(models.DbChatEntry).where(models.DbChatEntry.id == id)).one() query = sqlmodel.select(models.DbChatEntry).where(models.DbChatEntry.id == chat_id)
if for_update:
query = query.with_for_update()
chat = self.session.exec(query).one()
return chat return chat
def get_chat_entry_by_id(self, id: str) -> interface.ChatEntry: def get_chat_entry_by_id(self, chat_id: str) -> interface.ChatEntry:
return self.get_chat_by_id(id).to_entry() return self.get_chat_by_id(chat_id).to_entry()
def create_chat(self) -> models.DbChatEntry: def create_chat(self) -> models.DbChatEntry:
chat = models.DbChatEntry() chat = models.DbChatEntry()
self.session.add(chat) self.session.add(chat)
self.session.commit() self.maybe_commit()
return chat return chat
def add_prompter_message(self, id: str, message_request: interface.MessageRequest) -> None: def add_prompter_message(self, chat_id: str, message_request: interface.MessageRequest) -> None:
logger.info(f"Adding prompter message {message_request} to chat {id}") logger.info(f"Adding prompter message {message_request} to chat {chat_id}")
chat = self.get_chat_by_id(id) chat = self.get_chat_by_id(chat_id, for_update=True)
if not chat.conversation.is_prompter_turn: if not chat.conversation.is_prompter_turn:
raise fastapi.HTTPException(status_code=400, detail="Not your turn") raise fastapi.HTTPException(status_code=400, detail="Not your turn")
if chat.pending_message_request is not None: if chat.pending_message_request is not None:
@@ -55,12 +66,12 @@ class ChatRepository:
chat.pending_message_request = message_request chat.pending_message_request = message_request
chat.message_request_state = interface.MessageRequestState.pending chat.message_request_state = interface.MessageRequestState.pending
self.session.commit() self.maybe_commit()
logger.debug(f"Added prompter message {message_request} to chat {id}") logger.debug(f"Added prompter message {message_request} to chat {chat_id}")
def add_assistant_message(self, id: str, text: str) -> None: def add_assistant_message(self, chat_id: str, text: str) -> None:
logger.info(f"Adding assistant message {text} to chat {id}") logger.info(f"Adding assistant message {text} to chat {chat_id}")
chat = self.get_chat_by_id(id) chat = self.get_chat_by_id(chat_id, for_update=True)
chat.conversation.messages.append( chat.conversation.messages.append(
protocol.ConversationMessage( protocol.ConversationMessage(
text=text, text=text,
@@ -68,12 +79,12 @@ class ChatRepository:
) )
) )
chat.pending_message_request = None chat.pending_message_request = None
self.session.commit() self.maybe_commit()
logger.debug(f"Added assistant message {text} to chat {id}") logger.debug(f"Added assistant message {text} to chat {chat_id}")
def set_chat_state(self, id: str, state: interface.MessageRequestState) -> None: def set_chat_state(self, chat_id: str, state: interface.MessageRequestState) -> None:
logger.info(f"Setting chat {id} state to {state}") logger.info(f"Setting chat {chat_id} state to {state}")
chat = self.get_chat_by_id(id) chat = self.get_chat_by_id(chat_id, for_update=True)
chat.message_request_state = state chat.message_request_state = state
self.session.commit() self.maybe_commit()
logger.debug(f"Set chat {id} state to {state}") logger.debug(f"Set chat {chat_id} state to {state}")
@@ -9,8 +9,9 @@ class MessageRequest(pydantic.BaseModel):
model_name: str = "distilgpt2" model_name: str = "distilgpt2"
max_new_tokens: int = 100 max_new_tokens: int = 100
def compatible_with(self, worker_config: inference.WorkerConfig) -> bool: @property
return self.model_name == worker_config.model_name def worker_compat_hash(self) -> str:
return f"{self.model_name}"
class TokenResponseEvent(pydantic.BaseModel): class TokenResponseEvent(pydantic.BaseModel):
@@ -0,0 +1,27 @@
import redis.asyncio as redis
from oasst_inference_server.settings import settings
class RedisQueue:
def __init__(self, redis_client: redis.Redis, queue_id: str) -> None:
self.redis_client = redis_client
self.queue_id = queue_id
async def enqueue(self, value: str) -> None:
return await self.redis_client.rpush(self.queue_id, value)
async def dequeue(self, block: bool = True, timeout: int = 1) -> str:
if block:
return await self.redis_client.blpop(self.queue_id, timeout=timeout)
else:
return await self.redis_client.lpop(self.queue_id)
def chat_queue(redis_client: redis.Redis, chat_id: str) -> RedisQueue:
return RedisQueue(redis_client, f"chat:{chat_id}")
def work_queue(redis_client: redis.Redis, worker_compat_hash: str) -> RedisQueue:
if worker_compat_hash not in settings.allowed_worker_compat_hashes:
raise ValueError(f"Worker compat hash {worker_compat_hash} not allowed")
return RedisQueue(redis_client, f"work:{worker_compat_hash}")
@@ -8,6 +8,8 @@ class Settings(pydantic.BaseSettings):
redis_port: int = 6379 redis_port: int = 6379
redis_db: int = 0 redis_db: int = 0
allowed_worker_compat_hashes: list[str] = ["distilgpt2"]
sse_retry_timeout: int = 15000 sse_retry_timeout: int = 15000
update_alembic: bool = True update_alembic: bool = True
alembic_retries: int = 5 alembic_retries: int = 5
+1 -1
View File
@@ -55,7 +55,7 @@ def wait_for_inference_server(inference_server_url: str, timeout: int = 600):
if time.time() > time_limit: if time.time() > time_limit:
raise raise
sleep_duration = random.uniform(0, 10) sleep_duration = random.uniform(0, 10)
logger.warning(f"Inference server not ready. Retrying in {sleep_duration} seconds") logger.warning(f"Inference server not ready. Retrying in {sleep_duration:.2f} seconds")
time.sleep(sleep_duration) time.sleep(sleep_duration)
else: else:
logger.info("Inference server is ready") logger.info("Inference server is ready")
@@ -9,6 +9,10 @@ from . import protocol
class WorkerConfig(pydantic.BaseModel): class WorkerConfig(pydantic.BaseModel):
model_name: str = "distilgpt2" model_name: str = "distilgpt2"
@property
def compat_hash(self) -> str:
return f"{self.model_name}"
class WorkRequest(pydantic.BaseModel): class WorkRequest(pydantic.BaseModel):
conversation: protocol.Conversation = pydantic.Field(..., repr=False) conversation: protocol.Conversation = pydantic.Field(..., repr=False)