From f1edcc8a285dbc184a14ab50dccc39d258d45c92 Mon Sep 17 00:00:00 2001 From: Yannic Kilcher Date: Thu, 26 Jan 2023 16:41:57 +0100 Subject: [PATCH] added streaming worker --- inference/README.md | 7 ++ inference/worker/__main__.py | 66 +++++++++++++------ inference/worker/requirements.txt | 4 +- .../oasst_shared/schemas/inference.py | 4 ++ 4 files changed, 58 insertions(+), 23 deletions(-) diff --git a/inference/README.md b/inference/README.md index 3dee94f9..bd0272ad 100644 --- a/inference/README.md +++ b/inference/README.md @@ -26,6 +26,13 @@ pip install -r requirements.txt python __main__.py ``` +For the worker, you'll also want to have the text-generation-inference server +running: + +```bash +docker run --rm -it -p 8001:80 -e MODEL_ID=distilgpt2 ykilcher/text-generation-inference +``` + Run the client: ```bash diff --git a/inference/worker/__main__.py b/inference/worker/__main__.py index ad5e5cef..c8c1a4c9 100644 --- a/inference/worker/__main__.py +++ b/inference/worker/__main__.py @@ -1,13 +1,12 @@ -import re -import time +import json import rel -import torch +import requests +import sseclient import typer import websocket from loguru import logger from oasst_shared.schemas import inference, protocol -from transformers import pipeline app = typer.Typer() @@ -16,9 +15,8 @@ app = typer.Typer() def main( backend_url: str = "ws://localhost:8000", model_name: str = "distilgpt2", + inference_server_url: str = "http://localhost:8001", ): - pipe = pipeline("text-generation", model=model_name) - def on_open(ws: websocket.WebSocket): worker_config = inference.WorkerConfig(model_name=model_name) ws.send(worker_config.json()) @@ -37,23 +35,49 @@ def main( prompt = "\n".join(messages) + "\nAssistant:" - # TODO: replace this with incremental generation - torch.manual_seed(work_request.seed) - model_output = pipe(prompt, max_new_tokens=work_request.max_new_tokens, do_sample=True, return_full_text=False)[ - 0 - ]["generated_text"] - model_output = model_output.strip() + # TODO: use the seed + # torch.manual_seed(work_request.seed) + # model_output = pipe(prompt, max_new_tokens=work_request.max_new_tokens, do_sample=True, return_full_text=False)[ + # 0 + # ]["generated_text"] + # model_output = model_output.strip() - # fake streaming - split_idcs = [m.start() for m in re.finditer(r"([\w:]+)", model_output)] - pieces = [model_output[a:b] for a, b in zip([0] + split_idcs, split_idcs + [None])] - for piece in pieces: - if not piece: - continue - if piece.strip() in ("User:", "Assistant:"): + # # fake streaming + # split_idcs = [m.start() for m in re.finditer(r"([\w:]+)", model_output)] + # pieces = [model_output[a:b] for a, b in zip([0] + split_idcs, split_idcs + [None])] + # for piece in pieces: + # if not piece: + # continue + # if piece.strip() in ("User:", "Assistant:"): + # break + # ws.send(inference.WorkResponsePacket(token=piece).json()) + # time.sleep(0.1) + # ws.send(inference.WorkResponsePacket(is_end=True).json()) + + response = requests.post( + f"{inference_server_url}/generate_stream", + json={ + "inputs": prompt, + "parameters": { + "max_new_tokens": work_request.max_new_tokens, + "do_sample": work_request.do_sample, + "top_k": work_request.top_k, + "top_p": work_request.top_p, + "temperature": work_request.temperature, + }, + }, + stream=True, + headers={"Accept": "text/event-stream"}, + ) + response.raise_for_status() + + client = sseclient.SSEClient(response) + for event in client.events(): + data = json.loads(event.data) + if data["is_end"]: break - ws.send(inference.WorkResponsePacket(token=piece).json()) - time.sleep(0.1) + intermediate = data["event"] + ws.send(inference.WorkResponsePacket(token=intermediate["token"]).json()) ws.send(inference.WorkResponsePacket(is_end=True).json()) def on_error(ws: websocket.WebSocket, error: Exception): diff --git a/inference/worker/requirements.txt b/inference/worker/requirements.txt index c248c652..82169379 100644 --- a/inference/worker/requirements.txt +++ b/inference/worker/requirements.txt @@ -1,6 +1,6 @@ loguru rel -torch -transformers +requests +sseclient-py typer websocket-client diff --git a/oasst-shared/oasst_shared/schemas/inference.py b/oasst-shared/oasst_shared/schemas/inference.py index 0acb5014..b50cef9c 100644 --- a/oasst-shared/oasst_shared/schemas/inference.py +++ b/oasst-shared/oasst_shared/schemas/inference.py @@ -14,6 +14,10 @@ class WorkRequest(pydantic.BaseModel): model_name: str = "distilgpt2" max_new_tokens: int = 100 seed: int = pydantic.Field(default_factory=lambda: random.randint(0, 2**32 - 1)) + do_sample: bool = True + top_k: int = 50 + top_p: float = 0.9 + temperature: float = 1.0 class WorkResponsePacket(pydantic.BaseModel):