diff --git a/inference/README.md b/inference/README.md index 0475c876..99f3afc8 100644 --- a/inference/README.md +++ b/inference/README.md @@ -86,7 +86,7 @@ For the worker, you'll also want to have the text-generation-inference server running: ```bash -docker run --rm -it -p 8001:80 -e MODEL_ID=distilgpt2 ykilcher/text-generation-inference +docker run --rm -it -p 8001:80 -e MODEL_ID=distilgpt2 ghcr.io/huggingface/text-generation-inference ``` Run the client: diff --git a/inference/full-dev-setup.sh b/inference/full-dev-setup.sh index 2251c62b..5ef754d2 100755 --- a/inference/full-dev-setup.sh +++ b/inference/full-dev-setup.sh @@ -5,7 +5,7 @@ tmux new-session -d -s "inference-dev-setup" tmux send-keys "docker run --rm -it -p 6379:6379 redis" C-m tmux split-window -h -tmux send-keys "docker run --rm -it -p 8001:80 -e MODEL_ID=distilgpt2 ykilcher/text-generation-inference" C-m +tmux send-keys "docker run --rm -it -p 8001:80 -e MODEL_ID=distilgpt2 ghcr.io/huggingface/text-generation-inference" C-m tmux split-window -h tmux send-keys "cd server" C-m tmux send-keys "uvicorn main:app --reload" C-m diff --git a/inference/server/main.py b/inference/server/main.py index 4cb5f659..0c282394 100644 --- a/inference/server/main.py +++ b/inference/server/main.py @@ -57,7 +57,7 @@ class MessageRequest(pydantic.BaseModel): class TokenResponseEvent(pydantic.BaseModel): - token: str + token: inference.TokenResponse class MessageRequestState(str, enum.Enum): @@ -143,7 +143,7 @@ async def create_message(id: str, message_request: MessageRequest, fastapi_reque chat.conversation.messages.append( protocol.ConversationMessage( - text="".join([d.token for d in result_data[:-1]]), + text=response_packet.generated_text.text, is_assistant=True, ) ) diff --git a/inference/text-client/__main__.py b/inference/text-client/__main__.py index bf1f8b02..4a7fa110 100644 --- a/inference/text-client/__main__.py +++ b/inference/text-client/__main__.py @@ -32,7 +32,7 @@ def main(backend_url: str = "http://127.0.0.1:8000"): print("Assistant: ", end="", flush=True) for event in client.events(): data = json.loads(event.data) - print(data["token"], end="", flush=True) + print(data["token"]["text"], end="", flush=True) print() diff --git a/inference/worker/__main__.py b/inference/worker/__main__.py index 96fe164a..a5eb9185 100644 --- a/inference/worker/__main__.py +++ b/inference/worker/__main__.py @@ -54,24 +54,48 @@ def main( "top_p": work_request.top_p, "temperature": work_request.temperature, "seed": work_request.seed, + # "stop": ["User:", "Assistant:"], # TODO: this doesn't work... why? }, }, stream=True, headers={"Accept": "text/event-stream"}, ) - response.raise_for_status() + try: + response.raise_for_status() + except requests.HTTPError: + logger.exception("Failed to get response from inference server") + return client = sseclient.SSEClient(response) for event in client.events(): + logger.debug(f"Received event: {event}") data = json.loads(event.data) - if data["is_end"]: + if data["generated_text"]: break - intermediate = data["event"] - ws.send(inference.WorkResponsePacket(token=intermediate["token"]).json()) - ws.send(inference.WorkResponsePacket(is_end=True).json()) + token = data["token"] + ws.send( + inference.WorkResponsePacket( + token=inference.TokenResponse( + text=token["text"], + log_prob=token["logprob"], + token_id=token["id"], + ) + ).json() + ) + ws.send( + inference.WorkResponsePacket( + is_end=True, + generated_text=inference.GeneratedTextResponse( + text=data["generated_text"], + ), + ).json() + ) def on_error(ws: websocket.WebSocket, error: Exception): - logger.error(f"Connection error: {error}") + try: + raise error + except Exception: + logger.exception("Error in websocket") def on_close(ws: websocket.WebSocket, close_status_code: int, close_msg: str): logger.warning(f"Connection closed: {close_status_code=} {close_msg=}") diff --git a/oasst-shared/oasst_shared/schemas/inference.py b/oasst-shared/oasst_shared/schemas/inference.py index 91a16b61..764d8acf 100644 --- a/oasst-shared/oasst_shared/schemas/inference.py +++ b/oasst-shared/oasst_shared/schemas/inference.py @@ -20,6 +20,17 @@ class WorkRequest(pydantic.BaseModel): temperature: float = 1.0 +class TokenResponse(pydantic.BaseModel): + text: str + log_prob: float + token_id: int + + +class GeneratedTextResponse(pydantic.BaseModel): + text: str + + class WorkResponsePacket(pydantic.BaseModel): - token: str | None = None + token: TokenResponse | None = None + generated_text: GeneratedTextResponse | None = None is_end: bool = False