mirror of
https://github.com/wassname/Open-Assistant.git
synced 2026-06-27 16:10:30 +08:00
robustifying inference
This commit is contained in:
@@ -216,17 +216,24 @@ async def work(websocket: fastapi.WebSocket):
|
||||
break
|
||||
|
||||
try:
|
||||
in_progress = False
|
||||
while True:
|
||||
# maybe unnecessary to parse and re-serialize
|
||||
# could just pass the raw string and mark end via empty string
|
||||
response_packet = inference.WorkResponsePacket.parse_raw(await websocket.receive_text())
|
||||
in_progress = True
|
||||
await redisClient.rpush(chat.id, response_packet.json())
|
||||
if response_packet.is_end:
|
||||
break
|
||||
except fastapi.WebSocketException:
|
||||
# TODO: handle this better
|
||||
logger.exception(f"Websocket closed during handling of {chat.id}")
|
||||
chat.message_request_state = MessageRequestState.aborted_by_worker
|
||||
if in_progress:
|
||||
logger.warning(f"Aborting {chat.id=}")
|
||||
chat.message_request_state = MessageRequestState.aborted_by_worker
|
||||
else:
|
||||
logger.warning(f"Marking {chat.id=} as pending since no work was done.")
|
||||
chat.message_request_state = MessageRequestState.pending
|
||||
raise
|
||||
|
||||
chat.message_request_state = MessageRequestState.complete
|
||||
|
||||
@@ -12,28 +12,33 @@ app = typer.Typer()
|
||||
@app.command()
|
||||
def main(backend_url: str = "http://127.0.0.1:8000"):
|
||||
"""Simple REPL client."""
|
||||
chat_id = requests.post(f"{backend_url}/chat", json={}).json()["id"]
|
||||
while True:
|
||||
message = typer.prompt("User").strip()
|
||||
try:
|
||||
chat_id = requests.post(f"{backend_url}/chat", json={}).json()["id"]
|
||||
typer.echo(f"Chat ID: {chat_id}")
|
||||
while True:
|
||||
message = typer.prompt("User").strip()
|
||||
|
||||
# wait for stream to be ready
|
||||
# could implement a queue position indicator
|
||||
# could be implemented with long polling
|
||||
# but server load needs to be considered
|
||||
response = requests.post(
|
||||
f"{backend_url}/chat/{chat_id}/message",
|
||||
json={"message": message},
|
||||
stream=True,
|
||||
headers={"Accept": "text/event-stream"},
|
||||
)
|
||||
response.raise_for_status()
|
||||
# wait for stream to be ready
|
||||
# could implement a queue position indicator
|
||||
# could be implemented with long polling
|
||||
# but server load needs to be considered
|
||||
response = requests.post(
|
||||
f"{backend_url}/chat/{chat_id}/message",
|
||||
json={"message": message},
|
||||
stream=True,
|
||||
headers={"Accept": "text/event-stream"},
|
||||
)
|
||||
response.raise_for_status()
|
||||
|
||||
client = sseclient.SSEClient(response)
|
||||
print("Assistant: ", end="", flush=True)
|
||||
for event in client.events():
|
||||
data = json.loads(event.data)
|
||||
print(data["token"]["text"], end="", flush=True)
|
||||
print()
|
||||
client = sseclient.SSEClient(response)
|
||||
print("Assistant: ", end="", flush=True)
|
||||
for event in client.events():
|
||||
data = json.loads(event.data)
|
||||
print(data["token"]["text"], end="", flush=True)
|
||||
print()
|
||||
except Exception:
|
||||
typer.echo("Error, restarting chat...")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
Reference in New Issue
Block a user