switched to HF text-generation-inference

This commit is contained in:
Yannic Kilcher
2023-02-08 23:52:39 +01:00
parent af6885d416
commit bab056a73b
6 changed files with 47 additions and 12 deletions
+1 -1
View File
@@ -86,7 +86,7 @@ For the worker, you'll also want to have the text-generation-inference server
running:
```bash
docker run --rm -it -p 8001:80 -e MODEL_ID=distilgpt2 ykilcher/text-generation-inference
docker run --rm -it -p 8001:80 -e MODEL_ID=distilgpt2 ghcr.io/huggingface/text-generation-inference
```
Run the client:
+1 -1
View File
@@ -5,7 +5,7 @@
tmux new-session -d -s "inference-dev-setup"
tmux send-keys "docker run --rm -it -p 6379:6379 redis" C-m
tmux split-window -h
tmux send-keys "docker run --rm -it -p 8001:80 -e MODEL_ID=distilgpt2 ykilcher/text-generation-inference" C-m
tmux send-keys "docker run --rm -it -p 8001:80 -e MODEL_ID=distilgpt2 ghcr.io/huggingface/text-generation-inference" C-m
tmux split-window -h
tmux send-keys "cd server" C-m
tmux send-keys "uvicorn main:app --reload" C-m
+2 -2
View File
@@ -57,7 +57,7 @@ class MessageRequest(pydantic.BaseModel):
class TokenResponseEvent(pydantic.BaseModel):
token: str
token: inference.TokenResponse
class MessageRequestState(str, enum.Enum):
@@ -143,7 +143,7 @@ async def create_message(id: str, message_request: MessageRequest, fastapi_reque
chat.conversation.messages.append(
protocol.ConversationMessage(
text="".join([d.token for d in result_data[:-1]]),
text=response_packet.generated_text.text,
is_assistant=True,
)
)
+1 -1
View File
@@ -32,7 +32,7 @@ def main(backend_url: str = "http://127.0.0.1:8000"):
print("Assistant: ", end="", flush=True)
for event in client.events():
data = json.loads(event.data)
print(data["token"], end="", flush=True)
print(data["token"]["text"], end="", flush=True)
print()
+30 -6
View File
@@ -54,24 +54,48 @@ def main(
"top_p": work_request.top_p,
"temperature": work_request.temperature,
"seed": work_request.seed,
# "stop": ["User:", "Assistant:"], # TODO: this doesn't work... why?
},
},
stream=True,
headers={"Accept": "text/event-stream"},
)
response.raise_for_status()
try:
response.raise_for_status()
except requests.HTTPError:
logger.exception("Failed to get response from inference server")
return
client = sseclient.SSEClient(response)
for event in client.events():
logger.debug(f"Received event: {event}")
data = json.loads(event.data)
if data["is_end"]:
if data["generated_text"]:
break
intermediate = data["event"]
ws.send(inference.WorkResponsePacket(token=intermediate["token"]).json())
ws.send(inference.WorkResponsePacket(is_end=True).json())
token = data["token"]
ws.send(
inference.WorkResponsePacket(
token=inference.TokenResponse(
text=token["text"],
log_prob=token["logprob"],
token_id=token["id"],
)
).json()
)
ws.send(
inference.WorkResponsePacket(
is_end=True,
generated_text=inference.GeneratedTextResponse(
text=data["generated_text"],
),
).json()
)
def on_error(ws: websocket.WebSocket, error: Exception):
logger.error(f"Connection error: {error}")
try:
raise error
except Exception:
logger.exception("Error in websocket")
def on_close(ws: websocket.WebSocket, close_status_code: int, close_msg: str):
logger.warning(f"Connection closed: {close_status_code=} {close_msg=}")
+12 -1
View File
@@ -20,6 +20,17 @@ class WorkRequest(pydantic.BaseModel):
temperature: float = 1.0
class TokenResponse(pydantic.BaseModel):
text: str
log_prob: float
token_id: int
class GeneratedTextResponse(pydantic.BaseModel):
text: str
class WorkResponsePacket(pydantic.BaseModel):
token: str | None = None
token: TokenResponse | None = None
generated_text: GeneratedTextResponse | None = None
is_end: bool = False