mirror of
https://github.com/wassname/Open-Assistant.git
synced 2026-06-27 16:10:30 +08:00
added streaming worker
This commit is contained in:
@@ -26,6 +26,13 @@ pip install -r requirements.txt
|
||||
python __main__.py
|
||||
```
|
||||
|
||||
For the worker, you'll also want to have the text-generation-inference server
|
||||
running:
|
||||
|
||||
```bash
|
||||
docker run --rm -it -p 8001:80 -e MODEL_ID=distilgpt2 ykilcher/text-generation-inference
|
||||
```
|
||||
|
||||
Run the client:
|
||||
|
||||
```bash
|
||||
|
||||
@@ -1,13 +1,12 @@
|
||||
import re
|
||||
import time
|
||||
import json
|
||||
|
||||
import rel
|
||||
import torch
|
||||
import requests
|
||||
import sseclient
|
||||
import typer
|
||||
import websocket
|
||||
from loguru import logger
|
||||
from oasst_shared.schemas import inference, protocol
|
||||
from transformers import pipeline
|
||||
|
||||
app = typer.Typer()
|
||||
|
||||
@@ -16,9 +15,8 @@ app = typer.Typer()
|
||||
def main(
|
||||
backend_url: str = "ws://localhost:8000",
|
||||
model_name: str = "distilgpt2",
|
||||
inference_server_url: str = "http://localhost:8001",
|
||||
):
|
||||
pipe = pipeline("text-generation", model=model_name)
|
||||
|
||||
def on_open(ws: websocket.WebSocket):
|
||||
worker_config = inference.WorkerConfig(model_name=model_name)
|
||||
ws.send(worker_config.json())
|
||||
@@ -37,23 +35,49 @@ def main(
|
||||
|
||||
prompt = "\n".join(messages) + "\nAssistant:"
|
||||
|
||||
# TODO: replace this with incremental generation
|
||||
torch.manual_seed(work_request.seed)
|
||||
model_output = pipe(prompt, max_new_tokens=work_request.max_new_tokens, do_sample=True, return_full_text=False)[
|
||||
0
|
||||
]["generated_text"]
|
||||
model_output = model_output.strip()
|
||||
# TODO: use the seed
|
||||
# torch.manual_seed(work_request.seed)
|
||||
# model_output = pipe(prompt, max_new_tokens=work_request.max_new_tokens, do_sample=True, return_full_text=False)[
|
||||
# 0
|
||||
# ]["generated_text"]
|
||||
# model_output = model_output.strip()
|
||||
|
||||
# fake streaming
|
||||
split_idcs = [m.start() for m in re.finditer(r"([\w:]+)", model_output)]
|
||||
pieces = [model_output[a:b] for a, b in zip([0] + split_idcs, split_idcs + [None])]
|
||||
for piece in pieces:
|
||||
if not piece:
|
||||
continue
|
||||
if piece.strip() in ("User:", "Assistant:"):
|
||||
# # fake streaming
|
||||
# split_idcs = [m.start() for m in re.finditer(r"([\w:]+)", model_output)]
|
||||
# pieces = [model_output[a:b] for a, b in zip([0] + split_idcs, split_idcs + [None])]
|
||||
# for piece in pieces:
|
||||
# if not piece:
|
||||
# continue
|
||||
# if piece.strip() in ("User:", "Assistant:"):
|
||||
# break
|
||||
# ws.send(inference.WorkResponsePacket(token=piece).json())
|
||||
# time.sleep(0.1)
|
||||
# ws.send(inference.WorkResponsePacket(is_end=True).json())
|
||||
|
||||
response = requests.post(
|
||||
f"{inference_server_url}/generate_stream",
|
||||
json={
|
||||
"inputs": prompt,
|
||||
"parameters": {
|
||||
"max_new_tokens": work_request.max_new_tokens,
|
||||
"do_sample": work_request.do_sample,
|
||||
"top_k": work_request.top_k,
|
||||
"top_p": work_request.top_p,
|
||||
"temperature": work_request.temperature,
|
||||
},
|
||||
},
|
||||
stream=True,
|
||||
headers={"Accept": "text/event-stream"},
|
||||
)
|
||||
response.raise_for_status()
|
||||
|
||||
client = sseclient.SSEClient(response)
|
||||
for event in client.events():
|
||||
data = json.loads(event.data)
|
||||
if data["is_end"]:
|
||||
break
|
||||
ws.send(inference.WorkResponsePacket(token=piece).json())
|
||||
time.sleep(0.1)
|
||||
intermediate = data["event"]
|
||||
ws.send(inference.WorkResponsePacket(token=intermediate["token"]).json())
|
||||
ws.send(inference.WorkResponsePacket(is_end=True).json())
|
||||
|
||||
def on_error(ws: websocket.WebSocket, error: Exception):
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
loguru
|
||||
rel
|
||||
torch
|
||||
transformers
|
||||
requests
|
||||
sseclient-py
|
||||
typer
|
||||
websocket-client
|
||||
|
||||
@@ -14,6 +14,10 @@ class WorkRequest(pydantic.BaseModel):
|
||||
model_name: str = "distilgpt2"
|
||||
max_new_tokens: int = 100
|
||||
seed: int = pydantic.Field(default_factory=lambda: random.randint(0, 2**32 - 1))
|
||||
do_sample: bool = True
|
||||
top_k: int = 50
|
||||
top_p: float = 0.9
|
||||
temperature: float = 1.0
|
||||
|
||||
|
||||
class WorkResponsePacket(pydantic.BaseModel):
|
||||
|
||||
Reference in New Issue
Block a user