made inference server a bit more robust

This commit is contained in:
Yannic Kilcher
2023-01-26 21:01:52 +01:00
parent 2eed530e1a
commit 040344a41f
2 changed files with 60 additions and 36 deletions
+53 -35
View File
@@ -5,6 +5,7 @@ import uuid
import fastapi
import pydantic
import redis.asyncio as redis
import websockets.exceptions
from fastapi.middleware.cors import CORSMiddleware
from loguru import logger
from oasst_shared.schemas import inference, protocol
@@ -63,6 +64,7 @@ class MessageRequestState(str, enum.Enum):
pending = "pending"
in_progress = "in_progress"
complete = "complete"
aborted_by_worker = "aborted_by_worker"
class DbChatEntry(pydantic.BaseModel):
@@ -154,40 +156,56 @@ async def create_message(id: str, message_request: MessageRequest, fastapi_reque
async def work(websocket: fastapi.WebSocket):
await websocket.accept()
worker_config = inference.WorkerConfig.parse_raw(await websocket.receive_text())
while True:
# find a pending task that matches the worker's config
# could also be implemented using task queues
# but general compatibility matching is tricky
for chat in CHATS.values():
if (request := chat.pending_message_request) is not None:
if chat.message_request_state == MessageRequestState.pending:
if request.compatible_with(worker_config):
try:
while True:
print(websocket.client_state)
if websocket.client_state == fastapi.websockets.WebSocketState.DISCONNECTED:
logger.warning("Worker disconnected")
break
# find a pending task that matches the worker's config
# could also be implemented using task queues
# but general compatibility matching is tricky
for chat in CHATS.values():
if (request := chat.pending_message_request) is not None:
if chat.message_request_state == MessageRequestState.pending:
if request.compatible_with(worker_config):
break
else:
logger.debug("No pending tasks")
await asyncio.sleep(1)
continue
chat.message_request_state = MessageRequestState.in_progress
work_request = inference.WorkRequest(
conversation=chat.conversation,
model_name=request.model_name,
max_new_tokens=request.max_new_tokens,
)
logger.info(f"Created {work_request}")
try:
await websocket.send_text(work_request.json())
except websockets.exceptions.ConnectionClosedError:
logger.warning("Worker disconnected")
websocket.close()
chat.message_request_state = MessageRequestState.pending
break
try:
while True:
# maybe unnecessary to parse and re-serialize
# could just pass the raw string and mark end via empty string
response_packet = inference.WorkResponsePacket.parse_raw(await websocket.receive_text())
await redisClient.rpush(chat.id, response_packet.json())
if response_packet.is_end:
break
else:
logger.debug("No pending tasks")
await asyncio.sleep(1)
continue
except fastapi.WebSocketException:
# TODO: handle this better
logger.exception(f"Websocket closed during handling of {chat.id}")
chat.message_request_state = MessageRequestState.aborted_by_worker
raise
chat.message_request_state = MessageRequestState.in_progress
work_request = inference.WorkRequest(
conversation=chat.conversation,
model_name=request.model_name,
max_new_tokens=request.max_new_tokens,
)
logger.info(f"Created {work_request}")
try:
await websocket.send_text(work_request.json())
while True:
# maybe unnecessary to parse and re-serialize
# could just pass the raw string and mark end via empty string
response_packet = inference.WorkResponsePacket.parse_raw(await websocket.receive_text())
await redisClient.rpush(chat.id, response_packet.json())
if response_packet.is_end:
break
except fastapi.WebSocketException:
# TODO: handle this better
logger.exception(f"Websocket closed during handling of {chat.id}")
chat.message_request_state = MessageRequestState.complete
chat.message_request_state = MessageRequestState.complete
except fastapi.WebSocketException:
logger.exception("Websocket closed")
+7 -1
View File
@@ -33,7 +33,13 @@ def main(
# construct prompt
messages = [_prepare_message(message) for message in work_request.conversation.messages]
prompt = "\n".join(messages) + "\nAssistant:"
prefix = (
"The following is a conversation between a user and an assistant. "
"The assistant is helpful, creative, clever, and very friendly.\n"
"Assistant: Hello! How can I help you today?\n"
)
prompt = prefix + "\n".join(messages) + "\nAssistant:"
# TODO: use the seed
# torch.manual_seed(work_request.seed)