added streaming worker

This commit is contained in:
Yannic Kilcher
2023-01-26 16:41:57 +01:00
parent c2fa476904
commit f1edcc8a28
4 changed files with 58 additions and 23 deletions
+7
View File
@@ -26,6 +26,13 @@ pip install -r requirements.txt
python __main__.py
```
For the worker, you'll also want to have the text-generation-inference server
running:
```bash
docker run --rm -it -p 8001:80 -e MODEL_ID=distilgpt2 ykilcher/text-generation-inference
```
Run the client:
```bash
+45 -21
View File
@@ -1,13 +1,12 @@
import re
import time
import json
import rel
import torch
import requests
import sseclient
import typer
import websocket
from loguru import logger
from oasst_shared.schemas import inference, protocol
from transformers import pipeline
app = typer.Typer()
@@ -16,9 +15,8 @@ app = typer.Typer()
def main(
backend_url: str = "ws://localhost:8000",
model_name: str = "distilgpt2",
inference_server_url: str = "http://localhost:8001",
):
pipe = pipeline("text-generation", model=model_name)
def on_open(ws: websocket.WebSocket):
worker_config = inference.WorkerConfig(model_name=model_name)
ws.send(worker_config.json())
@@ -37,23 +35,49 @@ def main(
prompt = "\n".join(messages) + "\nAssistant:"
# TODO: replace this with incremental generation
torch.manual_seed(work_request.seed)
model_output = pipe(prompt, max_new_tokens=work_request.max_new_tokens, do_sample=True, return_full_text=False)[
0
]["generated_text"]
model_output = model_output.strip()
# TODO: use the seed
# torch.manual_seed(work_request.seed)
# model_output = pipe(prompt, max_new_tokens=work_request.max_new_tokens, do_sample=True, return_full_text=False)[
# 0
# ]["generated_text"]
# model_output = model_output.strip()
# fake streaming
split_idcs = [m.start() for m in re.finditer(r"([\w:]+)", model_output)]
pieces = [model_output[a:b] for a, b in zip([0] + split_idcs, split_idcs + [None])]
for piece in pieces:
if not piece:
continue
if piece.strip() in ("User:", "Assistant:"):
# # fake streaming
# split_idcs = [m.start() for m in re.finditer(r"([\w:]+)", model_output)]
# pieces = [model_output[a:b] for a, b in zip([0] + split_idcs, split_idcs + [None])]
# for piece in pieces:
# if not piece:
# continue
# if piece.strip() in ("User:", "Assistant:"):
# break
# ws.send(inference.WorkResponsePacket(token=piece).json())
# time.sleep(0.1)
# ws.send(inference.WorkResponsePacket(is_end=True).json())
response = requests.post(
f"{inference_server_url}/generate_stream",
json={
"inputs": prompt,
"parameters": {
"max_new_tokens": work_request.max_new_tokens,
"do_sample": work_request.do_sample,
"top_k": work_request.top_k,
"top_p": work_request.top_p,
"temperature": work_request.temperature,
},
},
stream=True,
headers={"Accept": "text/event-stream"},
)
response.raise_for_status()
client = sseclient.SSEClient(response)
for event in client.events():
data = json.loads(event.data)
if data["is_end"]:
break
ws.send(inference.WorkResponsePacket(token=piece).json())
time.sleep(0.1)
intermediate = data["event"]
ws.send(inference.WorkResponsePacket(token=intermediate["token"]).json())
ws.send(inference.WorkResponsePacket(is_end=True).json())
def on_error(ws: websocket.WebSocket, error: Exception):
+2 -2
View File
@@ -1,6 +1,6 @@
loguru
rel
torch
transformers
requests
sseclient-py
typer
websocket-client
@@ -14,6 +14,10 @@ class WorkRequest(pydantic.BaseModel):
model_name: str = "distilgpt2"
max_new_tokens: int = 100
seed: int = pydantic.Field(default_factory=lambda: random.randint(0, 2**32 - 1))
do_sample: bool = True
top_k: int = 50
top_p: float = 0.9
temperature: float = 1.0
class WorkResponsePacket(pydantic.BaseModel):