robustifying inference

This commit is contained in:
Yannic Kilcher
2023-02-09 15:31:46 +01:00
parent 7c4ff73241
commit ed7d920e5d
2 changed files with 32 additions and 20 deletions
+8 -1
View File
@@ -216,17 +216,24 @@ async def work(websocket: fastapi.WebSocket):
break
try:
in_progress = False
while True:
# maybe unnecessary to parse and re-serialize
# could just pass the raw string and mark end via empty string
response_packet = inference.WorkResponsePacket.parse_raw(await websocket.receive_text())
in_progress = True
await redisClient.rpush(chat.id, response_packet.json())
if response_packet.is_end:
break
except fastapi.WebSocketException:
# TODO: handle this better
logger.exception(f"Websocket closed during handling of {chat.id}")
chat.message_request_state = MessageRequestState.aborted_by_worker
if in_progress:
logger.warning(f"Aborting {chat.id=}")
chat.message_request_state = MessageRequestState.aborted_by_worker
else:
logger.warning(f"Marking {chat.id=} as pending since no work was done.")
chat.message_request_state = MessageRequestState.pending
raise
chat.message_request_state = MessageRequestState.complete
+24 -19
View File
@@ -12,28 +12,33 @@ app = typer.Typer()
@app.command()
def main(backend_url: str = "http://127.0.0.1:8000"):
"""Simple REPL client."""
chat_id = requests.post(f"{backend_url}/chat", json={}).json()["id"]
while True:
message = typer.prompt("User").strip()
try:
chat_id = requests.post(f"{backend_url}/chat", json={}).json()["id"]
typer.echo(f"Chat ID: {chat_id}")
while True:
message = typer.prompt("User").strip()
# wait for stream to be ready
# could implement a queue position indicator
# could be implemented with long polling
# but server load needs to be considered
response = requests.post(
f"{backend_url}/chat/{chat_id}/message",
json={"message": message},
stream=True,
headers={"Accept": "text/event-stream"},
)
response.raise_for_status()
# wait for stream to be ready
# could implement a queue position indicator
# could be implemented with long polling
# but server load needs to be considered
response = requests.post(
f"{backend_url}/chat/{chat_id}/message",
json={"message": message},
stream=True,
headers={"Accept": "text/event-stream"},
)
response.raise_for_status()
client = sseclient.SSEClient(response)
print("Assistant: ", end="", flush=True)
for event in client.events():
data = json.loads(event.data)
print(data["token"]["text"], end="", flush=True)
print()
client = sseclient.SSEClient(response)
print("Assistant: ", end="", flush=True)
for event in client.events():
data = json.loads(event.data)
print(data["token"]["text"], end="", flush=True)
print()
except Exception:
typer.echo("Error, restarting chat...")
if __name__ == "__main__":