diff --git a/inference/server/main.py b/inference/server/main.py index 28e8df22..4b2474da 100644 --- a/inference/server/main.py +++ b/inference/server/main.py @@ -216,17 +216,24 @@ async def work(websocket: fastapi.WebSocket): break try: + in_progress = False while True: # maybe unnecessary to parse and re-serialize # could just pass the raw string and mark end via empty string response_packet = inference.WorkResponsePacket.parse_raw(await websocket.receive_text()) + in_progress = True await redisClient.rpush(chat.id, response_packet.json()) if response_packet.is_end: break except fastapi.WebSocketException: # TODO: handle this better logger.exception(f"Websocket closed during handling of {chat.id}") - chat.message_request_state = MessageRequestState.aborted_by_worker + if in_progress: + logger.warning(f"Aborting {chat.id=}") + chat.message_request_state = MessageRequestState.aborted_by_worker + else: + logger.warning(f"Marking {chat.id=} as pending since no work was done.") + chat.message_request_state = MessageRequestState.pending raise chat.message_request_state = MessageRequestState.complete diff --git a/inference/text-client/__main__.py b/inference/text-client/__main__.py index 4a7fa110..8484978e 100644 --- a/inference/text-client/__main__.py +++ b/inference/text-client/__main__.py @@ -12,28 +12,33 @@ app = typer.Typer() @app.command() def main(backend_url: str = "http://127.0.0.1:8000"): """Simple REPL client.""" - chat_id = requests.post(f"{backend_url}/chat", json={}).json()["id"] while True: - message = typer.prompt("User").strip() + try: + chat_id = requests.post(f"{backend_url}/chat", json={}).json()["id"] + typer.echo(f"Chat ID: {chat_id}") + while True: + message = typer.prompt("User").strip() - # wait for stream to be ready - # could implement a queue position indicator - # could be implemented with long polling - # but server load needs to be considered - response = requests.post( - f"{backend_url}/chat/{chat_id}/message", - json={"message": message}, - stream=True, - headers={"Accept": "text/event-stream"}, - ) - response.raise_for_status() + # wait for stream to be ready + # could implement a queue position indicator + # could be implemented with long polling + # but server load needs to be considered + response = requests.post( + f"{backend_url}/chat/{chat_id}/message", + json={"message": message}, + stream=True, + headers={"Accept": "text/event-stream"}, + ) + response.raise_for_status() - client = sseclient.SSEClient(response) - print("Assistant: ", end="", flush=True) - for event in client.events(): - data = json.loads(event.data) - print(data["token"]["text"], end="", flush=True) - print() + client = sseclient.SSEClient(response) + print("Assistant: ", end="", flush=True) + for event in client.events(): + data = json.loads(event.data) + print(data["token"]["text"], end="", flush=True) + print() + except Exception: + typer.echo("Error, restarting chat...") if __name__ == "__main__":