mirror of
https://github.com/wassname/Open-Assistant.git
synced 2026-06-27 16:10:30 +08:00
switched to HF text-generation-inference
This commit is contained in:
+1
-1
@@ -86,7 +86,7 @@ For the worker, you'll also want to have the text-generation-inference server
|
||||
running:
|
||||
|
||||
```bash
|
||||
docker run --rm -it -p 8001:80 -e MODEL_ID=distilgpt2 ykilcher/text-generation-inference
|
||||
docker run --rm -it -p 8001:80 -e MODEL_ID=distilgpt2 ghcr.io/huggingface/text-generation-inference
|
||||
```
|
||||
|
||||
Run the client:
|
||||
|
||||
@@ -5,7 +5,7 @@
|
||||
tmux new-session -d -s "inference-dev-setup"
|
||||
tmux send-keys "docker run --rm -it -p 6379:6379 redis" C-m
|
||||
tmux split-window -h
|
||||
tmux send-keys "docker run --rm -it -p 8001:80 -e MODEL_ID=distilgpt2 ykilcher/text-generation-inference" C-m
|
||||
tmux send-keys "docker run --rm -it -p 8001:80 -e MODEL_ID=distilgpt2 ghcr.io/huggingface/text-generation-inference" C-m
|
||||
tmux split-window -h
|
||||
tmux send-keys "cd server" C-m
|
||||
tmux send-keys "uvicorn main:app --reload" C-m
|
||||
|
||||
@@ -57,7 +57,7 @@ class MessageRequest(pydantic.BaseModel):
|
||||
|
||||
|
||||
class TokenResponseEvent(pydantic.BaseModel):
|
||||
token: str
|
||||
token: inference.TokenResponse
|
||||
|
||||
|
||||
class MessageRequestState(str, enum.Enum):
|
||||
@@ -143,7 +143,7 @@ async def create_message(id: str, message_request: MessageRequest, fastapi_reque
|
||||
|
||||
chat.conversation.messages.append(
|
||||
protocol.ConversationMessage(
|
||||
text="".join([d.token for d in result_data[:-1]]),
|
||||
text=response_packet.generated_text.text,
|
||||
is_assistant=True,
|
||||
)
|
||||
)
|
||||
|
||||
@@ -32,7 +32,7 @@ def main(backend_url: str = "http://127.0.0.1:8000"):
|
||||
print("Assistant: ", end="", flush=True)
|
||||
for event in client.events():
|
||||
data = json.loads(event.data)
|
||||
print(data["token"], end="", flush=True)
|
||||
print(data["token"]["text"], end="", flush=True)
|
||||
print()
|
||||
|
||||
|
||||
|
||||
@@ -54,24 +54,48 @@ def main(
|
||||
"top_p": work_request.top_p,
|
||||
"temperature": work_request.temperature,
|
||||
"seed": work_request.seed,
|
||||
# "stop": ["User:", "Assistant:"], # TODO: this doesn't work... why?
|
||||
},
|
||||
},
|
||||
stream=True,
|
||||
headers={"Accept": "text/event-stream"},
|
||||
)
|
||||
response.raise_for_status()
|
||||
try:
|
||||
response.raise_for_status()
|
||||
except requests.HTTPError:
|
||||
logger.exception("Failed to get response from inference server")
|
||||
return
|
||||
|
||||
client = sseclient.SSEClient(response)
|
||||
for event in client.events():
|
||||
logger.debug(f"Received event: {event}")
|
||||
data = json.loads(event.data)
|
||||
if data["is_end"]:
|
||||
if data["generated_text"]:
|
||||
break
|
||||
intermediate = data["event"]
|
||||
ws.send(inference.WorkResponsePacket(token=intermediate["token"]).json())
|
||||
ws.send(inference.WorkResponsePacket(is_end=True).json())
|
||||
token = data["token"]
|
||||
ws.send(
|
||||
inference.WorkResponsePacket(
|
||||
token=inference.TokenResponse(
|
||||
text=token["text"],
|
||||
log_prob=token["logprob"],
|
||||
token_id=token["id"],
|
||||
)
|
||||
).json()
|
||||
)
|
||||
ws.send(
|
||||
inference.WorkResponsePacket(
|
||||
is_end=True,
|
||||
generated_text=inference.GeneratedTextResponse(
|
||||
text=data["generated_text"],
|
||||
),
|
||||
).json()
|
||||
)
|
||||
|
||||
def on_error(ws: websocket.WebSocket, error: Exception):
|
||||
logger.error(f"Connection error: {error}")
|
||||
try:
|
||||
raise error
|
||||
except Exception:
|
||||
logger.exception("Error in websocket")
|
||||
|
||||
def on_close(ws: websocket.WebSocket, close_status_code: int, close_msg: str):
|
||||
logger.warning(f"Connection closed: {close_status_code=} {close_msg=}")
|
||||
|
||||
@@ -20,6 +20,17 @@ class WorkRequest(pydantic.BaseModel):
|
||||
temperature: float = 1.0
|
||||
|
||||
|
||||
class TokenResponse(pydantic.BaseModel):
|
||||
text: str
|
||||
log_prob: float
|
||||
token_id: int
|
||||
|
||||
|
||||
class GeneratedTextResponse(pydantic.BaseModel):
|
||||
text: str
|
||||
|
||||
|
||||
class WorkResponsePacket(pydantic.BaseModel):
|
||||
token: str | None = None
|
||||
token: TokenResponse | None = None
|
||||
generated_text: GeneratedTextResponse | None = None
|
||||
is_end: bool = False
|
||||
|
||||
Reference in New Issue
Block a user