mirror of
https://github.com/wassname/Open-Assistant.git
synced 2026-07-02 17:00:28 +08:00
made inference server a bit more robust
This commit is contained in:
+53
-35
@@ -5,6 +5,7 @@ import uuid
|
||||
import fastapi
|
||||
import pydantic
|
||||
import redis.asyncio as redis
|
||||
import websockets.exceptions
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from loguru import logger
|
||||
from oasst_shared.schemas import inference, protocol
|
||||
@@ -63,6 +64,7 @@ class MessageRequestState(str, enum.Enum):
|
||||
pending = "pending"
|
||||
in_progress = "in_progress"
|
||||
complete = "complete"
|
||||
aborted_by_worker = "aborted_by_worker"
|
||||
|
||||
|
||||
class DbChatEntry(pydantic.BaseModel):
|
||||
@@ -154,40 +156,56 @@ async def create_message(id: str, message_request: MessageRequest, fastapi_reque
|
||||
async def work(websocket: fastapi.WebSocket):
|
||||
await websocket.accept()
|
||||
worker_config = inference.WorkerConfig.parse_raw(await websocket.receive_text())
|
||||
while True:
|
||||
# find a pending task that matches the worker's config
|
||||
# could also be implemented using task queues
|
||||
# but general compatibility matching is tricky
|
||||
for chat in CHATS.values():
|
||||
if (request := chat.pending_message_request) is not None:
|
||||
if chat.message_request_state == MessageRequestState.pending:
|
||||
if request.compatible_with(worker_config):
|
||||
try:
|
||||
while True:
|
||||
print(websocket.client_state)
|
||||
if websocket.client_state == fastapi.websockets.WebSocketState.DISCONNECTED:
|
||||
logger.warning("Worker disconnected")
|
||||
break
|
||||
# find a pending task that matches the worker's config
|
||||
# could also be implemented using task queues
|
||||
# but general compatibility matching is tricky
|
||||
for chat in CHATS.values():
|
||||
if (request := chat.pending_message_request) is not None:
|
||||
if chat.message_request_state == MessageRequestState.pending:
|
||||
if request.compatible_with(worker_config):
|
||||
break
|
||||
else:
|
||||
logger.debug("No pending tasks")
|
||||
await asyncio.sleep(1)
|
||||
continue
|
||||
|
||||
chat.message_request_state = MessageRequestState.in_progress
|
||||
|
||||
work_request = inference.WorkRequest(
|
||||
conversation=chat.conversation,
|
||||
model_name=request.model_name,
|
||||
max_new_tokens=request.max_new_tokens,
|
||||
)
|
||||
|
||||
logger.info(f"Created {work_request}")
|
||||
try:
|
||||
await websocket.send_text(work_request.json())
|
||||
except websockets.exceptions.ConnectionClosedError:
|
||||
logger.warning("Worker disconnected")
|
||||
websocket.close()
|
||||
chat.message_request_state = MessageRequestState.pending
|
||||
break
|
||||
|
||||
try:
|
||||
while True:
|
||||
# maybe unnecessary to parse and re-serialize
|
||||
# could just pass the raw string and mark end via empty string
|
||||
response_packet = inference.WorkResponsePacket.parse_raw(await websocket.receive_text())
|
||||
await redisClient.rpush(chat.id, response_packet.json())
|
||||
if response_packet.is_end:
|
||||
break
|
||||
else:
|
||||
logger.debug("No pending tasks")
|
||||
await asyncio.sleep(1)
|
||||
continue
|
||||
except fastapi.WebSocketException:
|
||||
# TODO: handle this better
|
||||
logger.exception(f"Websocket closed during handling of {chat.id}")
|
||||
chat.message_request_state = MessageRequestState.aborted_by_worker
|
||||
raise
|
||||
|
||||
chat.message_request_state = MessageRequestState.in_progress
|
||||
|
||||
work_request = inference.WorkRequest(
|
||||
conversation=chat.conversation,
|
||||
model_name=request.model_name,
|
||||
max_new_tokens=request.max_new_tokens,
|
||||
)
|
||||
|
||||
logger.info(f"Created {work_request}")
|
||||
try:
|
||||
await websocket.send_text(work_request.json())
|
||||
while True:
|
||||
# maybe unnecessary to parse and re-serialize
|
||||
# could just pass the raw string and mark end via empty string
|
||||
response_packet = inference.WorkResponsePacket.parse_raw(await websocket.receive_text())
|
||||
await redisClient.rpush(chat.id, response_packet.json())
|
||||
if response_packet.is_end:
|
||||
break
|
||||
except fastapi.WebSocketException:
|
||||
# TODO: handle this better
|
||||
logger.exception(f"Websocket closed during handling of {chat.id}")
|
||||
|
||||
chat.message_request_state = MessageRequestState.complete
|
||||
chat.message_request_state = MessageRequestState.complete
|
||||
except fastapi.WebSocketException:
|
||||
logger.exception("Websocket closed")
|
||||
|
||||
@@ -33,7 +33,13 @@ def main(
|
||||
# construct prompt
|
||||
messages = [_prepare_message(message) for message in work_request.conversation.messages]
|
||||
|
||||
prompt = "\n".join(messages) + "\nAssistant:"
|
||||
prefix = (
|
||||
"The following is a conversation between a user and an assistant. "
|
||||
"The assistant is helpful, creative, clever, and very friendly.\n"
|
||||
"Assistant: Hello! How can I help you today?\n"
|
||||
)
|
||||
|
||||
prompt = prefix + "\n".join(messages) + "\nAssistant:"
|
||||
|
||||
# TODO: use the seed
|
||||
# torch.manual_seed(work_request.seed)
|
||||
|
||||
Reference in New Issue
Block a user